From 24306cabb98720cf842dfeecc429afb873056e73 Mon Sep 17 00:00:00 2001 From: Moazzam Shahzad Date: Sun, 30 Nov 2025 17:37:35 +0500 Subject: [PATCH 1/4] Add arviz_base dependency --- conda-envs/environment-alternative-backends.yml | 1 + conda-envs/environment-dev.yml | 1 + conda-envs/environment-docs.yml | 1 + conda-envs/environment-test.yml | 1 + conda-envs/windows-environment-dev.yml | 1 + conda-envs/windows-environment-test.yml | 1 + requirements-dev.txt | 1 + 7 files changed, 7 insertions(+) diff --git a/conda-envs/environment-alternative-backends.yml b/conda-envs/environment-alternative-backends.yml index c4c70c3955..f61e18dc2e 100644 --- a/conda-envs/environment-alternative-backends.yml +++ b/conda-envs/environment-alternative-backends.yml @@ -6,6 +6,7 @@ channels: dependencies: # Base dependencies - arviz>=0.13.0 +- arviz-base>=0.7.0 - blas - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index e58cc3c89f..c56125106a 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -6,6 +6,7 @@ channels: dependencies: # Base dependencies - arviz>=0.13.0 +- arviz-base - blas - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index c29e4bc1a4..0b06bb125b 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -6,6 +6,7 @@ channels: dependencies: # Base dependencies - arviz>=0.13.0 +- arviz-base - cachetools>=4.2.1 - cloudpickle - numpy>=1.25.0 diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 1c8af58fc9..33ac729220 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -6,6 +6,7 @@ channels: dependencies: # Base dependencies - arviz>=0.13.0 +- arviz-base>=0.7.0 - blas - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 750bd2b034..0ee47f1bd0 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -6,6 +6,7 @@ channels: dependencies: # Base dependencies (see install guide for Windows) - arviz>=0.13.0 +- arviz-base>=0.7.0 - blas - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 1f65504a32..07d873a780 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -6,6 +6,7 @@ channels: dependencies: # Base dependencies (see install guide for Windows) - arviz>=0.13.0 +- arviz-base>=0.7.0 - blas - cachetools>=4.2.1 - cloudpickle diff --git a/requirements-dev.txt b/requirements-dev.txt index 82a9b8aca1..9563ebeb02 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,7 @@ # This file is auto-generated by scripts/generate_pip_deps_from_conda.py, do not modify. # See that file for comments about the need/usage of each dependency. +arviz-base arviz>=0.13.0 cachetools>=4.2.1 cloudpickle From 5363790caf62686f37b1bb47dafb5051eb7d864d Mon Sep 17 00:00:00 2001 From: Moazzam Shahzad Date: Tue, 23 Dec 2025 00:03:27 +0500 Subject: [PATCH 2/4] Transferred imports of rcParams and requires from arviz to arviz_base --- conda-envs/environment-dev.yml | 2 +- conda-envs/environment-docs.yml | 2 +- pymc/backends/arviz.py | 9 ++++++--- requirements-dev.txt | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index c56125106a..ecb990e2d9 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -6,7 +6,7 @@ channels: dependencies: # Base dependencies - arviz>=0.13.0 -- arviz-base +- arviz-base>=0.7.0 - blas - cachetools>=4.2.1 - cloudpickle diff --git a/conda-envs/environment-docs.yml b/conda-envs/environment-docs.yml index 0b06bb125b..f932a3ba50 100644 --- a/conda-envs/environment-docs.yml +++ b/conda-envs/environment-docs.yml @@ -6,7 +6,7 @@ channels: dependencies: # Base dependencies - arviz>=0.13.0 -- arviz-base +- arviz-base>=0.7.0 - cachetools>=4.2.1 - cloudpickle - numpy>=1.25.0 diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 63f8370523..606f470a97 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -28,8 +28,11 @@ import numpy as np import xarray -from arviz import InferenceData, concat, rcParams -from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires +from arviz import InferenceData, concat +from arviz.data.base import CoordSpec, DimSpec +from arviz_base import dict_to_dataset +from arviz_base.base import requires +from arviz_base.rcparams import RcParams from pytensor.graph import ancestors from pytensor.tensor.sharedvar import SharedVariable from rich.progress import Console @@ -211,7 +214,7 @@ def __init__( save_warmup: bool | None = None, include_transformed: bool = False, ): - self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup + self.save_warmup = RcParams["data.save_warmup"] if save_warmup is None else save_warmup self.include_transformed = include_transformed self.trace = trace diff --git a/requirements-dev.txt b/requirements-dev.txt index 9563ebeb02..357efc45d8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,7 @@ # This file is auto-generated by scripts/generate_pip_deps_from_conda.py, do not modify. # See that file for comments about the need/usage of each dependency. -arviz-base +arviz-base>=0.7.0 arviz>=0.13.0 cachetools>=4.2.1 cloudpickle From b450d3c62d7e26ddf4e75e8ee03244998396a62a Mon Sep 17 00:00:00 2001 From: Moazzam Shahzad Date: Fri, 26 Dec 2025 23:53:19 +0500 Subject: [PATCH 3/4] Attempt to fix failing CI tests --- pymc/backends/arviz.py | 8 +++++--- tests/backends/test_arviz.py | 7 ++++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 606f470a97..b012a20924 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -32,7 +32,7 @@ from arviz.data.base import CoordSpec, DimSpec from arviz_base import dict_to_dataset from arviz_base.base import requires -from arviz_base.rcparams import RcParams +from arviz_base.rcparams import rcParams from pytensor.graph import ancestors from pytensor.tensor.sharedvar import SharedVariable from rich.progress import Console @@ -214,7 +214,7 @@ def __init__( save_warmup: bool | None = None, include_transformed: bool = False, ): - self.save_warmup = RcParams["data.save_warmup"] if save_warmup is None else save_warmup + self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup self.include_transformed = include_transformed self.trace = trace @@ -444,7 +444,9 @@ def constant_data_to_xarray(self): scalars = [var_name for var_name, value in constant_data.items() if np.ndim(value) == 0] for s in scalars: s_dim_0_name = f"{s}_dim_0" - xarray_dataset = xarray_dataset.squeeze(s_dim_0_name, drop=True) + # only remove the scalar if it exists in dims, otherwise you get KeyError + if s_dim_0_name in xarray_dataset.dims: + xarray_dataset = xarray_dataset.squeeze(s_dim_0_name, drop=True) return xarray_dataset diff --git a/tests/backends/test_arviz.py b/tests/backends/test_arviz.py index 85c1d9915c..5e7637ac19 100644 --- a/tests/backends/test_arviz.py +++ b/tests/backends/test_arviz.py @@ -639,7 +639,12 @@ def test_constant_data_coords_issue_5046(self): assert len(data[k].shape) == len(dims[k]) ds = pm.backends.arviz.dict_to_dataset( - data=data, library=pm, coords=coords, dims=dims, default_dims=[], index_origin=0 + data=data, + inference_library=pm, + coords=coords, + dims=dims, + sample_dims=[], + index_origin=0, ) for dname, cvals in coords.items(): np.testing.assert_array_equal(ds[dname].values, cvals) From 5f2062748d5bef46f4f7eed364e60bd4fda63c95 Mon Sep 17 00:00:00 2001 From: Moazzam Shahzad Date: Tue, 30 Dec 2025 16:35:44 +0500 Subject: [PATCH 4/4] Final update to function calls in backends/arviz and smc/sampling --- pymc/backends/arviz.py | 34 +++++++++++++++++++++------------- pymc/smc/sampling.py | 4 +++- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index b012a20924..37d41fbb0a 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -308,14 +308,14 @@ def posterior_to_xarray(self): return ( dict_to_dataset( data, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, attrs=self.attrs, ), dict_to_dataset( data_warmup, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, attrs=self.attrs, @@ -350,14 +350,14 @@ def sample_stats_to_xarray(self): return ( dict_to_dataset( data, - library=pymc, + inference_library=pymc, dims=None, coords=self.coords, attrs=self.attrs, ), dict_to_dataset( data_warmup, - library=pymc, + inference_library=pymc, dims=None, coords=self.coords, attrs=self.attrs, @@ -370,7 +370,11 @@ def posterior_predictive_to_xarray(self): data = self.posterior_predictive dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data} return dict_to_dataset( - data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims + data, + inference_library=pymc, + coords=self.coords, + dims=dims, + sample_dims=self.sample_dims, ) @requires(["predictions"]) @@ -379,7 +383,11 @@ def predictions_to_xarray(self): data = self.predictions dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data} return dict_to_dataset( - data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims + data, + inference_library=pymc, + coords=self.coords, + dims=dims, + sample_dims=self.sample_dims, ) def priors_to_xarray(self): @@ -402,7 +410,7 @@ def priors_to_xarray(self): if var_names is None else dict_to_dataset_drop_incompatible_coords( {k: np.expand_dims(self.prior[k], 0) for k in var_names}, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, ) @@ -417,10 +425,10 @@ def observed_data_to_xarray(self): return None return dict_to_dataset( self.observations, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, - default_dims=[], + sample_dims=[], ) @requires("model") @@ -432,10 +440,10 @@ def constant_data_to_xarray(self): xarray_dataset = dict_to_dataset( constant_data, - library=pymc, + inference_library=pymc, coords=self.coords, dims=self.dims, - default_dims=[], + sample_dims=[], ) # provisional handling of scalars in constant @@ -712,9 +720,9 @@ def apply_function_over_dataset( return dict_to_dataset( out_trace, - library=pymc, + inference_library=pymc, dims=dims, coords=coords, - default_dims=list(sample_dims), + sample_dims=list(sample_dims), skip_event_dims=True, ) diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 249f6c5253..08d383db15 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -267,7 +267,9 @@ def _save_sample_stats( sample_stats = dict_to_dataset( sample_stats_dict, attrs=sample_settings_dict, - library=pymc, + inference_library=pymc, + sample_dims=["chain"], + check_conventions=False, ) ikwargs: dict[str, Any] = {"model": model}