From 59e7b3417fd900d0ea4476e01c4b7eb72a29b069 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Thu, 24 Jul 2025 11:59:41 +0200 Subject: [PATCH] Refactor v2.*Tables DRY --- doc/modules.rst | 1 - petab/v2/converters.py | 2 +- petab/v2/core.py | 443 +++++++++++++++++------------------------ tests/v2/test_core.py | 4 +- 4 files changed, 181 insertions(+), 269 deletions(-) diff --git a/doc/modules.rst b/doc/modules.rst index 6dacba5a..1eb0220c 100644 --- a/doc/modules.rst +++ b/doc/modules.rst @@ -37,5 +37,4 @@ API Reference petab.v2.experiments petab.v2.lint petab.v2.models - petab.v2.problem petab.v2.petab1to2 diff --git a/petab/v2/converters.py b/petab/v2/converters.py index f6d185b5..ae4f5888 100644 --- a/petab/v2/converters.py +++ b/petab/v2/converters.py @@ -401,7 +401,7 @@ def _add_indicators_to_conditions(self) -> None: # removed. Only keep the conditions setting our indicators. problem.condition_tables = [ ConditionTable( - conditions=[ + [ condition for condition in problem.conditions if condition.id.startswith("_petab") diff --git a/petab/v2/core.py b/petab/v2/core.py index 8170bd53..ff1efb28 100644 --- a/petab/v2/core.py +++ b/petab/v2/core.py @@ -6,13 +6,14 @@ import os import tempfile import traceback +from abc import abstractmethod from collections.abc import Sequence from enum import Enum from itertools import chain from math import nan from numbers import Number from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Any +from typing import TYPE_CHECKING, Annotated, Any, Generic, TypeVar, get_args import numpy as np import pandas as pd @@ -112,7 +113,7 @@ def _valid_petab_id(v: str) -> str: class ParameterScale(str, Enum): """Parameter scales. - Parameter scales as used in the PEtab parameters table. + Parameter scales as used in the PEtab parameter table. """ LIN = C.LIN @@ -123,7 +124,7 @@ class ParameterScale(str, Enum): class NoiseDistribution(str, Enum): """Noise distribution types. - Noise distributions as used in the PEtab observables table. + Noise distributions as used in the PEtab observable table. """ #: Normal distribution @@ -141,7 +142,7 @@ class NoiseDistribution(str, Enum): class PriorDistribution(str, Enum): """Prior types. - Prior types as used in the PEtab parameters table. + Prior types as used in the PEtab parameter table. """ #: Cauchy distribution. @@ -196,6 +197,89 @@ class PriorDistribution(str, Enum): ) +T = TypeVar("T", bound=BaseModel) + + +class BaseTable(BaseModel, Generic[T]): + """Base class for PEtab tables.""" + + elements: list[T] + + def __init__(self, elements: list[T] = None) -> None: + """Initialize the BaseTable with a list of elements.""" + if elements is None: + elements = [] + super().__init__(elements=elements) + + def __getitem__(self, id_: str) -> T: + """Get an element by ID. + + :param id_: The ID of the element to retrieve. + :return: The element with the given ID. + :raises KeyError: If no element with the given ID exists. + :raises NotImplementedError: + If the element type does not have an ID attribute. + """ + if "id" not in self._element_class().model_fields: + raise NotImplementedError( + f"__getitem__ is not implemented for {self.__class__.__name__}" + ) + + for element in self.elements: + if element.id == id_: + return element + + raise KeyError(f"{T.__name__} ID {id_} not found") + + @classmethod + @abstractmethod + def from_df(cls, df: pd.DataFrame) -> BaseTable[T]: + """Create a table from a DataFrame.""" + pass + + @abstractmethod + def to_df(self) -> pd.DataFrame: + """Convert the table to a DataFrame.""" + pass + + @classmethod + def from_tsv(cls, file_path: str | Path) -> BaseTable[T]: + """Create table from a TSV file.""" + df = pd.read_csv(file_path, sep="\t") + return cls.from_df(df) + + def to_tsv(self, file_path: str | Path) -> None: + """Write the table to a TSV file.""" + df = self.to_df() + df.to_csv( + file_path, sep="\t", index=not isinstance(df.index, pd.RangeIndex) + ) + + @classmethod + def _element_class(cls) -> type[T]: + """Get the class of the elements in the table.""" + return get_args(cls.model_fields["elements"].annotation)[0] + + def __add__(self, other: T) -> BaseTable[T]: + """Add an item to the table.""" + if not isinstance(other, self._element_class()): + raise TypeError( + f"Can only add {self._element_class().__name__} " + f"to {self.__class__.__name__}" + ) + return self.__class__(elements=self.elements + [other]) + + def __iadd__(self, other: T) -> BaseTable[T]: + """Add an item to the table in place.""" + if not isinstance(other, self._element_class()): + raise TypeError( + f"Can only add {self._element_class().__name__} " + f"to {self.__class__.__name__}" + ) + self.elements.append(other) + return self + + class Observable(BaseModel): """Observable definition.""" @@ -273,24 +357,19 @@ def _sympify_id_list(cls, v): return [sympify_petab(_valid_petab_id(pid)) for pid in v if pid] -class ObservableTable(BaseModel): +class ObservableTable(BaseTable[Observable]): """PEtab observable table.""" - #: List of observables. - observables: list[Observable] - - def __getitem__(self, observable_id: str) -> Observable: - """Get an observable by ID.""" - for observable in self.observables: - if observable.id == observable_id: - return observable - raise KeyError(f"Observable ID {observable_id} not found") + @property + def observables(self) -> list[Observable]: + """List of observables.""" + return self.elements @classmethod def from_df(cls, df: pd.DataFrame) -> ObservableTable: """Create an ObservableTable from a DataFrame.""" if df is None: - return cls(observables=[]) + return cls() df = get_observable_df(df) observables = [ @@ -298,11 +377,11 @@ def from_df(cls, df: pd.DataFrame) -> ObservableTable: for _, row in df.reset_index().iterrows() ] - return cls(observables=observables) + return cls(observables) def to_df(self) -> pd.DataFrame: """Convert the ObservableTable to a DataFrame.""" - records = self.model_dump(by_alias=True)["observables"] + records = self.model_dump(by_alias=True)["elements"] for record in records: obs = record[C.OBSERVABLE_FORMULA] noise = record[C.NOISE_FORMULA] @@ -316,30 +395,6 @@ def to_df(self) -> pd.DataFrame: ) return pd.DataFrame(records).set_index([C.OBSERVABLE_ID]) - @classmethod - def from_tsv(cls, file_path: str | Path) -> ObservableTable: - """Create an ObservableTable from a TSV file.""" - df = pd.read_csv(file_path, sep="\t") - return cls.from_df(df) - - def to_tsv(self, file_path: str | Path) -> None: - """Write the ObservableTable to a TSV file.""" - df = self.to_df() - df.to_csv(file_path, sep="\t", index=True) - - def __add__(self, other: Observable) -> ObservableTable: - """Add an observable to the table.""" - if not isinstance(other, Observable): - raise TypeError("Can only add Observable to ObservableTable") - return ObservableTable(observables=self.observables + [other]) - - def __iadd__(self, other: Observable) -> ObservableTable: - """Add an observable to the table in place.""" - if not isinstance(other, Observable): - raise TypeError("Can only add Observable to ObservableTable") - self.observables.append(other) - return self - class Change(BaseModel): """A change to the model or model state. @@ -386,7 +441,7 @@ class Condition(BaseModel): A set of simultaneously occurring changes to the model or model state, corresponding to a perturbation of the underlying system. This corresponds - to all rows of the PEtab conditions table with the same condition ID. + to all rows of the PEtab condition table with the same condition ID. >>> Condition( ... id="condition1", @@ -427,31 +482,26 @@ def __iadd__(self, other: Change) -> Condition: return self -class ConditionTable(BaseModel): - """PEtab conditions table.""" +class ConditionTable(BaseTable[Condition]): + """PEtab condition table.""" - #: List of conditions. - conditions: list[Condition] = [] - - def __getitem__(self, condition_id: str) -> Condition: - """Get a condition by ID.""" - for condition in self.conditions: - if condition.id == condition_id: - return condition - raise KeyError(f"Condition ID {condition_id} not found") + @property + def conditions(self) -> list[Condition]: + """List of conditions.""" + return self.elements @classmethod def from_df(cls, df: pd.DataFrame) -> ConditionTable: """Create a ConditionTable from a DataFrame.""" if df is None or df.empty: - return cls(conditions=[]) + return cls() conditions = [] for condition_id, sub_df in df.groupby(C.CONDITION_ID): changes = [Change(**row) for row in sub_df.to_dict("records")] conditions.append(Condition(id=condition_id, changes=changes)) - return cls(conditions=conditions) + return cls(conditions) def to_df(self) -> pd.DataFrame: """Convert the ConditionTable to a DataFrame.""" @@ -472,30 +522,6 @@ def to_df(self) -> pd.DataFrame: else pd.DataFrame(columns=C.CONDITION_DF_REQUIRED_COLS) ) - @classmethod - def from_tsv(cls, file_path: str | Path) -> ConditionTable: - """Create a ConditionTable from a TSV file.""" - df = pd.read_csv(file_path, sep="\t") - return cls.from_df(df) - - def to_tsv(self, file_path: str | Path) -> None: - """Write the ConditionTable to a TSV file.""" - df = self.to_df() - df.to_csv(file_path, sep="\t", index=False) - - def __add__(self, other: Condition) -> ConditionTable: - """Add a condition to the table.""" - if not isinstance(other, Condition): - raise TypeError("Can only add Condition to ConditionTable") - return ConditionTable(conditions=self.conditions + [other]) - - def __iadd__(self, other: Condition) -> ConditionTable: - """Add a condition to the table in place.""" - if not isinstance(other, Condition): - raise TypeError("Can only add Condition to ConditionTable") - self.conditions.append(other) - return self - @property def free_symbols(self) -> set[sp.Symbol]: """Get all free symbols in the condition table. @@ -518,7 +544,7 @@ class ExperimentPeriod(BaseModel): """A period of a timecourse or experiment defined by a start time and a list of condition IDs. - This corresponds to a row of the PEtab experiments table. + This corresponds to a row of the PEtab experiment table. """ #: The start time of the period in time units as defined in the model. @@ -559,7 +585,7 @@ class Experiment(BaseModel): """An experiment or a timecourse defined by an ID and a set of different periods. - Corresponds to a group of rows of the PEtab experiments table with the same + Corresponds to a group of rows of the PEtab experiment table with the same experiment ID. """ @@ -601,17 +627,19 @@ def sort_periods(self) -> None: self.periods.sort(key=lambda period: period.time) -class ExperimentTable(BaseModel): - """PEtab experiments table.""" +class ExperimentTable(BaseTable[Experiment]): + """PEtab experiment table.""" - #: List of experiments. - experiments: list[Experiment] + @property + def experiments(self) -> list[Experiment]: + """List of experiments.""" + return self.elements @classmethod def from_df(cls, df: pd.DataFrame) -> ExperimentTable: """Create an ExperimentTable from a DataFrame.""" if df is None: - return cls(experiments=[]) + return cls() experiments = [] for experiment_id, cur_exp_df in df.groupby(C.EXPERIMENT_ID): @@ -631,7 +659,7 @@ def from_df(cls, df: pd.DataFrame) -> ExperimentTable: ) experiments.append(Experiment(id=experiment_id, periods=periods)) - return cls(experiments=experiments) + return cls(experiments) def to_df(self) -> pd.DataFrame: """Convert the ExperimentTable to a DataFrame.""" @@ -651,37 +679,6 @@ def to_df(self) -> pd.DataFrame: else pd.DataFrame(columns=C.EXPERIMENT_DF_REQUIRED_COLS) ) - @classmethod - def from_tsv(cls, file_path: str | Path) -> ExperimentTable: - """Create an ExperimentTable from a TSV file.""" - df = pd.read_csv(file_path, sep="\t") - return cls.from_df(df) - - def to_tsv(self, file_path: str | Path) -> None: - """Write the ExperimentTable to a TSV file.""" - df = self.to_df() - df.to_csv(file_path, sep="\t", index=False) - - def __add__(self, other: Experiment) -> ExperimentTable: - """Add an experiment to the table.""" - if not isinstance(other, Experiment): - raise TypeError("Can only add Experiment to ExperimentTable") - return ExperimentTable(experiments=self.experiments + [other]) - - def __iadd__(self, other: Experiment) -> ExperimentTable: - """Add an experiment to the table in place.""" - if not isinstance(other, Experiment): - raise TypeError("Can only add Experiment to ExperimentTable") - self.experiments.append(other) - return self - - def __getitem__(self, item): - """Get an experiment by ID.""" - for experiment in self.experiments: - if experiment.id == item: - return experiment - raise KeyError(f"Experiment ID {item} not found") - class Measurement(BaseModel): """A measurement. @@ -761,11 +758,13 @@ def _sympify_list(cls, v): return [sympify_petab(x) for x in v] -class MeasurementTable(BaseModel): +class MeasurementTable(BaseTable[Measurement]): """PEtab measurement table.""" - #: List of measurements. - measurements: list[Measurement] + @property + def measurements(self) -> list[Measurement]: + """List of measurements.""" + return self.elements @classmethod def from_df( @@ -774,7 +773,7 @@ def from_df( ) -> MeasurementTable: """Create a MeasurementTable from a DataFrame.""" if df is None: - return cls(measurements=[]) + return cls() measurements = [ Measurement( @@ -783,11 +782,11 @@ def from_df( for _, row in df.reset_index().iterrows() ] - return cls(measurements=measurements) + return cls(measurements) def to_df(self) -> pd.DataFrame: """Convert the MeasurementTable to a DataFrame.""" - records = self.model_dump(by_alias=True)["measurements"] + records = self.model_dump(by_alias=True)["elements"] for record in records: record[C.OBSERVABLE_PARAMETERS] = C.PARAMETER_SEPARATOR.join( map(str, record[C.OBSERVABLE_PARAMETERS]) @@ -798,30 +797,6 @@ def to_df(self) -> pd.DataFrame: return pd.DataFrame(records) - @classmethod - def from_tsv(cls, file_path: str | Path) -> MeasurementTable: - """Create a MeasurementTable from a TSV file.""" - df = pd.read_csv(file_path, sep="\t") - return cls.from_df(df) - - def to_tsv(self, file_path: str | Path) -> None: - """Write the MeasurementTable to a TSV file.""" - df = self.to_df() - df.to_csv(file_path, sep="\t", index=False) - - def __add__(self, other: Measurement) -> MeasurementTable: - """Add a measurement to the table.""" - if not isinstance(other, Measurement): - raise TypeError("Can only add Measurement to MeasurementTable") - return MeasurementTable(measurements=self.measurements + [other]) - - def __iadd__(self, other: Measurement) -> MeasurementTable: - """Add a measurement to the table in place.""" - if not isinstance(other, Measurement): - raise TypeError("Can only add Measurement to MeasurementTable") - self.measurements.append(other) - return self - class Mapping(BaseModel): """Mapping PEtab entities to model entities.""" @@ -845,57 +820,35 @@ class Mapping(BaseModel): ) -class MappingTable(BaseModel): +class MappingTable(BaseTable[Mapping]): """PEtab mapping table.""" - #: List of mappings. - mappings: list[Mapping] + @property + def mappings(self) -> list[Mapping]: + """List of mappings.""" + return self.elements @classmethod def from_df(cls, df: pd.DataFrame) -> MappingTable: """Create a MappingTable from a DataFrame.""" if df is None: - return cls(mappings=[]) + return cls() mappings = [ Mapping(**row.to_dict()) for _, row in df.reset_index().iterrows() ] - return cls(mappings=mappings) + return cls(mappings) def to_df(self) -> pd.DataFrame: """Convert the MappingTable to a DataFrame.""" res = ( - pd.DataFrame(self.model_dump(by_alias=True)["mappings"]) + pd.DataFrame(self.model_dump(by_alias=True)["elements"]) if self.mappings else pd.DataFrame(columns=C.MAPPING_DF_REQUIRED_COLS) ) return res.set_index([C.PETAB_ENTITY_ID]) - @classmethod - def from_tsv(cls, file_path: str | Path) -> MappingTable: - """Create a MappingTable from a TSV file.""" - df = pd.read_csv(file_path, sep="\t") - return cls.from_df(df) - - def to_tsv(self, file_path: str | Path) -> None: - """Write the MappingTable to a TSV file.""" - df = self.to_df() - df.to_csv(file_path, sep="\t", index=False) - - def __add__(self, other: Mapping) -> MappingTable: - """Add a mapping to the table.""" - if not isinstance(other, Mapping): - raise TypeError("Can only add Mapping to MappingTable") - return MappingTable(mappings=self.mappings + [other]) - - def __iadd__(self, other: Mapping) -> MappingTable: - """Add a mapping to the table in place.""" - if not isinstance(other, Mapping): - raise TypeError("Can only add Mapping to MappingTable") - self.mappings.append(other) - return self - def __getitem__(self, petab_id: str) -> Mapping: """Get a mapping by PEtab ID.""" for mapping in self.mappings: @@ -1075,71 +1028,39 @@ def prior_dist(self) -> Distribution: return cls(*self.prior_parameters, log=log, trunc=[self.lb, self.ub]) -class ParameterTable(BaseModel): +class ParameterTable(BaseTable[Parameter]): """PEtab parameter table.""" - #: List of parameters. - parameters: list[Parameter] + @property + def parameters(self) -> list[Parameter]: + """List of parameters.""" + return self.elements @classmethod def from_df(cls, df: pd.DataFrame) -> ParameterTable: """Create a ParameterTable from a DataFrame.""" if df is None: - return cls(parameters=[]) + return cls() parameters = [ Parameter(**row.to_dict()) for _, row in df.reset_index().iterrows() ] - return cls(parameters=parameters) + return cls(parameters) def to_df(self) -> pd.DataFrame: """Convert the ParameterTable to a DataFrame.""" return pd.DataFrame( - self.model_dump(by_alias=True)["parameters"] + self.model_dump(by_alias=True)["elements"] ).set_index([C.PARAMETER_ID]) - @classmethod - def from_tsv(cls, file_path: str | Path) -> ParameterTable: - """Create a ParameterTable from a TSV file.""" - df = pd.read_csv(file_path, sep="\t") - return cls.from_df(df) - - def to_tsv(self, file_path: str | Path) -> None: - """Write the ParameterTable to a TSV file.""" - df = self.to_df() - df.to_csv(file_path, sep="\t", index=False) - - def __add__(self, other: Parameter) -> ParameterTable: - """Add a parameter to the table.""" - if not isinstance(other, Parameter): - raise TypeError("Can only add Parameter to ParameterTable") - return ParameterTable(parameters=self.parameters + [other]) - - def __iadd__(self, other: Parameter) -> ParameterTable: - """Add a parameter to the table in place.""" - if not isinstance(other, Parameter): - raise TypeError("Can only add Parameter to ParameterTable") - self.parameters.append(other) - return self - - def __getitem__(self, item) -> Parameter: - """Get a parameter by ID.""" - for parameter in self.parameters: - if parameter.id == item: - return parameter - raise KeyError(f"Parameter ID {item} not found") - @property def n_estimated(self) -> int: """Number of estimated parameters.""" return sum(p.estimate for p in self.parameters) -"""PEtab v2 problems.""" - - class Problem: """ PEtab parameter estimation problem @@ -1176,22 +1097,12 @@ def __init__( default_validation_tasks.copy() ) - self.observable_tables = observable_tables or [ - ObservableTable(observables=[]) - ] - self.condition_tables = condition_tables or [ - ConditionTable(conditions=[]) - ] - self.experiment_tables = experiment_tables or [ - ExperimentTable(experiments=[]) - ] - self.measurement_tables = measurement_tables or [ - MeasurementTable(measurements=[]) - ] - self.mapping_tables = mapping_tables or [MappingTable(mappings=[])] - self.parameter_tables = parameter_tables or [ - ParameterTable(parameters=[]) - ] + self.observable_tables = observable_tables or [ObservableTable()] + self.condition_tables = condition_tables or [ConditionTable()] + self.experiment_tables = experiment_tables or [ExperimentTable()] + self.measurement_tables = measurement_tables or [MeasurementTable()] + self.mapping_tables = mapping_tables or [MappingTable()] + self.parameter_tables = parameter_tables or [ParameterTable()] def __str__(self): model = f"with model ({self.model})" if self.model else "without model" @@ -1235,7 +1146,7 @@ def __getitem__(self, key): for table in table_list: try: return table[key] - except KeyError: + except (KeyError, NotImplementedError): pass raise KeyError( @@ -1483,10 +1394,9 @@ def get_problem(problem: str | Path | Problem) -> Problem: @property def condition_df(self) -> pd.DataFrame | None: """Combined condition tables as DataFrame.""" - conditions = self.conditions return ( - ConditionTable(conditions=conditions).to_df() - if conditions + ConditionTable(conditions).to_df() + if (conditions := self.conditions) else None ) @@ -1498,7 +1408,7 @@ def condition_df(self, value: pd.DataFrame): def experiment_df(self) -> pd.DataFrame | None: """Experiment table as DataFrame.""" return ( - ExperimentTable(experiments=experiments).to_df() + ExperimentTable(experiments).to_df() if (experiments := self.experiments) else None ) @@ -1510,10 +1420,9 @@ def experiment_df(self, value: pd.DataFrame): @property def measurement_df(self) -> pd.DataFrame | None: """Combined measurement tables as DataFrame.""" - measurements = self.measurements return ( - MeasurementTable(measurements=measurements).to_df() - if measurements + MeasurementTable(measurements).to_df() + if (measurements := self.measurements) else None ) @@ -1524,10 +1433,9 @@ def measurement_df(self, value: pd.DataFrame): @property def parameter_df(self) -> pd.DataFrame | None: """Combined parameter tables as DataFrame.""" - parameters = self.parameters return ( - ParameterTable(parameters=parameters).to_df() - if parameters + ParameterTable(parameters).to_df() + if (parameters := self.parameters) else None ) @@ -1538,10 +1446,9 @@ def parameter_df(self, value: pd.DataFrame): @property def observable_df(self) -> pd.DataFrame | None: """Combined observable tables as DataFrame.""" - observables = self.observables return ( - ObservableTable(observables=observables).to_df() - if observables + ObservableTable(observables).to_df() + if (observables := self.observables) else None ) @@ -1552,8 +1459,11 @@ def observable_df(self, value: pd.DataFrame): @property def mapping_df(self) -> pd.DataFrame | None: """Combined mapping tables as DataFrame.""" - mappings = self.mappings - return MappingTable(mappings=mappings).to_df() if mappings else None + return ( + MappingTable(mappings).to_df() + if (mappings := self.mappings) + else None + ) @mapping_df.setter def mapping_df(self, value: pd.DataFrame): @@ -1888,7 +1798,7 @@ def add_condition( for target_id, target_value in kwargs.items() ] if not self.condition_tables: - self.condition_tables.append(ConditionTable(conditions=[])) + self.condition_tables.append(ConditionTable()) self.condition_tables[-1].conditions.append( Condition(id=id_, changes=changes) ) @@ -1939,7 +1849,7 @@ def add_observable( record.update(kwargs) if not self.observable_tables: - self.observable_tables.append(ObservableTable(observables=[])) + self.observable_tables.append(ObservableTable()) self.observable_tables[-1] += Observable(**record) @@ -1991,7 +1901,7 @@ def add_parameter( record.update(kwargs) if not self.parameter_tables: - self.parameter_tables.append(ParameterTable(parameters=[])) + self.parameter_tables.append(ParameterTable()) self.parameter_tables[-1] += Parameter(**record) @@ -2027,7 +1937,7 @@ def add_measurement( noise_parameters = [noise_parameters] if not self.measurement_tables: - self.measurement_tables.append(MeasurementTable(measurements=[])) + self.measurement_tables.append(MeasurementTable()) self.measurement_tables[-1].measurements.append( Measurement( @@ -2054,7 +1964,7 @@ def add_mapping( name: A name (any string) for the entity referenced by `petab_id`. """ if not self.mapping_tables: - self.mapping_tables.append(MappingTable(mappings=[])) + self.mapping_tables.append(MappingTable()) self.mapping_tables[-1].mappings.append( Mapping(petab_id=petab_id, model_id=model_id, name=name) ) @@ -2085,7 +1995,7 @@ def add_experiment(self, id_: str, *args): ] if not self.experiment_tables: - self.experiment_tables.append(ExperimentTable(experiments=[])) + self.experiment_tables.append(ExperimentTable()) self.experiment_tables[-1].experiments.append( Experiment(id=id_, periods=periods) ) @@ -2102,25 +2012,23 @@ def __iadd__(self, other): if isinstance(other, Observable): if not self.observable_tables: - self.observable_tables.append(ObservableTable(observables=[])) + self.observable_tables.append(ObservableTable()) self.observable_tables[-1] += other elif isinstance(other, Parameter): if not self.parameter_tables: - self.parameter_tables.append(ParameterTable(parameters=[])) + self.parameter_tables.append(ParameterTable()) self.parameter_tables[-1] += other elif isinstance(other, Measurement): if not self.measurement_tables: - self.measurement_tables.append( - MeasurementTable(measurements=[]) - ) + self.measurement_tables.append(MeasurementTable()) self.measurement_tables[-1] += other elif isinstance(other, Condition): if not self.condition_tables: - self.condition_tables.append(ConditionTable(conditions=[])) + self.condition_tables.append(ConditionTable()) self.condition_tables[-1] += other elif isinstance(other, Experiment): if not self.experiment_tables: - self.experiment_tables.append(ExperimentTable(experiments=[])) + self.experiment_tables.append(ExperimentTable()) self.experiment_tables[-1] += other else: raise ValueError( @@ -2155,7 +2063,7 @@ def model_dump(self, **kwargs) -> dict[str, Any]: 'measurement_files': [], 'model_files': {}, 'observable_files': [], - 'parameter_file': []}, + 'parameter_files': []}, 'experiments': [], 'mappings': [], 'measurements': [], @@ -2182,7 +2090,12 @@ def model_dump(self, **kwargs) -> dict[str, Any]: ("mappings", self.mapping_tables), ): res[field] = ( - [table.model_dump(**kwargs) for table in table_list] + list( + chain.from_iterable( + table.model_dump(**kwargs)["elements"] + for table in table_list + ) + ) if table_list else [] ) diff --git a/tests/v2/test_core.py b/tests/v2/test_core.py index 643f9172..da75dccd 100644 --- a/tests/v2/test_core.py +++ b/tests/v2/test_core.py @@ -273,7 +273,7 @@ def test_condition_table(): assert ( ConditionTable( - conditions=[ + [ Condition( id="condition1", changes=[Change(target_id="k1", target_value="true")], @@ -284,7 +284,7 @@ def test_condition_table(): ) assert ConditionTable( - conditions=[ + [ Condition( id="condition1", changes=[Change(target_id="k1", target_value=x / y)],