Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions petab/v2/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def free_symbols(self) -> set[sp.Symbol]:

class ExperimentPeriod(BaseModel):
"""A period of a timecourse or experiment defined by a start time
and a condition ID.
and a list of condition IDs.

This corresponds to a row of the PEtab experiments table.
"""
Expand All @@ -484,20 +484,19 @@ class ExperimentPeriod(BaseModel):
time: Annotated[float, AfterValidator(_is_finite_or_neg_inf)] = Field(
alias=C.TIME
)
#: The ID of the condition to be applied at the start time.
condition_id: str | None = Field(alias=C.CONDITION_ID, default=None)
#: The IDs of the conditions to be applied at the start time.
condition_ids: list[str] = Field(default_factory=list)

#: :meta private:
model_config = ConfigDict(populate_by_name=True, extra="allow")

@field_validator("condition_id", mode="before")
@field_validator("condition_ids", mode="before")
@classmethod
def _validate_id(cls, condition_id):
if pd.isna(condition_id) or not condition_id:
return None
if not is_valid_identifier(condition_id):
raise ValueError(f"Invalid ID: {condition_id}")
return condition_id
def _validate_ids(cls, condition_ids):
for condition_id in condition_ids:
if not is_valid_identifier(condition_id):
raise ValueError(f"Invalid ID: {condition_id}")
return condition_ids


class Experiment(BaseModel):
Expand Down Expand Up @@ -548,12 +547,20 @@ def from_df(cls, df: pd.DataFrame) -> ExperimentTable:

