From 0e890b0aafa9930cde57a6d0174c1f4ca473b50b Mon Sep 17 00:00:00 2001 From: eclipse1605 Date: Sat, 20 Dec 2025 16:49:57 +0530 Subject: [PATCH 1/4] attempt to fix warmup bookkeeping --- pymc/backends/mcbackend.py | 3 --- pymc/sampling/mcmc.py | 16 ++++------------ pymc/step_methods/compound.py | 10 ++++------ pymc/step_methods/hmc/base_hmc.py | 1 - pymc/step_methods/hmc/hmc.py | 1 - pymc/step_methods/hmc/nuts.py | 1 - pymc/step_methods/metropolis.py | 31 ++++++------------------------- pymc/step_methods/slicer.py | 7 ++----- tests/backends/test_mcbackend.py | 8 +++----- tests/sampling/test_mcmc.py | 27 +++++++++++++++++++++++++++ 10 files changed, 46 insertions(+), 59 deletions(-) diff --git a/pymc/backends/mcbackend.py b/pymc/backends/mcbackend.py index d02a6dbebb..e89ac19cf2 100644 --- a/pymc/backends/mcbackend.py +++ b/pymc/backends/mcbackend.py @@ -34,7 +34,6 @@ BlockedStep, CompoundStep, StatsBijection, - check_step_emits_tune, flat_statname, flatten_steps, ) @@ -210,8 +209,6 @@ def make_runmeta_and_point_fn( ) -> tuple[mcb.RunMeta, PointFunc]: variables, point_fn = get_variables_and_point_fn(model, initial_point) - check_step_emits_tune(step) - # In PyMC the sampler stats are grouped by the sampler. sample_stats = [] steps = flatten_steps(step) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index de341c68cd..949235bc76 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1043,18 +1043,10 @@ def _sample_return( else: traces, length = _choose_chains(traces, 0) mtrace = MultiTrace(traces)[:length] - # count the number of tune/draw iterations that happened - # ideally via the "tune" statistic, but not all samplers record it! - if "tune" in mtrace.stat_names: - # Get the tune stat directly from chain 0, sampler 0 - stat = mtrace._straces[0].get_sampler_stats("tune", sampler_idx=0) - stat = tuple(stat) - n_tune = stat.count(True) - n_draws = stat.count(False) - else: - # these may be wrong when KeyboardInterrupt happened, but they're better than nothing - n_tune = min(tune, len(mtrace)) - n_draws = max(0, len(mtrace) - n_tune) + # Count the number of tune/draw iterations that happened. + # The warmup/draw boundary is owned by the sampling driver. + n_tune = min(tune, len(mtrace)) + n_draws = max(0, len(mtrace) - n_tune) if discard_tuned_samples: mtrace = mtrace[n_tune:] diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index a9cae903f0..5984446c94 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -92,6 +92,10 @@ def infer_warn_stats_info( sds[sname] = (dtype, None) elif sds: stats_dtypes.append({sname: dtype for sname, (dtype, _) in sds.items()}) + + # Even when a step method does not emit any stats, downstream components still assume one stats "slot" per step method. represent that with a single empty dict. + if not stats_dtypes: + stats_dtypes.append({}) return stats_dtypes, sds @@ -352,12 +356,6 @@ def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]: def check_step_emits_tune(step: CompoundStep | BlockedStep): - if isinstance(step, BlockedStep) and "tune" not in step.stats_dtypes_shapes: - raise TypeError(f"{type(step)} does not emit the required 'tune' stat.") - elif isinstance(step, CompoundStep): - for sstep in step.methods: - if "tune" not in sstep.stats_dtypes_shapes: - raise TypeError(f"{type(sstep)} does not emit the required 'tune' stat.") return diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index 297b095e23..c3e6d75e5c 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -273,7 +273,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.iter_count += 1 stats: dict[str, Any] = { - "tune": self.tune, "diverging": diverging, "divergences": self.divergences, "perf_counter_diff": perf_end - perf_start, diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index 1697341bc8..57fd5219b1 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -53,7 +53,6 @@ class HamiltonianMC(BaseHMC): stats_dtypes_shapes = { "step_size": (np.float64, []), "n_steps": (np.int64, []), - "tune": (bool, []), "step_size_bar": (np.float64, []), "accept": (np.float64, []), "diverging": (bool, []), diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index c927d57e31..f674e852ee 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -110,7 +110,6 @@ class NUTS(BaseHMC): stats_dtypes_shapes = { "depth": (np.int64, []), "step_size": (np.float64, []), - "tune": (bool, []), "mean_tree_accept": (np.float64, []), "step_size_bar": (np.float64, []), "tree_size": (np.float64, []), diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index c042bc1f3d..b371f6dd48 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -146,7 +146,6 @@ class Metropolis(ArrayStepShared): stats_dtypes_shapes = { "accept": (np.float64, []), "accepted": (np.float64, []), - "tune": (bool, []), "scaling": (np.float64, []), } @@ -316,7 +315,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.steps_until_tune -= 1 stats = { - "tune": self.tune, "scaling": np.mean(self.scaling), "accept": np.mean(np.exp(self.accept_rate_iter)), "accepted": np.mean(self.accepted_iter), @@ -331,7 +329,6 @@ def competence(var, has_grad): @staticmethod def _progressbar_config(n_chains=1): columns = [ - TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), TextColumn("{task.fields[scaling]:0.2f}", table_column=Column("Scaling", ratio=1)), TextColumn( "{task.fields[accept_rate]:0.2f}", table_column=Column("Accept Rate", ratio=1) @@ -339,7 +336,6 @@ def _progressbar_config(n_chains=1): ] stats = { - "tune": [True] * n_chains, "scaling": [0] * n_chains, "accept_rate": [0.0] * n_chains, } @@ -351,7 +347,7 @@ def _make_progressbar_update_functions(): def update_stats(step_stats): return { "accept_rate" if key == "accept" else key: step_stats[key] - for key in ("tune", "accept", "scaling") + for key in ("accept", "scaling") } return (update_stats,) @@ -448,7 +444,6 @@ class BinaryMetropolis(ArrayStep): stats_dtypes_shapes = { "accept": (np.float64, []), - "tune": (bool, []), "p_jump": (np.float64, []), } @@ -505,7 +500,6 @@ def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: self.accepted += accepted stats = { - "tune": self.tune, "accept": np.exp(accept), "p_jump": p_jump, } @@ -574,9 +568,7 @@ class BinaryGibbsMetropolis(ArrayStep): name = "binary_gibbs_metropolis" - stats_dtypes_shapes = { - "tune": (bool, []), - } + stats_dtypes_shapes = {} _state_class = BinaryGibbsMetropolisState @@ -594,8 +586,6 @@ def __init__( ): model = pm.modelcontext(model) - # Doesn't actually tune, but it's required to emit a sampler stat - # that indicates whether a draw was done in a tuning phase. self.tune = True # transition probabilities self.transit_p = transit_p @@ -649,10 +639,7 @@ def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: if accepted: logp_curr = logp_prop - stats = { - "tune": self.tune, - } - return q, [stats] + return q, [{}] @staticmethod def competence(var): @@ -695,9 +682,7 @@ class CategoricalGibbsMetropolis(ArrayStep): name = "categorical_gibbs_metropolis" - stats_dtypes_shapes = { - "tune": (bool, []), - } + stats_dtypes_shapes = {} _state_class = CategoricalGibbsMetropolisState @@ -793,7 +778,7 @@ def astep_unif(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType logp_curr = logp_prop # This step doesn't have any tunable parameters - return q, [{"tune": False}] + return q, [{}] def astep_prop(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: logp = args[0] @@ -811,7 +796,7 @@ def astep_prop(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType logp_curr = self.metropolis_proportional(q, logp, logp_curr, dim, k) # This step doesn't have any tunable parameters - return q, [{"tune": False}] + return q, [{}] def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: raise NotImplementedError() @@ -919,7 +904,6 @@ class DEMetropolis(PopulationArrayStepShared): stats_dtypes_shapes = { "accept": (np.float64, []), "accepted": (bool, []), - "tune": (bool, []), "scaling": (np.float64, []), "lambda": (np.float64, []), } @@ -1011,7 +995,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.steps_until_tune -= 1 stats = { - "tune": self.tune, "scaling": self.scaling, "lambda": self.lamb, "accept": np.exp(accept), @@ -1090,7 +1073,6 @@ class DEMetropolisZ(ArrayStepShared): stats_dtypes_shapes = { "accept": (np.float64, []), "accepted": (bool, []), - "tune": (bool, []), "scaling": (np.float64, []), "lambda": (np.float64, []), } @@ -1213,7 +1195,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.steps_until_tune -= 1 stats = { - "tune": self.tune, "scaling": np.mean(self.scaling), "lambda": self.lamb, "accept": np.exp(accept), diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 180ac1c882..5ea92fc916 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -72,7 +72,6 @@ class Slice(ArrayStepShared): name = "slice" default_blocked = False stats_dtypes_shapes = { - "tune": (bool, []), "nstep_out": (int, []), "nstep_in": (int, []), } @@ -184,7 +183,6 @@ def astep(self, apoint: RaveledVars) -> tuple[RaveledVars, StatsType]: self.n_tunes += 1 stats = { - "tune": self.tune, "nstep_out": nstep_out, "nstep_in": nstep_in, } @@ -202,18 +200,17 @@ def competence(var, has_grad): @staticmethod def _progressbar_config(n_chains=1): columns = [ - TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), TextColumn("{task.fields[nstep_out]}", table_column=Column("Steps out", ratio=1)), TextColumn("{task.fields[nstep_in]}", table_column=Column("Steps in", ratio=1)), ] - stats = {"tune": [True] * n_chains, "nstep_out": [0] * n_chains, "nstep_in": [0] * n_chains} + stats = {"nstep_out": [0] * n_chains, "nstep_in": [0] * n_chains} return columns, stats @staticmethod def _make_progressbar_update_functions(): def update_stats(step_stats): - return {key: step_stats[key] for key in {"tune", "nstep_out", "nstep_in"}} + return {key: step_stats[key] for key in {"nstep_out", "nstep_in"}} return (update_stats,) diff --git a/tests/backends/test_mcbackend.py b/tests/backends/test_mcbackend.py index e72731af6b..64ad927454 100644 --- a/tests/backends/test_mcbackend.py +++ b/tests/backends/test_mcbackend.py @@ -293,13 +293,11 @@ def test_return_multitrace(self, simple_model, discard_warmup): return_inferencedata=False, ) assert isinstance(mtrace, pm.backends.base.MultiTrace) - tune = mtrace._straces[0].get_sampler_stats("tune") - assert isinstance(tune, np.ndarray) + # warmup is tracked by the sampling driver if discard_warmup: - assert tune.shape == (7, 3) + assert len(mtrace) == 7 else: - assert tune.shape == (12, 3) - pass + assert len(mtrace) == 12 @pytest.mark.parametrize("cores", [1, 3]) def test_return_inferencedata(self, simple_model, cores): diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 090b76130b..fcacad7a95 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -414,6 +414,33 @@ def test_sample_return_lengths(self): assert idata.posterior.sizes["draw"] == 100 assert idata.posterior.sizes["chain"] == 3 + def test_categorical_gibbs_respects_driver_tune_boundary(self): + with pm.Model(): + pm.Categorical("x", p=np.array([0.2, 0.3, 0.5])) + sample_kwargs = { + "tune": 5, + "draws": 7, + "chains": 1, + "cores": 1, + "return_inferencedata": False, + "compute_convergence_checks": False, + "progressbar": False, + "random_seed": 123, + } + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) + mtrace = pm.sample(discard_tuned_samples=True, **sample_kwargs) + assert len(mtrace) == 7 + assert mtrace.report.n_tune == 5 + assert mtrace.report.n_draws == 7 + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) + with pytest.warns(UserWarning, match="will be included"): + mtrace_warmup = pm.sample(discard_tuned_samples=False, **sample_kwargs) + assert len(mtrace_warmup) == 12 + assert mtrace_warmup.report.n_tune == 5 + assert mtrace_warmup.report.n_draws == 7 + @pytest.mark.parametrize("cores", [1, 2]) def test_logs_sampler_warnings(self, caplog, cores): """Asserts that "warning" sampler stats are logged during sampling.""" From 889d7cc6a2247ca28daf10b7e4e9b3d0dd1d3535 Mon Sep 17 00:00:00 2001 From: eclipse1605 Date: Tue, 30 Dec 2025 15:59:57 +0530 Subject: [PATCH 2/4] make warmup/tune driver owned and persist per draw tune marker via backends --- pymc/backends/base.py | 11 ++++++++++- pymc/backends/mcbackend.py | 24 +++++++++++++++++++++++- pymc/backends/ndarray.py | 2 +- pymc/backends/zarr.py | 8 +++++++- pymc/progress_bar.py | 10 +++++++++- pymc/sampling/mcmc.py | 6 +++--- pymc/sampling/parallel.py | 2 +- pymc/sampling/population.py | 2 +- 8 files changed, 55 insertions(+), 10 deletions(-) diff --git a/pymc/backends/base.py b/pymc/backends/base.py index 993acc0df4..d185bf03d0 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -113,7 +113,13 @@ def point(self, idx: int) -> dict[str, np.ndarray]: """ raise NotImplementedError() - def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]): + def record( + self, + draw: Mapping[str, np.ndarray], + stats: Sequence[Mapping[str, Any]], + *, + tune: bool | None = None, + ): """Record results of a sampling iteration. Parameters @@ -122,6 +128,9 @@ def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, An Values mapped to variable names stats: list of dicts The diagnostic values for each sampler + tune: bool | None + Whether this draw belongs to the tuning/warmup phase. This is a driver-owned + concept and is intended for storage/backends to persist warmup information. """ raise NotImplementedError() diff --git a/pymc/backends/mcbackend.py b/pymc/backends/mcbackend.py index e89ac19cf2..52e47c9c61 100644 --- a/pymc/backends/mcbackend.py +++ b/pymc/backends/mcbackend.py @@ -105,16 +105,26 @@ def __init__( {sname: stats_dtypes[fname] for fname, sname, is_obj in sstats} for sstats in stats_bijection._stat_groups ] + if "tune" in stats_dtypes and self.sampler_vars: + # expose driver-owned warmup marker via the sampler-stats API. + self.sampler_vars[0].setdefault("tune", stats_dtypes["tune"]) self._chain = chain self._point_fn = point_fn self._statsbj = stats_bijection super().__init__() - def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]): + def record( + self, + draw: Mapping[str, np.ndarray], + stats: Sequence[Mapping[str, Any]], + *, + tune: bool | None = None, + ): values = self._point_fn(draw) value_dict = dict(zip(self.varnames, values)) stats_dict = self._statsbj.map(stats) + stats_dict["tune"] = bool(tune) # Apply pickling to objects stats for fname in self._statsbj.object_stats.keys(): val_bytes = pickle.dumps(stats_dict[fname]) @@ -147,6 +157,8 @@ def get_sampler_stats( self, stat_name: str, sampler_idx: int | None = None, burn=0, thin=1 ) -> np.ndarray: slc = slice(burn, None, thin) + if stat_name == "tune": + return self._get_stats("tune", slc) # When there's just one sampler, default to remove the sampler dimension if sampler_idx is None and self._statsbj.n_samplers == 1: sampler_idx = 0 @@ -232,6 +244,16 @@ def make_runmeta_and_point_fn( ) sample_stats.append(svar) + # driver owned warmup marker. stored once per draw. + sample_stats.append( + mcb.Variable( + name="tune", + dtype=np.dtype(bool).name, + shape=[], + undefined_ndim=False, + ) + ) + coordinates = [ mcb.Coordinate(dname, mcb.npproto.utils.ndarray_from_numpy(np.array(cvals))) for dname, cvals in model.coords.items() diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index a08fc8f47e..d7142ea109 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -97,7 +97,7 @@ def setup(self, draws, chain, sampler_vars=None) -> None: new = np.zeros(draws, dtype=dtype) data[varname] = np.concatenate([old, new]) - def record(self, point, sampler_stats=None) -> None: + def record(self, point, sampler_stats=None, *, tune: bool | None = None) -> None: """Record results of a sampling iteration. Parameters diff --git a/pymc/backends/zarr.py b/pymc/backends/zarr.py index 9b7664c504..a5760a6a46 100644 --- a/pymc/backends/zarr.py +++ b/pymc/backends/zarr.py @@ -159,7 +159,11 @@ def buffer(self, group, var_name, value): buffer[var_name].append(value) def record( - self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]] + self, + draw: Mapping[str, np.ndarray], + stats: Sequence[Mapping[str, Any]], + *, + tune: bool | None = None, ) -> bool | None: """Record the step method's returned draw and stats. @@ -185,6 +189,7 @@ def record( self.buffer(group="posterior", var_name=var_name, value=var_value) for var_name, var_value in self.stats_bijection.map(stats).items(): self.buffer(group="sample_stats", var_name=var_name, value=var_value) + self.buffer(group="sample_stats", var_name="tune", value=bool(tune)) self._buffered_draws += 1 if self._buffered_draws == self.draws_until_flush: self.flush() @@ -525,6 +530,7 @@ def init_trace( stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps( [step] if isinstance(step, BlockedStep) else step.methods ) + stats_dtypes_shapes = {"tune": (bool, [])} | stats_dtypes_shapes self.init_group_with_empty( group=self.root.create_group(name="sample_stats", overwrite=True), var_dtype_and_shape=stats_dtypes_shapes, diff --git a/pymc/progress_bar.py b/pymc/progress_bar.py index 26d2ab6fb5..e8fe8f9c5b 100644 --- a/pymc/progress_bar.py +++ b/pymc/progress_bar.py @@ -285,6 +285,7 @@ def __init__( self._show_progress = show_progress self.completed_draws = 0 + self.tune = tune self.total_draws = draws + tune self.desc = "Sampling chain" self.chains = chains @@ -308,6 +309,7 @@ def _initialize_tasks(self): draws=0, total=self.total_draws * self.chains - 1, chain_idx=0, + tune=self.tune > 0, sampling_speed=0, speed_unit="draws/s", failing=False, @@ -323,6 +325,7 @@ def _initialize_tasks(self): draws=0, total=self.total_draws - 1, chain_idx=chain_idx, + tune=self.tune > 0, sampling_speed=0, speed_unit="draws/s", failing=False, @@ -381,6 +384,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats): self.tasks[chain_idx], completed=draw, draws=draw, + tune=tuning, sampling_speed=speed, speed_unit=unit, failing=failing, @@ -391,13 +395,17 @@ def update(self, chain_idx, is_last, draw, tuning, stats): self._progress.update( self.tasks[chain_idx], draws=draw + 1 if not self.combined_progress else draw, + tune=False, failing=failing, **all_step_stats, refresh=True, ) def create_progress_bar(self, step_columns, progressbar, progressbar_theme): - columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))] + columns = [ + TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1)), + TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), + ] if self.full_stats: columns += step_columns diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 949235bc76..e0c790bc34 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1213,7 +1213,7 @@ def _sample( try: for it, stats in enumerate(sampling_gen): progress_manager.update( - chain_idx=chain, is_last=False, draw=it, stats=stats, tuning=it > tune + chain_idx=chain, is_last=False, draw=it, stats=stats, tuning=it < tune ) if not progress_manager.combined_progress or chain == progress_manager.chains - 1: @@ -1284,7 +1284,7 @@ def _iter_sample( step.stop_tuning() point, stats = step.step(point) - trace.record(point, stats) + trace.record(point, stats, tune=i < tune) log_warning_stats(stats) if callback is not None: @@ -1397,7 +1397,7 @@ def _mp_sample( strace = traces[draw.chain] if not zarr_recording: # Zarr recording happens in each process - strace.record(draw.point, draw.stats) + strace.record(draw.point, draw.stats, tune=draw.tuning) log_warning_stats(draw.stats) if callback is not None: diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 6e229b9606..4ed07e3b20 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -219,7 +219,7 @@ def _start_loop(self): raise KeyboardInterrupt() elif msg[0] == "write_next": if zarr_recording: - self._zarr_chain.record(point, stats) + self._zarr_chain.record(point, stats, tune=tuning) self._write_point(point) is_last = draw + 1 == self._draws + self._tune self._msg_pipe.send(("writing_done", is_last, draw, tuning, stats)) diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 5bd1771704..842aebb045 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -458,7 +458,7 @@ def _iter_population( # apply the update to the points and record to the traces for c, strace in enumerate(traces): points[c], stats = updates[c] - flushed = strace.record(points[c], stats) + flushed = strace.record(points[c], stats, tune=i < tune) log_warning_stats(stats) if flushed and isinstance(strace, ZarrChain): sampling_state = popstep.request_sampling_state(c) From 9a8e1cb82ba61387f3291be7e31498a96f11132c Mon Sep 17 00:00:00 2001 From: eclipse1605 Date: Tue, 30 Dec 2025 21:27:44 +0530 Subject: [PATCH 3/4] persist driver owned warmup marker as in_warmup --- pymc/backends/arviz.py | 2 +- pymc/backends/base.py | 60 ++++++++++++++++++++++++++++++++++-- pymc/backends/mcbackend.py | 17 +++++----- pymc/backends/ndarray.py | 2 +- pymc/backends/zarr.py | 6 ++-- pymc/progress_bar.py | 10 +++--- pymc/sampling/mcmc.py | 8 +++-- pymc/sampling/parallel.py | 2 +- pymc/sampling/population.py | 4 +-- tests/backends/test_arviz.py | 10 +++--- 10 files changed, 89 insertions(+), 32 deletions(-) diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 63f8370523..6bf7b9c8e3 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -333,7 +333,7 @@ def sample_stats_to_xarray(self): data_warmup = {} for stat in self.trace.stat_names: name = rename_key.get(stat, stat) - if name == "tune": + if name in {"tune", "in_warmup"}: continue if self.warmup_trace: data_warmup[name] = np.array( diff --git a/pymc/backends/base.py b/pymc/backends/base.py index d185bf03d0..9559a28fa4 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -17,12 +17,14 @@ See the docstring for pymc.backends for more information """ +import inspect import itertools as itl import logging import warnings from abc import ABC from collections.abc import Mapping, Sequence, Sized +from functools import cache from typing import ( Any, TypeVar, @@ -40,6 +42,58 @@ logger = logging.getLogger(__name__) +@cache +def _record_supports_in_warmup(trace_type: type) -> str: + """Return how to call `trace.record` for this backend type. + + Returns + ------- + mode : {"kw", "no", "try"} + - "kw": safe to pass `in_warmup=` (parameter present or **kwargs supported) + - "no": do not pass `in_warmup=` + - "try": signature introspection failed; try `in_warmup=` and fallback on + unexpected-keyword errors. + """ + try: + sig = inspect.signature(trace_type.record) + except (TypeError, ValueError): + return "try" + + parameters = sig.parameters + if "in_warmup" in parameters: + return "kw" + if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values()): + return "kw" + return "no" + + +def _record_with_in_warmup( + trace: "IBaseTrace", + draw: Mapping[str, np.ndarray], + stats: Sequence[Mapping[str, Any]], + *, + in_warmup: bool, +): + """Record a draw, passing `in_warmup` when the backend supports it. + + this will keep compatibility with custom traces that predate the `in_warmup` kwarg. + """ + mode = _record_supports_in_warmup(type(trace)) + if mode == "kw": + return trace.record(draw, stats, in_warmup=in_warmup) + if mode == "no": + return trace.record(draw, stats) + + # fallback for backends we can't introspect reliably. + try: + return trace.record(draw, stats, in_warmup=in_warmup) + except TypeError as err: + message = str(err) + if "unexpected keyword argument" in message and "in_warmup" in message: + return trace.record(draw, stats) + raise + + class BackendError(Exception): pass @@ -118,7 +172,7 @@ def record( draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]], *, - tune: bool | None = None, + in_warmup: bool, ): """Record results of a sampling iteration. @@ -128,8 +182,8 @@ def record( Values mapped to variable names stats: list of dicts The diagnostic values for each sampler - tune: bool | None - Whether this draw belongs to the tuning/warmup phase. This is a driver-owned + in_warmup: bool + Whether this draw belongs to the warmup phase. This is a driver-owned concept and is intended for storage/backends to persist warmup information. """ raise NotImplementedError() diff --git a/pymc/backends/mcbackend.py b/pymc/backends/mcbackend.py index 52e47c9c61..891afa51f9 100644 --- a/pymc/backends/mcbackend.py +++ b/pymc/backends/mcbackend.py @@ -105,9 +105,9 @@ def __init__( {sname: stats_dtypes[fname] for fname, sname, is_obj in sstats} for sstats in stats_bijection._stat_groups ] - if "tune" in stats_dtypes and self.sampler_vars: - # expose driver-owned warmup marker via the sampler-stats API. - self.sampler_vars[0].setdefault("tune", stats_dtypes["tune"]) + if "in_warmup" in stats_dtypes and self.sampler_vars: + # Expose driver-owned warmup marker via the sampler-stats API. + self.sampler_vars[0].setdefault("in_warmup", stats_dtypes["in_warmup"]) self._chain = chain self._point_fn = point_fn @@ -119,12 +119,12 @@ def record( draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]], *, - tune: bool | None = None, + in_warmup: bool, ): values = self._point_fn(draw) value_dict = dict(zip(self.varnames, values)) stats_dict = self._statsbj.map(stats) - stats_dict["tune"] = bool(tune) + stats_dict["in_warmup"] = bool(in_warmup) # Apply pickling to objects stats for fname in self._statsbj.object_stats.keys(): val_bytes = pickle.dumps(stats_dict[fname]) @@ -157,8 +157,9 @@ def get_sampler_stats( self, stat_name: str, sampler_idx: int | None = None, burn=0, thin=1 ) -> np.ndarray: slc = slice(burn, None, thin) - if stat_name == "tune": - return self._get_stats("tune", slc) + if stat_name in {"in_warmup", "tune"}: + # Backwards-friendly alias for users that might try "tune". + return self._get_stats("in_warmup", slc) # When there's just one sampler, default to remove the sampler dimension if sampler_idx is None and self._statsbj.n_samplers == 1: sampler_idx = 0 @@ -247,7 +248,7 @@ def make_runmeta_and_point_fn( # driver owned warmup marker. stored once per draw. sample_stats.append( mcb.Variable( - name="tune", + name="in_warmup", dtype=np.dtype(bool).name, shape=[], undefined_ndim=False, diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index d7142ea109..77ac9a4fcb 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -97,7 +97,7 @@ def setup(self, draws, chain, sampler_vars=None) -> None: new = np.zeros(draws, dtype=dtype) data[varname] = np.concatenate([old, new]) - def record(self, point, sampler_stats=None, *, tune: bool | None = None) -> None: + def record(self, point, sampler_stats=None, *, in_warmup: bool) -> None: """Record results of a sampling iteration. Parameters diff --git a/pymc/backends/zarr.py b/pymc/backends/zarr.py index a5760a6a46..d5be2a2776 100644 --- a/pymc/backends/zarr.py +++ b/pymc/backends/zarr.py @@ -163,7 +163,7 @@ def record( draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]], *, - tune: bool | None = None, + in_warmup: bool, ) -> bool | None: """Record the step method's returned draw and stats. @@ -189,7 +189,7 @@ def record( self.buffer(group="posterior", var_name=var_name, value=var_value) for var_name, var_value in self.stats_bijection.map(stats).items(): self.buffer(group="sample_stats", var_name=var_name, value=var_value) - self.buffer(group="sample_stats", var_name="tune", value=bool(tune)) + self.buffer(group="sample_stats", var_name="in_warmup", value=bool(in_warmup)) self._buffered_draws += 1 if self._buffered_draws == self.draws_until_flush: self.flush() @@ -530,7 +530,7 @@ def init_trace( stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps( [step] if isinstance(step, BlockedStep) else step.methods ) - stats_dtypes_shapes = {"tune": (bool, [])} | stats_dtypes_shapes + stats_dtypes_shapes = {"in_warmup": (bool, [])} | stats_dtypes_shapes self.init_group_with_empty( group=self.root.create_group(name="sample_stats", overwrite=True), var_dtype_and_shape=stats_dtypes_shapes, diff --git a/pymc/progress_bar.py b/pymc/progress_bar.py index e8fe8f9c5b..62c5677442 100644 --- a/pymc/progress_bar.py +++ b/pymc/progress_bar.py @@ -309,7 +309,7 @@ def _initialize_tasks(self): draws=0, total=self.total_draws * self.chains - 1, chain_idx=0, - tune=self.tune > 0, + in_warmup=self.tune > 0, sampling_speed=0, speed_unit="draws/s", failing=False, @@ -325,7 +325,7 @@ def _initialize_tasks(self): draws=0, total=self.total_draws - 1, chain_idx=chain_idx, - tune=self.tune > 0, + in_warmup=self.tune > 0, sampling_speed=0, speed_unit="draws/s", failing=False, @@ -384,7 +384,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats): self.tasks[chain_idx], completed=draw, draws=draw, - tune=tuning, + in_warmup=tuning, sampling_speed=speed, speed_unit=unit, failing=failing, @@ -395,7 +395,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats): self._progress.update( self.tasks[chain_idx], draws=draw + 1 if not self.combined_progress else draw, - tune=False, + in_warmup=False, failing=failing, **all_step_stats, refresh=True, @@ -404,7 +404,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats): def create_progress_bar(self, step_columns, progressbar, progressbar_theme): columns = [ TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1)), - TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), + TextColumn("{task.fields[in_warmup]}", table_column=Column("Warmup", ratio=1)), ] if self.full_stats: diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index e0c790bc34..bedf684139 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -48,7 +48,7 @@ find_constants, find_observations, ) -from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains +from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains, _record_with_in_warmup from pymc.backends.zarr import ZarrChain, ZarrTrace from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError @@ -1284,7 +1284,7 @@ def _iter_sample( step.stop_tuning() point, stats = step.step(point) - trace.record(point, stats, tune=i < tune) + _record_with_in_warmup(trace, point, stats, in_warmup=i < tune) log_warning_stats(stats) if callback is not None: @@ -1397,7 +1397,9 @@ def _mp_sample( strace = traces[draw.chain] if not zarr_recording: # Zarr recording happens in each process - strace.record(draw.point, draw.stats, tune=draw.tuning) + _record_with_in_warmup( + strace, draw.point, draw.stats, in_warmup=draw.tuning + ) log_warning_stats(draw.stats) if callback is not None: diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 4ed07e3b20..1e0ebfb1d2 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -219,7 +219,7 @@ def _start_loop(self): raise KeyboardInterrupt() elif msg[0] == "write_next": if zarr_recording: - self._zarr_chain.record(point, stats, tune=tuning) + self._zarr_chain.record(point, stats, in_warmup=tuning) self._write_point(point) is_last = draw + 1 == self._draws + self._tune self._msg_pipe.send(("writing_done", is_last, draw, tuning, stats)) diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 842aebb045..f7ab2232ca 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -26,7 +26,7 @@ from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn -from pymc.backends.base import BaseTrace +from pymc.backends.base import BaseTrace, _record_with_in_warmup from pymc.backends.zarr import ZarrChain from pymc.initial_point import PointType from pymc.model import Model, modelcontext @@ -458,7 +458,7 @@ def _iter_population( # apply the update to the points and record to the traces for c, strace in enumerate(traces): points[c], stats = updates[c] - flushed = strace.record(points[c], stats, tune=i < tune) + flushed = _record_with_in_warmup(strace, points[c], stats, in_warmup=i < tune) log_warning_stats(stats) if flushed and isinstance(strace, ZarrChain): sampling_state = popstep.request_sampling_state(c) diff --git a/tests/backends/test_arviz.py b/tests/backends/test_arviz.py index 85c1d9915c..571fa93f8a 100644 --- a/tests/backends/test_arviz.py +++ b/tests/backends/test_arviz.py @@ -748,9 +748,9 @@ def test_save_warmup(self, save_warmup, chains, tune, draws): post_prefix = "" if draws > 0 else "~" test_dict = { f"{post_prefix}posterior": ["u1", "n1"], - f"{post_prefix}sample_stats": ["~tune", "accept"], + f"{post_prefix}sample_stats": ["~in_warmup", "accept"], f"{warmup_prefix}warmup_posterior": ["u1", "n1"], - f"{warmup_prefix}warmup_sample_stats": ["~tune"], + f"{warmup_prefix}warmup_sample_stats": ["~in_warmup"], "~warmup_log_likelihood": [], "~log_likelihood": [], } @@ -785,9 +785,9 @@ def test_save_warmup_issue_1208_after_3_9(self): idata = to_inference_data(trace, save_warmup=True) test_dict = { "posterior": ["u1", "n1"], - "sample_stats": ["~tune", "accept"], + "sample_stats": ["~in_warmup", "accept"], "warmup_posterior": ["u1", "n1"], - "warmup_sample_stats": ["~tune", "accept"], + "warmup_sample_stats": ["~in_warmup", "accept"], } fails = check_multiple_attrs(test_dict, idata) assert not fails @@ -799,7 +799,7 @@ def test_save_warmup_issue_1208_after_3_9(self): idata = to_inference_data(trace[-30:], save_warmup=True) test_dict = { "posterior": ["u1", "n1"], - "sample_stats": ["~tune", "accept"], + "sample_stats": ["~in_warmup", "accept"], "~warmup_posterior": [], "~warmup_sample_stats": [], } From 444c4aa1cc0a349cd8c3003a29d27b872803ab37 Mon Sep 17 00:00:00 2001 From: eclipse1605 Date: Fri, 2 Jan 2026 11:25:18 +0530 Subject: [PATCH 4/4] drop compatibility shim --- pymc/backends/base.py | 54 ------------------------------------- pymc/sampling/mcmc.py | 8 +++--- pymc/sampling/population.py | 4 +-- 3 files changed, 5 insertions(+), 61 deletions(-) diff --git a/pymc/backends/base.py b/pymc/backends/base.py index 9559a28fa4..528552650f 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -17,14 +17,12 @@ See the docstring for pymc.backends for more information """ -import inspect import itertools as itl import logging import warnings from abc import ABC from collections.abc import Mapping, Sequence, Sized -from functools import cache from typing import ( Any, TypeVar, @@ -42,58 +40,6 @@ logger = logging.getLogger(__name__) -@cache -def _record_supports_in_warmup(trace_type: type) -> str: - """Return how to call `trace.record` for this backend type. - - Returns - ------- - mode : {"kw", "no", "try"} - - "kw": safe to pass `in_warmup=` (parameter present or **kwargs supported) - - "no": do not pass `in_warmup=` - - "try": signature introspection failed; try `in_warmup=` and fallback on - unexpected-keyword errors. - """ - try: - sig = inspect.signature(trace_type.record) - except (TypeError, ValueError): - return "try" - - parameters = sig.parameters - if "in_warmup" in parameters: - return "kw" - if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values()): - return "kw" - return "no" - - -def _record_with_in_warmup( - trace: "IBaseTrace", - draw: Mapping[str, np.ndarray], - stats: Sequence[Mapping[str, Any]], - *, - in_warmup: bool, -): - """Record a draw, passing `in_warmup` when the backend supports it. - - this will keep compatibility with custom traces that predate the `in_warmup` kwarg. - """ - mode = _record_supports_in_warmup(type(trace)) - if mode == "kw": - return trace.record(draw, stats, in_warmup=in_warmup) - if mode == "no": - return trace.record(draw, stats) - - # fallback for backends we can't introspect reliably. - try: - return trace.record(draw, stats, in_warmup=in_warmup) - except TypeError as err: - message = str(err) - if "unexpected keyword argument" in message and "in_warmup" in message: - return trace.record(draw, stats) - raise - - class BackendError(Exception): pass diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index bedf684139..c03c3ce5fc 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -48,7 +48,7 @@ find_constants, find_observations, ) -from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains, _record_with_in_warmup +from pymc.backends.base import IBaseTrace, MultiTrace, _choose_chains from pymc.backends.zarr import ZarrChain, ZarrTrace from pymc.blocking import DictToArrayBijection from pymc.exceptions import SamplingError @@ -1284,7 +1284,7 @@ def _iter_sample( step.stop_tuning() point, stats = step.step(point) - _record_with_in_warmup(trace, point, stats, in_warmup=i < tune) + trace.record(point, stats, in_warmup=i < tune) log_warning_stats(stats) if callback is not None: @@ -1397,9 +1397,7 @@ def _mp_sample( strace = traces[draw.chain] if not zarr_recording: # Zarr recording happens in each process - _record_with_in_warmup( - strace, draw.point, draw.stats, in_warmup=draw.tuning - ) + strace.record(draw.point, draw.stats, in_warmup=draw.tuning) log_warning_stats(draw.stats) if callback is not None: diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index f7ab2232ca..7d1d9902f9 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -26,7 +26,7 @@ from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn -from pymc.backends.base import BaseTrace, _record_with_in_warmup +from pymc.backends.base import BaseTrace from pymc.backends.zarr import ZarrChain from pymc.initial_point import PointType from pymc.model import Model, modelcontext @@ -458,7 +458,7 @@ def _iter_population( # apply the update to the points and record to the traces for c, strace in enumerate(traces): points[c], stats = updates[c] - flushed = _record_with_in_warmup(strace, points[c], stats, in_warmup=i < tune) + flushed = strace.record(points[c], stats, in_warmup=i < tune) log_warning_stats(stats) if flushed and isinstance(strace, ZarrChain): sampling_state = popstep.request_sampling_state(c)