diff --git a/petab/v2/core.py b/petab/v2/core.py index 37797610..1ee74ace 100644 --- a/petab/v2/core.py +++ b/petab/v2/core.py @@ -73,8 +73,11 @@ def _not_nan(v: float, info: ValidationInfo) -> float: def _convert_nan_to_none(v): + """Convert NaN or "" to None.""" if isinstance(v, float) and np.isnan(v): return None + if isinstance(v, str) and v == "": + return None return v @@ -503,9 +506,17 @@ class ExperimentPeriod(BaseModel): @field_validator("condition_ids", mode="before") @classmethod def _validate_ids(cls, condition_ids): + if condition_ids in [None, "", [], [""]]: + # unspecified, or "use-model-as-is" + return [] + for condition_id in condition_ids: + # The empty condition ID for "use-model-as-is" has been handled + # above. Having a combination of empty and non-empty IDs is an + # error, since the targets of conditions to be combined must be + # disjoint. if not is_valid_identifier(condition_id): - raise ValueError(f"Invalid ID: {condition_id}") + raise ValueError(f"Invalid {C.CONDITION_ID}: `{condition_id}'") return condition_ids @@ -854,17 +865,23 @@ class Parameter(BaseModel): #: Parameter ID. id: str = Field(alias=C.PARAMETER_ID) #: Lower bound. - lb: float | None = Field(alias=C.LOWER_BOUND, default=None) + lb: Annotated[float | None, BeforeValidator(_convert_nan_to_none)] = Field( + alias=C.LOWER_BOUND, default=None + ) #: Upper bound. - ub: float | None = Field(alias=C.UPPER_BOUND, default=None) + ub: Annotated[float | None, BeforeValidator(_convert_nan_to_none)] = Field( + alias=C.UPPER_BOUND, default=None + ) #: Nominal value. - nominal_value: float | None = Field(alias=C.NOMINAL_VALUE, default=None) + nominal_value: Annotated[ + float | None, BeforeValidator(_convert_nan_to_none) + ] = Field(alias=C.NOMINAL_VALUE, default=None) #: Is the parameter to be estimated? estimate: bool = Field(alias=C.ESTIMATE, default=True) #: Type of parameter prior distribution. - prior_distribution: PriorDistribution | None = Field( - alias=C.PRIOR_DISTRIBUTION, default=None - ) + prior_distribution: Annotated[ + PriorDistribution | None, BeforeValidator(_convert_nan_to_none) + ] = Field(alias=C.PRIOR_DISTRIBUTION, default=None) #: Prior distribution parameters. prior_parameters: list[float] = Field( alias=C.PRIOR_PARAMETERS, default_factory=list @@ -889,8 +906,18 @@ def _validate_id(cls, v): @field_validator("prior_parameters", mode="before") @classmethod - def _validate_prior_parameters(cls, v): + def _validate_prior_parameters( + cls, v: str | list[str] | float | None | np.ndarray + ): + if v is None: + return [] + + if isinstance(v, float) and np.isnan(v): + return [] + if isinstance(v, str): + if v == "": + return [] v = v.split(C.PARAMETER_SEPARATOR) elif not isinstance(v, Sequence): v = [v] @@ -899,7 +926,7 @@ def _validate_prior_parameters(cls, v): @field_validator("estimate", mode="before") @classmethod - def _validate_estimate_before(cls, v): + def _validate_estimate_before(cls, v: bool | str): if isinstance(v, bool): return v @@ -918,12 +945,17 @@ def _validate_estimate_before(cls, v): def _serialize_estimate(self, estimate: bool, _info): return str(estimate).lower() - @field_validator("lb", "ub", "nominal_value") - @classmethod - def _convert_nan_to_none(cls, v): - if isinstance(v, float) and np.isnan(v): - return None - return v + @field_serializer("prior_distribution") + def _serialize_prior_distribution( + self, prior_distribution: PriorDistribution | None, _info + ): + if prior_distribution is None: + return "" + return str(prior_distribution) + + @field_serializer("prior_parameters") + def _serialize_prior_parameters(self, prior_parameters: list[str], _info): + return C.PARAMETER_SEPARATOR.join(prior_parameters) @model_validator(mode="after") def _validate(self) -> Self: @@ -952,7 +984,7 @@ def _validate(self) -> Self: @property def prior_dist(self) -> Distribution: - """Get the pior distribution of the parameter.""" + """Get the prior distribution of the parameter.""" if self.estimate is False: raise ValueError(f"Parameter `{self.id}' is not estimated.") @@ -980,6 +1012,13 @@ def prior_dist(self) -> Distribution: "transformation." ) return cls(*self.prior_parameters, trunc=[self.lb, self.ub]) + + if cls == Uniform: + # `Uniform.__init__` does not accept the `trunc` parameter + low = max(self.prior_parameters[0], self.lb) + high = min(self.prior_parameters[1], self.ub) + return cls(low, high, log=log) + return cls(*self.prior_parameters, log=log, trunc=[self.lb, self.ub]) diff --git a/petab/v2/lint.py b/petab/v2/lint.py index 0fb055e8..2558ea3c 100644 --- a/petab/v2/lint.py +++ b/petab/v2/lint.py @@ -14,6 +14,8 @@ import pandas as pd import sympy as sp +from ..v2.C import * +from .core import PriorDistribution from .problem import Problem logger = logging.getLogger(__name__) @@ -37,6 +39,8 @@ "CheckUnusedExperiments", "CheckObservablesDoNotShadowModelEntities", "CheckUnusedConditions", + "CheckAllObservablesDefined", + "CheckPriorDistribution", "lint_problem", "default_validation_tasks", ] @@ -77,8 +81,12 @@ def __post_init__(self): def __str__(self): return f"{self.level.name}: {self.message}" - def _get_task_name(self): - """Get the name of the ValidationTask that raised this error.""" + @staticmethod + def _get_task_name() -> str | None: + """Get the name of the ValidationTask that raised this error. + + Expected to be called from below a `ValidationTask.run`. + """ import inspect # walk up the stack until we find the ValidationTask.run method @@ -88,6 +96,7 @@ def _get_task_name(self): task = frame.f_locals["self"] if isinstance(task, ValidationTask): return task.__class__.__name__ + return None @dataclass @@ -222,6 +231,8 @@ def run(self, problem: Problem) -> ValidationIssue | None: f"Missing files: {', '.join(missing_files)}" ) + return None + class CheckModel(ValidationTask): """A task to validate the model of a PEtab problem.""" @@ -234,6 +245,8 @@ def run(self, problem: Problem) -> ValidationIssue | None: # TODO get actual model validation messages return ValidationError("Model is invalid.") + return None + class CheckMeasuredObservablesDefined(ValidationTask): """A task to check that all observables referenced by the measurements @@ -252,10 +265,13 @@ def run(self, problem: Problem) -> ValidationIssue | None: "measurement table but not defined in observable table." ) + return None + class CheckOverridesMatchPlaceholders(ValidationTask): """A task to check that the number of observable/noise parameters - in the measurements match the number of placeholders in the observables.""" + in the measurements matches the number of placeholders in the observables. + """ def run(self, problem: Problem) -> ValidationIssue | None: observable_parameters_count = { @@ -320,18 +336,20 @@ def run(self, problem: Problem) -> ValidationIssue | None: if messages: return ValidationError("\n".join(messages)) + return None + class CheckPosLogMeasurements(ValidationTask): """Check that measurements for observables with log-transformation are positive.""" def run(self, problem: Problem) -> ValidationIssue | None: - from .core import NoiseDistribution as nd + from .core import NoiseDistribution as ND # noqa: N813 log_observables = { o.id for o in problem.observable_table.observables - if o.noise_distribution in [nd.LOG_NORMAL, nd.LOG_LAPLACE] + if o.noise_distribution in [ND.LOG_NORMAL, ND.LOG_LAPLACE] } if log_observables: for m in problem.measurement_table.measurements: @@ -342,6 +360,8 @@ def run(self, problem: Problem) -> ValidationIssue | None: f"positive, but {m.measurement} <= 0 for {m}" ) + return None + class CheckMeasuredExperimentsDefined(ValidationTask): """A task to check that all experiments referenced by measurements @@ -369,6 +389,8 @@ def run(self, problem: Problem) -> ValidationIssue | None: + str(missing_experiments) ) + return None + class CheckValidConditionTargets(ValidationTask): """Check that all condition table targets are valid.""" @@ -418,6 +440,32 @@ def run(self, problem: Problem) -> ValidationIssue | None: f"{invalid} at time {period.time}." ) period_targets |= condition_targets + return None + + +class CheckAllObservablesDefined(ValidationTask): + """A task to validate that all observables in the measurement table are + defined in the observable table.""" + + def run(self, problem: Problem) -> ValidationIssue | None: + if problem.measurement_df is None: + return None + + measurement_df = problem.measurement_df + observable_df = problem.observable_df + used_observables = set(measurement_df[OBSERVABLE_ID].values) + defined_observables = ( + set(observable_df.index.values) + if observable_df is not None + else set() + ) + if undefined_observables := (used_observables - defined_observables): + return ValidationError( + f"Observables {undefined_observables} are used in the" + "measurements table but are not defined in observables table." + ) + + return None class CheckUniquePrimaryKeys(ValidationTask): @@ -429,7 +477,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: # check for uniqueness of all primary keys counter = Counter(c.id for c in problem.condition_table.conditions) - duplicates = {id for id, count in counter.items() if count > 1} + duplicates = {id_ for id_, count in counter.items() if count > 1} if duplicates: return ValidationError( @@ -437,7 +485,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: ) counter = Counter(o.id for o in problem.observable_table.observables) - duplicates = {id for id, count in counter.items() if count > 1} + duplicates = {id_ for id_, count in counter.items() if count > 1} if duplicates: return ValidationError( @@ -445,7 +493,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: ) counter = Counter(e.id for e in problem.experiment_table.experiments) - duplicates = {id for id, count in counter.items() if count > 1} + duplicates = {id_ for id_, count in counter.items() if count > 1} if duplicates: return ValidationError( @@ -453,13 +501,15 @@ def run(self, problem: Problem) -> ValidationIssue | None: ) counter = Counter(p.id for p in problem.parameter_table.parameters) - duplicates = {id for id, count in counter.items() if count > 1} + duplicates = {id_ for id_, count in counter.items() if count > 1} if duplicates: return ValidationError( f"Parameter table contains duplicate IDs: {duplicates}" ) + return None + class CheckObservablesDoNotShadowModelEntities(ValidationTask): """A task to check that observable IDs do not shadow model entities.""" @@ -479,6 +529,8 @@ def run(self, problem: Problem) -> ValidationIssue | None: f"Observable IDs {shadowed_entities} shadow model entities." ) + return None + class CheckExperimentTable(ValidationTask): """A task to validate the experiment table of a PEtab problem.""" @@ -498,6 +550,8 @@ def run(self, problem: Problem) -> ValidationIssue | None: if messages: return ValidationError("\n".join(messages)) + return None + class CheckExperimentConditionsExist(ValidationTask): """A task to validate that all conditions in the experiment table exist @@ -526,6 +580,8 @@ def run(self, problem: Problem) -> ValidationIssue | None: if messages: return ValidationError("\n".join(messages)) + return None + class CheckAllParametersPresentInParameterTable(ValidationTask): """Ensure all required parameters are contained in the parameter table @@ -573,6 +629,8 @@ def run(self, problem: Problem) -> ValidationIssue | None: + str(extraneous) ) + return None + class CheckValidParameterInConditionOrParameterTable(ValidationTask): """A task to check that all required and only allowed model parameters are @@ -646,9 +704,11 @@ def run(self, problem: Problem) -> ValidationIssue | None: "the condition table and the parameter table." ) + return None + class CheckUnusedExperiments(ValidationTask): - """A task to check for experiments that are not used in the measurements + """A task to check for experiments that are not used in the measurement table.""" def run(self, problem: Problem) -> ValidationIssue | None: @@ -668,9 +728,11 @@ def run(self, problem: Problem) -> ValidationIssue | None: "measurements table." ) + return None + class CheckUnusedConditions(ValidationTask): - """A task to check for conditions that are not used in the experiments + """A task to check for conditions that are not used in the experiment table.""" def run(self, problem: Problem) -> ValidationIssue | None: @@ -692,6 +754,8 @@ def run(self, problem: Problem) -> ValidationIssue | None: "experiments table." ) + return None + class CheckVisualizationTable(ValidationTask): """A task to validate the visualization table of a PEtab problem.""" @@ -708,6 +772,68 @@ def run(self, problem: Problem) -> ValidationIssue | None: message="Visualization table is invalid.", ) + return None + + +class CheckPriorDistribution(ValidationTask): + """A task to validate the prior distribution of a PEtab problem.""" + + _num_pars = { + PriorDistribution.CAUCHY: 2, + PriorDistribution.CHI_SQUARED: 1, + PriorDistribution.EXPONENTIAL: 1, + PriorDistribution.GAMMA: 2, + PriorDistribution.LAPLACE: 2, + PriorDistribution.LOG10_NORMAL: 2, + PriorDistribution.LOG_LAPLACE: 2, + PriorDistribution.LOG_NORMAL: 2, + PriorDistribution.LOG_UNIFORM: 2, + PriorDistribution.NORMAL: 2, + PriorDistribution.RAYLEIGH: 1, + PriorDistribution.UNIFORM: 2, + } + + def run(self, problem: Problem) -> ValidationIssue | None: + messages = [] + for parameter in problem.parameter_table.parameters: + if parameter.prior_distribution is None: + continue + + if parameter.prior_distribution not in PRIOR_DISTRIBUTIONS: + messages.append( + f"Prior distribution `{parameter.prior_distribution}' " + f"for parameter `{parameter.id}' is not valid." + ) + continue + + if ( + exp_num_par := self._num_pars[parameter.prior_distribution] + ) != len(parameter.prior_parameters): + messages.append( + f"Prior distribution `{parameter.prior_distribution}' " + f"for parameter `{parameter.id}' requires " + f"{exp_num_par} parameters, but got " + f"{len(parameter.prior_parameters)} " + f"({parameter.prior_parameters})." + ) + + # TODO: check distribution parameter domains more specifically + try: + if parameter.estimate: + # .prior_dist fails for non-estimated parameters + _ = parameter.prior_dist.sample(1) + except Exception as e: + messages.append( + f"Prior parameters `{parameter.prior_parameters}' " + f"for parameter `{parameter.id}' are invalid " + f"(hint: {e})." + ) + + if messages: + return ValidationError("\n".join(messages)) + + return None + def get_valid_parameters_for_parameter_table( problem: Problem, @@ -752,7 +878,7 @@ def get_valid_parameters_for_parameter_table( if mapping.model_id and mapping.model_id in parameter_ids.keys(): parameter_ids[mapping.petab_id] = None - # add output parameters from observables table + # add output parameters from observable table output_parameters = get_output_parameters(problem) for p in output_parameters: if p not in invalid: @@ -781,7 +907,7 @@ def get_required_parameters_for_parameter_table( problem: Problem, ) -> Set[str]: """ - Get set of parameters which need to go into the parameter table + Get the set of parameters that need to go into the parameter table Arguments: problem: The PEtab problem @@ -965,4 +1091,9 @@ def get_placeholders( # TODO: atomize checks, update to long condition table, re-enable # CheckVisualizationTable(), # TODO validate mapping table + CheckValidParameterInConditionOrParameterTable(), + CheckAllObservablesDefined(), + CheckAllParametersPresentInParameterTable(), + CheckValidConditionTargets(), + CheckPriorDistribution(), ] diff --git a/petab/v2/petab1to2.py b/petab/v2/petab1to2.py index c788f116..bc7398fc 100644 --- a/petab/v2/petab1to2.py +++ b/petab/v2/petab1to2.py @@ -455,4 +455,21 @@ def update_prior(row): errors="ignore", ) + # if uniform, we need to explicitly set the parameters + def update_prior_pars(row): + prior_type = row.get(v2.C.PRIOR_DISTRIBUTION) + prior_pars = row.get(v2.C.PRIOR_PARAMETERS) + + if prior_type in (v2.C.UNIFORM, v2.C.LOG_UNIFORM) and pd.isna( + prior_pars + ): + return ( + f"{row[v2.C.LOWER_BOUND]}{v2.C.PARAMETER_SEPARATOR}" + f"{row[v2.C.UPPER_BOUND]}" + ) + + return prior_pars + + df[v2.C.PRIOR_PARAMETERS] = df.apply(update_prior_pars, axis=1) + return df diff --git a/petab/v2/problem.py b/petab/v2/problem.py index 01903b16..52baf724 100644 --- a/petab/v2/problem.py +++ b/petab/v2/problem.py @@ -1121,8 +1121,8 @@ def model_dump(self, **kwargs) -> dict[str, Any]: 'id': 'par', 'lb': 0.0, 'nominal_value': None, - 'prior_distribution': None, - 'prior_parameters': [], + 'prior_distribution': '', + 'prior_parameters': '', 'ub': 1.0}]} """ res = { diff --git a/tests/v2/test_core.py b/tests/v2/test_core.py index 074c0d2d..2aba25e4 100644 --- a/tests/v2/test_core.py +++ b/tests/v2/test_core.py @@ -212,7 +212,7 @@ def test_period(): with pytest.raises(ValidationError, match="got inf"): ExperimentPeriod(time="inf", condition_ids=["p1"]) - with pytest.raises(ValidationError, match="Invalid ID"): + with pytest.raises(ValidationError, match="Invalid conditionId"): ExperimentPeriod(time=1, condition_ids=["1_condition"]) with pytest.raises(ValidationError, match="type=missing"):