experiments = []
for experiment_id, cur_exp_df in df.groupby(C.EXPERIMENT_ID):
periods = [
ExperimentPeriod(
time=row[C.TIME], condition_id=row[C.CONDITION_ID]
periods = []
for timepoint in cur_exp_df[C.TIME].unique():
condition_ids = [
cid
for cid in cur_exp_df.loc[
cur_exp_df[C.TIME] == timepoint, C.CONDITION_ID
]
if not pd.isna(cid)
]
Comment on lines +551 to +558
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might need to add a dropna

Suggested change
for timepoint in cur_exp_df[C.TIME].unique():
condition_ids = [
cid
for cid in cur_exp_df.loc[
cur_exp_df[C.TIME] == timepoint, C.CONDITION_ID
]
if not pd.isna(cid)
]
for timepoint, cur_exp_time_df in cur_exp_df.groupby(C.TIME):
condition_ids = cur_exp_time_df[C.CONDITION_ID].unique()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to avoid the overhead of the additional groupy / creating the additional dataframes. Usually there will only be a single condition ID.

I'd avoid the unique otherwise this will hide errors in case of duplicated conditions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

creating the additional dataframes

Ah, you're right, I thought it was some view instead

I'd avoid the unique

Yes, not sure why I did that...

No issue with your code but I find the groupby faster to understand.
Independently of this, again this might be something that tools want in a library function

periods.append(
ExperimentPeriod(
time=timepoint, condition_ids=condition_ids
)
)
for _, row in cur_exp_df.iterrows()
]
experiments.append(Experiment(id=experiment_id, periods=periods))

return cls(experiments=experiments)
Expand All @@ -563,10 +570,12 @@ def to_df(self) -> pd.DataFrame:
records = [
{
C.EXPERIMENT_ID: experiment.id,
**period.model_dump(by_alias=True),
C.TIME: period.time,
C.CONDITION_ID: condition_id,
}
for experiment in self.experiments
for period in experiment.periods
for condition_id in period.condition_ids or [""]
]
return (
pd.DataFrame(records)
Expand Down
55 changes: 42 additions & 13 deletions petab/v2/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections.abc import Set
from dataclasses import dataclass, field
from enum import IntEnum
from itertools import chain
from pathlib import Path

import pandas as pd
Expand Down Expand Up @@ -373,8 +374,10 @@ class CheckValidConditionTargets(ValidationTask):
"""Check that all condition table targets are valid."""

def run(self, problem: Problem) -> ValidationIssue | None:
allowed_targets = set(
problem.model.get_valid_ids_for_condition_table()
allowed_targets = (
set(problem.model.get_valid_ids_for_condition_table())
if problem.model
else set()
)
allowed_targets |= set(get_output_parameters(problem))
allowed_targets |= {
Expand All @@ -394,6 +397,28 @@ def run(self, problem: Problem) -> ValidationIssue | None:
f"Condition table contains invalid targets: {invalid}"
)

# Check that changes of simultaneously applied conditions don't
# intersect
for experiment in problem.experiment_table.experiments:
for period in experiment.periods:
if not period.condition_ids:
continue
period_targets = set()
for condition_id in period.condition_ids:
condition_targets = {
change.target_id
for cond in problem.condition_table.conditions
if cond.id == condition_id
for change in cond.changes
}
Comment on lines +408 to +413
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this useful for tools? I guess tools might like some problem.condition_table.get_updates(condition_id) that provides a dictionary of updates to be applied. Then the keys can be used here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's already in a dict, the keys will be unique and we can't check for duplicates any. It works if we do the checking every time in such a function.
I think something like that will come when progressing with the importer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rather meant like this. It checks all conditions at once though, rather than one at a time, but there's no issue with identifying duplicates.

from collections import Counter
from itertools import chain

all_updates = [
    problem.condition_table.get_updates(condition_id)
    for condition_id in period.condition_ids
]
duplicates = {
    target_id
    for target_id, count in Counter(chain.from_iterable(all_updates)).items()
    if count > 1
}

Or a method that provides all_updates as a condition_id => updates dict.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not that it simplifies the code here, just that it might be useful to tools. But fine to keep as is until a tool dev requests something...

if invalid := (period_targets & condition_targets):
return ValidationError(
"Simultaneously applied conditions for experiment "
f"{experiment.id} have overlapping targets "
f"{invalid} at time {period.time}."
)
period_targets |= condition_targets


class CheckUniquePrimaryKeys(ValidationTask):
"""Check that all primary keys are unique."""
Expand Down Expand Up @@ -484,11 +509,14 @@ def run(self, problem: Problem) -> ValidationIssue | None:
c.id for c in problem.condition_table.conditions
}
for experiment in problem.experiment_table.experiments:
missing_conditions = {
period.condition_id
for period in experiment.periods
if period.condition_id is not None
} - available_conditions
missing_conditions = (
set(
chain.from_iterable(
period.condition_ids for period in experiment.periods
)
)
- available_conditions
)
Comment on lines +512 to +519
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing for here and below, is it useful to add this to the classes? Looks like a lot of additional interface though so fine as is

problem.experiment_table.get_condition_ids()
problem.experiment_table.get_experiment(experiment_id).get_condition_ids()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure yet. I think, usually one only needs the conditions for a single period, not for the full experiment at once.

if missing_conditions:
messages.append(
f"Experiment {experiment.id} requires conditions that are "
Expand Down Expand Up @@ -646,12 +674,13 @@ class CheckUnusedConditions(ValidationTask):
table."""

def run(self, problem: Problem) -> ValidationIssue | None:
used_conditions = {
p.condition_id
for e in problem.experiment_table.experiments
for p in e.periods
if p.condition_id is not None
}
used_conditions = set(
chain.from_iterable(
p.condition_ids
for e in problem.experiment_table.experiments
for p in e.periods
)
)
available_conditions = {
c.id for c in problem.condition_table.conditions
}
Expand Down
7 changes: 6 additions & 1 deletion petab/v2/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,12 @@ def add_experiment(self, id_: str, *args):
)

periods = [
core.ExperimentPeriod(time=args[i], condition_id=args[i + 1])
core.ExperimentPeriod(
time=args[i],
condition_ids=[cond]
if isinstance((cond := args[i + 1]), str)
else cond,
)
for i in range(0, len(args), 2)
]

Expand Down
16 changes: 8 additions & 8 deletions tests/v2/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def test_experiment_add_periods():
exp = Experiment(id="exp1")
assert exp.periods == []

p1 = ExperimentPeriod(time=0, condition_id="p1")
p2 = ExperimentPeriod(time=1, condition_id="p2")
p3 = ExperimentPeriod(time=2, condition_id="p3")
p1 = ExperimentPeriod(time=0, condition_ids=["p1"])
p2 = ExperimentPeriod(time=1, condition_ids=["p2"])
p3 = ExperimentPeriod(time=2, condition_ids=["p3"])
exp += p1
exp += p2

Expand Down Expand Up @@ -201,22 +201,22 @@ def test_change():

def test_period():
ExperimentPeriod(time=0)
ExperimentPeriod(time=1, condition_id="p1")
ExperimentPeriod(time="-inf", condition_id="p1")
ExperimentPeriod(time=1, condition_ids=["p1"])
ExperimentPeriod(time="-inf", condition_ids=["p1"])

assert (
ExperimentPeriod(time="1", condition_id="p1", non_petab=1).non_petab
== 1
)

with pytest.raises(ValidationError, match="got inf"):
ExperimentPeriod(time="inf", condition_id="p1")
ExperimentPeriod(time="inf", condition_ids=["p1"])

with pytest.raises(ValidationError, match="Invalid ID"):
ExperimentPeriod(time=1, condition_id="1_condition")
ExperimentPeriod(time=1, condition_ids=["1_condition"])

with pytest.raises(ValidationError, match="type=missing"):
ExperimentPeriod(condition_id="condition")
ExperimentPeriod(condition_ids=["condition"])


def test_parameter():
Expand Down
18 changes: 17 additions & 1 deletion tests/v2/test_lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from copy import deepcopy

from petab.v2 import Problem
from petab.v2.C import *
from petab.v2.lint import *
from petab.v2.models.sbml_model import SbmlModel


def test_check_experiments():
Expand All @@ -21,3 +21,19 @@ def test_check_experiments():
tmp_problem = deepcopy(problem)
tmp_problem["e1"].periods[0].time = tmp_problem["e1"].periods[1].time
assert check.run(tmp_problem) is not None


def test_check_incompatible_targets():
"""Multiple conditions with overlapping targets cannot be applied
at the same time."""
problem = Problem()
problem.model = SbmlModel.from_antimony("p1 = 1; p2 = 2")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤯

Didn't know an SBML test model could be specified so simply.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💪

problem.add_experiment("e1", 0, "c1", 1, "c2")
problem.add_condition("c1", p1="1")
problem.add_condition("c2", p1="2", p2="2")
check = CheckValidConditionTargets()
assert check.run(problem) is None

problem["e1"].periods[0].condition_ids.append("c2")
assert (error := check.run(problem)) is not None
assert "overlapping targets {'p1'}" in error.message