Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 55 additions & 16 deletions petab/v2/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,11 @@ def _not_nan(v: float, info: ValidationInfo) -> float:


def _convert_nan_to_none(v):
"""Convert NaN or "" to None."""
if isinstance(v, float) and np.isnan(v):
return None
if isinstance(v, str) and v == "":
return None
return v


Expand Down Expand Up @@ -503,9 +506,17 @@ class ExperimentPeriod(BaseModel):
@field_validator("condition_ids", mode="before")
@classmethod
def _validate_ids(cls, condition_ids):
if condition_ids in [None, "", [], [""]]:
# unspecified, or "use-model-as-is"
return []

for condition_id in condition_ids:
# The empty condition ID for "use-model-as-is" has been handled
# above. Having a combination of empty and non-empty IDs is an
# error, since the targets of conditions to be combined must be
# disjoint.
if not is_valid_identifier(condition_id):
raise ValueError(f"Invalid ID: {condition_id}")
raise ValueError(f"Invalid {C.CONDITION_ID}: `{condition_id}'")
return condition_ids


Expand Down Expand Up @@ -854,17 +865,23 @@ class Parameter(BaseModel):
#: Parameter ID.
id: str = Field(alias=C.PARAMETER_ID)
#: Lower bound.
lb: float | None = Field(alias=C.LOWER_BOUND, default=None)
lb: Annotated[float | None, BeforeValidator(_convert_nan_to_none)] = Field(
alias=C.LOWER_BOUND, default=None
)
#: Upper bound.
ub: float | None = Field(alias=C.UPPER_BOUND, default=None)
ub: Annotated[float | None, BeforeValidator(_convert_nan_to_none)] = Field(
alias=C.UPPER_BOUND, default=None
)
#: Nominal value.
nominal_value: float | None = Field(alias=C.NOMINAL_VALUE, default=None)
nominal_value: Annotated[
float | None, BeforeValidator(_convert_nan_to_none)
] = Field(alias=C.NOMINAL_VALUE, default=None)
#: Is the parameter to be estimated?
estimate: bool = Field(alias=C.ESTIMATE, default=True)
#: Type of parameter prior distribution.
prior_distribution: PriorDistribution | None = Field(
alias=C.PRIOR_DISTRIBUTION, default=None
)
prior_distribution: Annotated[
PriorDistribution | None, BeforeValidator(_convert_nan_to_none)
] = Field(alias=C.PRIOR_DISTRIBUTION, default=None)
#: Prior distribution parameters.
prior_parameters: list[float] = Field(
alias=C.PRIOR_PARAMETERS, default_factory=list
Expand All @@ -889,8 +906,18 @@ def _validate_id(cls, v):

@field_validator("prior_parameters", mode="before")
@classmethod
def _validate_prior_parameters(cls, v):
def _validate_prior_parameters(
cls, v: str | list[str] | float | None | np.ndarray
):
if v is None:
return []

if isinstance(v, float) and np.isnan(v):
return []

if isinstance(v, str):
if v == "":
return []
v = v.split(C.PARAMETER_SEPARATOR)
elif not isinstance(v, Sequence):
v = [v]
Expand All @@ -899,7 +926,7 @@ def _validate_prior_parameters(cls, v):

@field_validator("estimate", mode="before")
@classmethod
def _validate_estimate_before(cls, v):
def _validate_estimate_before(cls, v: bool | str):
if isinstance(v, bool):
return v

Expand All @@ -918,12 +945,17 @@ def _validate_estimate_before(cls, v):
def _serialize_estimate(self, estimate: bool, _info):
return str(estimate).lower()

@field_validator("lb", "ub", "nominal_value")
@classmethod
def _convert_nan_to_none(cls, v):
if isinstance(v, float) and np.isnan(v):
return None
return v
@field_serializer("prior_distribution")
def _serialize_prior_distribution(
self, prior_distribution: PriorDistribution | None, _info
):
if prior_distribution is None:
return ""
return str(prior_distribution)

@field_serializer("prior_parameters")
def _serialize_prior_parameters(self, prior_parameters: list[str], _info):
return C.PARAMETER_SEPARATOR.join(prior_parameters)

@model_validator(mode="after")
def _validate(self) -> Self:
Expand Down Expand Up @@ -952,7 +984,7 @@ def _validate(self) -> Self:

@property
def prior_dist(self) -> Distribution:
"""Get the pior distribution of the parameter."""
"""Get the prior distribution of the parameter."""
if self.estimate is False:
raise ValueError(f"Parameter `{self.id}' is not estimated.")

Expand Down Expand Up @@ -980,6 +1012,13 @@ def prior_dist(self) -> Distribution:
"transformation."
)
return cls(*self.prior_parameters, trunc=[self.lb, self.ub])

if cls == Uniform:
# `Uniform.__init__` does not accept the `trunc` parameter
low = max(self.prior_parameters[0], self.lb)
high = min(self.prior_parameters[1], self.ub)
return cls(low, high, log=log)

return cls(*self.prior_parameters, log=log, trunc=[self.lb, self.ub])


Expand Down
Loading