From 70ba791af54bd0095eeb23e9d926ed8caecb337c Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Wed, 3 Dec 2025 08:22:46 +0100 Subject: [PATCH] Implement sampling for v2 prior distributions --- petab/v1/distributions.py | 27 ++++++++++++++++++++++----- tests/v1/test_distributions.py | 5 +++++ 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/petab/v1/distributions.py b/petab/v1/distributions.py index 411add56..de7a638c 100644 --- a/petab/v1/distributions.py +++ b/petab/v1/distributions.py @@ -508,6 +508,9 @@ def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float: def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float: return cauchy.ppf(q, loc=self._loc, scale=self._scale) + def _sample(self, shape=None) -> np.ndarray | float: + return cauchy.rvs(loc=self._loc, scale=self._scale, size=shape) + @property def loc(self) -> float: """The location parameter of the underlying distribution.""" @@ -541,14 +544,16 @@ class ChiSquare(Distribution): def __init__( self, - dof: int, + dof: int | float, trunc: tuple[float, float] | None = None, log: bool | float = False, ): - if not dof.is_integer() or dof < 1: - raise ValueError( - f"`dof' must be a positive integer, but was `{dof}'." - ) + if isinstance(dof, float): + if not dof.is_integer() or dof < 1: + raise ValueError( + f"`dof' must be a positive integer, but was `{dof}'." + ) + dof = int(dof) self._dof = dof super().__init__(log=log, trunc=trunc) @@ -565,6 +570,9 @@ def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float: def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float: return chi2.ppf(q, df=self._dof) + def _sample(self, shape=None) -> np.ndarray | float: + return chi2.rvs(df=self._dof, size=shape) + @property def dof(self) -> int: """The degrees of freedom parameter.""" @@ -602,6 +610,9 @@ def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float: def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float: return expon.ppf(q, scale=self._scale) + def _sample(self, shape=None) -> np.ndarray | float: + return expon.rvs(scale=self._scale, size=shape) + @property def scale(self) -> float: """The scale parameter of the underlying distribution.""" @@ -650,6 +661,9 @@ def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float: def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float: return gamma.ppf(q, a=self._shape, scale=self._scale) + def _sample(self, shape=None) -> np.ndarray | float: + return gamma.rvs(a=self._shape, scale=self._scale, size=shape) + @property def shape(self) -> float: """The shape parameter of the underlying distribution.""" @@ -700,6 +714,9 @@ def _cdf_untransformed_untruncated(self, x) -> np.ndarray | float: def _ppf_untransformed_untruncated(self, q) -> np.ndarray | float: return rayleigh.ppf(q, scale=self._scale) + def _sample(self, shape=None) -> np.ndarray | float: + return rayleigh.rvs(scale=self._scale, size=shape) + @property def scale(self) -> float: """The scale parameter of the underlying distribution.""" diff --git a/tests/v1/test_distributions.py b/tests/v1/test_distributions.py index e06d9edc..7b7cd4aa 100644 --- a/tests/v1/test_distributions.py +++ b/tests/v1/test_distributions.py @@ -34,6 +34,11 @@ Normal(2, 1, log=10), Laplace(1, 2, trunc=(1, 2)), Laplace(1, 0.5, log=True, trunc=(0.5, 8)), + Cauchy(2, 1), + ChiSquare(4), + Exponential(1), + Gamma(3, 5), + Rayleigh(3), ], ) def test_sample_matches_pdf(distribution):