Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions tests/backends/test_arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@
import pytest
import xarray

from arviz import InferenceData
from arviz.tests.helpers import check_multiple_attrs
from arviz_base.testing import check_multiple_attrs
from numpy import ma
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1

import pymc as pm

from pymc.backends.arviz import (
InferenceDataConverter,
DataTreeConverter,
dataset_to_point_list,
predictions_to_inference_data,
to_inference_data,
Expand Down Expand Up @@ -110,7 +109,7 @@ def get_inference_data(self, data, eight_schools_params):

def get_predictions_inference_data(
self, data, eight_schools_params, inplace
) -> tuple[InferenceData, dict[str, np.ndarray]]:
) -> tuple[xarray.DataTree, dict[str, np.ndarray]]:
with data.model:
prior = pm.sample_prior_predictive(return_inferencedata=False)
posterior_predictive = pm.sample_posterior_predictive(
Expand All @@ -123,17 +122,17 @@ def get_predictions_inference_data(
coords={"school": np.arange(eight_schools_params["J"])},
dims={"theta": ["school"], "eta": ["school"]},
)
assert isinstance(idata, InferenceData)
assert isinstance(idata, xarray.DataTree)
extended = predictions_to_inference_data(
posterior_predictive, idata_orig=idata, inplace=inplace
)
assert isinstance(extended, InferenceData)
assert isinstance(extended, xarray.DataTree)
assert (id(idata) == id(extended)) == inplace
return (extended, posterior_predictive)

def make_predictions_inference_data(
self, data, eight_schools_params
) -> tuple[InferenceData, dict[str, np.ndarray]]:
) -> tuple[xarray.DataTree, dict[str, np.ndarray]]:
with data.model:
posterior_predictive = pm.sample_posterior_predictive(
data.obj, return_inferencedata=False
Expand All @@ -144,7 +143,7 @@ def make_predictions_inference_data(
coords={"school": np.arange(eight_schools_params["J"])},
dims={"theta": ["school"], "eta": ["school"]},
)
assert isinstance(idata, InferenceData)
assert isinstance(idata, xarray.DataTree)
return idata, posterior_predictive

def test_to_idata(self, data, eight_schools_params, chains, draws):
Expand All @@ -166,7 +165,7 @@ def test_to_idata(self, data, eight_schools_params, chains, draws):
assert inference_data.log_likelihood["obs"].shape == (chains, draws, *obs.shape)

def test_predictions_to_idata(self, data, eight_schools_params):
"Test that we can add predictions to a previously-existing InferenceData."
"Test that we can add predictions to a previously-existing xarray.DataTree."
test_dict = {
"posterior": ["mu", "tau", "eta", "theta"],
"sample_stats": ["diverging", "lp"],
Expand Down Expand Up @@ -236,7 +235,7 @@ def test_posterior_predictive_thinned(self, data):
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
idata = pm.sample(tune=5, draws=draws, chains=2, return_inferencedata=True)
thinned_idata = idata.sel(draw=slice(None, None, thin_by))
idata.extend(pm.sample_posterior_predictive(thinned_idata))
idata.update(pm.sample_posterior_predictive(thinned_idata))
test_dict = {
"posterior": ["mu", "tau", "eta", "theta"],
"sample_stats": ["diverging", "lp", "~log_likelihood"],
Expand Down Expand Up @@ -639,7 +638,12 @@ def test_constant_data_coords_issue_5046(self):
assert len(data[k].shape) == len(dims[k])

ds = pm.backends.arviz.dict_to_dataset(
data=data, library=pm, coords=coords, dims=dims, default_dims=[], index_origin=0
data=data,
inference_library=pm,
coords=coords,
dims=dims,
sample_dims=[],
index_origin=0,
)
for dname, cvals in coords.items():
np.testing.assert_array_equal(ds[dname].values, cvals)
Expand All @@ -661,14 +665,14 @@ def test_issue_5043_autoconvert_coord_values(self):
)
# The converter must convert coord values them to numpy arrays
# because tuples as coordinate values causes problems with xarray.
converter = InferenceDataConverter(trace=mtrace)
converter = DataTreeConverter(trace=mtrace)
assert isinstance(converter.coords["city"], np.ndarray)
converter.to_inference_data()

# We're not automatically converting things other than tuple,
# so advanced use cases remain supported at the InferenceData level.
# so advanced use cases remain supported at the DataTree level.
# They just can't be used in the model construction already.
converter = InferenceDataConverter(
converter = DataTreeConverter(
trace=mtrace,
coords={
"city": pd.MultiIndex.from_tuples(
Expand Down Expand Up @@ -862,11 +866,13 @@ def test_incompatible_coordinate_lengths():
"Incompatible coordinate length of 3 for dimension 'a' of variable 'y'"
),
):
prior = pm.sample_prior_predictive(draws=1).prior.squeeze(("chain", "draw"))
prior = (
pm.sample_prior_predictive(draws=1).prior.to_dataset().squeeze(("chain", "draw"))
)
assert prior.x.dims == prior.y.dims == ("a",)
assert prior.x.shape == prior.y.shape == (3,)
assert np.isnan(prior.y.values[-1])
assert list(prior.coords["a"]) == [0, 1, 2]
assert list(prior.coords["a"]) == [-1, -2, -3]

pm.backends.arviz.RAISE_ON_INCOMPATIBLE_COORD_LENGTHS = True
with pytest.raises(ValueError):
Expand Down
4 changes: 1 addition & 3 deletions tests/backends/test_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import xarray as xr
import zarr

from arviz import InferenceData

import pymc as pm

from pymc.backends.zarr import ZarrTrace
Expand Down Expand Up @@ -436,7 +434,7 @@ def test_sample(
assert isinstance(out_trace, ZarrTrace)
assert out_trace.root.store is trace.root.store
else:
assert isinstance(out_trace, InferenceData)
assert isinstance(out_trace, xr.DataTree)

expected_groups = {"posterior", "constant_data", "observed_data", "sample_stats"}
if include_transformed:
Expand Down
16 changes: 8 additions & 8 deletions tests/gp/test_hsgp_approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ def test_prior(self, model, cov_func, X1, parametrization, rng):

idata = pm.sample_prior_predictive(draws=1000, random_seed=rng)

samples1 = az.extract(idata.prior["f1"])["f1"].values.T
samples2 = az.extract(idata.prior["f2"])["f2"].values.T
samples1 = az.extract(idata.prior["f1"]).values.T
samples2 = az.extract(idata.prior["f2"]).values.T

h0, mmd, critical_value, reject = two_sample_test(
samples1, samples2, n_sims=500, alpha=0.01
Expand All @@ -242,8 +242,8 @@ def test_conditional(self, model, cov_func, X1, parametrization):

idata = pm.sample_prior_predictive(draws=1000)

samples1 = az.extract(idata.prior["f"])["f"].values.T
samples2 = az.extract(idata.prior["fc"])["fc"].values.T
samples1 = az.extract(idata.prior["f"]).values.T
samples2 = az.extract(idata.prior["fc"]).values.T

h0, mmd, critical_value, reject = two_sample_test(
samples1, samples2, n_sims=500, alpha=0.01
Expand Down Expand Up @@ -302,8 +302,8 @@ def test_prior(self, model, cov_func, eta, X1, rng):

idata = pm.sample_prior_predictive(draws=1000, random_seed=rng)

samples1 = az.extract(idata.prior["f1"])["f1"].values.T
samples2 = az.extract(idata.prior["f2"])["f2"].values.T
samples1 = az.extract(idata.prior["f1"]).values.T
samples2 = az.extract(idata.prior["f2"]).values.T

h0, mmd, critical_value, reject = two_sample_test(
samples1, samples2, n_sims=500, alpha=0.01
Expand All @@ -323,8 +323,8 @@ def test_conditional_periodic(self, model, cov_func, X1):

idata = pm.sample_prior_predictive(draws=1000)

samples1 = az.extract(idata.prior["f"])["f"].values.T
samples2 = az.extract(idata.prior["fc"])["fc"].values.T
samples1 = az.extract(idata.prior["f"]).values.T
samples2 = az.extract(idata.prior["fc"]).values.T

h0, mmd, critical_value, reject = two_sample_test(
samples1, samples2, n_sims=500, alpha=0.01
Expand Down
8 changes: 5 additions & 3 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_nested_model_to_netcdf(self, tmp_path):
with pm.Model("scope") as model:
b = pm.Normal("var")
trace = pm.sample(100, tune=0)
az.to_netcdf(trace, tmp_path / "trace.nc")
trace.to_netcdf(tmp_path / "trace.nc")
trace1 = az.from_netcdf(tmp_path / "trace.nc")
assert "scope::var" in trace1.posterior

Expand Down Expand Up @@ -1430,8 +1430,10 @@ def test_interval_missing_observations(self):
np.testing.assert_array_equal(trace["theta2"][0][~obs2.mask], obs1[~obs2.mask])

pp_idata = pm.sample_posterior_predictive(trace, random_seed=rng)
pp_trace = pp_idata.posterior_predictive.stack(sample=["chain", "draw"]).transpose(
"sample", ...
pp_trace = (
pp_idata.posterior_predictive.to_dataset()
.stack(sample=["chain", "draw"])
.transpose("sample", ...)
)
assert set(pp_trace.keys()) == {
"theta1",
Expand Down
12 changes: 8 additions & 4 deletions tests/model/transform/test_conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,11 @@ def test_do_posterior_predictive():
# Dummy posterior
idata_m = az.from_dict(
{
"x": np.full((2, 500), 25),
"y": np.full((2, 500), np.nan),
"z": np.full((2, 500), np.nan),
"posterior": {
"x": np.full((2, 500), 25),
"y": np.full((2, 500), np.nan),
"z": np.full((2, 500), np.nan),
}
}
)

Expand Down Expand Up @@ -293,7 +295,9 @@ def test_do_sample_posterior_predictive(make_interventions_shared):
b = pm.Deterministic("b", a * 2)
c = pm.Normal("c", b / 2)

idata = az.from_dict({"a": [[1.0]], "b": [[2.0]], "c": [[1.0]]})
idata = az.from_dict(
{"posterior": {"a": np.array([[1.0]]), "b": np.array([[2.0]]), "c": np.array([[1.0]])}}
)

with do(model, {a: 1000}, make_interventions_shared=make_interventions_shared):
pp = sample_posterior_predictive(idata, var_names=["c"], predictions=True).predictions
Expand Down
Loading
Loading