From d1f692f733b957d3cceff4844f6182aeef893c7f Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Thu, 3 Apr 2025 13:38:00 +0200 Subject: [PATCH 1/2] v2: Allow applying multiple conditions simultaneously --- petab/v2/core.py | 41 +++++++++++++++++++------------- petab/v2/lint.py | 55 +++++++++++++++++++++++++++++++++---------- petab/v2/problem.py | 7 +++++- tests/v2/test_core.py | 16 ++++++------- tests/v2/test_lint.py | 18 +++++++++++++- 5 files changed, 98 insertions(+), 39 deletions(-) diff --git a/petab/v2/core.py b/petab/v2/core.py index 72eaec31..43a5ab3d 100644 --- a/petab/v2/core.py +++ b/petab/v2/core.py @@ -475,7 +475,7 @@ def free_symbols(self) -> set[sp.Symbol]: class ExperimentPeriod(BaseModel): """A period of a timecourse or experiment defined by a start time - and a condition ID. + and a list of condition IDs. This corresponds to a row of the PEtab experiments table. """ @@ -484,20 +484,19 @@ class ExperimentPeriod(BaseModel): time: Annotated[float, AfterValidator(_is_finite_or_neg_inf)] = Field( alias=C.TIME ) - #: The ID of the condition to be applied at the start time. - condition_id: str | None = Field(alias=C.CONDITION_ID, default=None) + #: The IDs of the conditions to be applied at the start time. + condition_ids: list[str] = [] #: :meta private: model_config = ConfigDict(populate_by_name=True, extra="allow") - @field_validator("condition_id", mode="before") + @field_validator("condition_ids", mode="before") @classmethod - def _validate_id(cls, condition_id): - if pd.isna(condition_id) or not condition_id: - return None - if not is_valid_identifier(condition_id): - raise ValueError(f"Invalid ID: {condition_id}") - return condition_id + def _validate_ids(cls, condition_ids): + for condition_id in condition_ids: + if not is_valid_identifier(condition_id): + raise ValueError(f"Invalid ID: {condition_id}") + return condition_ids class Experiment(BaseModel): @@ -548,12 +547,20 @@ def from_df(cls, df: pd.DataFrame) -> ExperimentTable: experiments = [] for experiment_id, cur_exp_df in df.groupby(C.EXPERIMENT_ID): - periods = [ - ExperimentPeriod( - time=row[C.TIME], condition_id=row[C.CONDITION_ID] + periods = [] + for timepoint in cur_exp_df[C.TIME].unique(): + condition_ids = [ + cid + for cid in cur_exp_df.loc[ + cur_exp_df[C.TIME] == timepoint, C.CONDITION_ID + ] + if not pd.isna(cid) + ] + periods.append( + ExperimentPeriod( + time=timepoint, condition_ids=condition_ids + ) ) - for _, row in cur_exp_df.iterrows() - ] experiments.append(Experiment(id=experiment_id, periods=periods)) return cls(experiments=experiments) @@ -563,10 +570,12 @@ def to_df(self) -> pd.DataFrame: records = [ { C.EXPERIMENT_ID: experiment.id, - **period.model_dump(by_alias=True), + C.TIME: period.time, + C.CONDITION_ID: condition_id, } for experiment in self.experiments for period in experiment.periods + for condition_id in period.condition_ids or [""] ] return ( pd.DataFrame(records) diff --git a/petab/v2/lint.py b/petab/v2/lint.py index 71d655dd..6e7fc161 100644 --- a/petab/v2/lint.py +++ b/petab/v2/lint.py @@ -8,6 +8,7 @@ from collections.abc import Set from dataclasses import dataclass, field from enum import IntEnum +from itertools import chain from pathlib import Path import pandas as pd @@ -373,8 +374,10 @@ class CheckValidConditionTargets(ValidationTask): """Check that all condition table targets are valid.""" def run(self, problem: Problem) -> ValidationIssue | None: - allowed_targets = set( - problem.model.get_valid_ids_for_condition_table() + allowed_targets = ( + set(problem.model.get_valid_ids_for_condition_table()) + if problem.model + else set() ) allowed_targets |= set(get_output_parameters(problem)) allowed_targets |= { @@ -394,6 +397,28 @@ def run(self, problem: Problem) -> ValidationIssue | None: f"Condition table contains invalid targets: {invalid}" ) + # Check that changes of simultaneously applied conditions don't + # intersect + for experiment in problem.experiment_table.experiments: + for period in experiment.periods: + if not period.condition_ids: + continue + period_targets = set() + for condition_id in period.condition_ids: + condition_targets = { + change.target_id + for cond in problem.condition_table.conditions + if cond.id == condition_id + for change in cond.changes + } + if invalid := (period_targets & condition_targets): + return ValidationError( + "Simultaneously applied conditions for experiment " + f"{experiment.id} have overlapping targets " + f"{invalid} at time {period.time}." + ) + period_targets |= condition_targets + class CheckUniquePrimaryKeys(ValidationTask): """Check that all primary keys are unique.""" @@ -484,11 +509,14 @@ def run(self, problem: Problem) -> ValidationIssue | None: c.id for c in problem.condition_table.conditions } for experiment in problem.experiment_table.experiments: - missing_conditions = { - period.condition_id - for period in experiment.periods - if period.condition_id is not None - } - available_conditions + missing_conditions = ( + set( + chain.from_iterable( + period.condition_ids for period in experiment.periods + ) + ) + - available_conditions + ) if missing_conditions: messages.append( f"Experiment {experiment.id} requires conditions that are " @@ -646,12 +674,13 @@ class CheckUnusedConditions(ValidationTask): table.""" def run(self, problem: Problem) -> ValidationIssue | None: - used_conditions = { - p.condition_id - for e in problem.experiment_table.experiments - for p in e.periods - if p.condition_id is not None - } + used_conditions = set( + chain.from_iterable( + p.condition_ids + for e in problem.experiment_table.experiments + for p in e.periods + ) + ) available_conditions = { c.id for c in problem.condition_table.conditions } diff --git a/petab/v2/problem.py b/petab/v2/problem.py index b0b76aa9..c79b2f68 100644 --- a/petab/v2/problem.py +++ b/petab/v2/problem.py @@ -1062,7 +1062,12 @@ def add_experiment(self, id_: str, *args): ) periods = [ - core.ExperimentPeriod(time=args[i], condition_id=args[i + 1]) + core.ExperimentPeriod( + time=args[i], + condition_ids=[cond] + if isinstance((cond := args[i + 1]), str) + else cond, + ) for i in range(0, len(args), 2) ] diff --git a/tests/v2/test_core.py b/tests/v2/test_core.py index 181f5523..074c0d2d 100644 --- a/tests/v2/test_core.py +++ b/tests/v2/test_core.py @@ -39,9 +39,9 @@ def test_experiment_add_periods(): exp = Experiment(id="exp1") assert exp.periods == [] - p1 = ExperimentPeriod(time=0, condition_id="p1") - p2 = ExperimentPeriod(time=1, condition_id="p2") - p3 = ExperimentPeriod(time=2, condition_id="p3") + p1 = ExperimentPeriod(time=0, condition_ids=["p1"]) + p2 = ExperimentPeriod(time=1, condition_ids=["p2"]) + p3 = ExperimentPeriod(time=2, condition_ids=["p3"]) exp += p1 exp += p2 @@ -201,8 +201,8 @@ def test_change(): def test_period(): ExperimentPeriod(time=0) - ExperimentPeriod(time=1, condition_id="p1") - ExperimentPeriod(time="-inf", condition_id="p1") + ExperimentPeriod(time=1, condition_ids=["p1"]) + ExperimentPeriod(time="-inf", condition_ids=["p1"]) assert ( ExperimentPeriod(time="1", condition_id="p1", non_petab=1).non_petab @@ -210,13 +210,13 @@ def test_period(): ) with pytest.raises(ValidationError, match="got inf"): - ExperimentPeriod(time="inf", condition_id="p1") + ExperimentPeriod(time="inf", condition_ids=["p1"]) with pytest.raises(ValidationError, match="Invalid ID"): - ExperimentPeriod(time=1, condition_id="1_condition") + ExperimentPeriod(time=1, condition_ids=["1_condition"]) with pytest.raises(ValidationError, match="type=missing"): - ExperimentPeriod(condition_id="condition") + ExperimentPeriod(condition_ids=["condition"]) def test_parameter(): diff --git a/tests/v2/test_lint.py b/tests/v2/test_lint.py index 33cdb300..74aaaa29 100644 --- a/tests/v2/test_lint.py +++ b/tests/v2/test_lint.py @@ -3,8 +3,8 @@ from copy import deepcopy from petab.v2 import Problem -from petab.v2.C import * from petab.v2.lint import * +from petab.v2.models.sbml_model import SbmlModel def test_check_experiments(): @@ -21,3 +21,19 @@ def test_check_experiments(): tmp_problem = deepcopy(problem) tmp_problem["e1"].periods[0].time = tmp_problem["e1"].periods[1].time assert check.run(tmp_problem) is not None + + +def test_check_incompatible_targets(): + """Multiple conditions with overlapping targets cannot be applied + at the same time.""" + problem = Problem() + problem.model = SbmlModel.from_antimony("p1 = 1; p2 = 2") + problem.add_experiment("e1", 0, "c1", 1, "c2") + problem.add_condition("c1", p1="1") + problem.add_condition("c2", p1="2", p2="2") + check = CheckValidConditionTargets() + assert check.run(problem) is None + + problem["e1"].periods[0].condition_ids.append("c2") + assert (error := check.run(problem)) is not None + assert "overlapping targets {'p1'}" in error.message From ba1996f130fc1b0811c18c66f6ec827d8518a07d Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Thu, 24 Apr 2025 18:16:28 +0200 Subject: [PATCH 2/2] Update petab/v2/core.py Co-authored-by: Dilan Pathirana <59329744+dilpath@users.noreply.github.com> --- petab/v2/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/petab/v2/core.py b/petab/v2/core.py index 43a5ab3d..41abfb36 100644 --- a/petab/v2/core.py +++ b/petab/v2/core.py @@ -485,7 +485,7 @@ class ExperimentPeriod(BaseModel): alias=C.TIME ) #: The IDs of the conditions to be applied at the start time. - condition_ids: list[str] = [] + condition_ids: list[str] = Field(default_factory=list) #: :meta private: model_config = ConfigDict(populate_by_name=True, extra="allow")