diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 63f8370523..6bf7b9c8e3 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -333,7 +333,7 @@ def sample_stats_to_xarray(self): data_warmup = {} for stat in self.trace.stat_names: name = rename_key.get(stat, stat) - if name == "tune": + if name in {"tune", "in_warmup"}: continue if self.warmup_trace: data_warmup[name] = np.array( diff --git a/pymc/backends/base.py b/pymc/backends/base.py index 993acc0df4..528552650f 100644 --- a/pymc/backends/base.py +++ b/pymc/backends/base.py @@ -113,7 +113,13 @@ def point(self, idx: int) -> dict[str, np.ndarray]: """ raise NotImplementedError() - def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]): + def record( + self, + draw: Mapping[str, np.ndarray], + stats: Sequence[Mapping[str, Any]], + *, + in_warmup: bool, + ): """Record results of a sampling iteration. Parameters @@ -122,6 +128,9 @@ def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, An Values mapped to variable names stats: list of dicts The diagnostic values for each sampler + in_warmup: bool + Whether this draw belongs to the warmup phase. This is a driver-owned + concept and is intended for storage/backends to persist warmup information. """ raise NotImplementedError() diff --git a/pymc/backends/mcbackend.py b/pymc/backends/mcbackend.py index d02a6dbebb..891afa51f9 100644 --- a/pymc/backends/mcbackend.py +++ b/pymc/backends/mcbackend.py @@ -34,7 +34,6 @@ BlockedStep, CompoundStep, StatsBijection, - check_step_emits_tune, flat_statname, flatten_steps, ) @@ -106,16 +105,26 @@ def __init__( {sname: stats_dtypes[fname] for fname, sname, is_obj in sstats} for sstats in stats_bijection._stat_groups ] + if "in_warmup" in stats_dtypes and self.sampler_vars: + # Expose driver-owned warmup marker via the sampler-stats API. + self.sampler_vars[0].setdefault("in_warmup", stats_dtypes["in_warmup"]) self._chain = chain self._point_fn = point_fn self._statsbj = stats_bijection super().__init__() - def record(self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]]): + def record( + self, + draw: Mapping[str, np.ndarray], + stats: Sequence[Mapping[str, Any]], + *, + in_warmup: bool, + ): values = self._point_fn(draw) value_dict = dict(zip(self.varnames, values)) stats_dict = self._statsbj.map(stats) + stats_dict["in_warmup"] = bool(in_warmup) # Apply pickling to objects stats for fname in self._statsbj.object_stats.keys(): val_bytes = pickle.dumps(stats_dict[fname]) @@ -148,6 +157,9 @@ def get_sampler_stats( self, stat_name: str, sampler_idx: int | None = None, burn=0, thin=1 ) -> np.ndarray: slc = slice(burn, None, thin) + if stat_name in {"in_warmup", "tune"}: + # Backwards-friendly alias for users that might try "tune". + return self._get_stats("in_warmup", slc) # When there's just one sampler, default to remove the sampler dimension if sampler_idx is None and self._statsbj.n_samplers == 1: sampler_idx = 0 @@ -210,8 +222,6 @@ def make_runmeta_and_point_fn( ) -> tuple[mcb.RunMeta, PointFunc]: variables, point_fn = get_variables_and_point_fn(model, initial_point) - check_step_emits_tune(step) - # In PyMC the sampler stats are grouped by the sampler. sample_stats = [] steps = flatten_steps(step) @@ -235,6 +245,16 @@ def make_runmeta_and_point_fn( ) sample_stats.append(svar) + # driver owned warmup marker. stored once per draw. + sample_stats.append( + mcb.Variable( + name="in_warmup", + dtype=np.dtype(bool).name, + shape=[], + undefined_ndim=False, + ) + ) + coordinates = [ mcb.Coordinate(dname, mcb.npproto.utils.ndarray_from_numpy(np.array(cvals))) for dname, cvals in model.coords.items() diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index a08fc8f47e..5d8d1be62b 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -97,7 +97,7 @@ def setup(self, draws, chain, sampler_vars=None) -> None: new = np.zeros(draws, dtype=dtype) data[varname] = np.concatenate([old, new]) - def record(self, point, sampler_stats=None) -> None: + def record(self, point, sampler_stats=None, *, in_warmup: bool) -> None: """Record results of a sampling iteration. Parameters @@ -238,5 +238,5 @@ def point_fun(point): chain.fn = point_fun for point in point_list: - chain.record(point) + chain.record(point, in_warmup=False) return MultiTrace([chain]) diff --git a/pymc/backends/zarr.py b/pymc/backends/zarr.py index 9b7664c504..d5be2a2776 100644 --- a/pymc/backends/zarr.py +++ b/pymc/backends/zarr.py @@ -159,7 +159,11 @@ def buffer(self, group, var_name, value): buffer[var_name].append(value) def record( - self, draw: Mapping[str, np.ndarray], stats: Sequence[Mapping[str, Any]] + self, + draw: Mapping[str, np.ndarray], + stats: Sequence[Mapping[str, Any]], + *, + in_warmup: bool, ) -> bool | None: """Record the step method's returned draw and stats. @@ -185,6 +189,7 @@ def record( self.buffer(group="posterior", var_name=var_name, value=var_value) for var_name, var_value in self.stats_bijection.map(stats).items(): self.buffer(group="sample_stats", var_name=var_name, value=var_value) + self.buffer(group="sample_stats", var_name="in_warmup", value=bool(in_warmup)) self._buffered_draws += 1 if self._buffered_draws == self.draws_until_flush: self.flush() @@ -525,6 +530,7 @@ def init_trace( stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps( [step] if isinstance(step, BlockedStep) else step.methods ) + stats_dtypes_shapes = {"in_warmup": (bool, [])} | stats_dtypes_shapes self.init_group_with_empty( group=self.root.create_group(name="sample_stats", overwrite=True), var_dtype_and_shape=stats_dtypes_shapes, diff --git a/pymc/progress_bar.py b/pymc/progress_bar.py index 26d2ab6fb5..62c5677442 100644 --- a/pymc/progress_bar.py +++ b/pymc/progress_bar.py @@ -285,6 +285,7 @@ def __init__( self._show_progress = show_progress self.completed_draws = 0 + self.tune = tune self.total_draws = draws + tune self.desc = "Sampling chain" self.chains = chains @@ -308,6 +309,7 @@ def _initialize_tasks(self): draws=0, total=self.total_draws * self.chains - 1, chain_idx=0, + in_warmup=self.tune > 0, sampling_speed=0, speed_unit="draws/s", failing=False, @@ -323,6 +325,7 @@ def _initialize_tasks(self): draws=0, total=self.total_draws - 1, chain_idx=chain_idx, + in_warmup=self.tune > 0, sampling_speed=0, speed_unit="draws/s", failing=False, @@ -381,6 +384,7 @@ def update(self, chain_idx, is_last, draw, tuning, stats): self.tasks[chain_idx], completed=draw, draws=draw, + in_warmup=tuning, sampling_speed=speed, speed_unit=unit, failing=failing, @@ -391,13 +395,17 @@ def update(self, chain_idx, is_last, draw, tuning, stats): self._progress.update( self.tasks[chain_idx], draws=draw + 1 if not self.combined_progress else draw, + in_warmup=False, failing=failing, **all_step_stats, refresh=True, ) def create_progress_bar(self, step_columns, progressbar, progressbar_theme): - columns = [TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1))] + columns = [ + TextColumn("{task.fields[draws]}", table_column=Column("Draws", ratio=1)), + TextColumn("{task.fields[in_warmup]}", table_column=Column("Warmup", ratio=1)), + ] if self.full_stats: columns += step_columns diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index de341c68cd..c03c3ce5fc 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -1043,18 +1043,10 @@ def _sample_return( else: traces, length = _choose_chains(traces, 0) mtrace = MultiTrace(traces)[:length] - # count the number of tune/draw iterations that happened - # ideally via the "tune" statistic, but not all samplers record it! - if "tune" in mtrace.stat_names: - # Get the tune stat directly from chain 0, sampler 0 - stat = mtrace._straces[0].get_sampler_stats("tune", sampler_idx=0) - stat = tuple(stat) - n_tune = stat.count(True) - n_draws = stat.count(False) - else: - # these may be wrong when KeyboardInterrupt happened, but they're better than nothing - n_tune = min(tune, len(mtrace)) - n_draws = max(0, len(mtrace) - n_tune) + # Count the number of tune/draw iterations that happened. + # The warmup/draw boundary is owned by the sampling driver. + n_tune = min(tune, len(mtrace)) + n_draws = max(0, len(mtrace) - n_tune) if discard_tuned_samples: mtrace = mtrace[n_tune:] @@ -1221,7 +1213,7 @@ def _sample( try: for it, stats in enumerate(sampling_gen): progress_manager.update( - chain_idx=chain, is_last=False, draw=it, stats=stats, tuning=it > tune + chain_idx=chain, is_last=False, draw=it, stats=stats, tuning=it < tune ) if not progress_manager.combined_progress or chain == progress_manager.chains - 1: @@ -1292,7 +1284,7 @@ def _iter_sample( step.stop_tuning() point, stats = step.step(point) - trace.record(point, stats) + trace.record(point, stats, in_warmup=i < tune) log_warning_stats(stats) if callback is not None: @@ -1405,7 +1397,7 @@ def _mp_sample( strace = traces[draw.chain] if not zarr_recording: # Zarr recording happens in each process - strace.record(draw.point, draw.stats) + strace.record(draw.point, draw.stats, in_warmup=draw.tuning) log_warning_stats(draw.stats) if callback is not None: diff --git a/pymc/sampling/parallel.py b/pymc/sampling/parallel.py index 6e229b9606..1e0ebfb1d2 100644 --- a/pymc/sampling/parallel.py +++ b/pymc/sampling/parallel.py @@ -219,7 +219,7 @@ def _start_loop(self): raise KeyboardInterrupt() elif msg[0] == "write_next": if zarr_recording: - self._zarr_chain.record(point, stats) + self._zarr_chain.record(point, stats, in_warmup=tuning) self._write_point(point) is_last = draw + 1 == self._draws + self._tune self._msg_pipe.send(("writing_done", is_last, draw, tuning, stats)) diff --git a/pymc/sampling/population.py b/pymc/sampling/population.py index 5bd1771704..7d1d9902f9 100644 --- a/pymc/sampling/population.py +++ b/pymc/sampling/population.py @@ -458,7 +458,7 @@ def _iter_population( # apply the update to the points and record to the traces for c, strace in enumerate(traces): points[c], stats = updates[c] - flushed = strace.record(points[c], stats) + flushed = strace.record(points[c], stats, in_warmup=i < tune) log_warning_stats(stats) if flushed and isinstance(strace, ZarrChain): sampling_state = popstep.request_sampling_state(c) diff --git a/pymc/smc/kernels.py b/pymc/smc/kernels.py index 567b73f514..aa7ccd73d4 100644 --- a/pymc/smc/kernels.py +++ b/pymc/smc/kernels.py @@ -359,7 +359,7 @@ def _posterior_to_trace(self, chain=0) -> NDArray: var_samples = np.round(var_samples).astype(var.dtype) value.append(var_samples.reshape(shape)) size += new_size - strace.record(point=dict(zip(varnames, value))) + strace.record(point=dict(zip(varnames, value)), in_warmup=False) return strace diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index a9cae903f0..389bd3b30a 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -92,6 +92,10 @@ def infer_warn_stats_info( sds[sname] = (dtype, None) elif sds: stats_dtypes.append({sname: dtype for sname, (dtype, _) in sds.items()}) + + # Even when a step method does not emit any stats, downstream components still assume one stats "slot" per step method. represent that with a single empty dict. + if not stats_dtypes: + stats_dtypes.append({}) return stats_dtypes, sds @@ -351,16 +355,6 @@ def flatten_steps(step: BlockedStep | CompoundStep) -> list[BlockedStep]: return steps -def check_step_emits_tune(step: CompoundStep | BlockedStep): - if isinstance(step, BlockedStep) and "tune" not in step.stats_dtypes_shapes: - raise TypeError(f"{type(step)} does not emit the required 'tune' stat.") - elif isinstance(step, CompoundStep): - for sstep in step.methods: - if "tune" not in sstep.stats_dtypes_shapes: - raise TypeError(f"{type(sstep)} does not emit the required 'tune' stat.") - return - - class StatsBijection: """Map between a `list` of stats to `dict` of stats.""" diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index 297b095e23..c3e6d75e5c 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -273,7 +273,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.iter_count += 1 stats: dict[str, Any] = { - "tune": self.tune, "diverging": diverging, "divergences": self.divergences, "perf_counter_diff": perf_end - perf_start, diff --git a/pymc/step_methods/hmc/hmc.py b/pymc/step_methods/hmc/hmc.py index 1697341bc8..57fd5219b1 100644 --- a/pymc/step_methods/hmc/hmc.py +++ b/pymc/step_methods/hmc/hmc.py @@ -53,7 +53,6 @@ class HamiltonianMC(BaseHMC): stats_dtypes_shapes = { "step_size": (np.float64, []), "n_steps": (np.int64, []), - "tune": (bool, []), "step_size_bar": (np.float64, []), "accept": (np.float64, []), "diverging": (bool, []), diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index c927d57e31..f674e852ee 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -110,7 +110,6 @@ class NUTS(BaseHMC): stats_dtypes_shapes = { "depth": (np.int64, []), "step_size": (np.float64, []), - "tune": (bool, []), "mean_tree_accept": (np.float64, []), "step_size_bar": (np.float64, []), "tree_size": (np.float64, []), diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index c042bc1f3d..eb154b0016 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -146,7 +146,6 @@ class Metropolis(ArrayStepShared): stats_dtypes_shapes = { "accept": (np.float64, []), "accepted": (np.float64, []), - "tune": (bool, []), "scaling": (np.float64, []), } @@ -316,7 +315,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.steps_until_tune -= 1 stats = { - "tune": self.tune, "scaling": np.mean(self.scaling), "accept": np.mean(np.exp(self.accept_rate_iter)), "accepted": np.mean(self.accepted_iter), @@ -331,7 +329,6 @@ def competence(var, has_grad): @staticmethod def _progressbar_config(n_chains=1): columns = [ - TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), TextColumn("{task.fields[scaling]:0.2f}", table_column=Column("Scaling", ratio=1)), TextColumn( "{task.fields[accept_rate]:0.2f}", table_column=Column("Accept Rate", ratio=1) @@ -339,7 +336,6 @@ def _progressbar_config(n_chains=1): ] stats = { - "tune": [True] * n_chains, "scaling": [0] * n_chains, "accept_rate": [0.0] * n_chains, } @@ -351,7 +347,7 @@ def _make_progressbar_update_functions(): def update_stats(step_stats): return { "accept_rate" if key == "accept" else key: step_stats[key] - for key in ("tune", "accept", "scaling") + for key in ("accept", "scaling") } return (update_stats,) @@ -448,7 +444,6 @@ class BinaryMetropolis(ArrayStep): stats_dtypes_shapes = { "accept": (np.float64, []), - "tune": (bool, []), "p_jump": (np.float64, []), } @@ -505,7 +500,6 @@ def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: self.accepted += accepted stats = { - "tune": self.tune, "accept": np.exp(accept), "p_jump": p_jump, } @@ -537,7 +531,6 @@ def competence(var): @dataclass_state class BinaryGibbsMetropolisState(StepMethodState): - tune: bool transit_p: int shuffle_dims: bool order: list @@ -574,9 +567,7 @@ class BinaryGibbsMetropolis(ArrayStep): name = "binary_gibbs_metropolis" - stats_dtypes_shapes = { - "tune": (bool, []), - } + stats_dtypes_shapes = {} _state_class = BinaryGibbsMetropolisState @@ -594,9 +585,6 @@ def __init__( ): model = pm.modelcontext(model) - # Doesn't actually tune, but it's required to emit a sampler stat - # that indicates whether a draw was done in a tuning phase. - self.tune = True # transition probabilities self.transit_p = transit_p @@ -649,10 +637,7 @@ def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: if accepted: logp_curr = logp_prop - stats = { - "tune": self.tune, - } - return q, [stats] + return q, [{}] @staticmethod def competence(var): @@ -695,9 +680,7 @@ class CategoricalGibbsMetropolis(ArrayStep): name = "categorical_gibbs_metropolis" - stats_dtypes_shapes = { - "tune": (bool, []), - } + stats_dtypes_shapes = {} _state_class = CategoricalGibbsMetropolisState @@ -793,7 +776,7 @@ def astep_unif(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType logp_curr = logp_prop # This step doesn't have any tunable parameters - return q, [{"tune": False}] + return q, [{}] def astep_prop(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: logp = args[0] @@ -811,7 +794,7 @@ def astep_prop(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType logp_curr = self.metropolis_proportional(q, logp, logp_curr, dim, k) # This step doesn't have any tunable parameters - return q, [{"tune": False}] + return q, [{}] def astep(self, apoint: RaveledVars, *args) -> tuple[RaveledVars, StatsType]: raise NotImplementedError() @@ -919,7 +902,6 @@ class DEMetropolis(PopulationArrayStepShared): stats_dtypes_shapes = { "accept": (np.float64, []), "accepted": (bool, []), - "tune": (bool, []), "scaling": (np.float64, []), "lambda": (np.float64, []), } @@ -1011,7 +993,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.steps_until_tune -= 1 stats = { - "tune": self.tune, "scaling": self.scaling, "lambda": self.lamb, "accept": np.exp(accept), @@ -1090,7 +1071,6 @@ class DEMetropolisZ(ArrayStepShared): stats_dtypes_shapes = { "accept": (np.float64, []), "accepted": (bool, []), - "tune": (bool, []), "scaling": (np.float64, []), "lambda": (np.float64, []), } @@ -1213,7 +1193,6 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: self.steps_until_tune -= 1 stats = { - "tune": self.tune, "scaling": np.mean(self.scaling), "lambda": self.lamb, "accept": np.exp(accept), diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 180ac1c882..5ea92fc916 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -72,7 +72,6 @@ class Slice(ArrayStepShared): name = "slice" default_blocked = False stats_dtypes_shapes = { - "tune": (bool, []), "nstep_out": (int, []), "nstep_in": (int, []), } @@ -184,7 +183,6 @@ def astep(self, apoint: RaveledVars) -> tuple[RaveledVars, StatsType]: self.n_tunes += 1 stats = { - "tune": self.tune, "nstep_out": nstep_out, "nstep_in": nstep_in, } @@ -202,18 +200,17 @@ def competence(var, has_grad): @staticmethod def _progressbar_config(n_chains=1): columns = [ - TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), TextColumn("{task.fields[nstep_out]}", table_column=Column("Steps out", ratio=1)), TextColumn("{task.fields[nstep_in]}", table_column=Column("Steps in", ratio=1)), ] - stats = {"tune": [True] * n_chains, "nstep_out": [0] * n_chains, "nstep_in": [0] * n_chains} + stats = {"nstep_out": [0] * n_chains, "nstep_in": [0] * n_chains} return columns, stats @staticmethod def _make_progressbar_update_functions(): def update_stats(step_stats): - return {key: step_stats[key] for key in {"tune", "nstep_out", "nstep_in"}} + return {key: step_stats[key] for key in {"nstep_out", "nstep_in"}} return (update_stats,) diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index 3cd5cc3dcf..1f0ce2e0b8 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -1592,7 +1592,7 @@ def sample( try: trace.setup(draws=draws, chain=0) for point in points: - trace.record(point) + trace.record(point, in_warmup=False) finally: trace.close() diff --git a/tests/backends/fixtures.py b/tests/backends/fixtures.py index a4f28a1262..a6a4699fe1 100644 --- a/tests/backends/fixtures.py +++ b/tests/backends/fixtures.py @@ -195,11 +195,11 @@ def setup_class(cls): stats2 = [ {key: val[idx] for key, val in stats.items()} for stats in cls.expected_stats[1] ] - strace0.record(point=point0, sampler_stats=stats1) - strace1.record(point=point1, sampler_stats=stats2) + strace0.record(point=point0, sampler_stats=stats1, in_warmup=False) + strace1.record(point=point1, sampler_stats=stats2, in_warmup=False) else: - strace0.record(point=point0) - strace1.record(point=point1) + strace0.record(point=point0, in_warmup=False) + strace1.record(point=point1, in_warmup=False) strace0.close() strace1.close() cls.mtrace = base.MultiTrace([strace0, strace1]) @@ -244,9 +244,9 @@ def record_point(self, val): } if self.sampler_vars is not None: stats = [{key: dtype(val) for key, dtype in vars.items()} for vars in self.sampler_vars] - self.strace.record(point=point, sampler_stats=stats) + self.strace.record(point=point, sampler_stats=stats, in_warmup=False) else: - self.strace.record(point=point) + self.strace.record(point=point, in_warmup=False) def test_standard_close(self): for idx in range(self.draws): @@ -270,7 +270,7 @@ def test_standard_close(self): def test_missing_stats(self): if self.sampler_vars is not None: with pytest.raises(ValueError): - self.strace.record(point=self.test_point) + self.strace.record(point=self.test_point, in_warmup=False) def test_clean_interrupt(self): self.record_point(0) diff --git a/tests/backends/test_arviz.py b/tests/backends/test_arviz.py index 85c1d9915c..571fa93f8a 100644 --- a/tests/backends/test_arviz.py +++ b/tests/backends/test_arviz.py @@ -748,9 +748,9 @@ def test_save_warmup(self, save_warmup, chains, tune, draws): post_prefix = "" if draws > 0 else "~" test_dict = { f"{post_prefix}posterior": ["u1", "n1"], - f"{post_prefix}sample_stats": ["~tune", "accept"], + f"{post_prefix}sample_stats": ["~in_warmup", "accept"], f"{warmup_prefix}warmup_posterior": ["u1", "n1"], - f"{warmup_prefix}warmup_sample_stats": ["~tune"], + f"{warmup_prefix}warmup_sample_stats": ["~in_warmup"], "~warmup_log_likelihood": [], "~log_likelihood": [], } @@ -785,9 +785,9 @@ def test_save_warmup_issue_1208_after_3_9(self): idata = to_inference_data(trace, save_warmup=True) test_dict = { "posterior": ["u1", "n1"], - "sample_stats": ["~tune", "accept"], + "sample_stats": ["~in_warmup", "accept"], "warmup_posterior": ["u1", "n1"], - "warmup_sample_stats": ["~tune", "accept"], + "warmup_sample_stats": ["~in_warmup", "accept"], } fails = check_multiple_attrs(test_dict, idata) assert not fails @@ -799,7 +799,7 @@ def test_save_warmup_issue_1208_after_3_9(self): idata = to_inference_data(trace[-30:], save_warmup=True) test_dict = { "posterior": ["u1", "n1"], - "sample_stats": ["~tune", "accept"], + "sample_stats": ["~in_warmup", "accept"], "~warmup_posterior": [], "~warmup_sample_stats": [], } diff --git a/tests/backends/test_base.py b/tests/backends/test_base.py index 0f450119a7..bbfbedab32 100644 --- a/tests/backends/test_base.py +++ b/tests/backends/test_base.py @@ -43,7 +43,7 @@ def test_init_trace_continuation_unsupported(self): B = pm.Uniform("B") strace = pm.backends.ndarray.NDArray(vars=[A, B]) strace.setup(10, 0) - strace.record({"A": 2, "B_interval__": 0.1}) + strace.record({"A": 2, "B_interval__": 0.1}, in_warmup=False) assert len(strace) == 1 with pytest.raises(ValueError, match="Continuation of traces"): _init_trace( diff --git a/tests/backends/test_mcbackend.py b/tests/backends/test_mcbackend.py index e72731af6b..89fa2ccba0 100644 --- a/tests/backends/test_mcbackend.py +++ b/tests/backends/test_mcbackend.py @@ -119,7 +119,8 @@ def test_make_runmeta_and_point_fn(simple_model): assert not vars["vector"].is_deterministic assert not vars["vector_interval__"].is_deterministic assert vars["matrix"].is_deterministic - assert len(rmeta.sample_stats) == len(step.stats_dtypes[0]) + assert "in_warmup" in {s.name for s in rmeta.sample_stats} + assert len(rmeta.sample_stats) == len(step.stats_dtypes[0]) + 1 with simple_model: step = pm.NUTS() @@ -201,7 +202,7 @@ def test_get_sampler_stats(self): for i in range(N): draw = {"a": rng.normal(), "b_interval__": rng.normal()} stats = [{"tune": (i <= 5), "s1": i, "accepted": bool(rng.randint(0, 2))}] - cra.record(draw, stats) + cra.record(draw, stats, in_warmup=i <= 5) # Check final state of the chain assert len(cra) == N @@ -254,7 +255,7 @@ def test_get_sampler_stats_compound(self, caplog): {"tune": tune, "s1": i, "accepted": bool(rng.randint(0, 2))}, {"tune": tune, "s2": i, "accepted": bool(rng.randint(0, 2))}, ] - cra.record(draw, stats) + cra.record(draw, stats, in_warmup=tune) # The 'accepted' stat was emitted by both samplers assert cra.get_sampler_stats("accepted", sampler_idx=None).shape == (N, 2) @@ -293,13 +294,20 @@ def test_return_multitrace(self, simple_model, discard_warmup): return_inferencedata=False, ) assert isinstance(mtrace, pm.backends.base.MultiTrace) - tune = mtrace._straces[0].get_sampler_stats("tune") - assert isinstance(tune, np.ndarray) + in_warmup = mtrace.get_sampler_stats("in_warmup", combine=False, squeeze=False) + assert len(in_warmup) == 3 + assert all(s.dtype == np.dtype(bool) for s in in_warmup) + + # Warmup is tracked by the sampling driver and persisted via `in_warmup`. if discard_warmup: - assert tune.shape == (7, 3) + assert len(mtrace) == 7 + assert all(len(s) == 7 for s in in_warmup) + assert all(not np.any(s) for s in in_warmup) else: - assert tune.shape == (12, 3) - pass + assert len(mtrace) == 12 + assert all(len(s) == 12 for s in in_warmup) + assert all(np.all(s[:5]) for s in in_warmup) + assert all(not np.any(s[5:]) for s in in_warmup) @pytest.mark.parametrize("cores", [1, 3]) def test_return_inferencedata(self, simple_model, cores): diff --git a/tests/backends/test_zarr.py b/tests/backends/test_zarr.py index af9c9e0a06..b90834f367 100644 --- a/tests/backends/test_zarr.py +++ b/tests/backends/test_zarr.py @@ -132,7 +132,7 @@ def test_record(model, model_step, include_transformed, draws_per_chunk): else: manually_collected_draws.append(point) manually_collected_stats.append(stats) - trace.straces[0].record(point, stats) + trace.straces[0].record(point, stats, in_warmup=tuning) trace.straces[0].record_sampling_state(model_step) assert {group_name for group_name, _ in trace.root.groups()} == expected_groups diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 090b76130b..fcacad7a95 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -414,6 +414,33 @@ def test_sample_return_lengths(self): assert idata.posterior.sizes["draw"] == 100 assert idata.posterior.sizes["chain"] == 3 + def test_categorical_gibbs_respects_driver_tune_boundary(self): + with pm.Model(): + pm.Categorical("x", p=np.array([0.2, 0.3, 0.5])) + sample_kwargs = { + "tune": 5, + "draws": 7, + "chains": 1, + "cores": 1, + "return_inferencedata": False, + "compute_convergence_checks": False, + "progressbar": False, + "random_seed": 123, + } + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) + mtrace = pm.sample(discard_tuned_samples=True, **sample_kwargs) + assert len(mtrace) == 7 + assert mtrace.report.n_tune == 5 + assert mtrace.report.n_draws == 7 + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) + with pytest.warns(UserWarning, match="will be included"): + mtrace_warmup = pm.sample(discard_tuned_samples=False, **sample_kwargs) + assert len(mtrace_warmup) == 12 + assert mtrace_warmup.report.n_tune == 5 + assert mtrace_warmup.report.n_draws == 7 + @pytest.mark.parametrize("cores", [1, 2]) def test_logs_sampler_warnings(self, caplog, cores): """Asserts that "warning" sampler stats are logged during sampling.""" diff --git a/tests/step_methods/hmc/test_nuts.py b/tests/step_methods/hmc/test_nuts.py index 8d497f3011..c453652079 100644 --- a/tests/step_methods/hmc/test_nuts.py +++ b/tests/step_methods/hmc/test_nuts.py @@ -157,7 +157,6 @@ def test_sampler_stats(self): "step_size", "step_size_bar", "tree_size", - "tune", "perf_counter_diff", "perf_counter_start", "process_time_diff", diff --git a/tests/step_methods/test_compound.py b/tests/step_methods/test_compound.py index 6c8957f9b3..a7a08bd500 100644 --- a/tests/step_methods/test_compound.py +++ b/tests/step_methods/test_compound.py @@ -36,11 +36,11 @@ from tests.models import simple_2model_continuous -def test_all_stepmethods_emit_tune_stat(): +def test_stepmethods_do_not_require_tune_stat(): step_types = pm.step_methods.STEP_METHODS assert len(step_types) > 5 for cls in step_types: - assert "tune" in cls.stats_dtypes_shapes + assert "tune" not in cls.stats_dtypes_shapes class TestCompoundStep: