Skip to content

Conversation

@arec1b0
Copy link

@arec1b0 arec1b0 commented Dec 23, 2025

Pre-sampling validation for shape mismatches in PyMC

Description

This PR implements pre-sampling validation to catch shape mismatches between model coordinates and variable dimensions before sampling begins, preventing potentially hours of wasted computation time.

Previously, PyMC would happily sample for extended periods only to fail during the final InferenceData conversion step when shape inconsistencies existed between model coords and variable dims.
This PR adds a validation function that runs a smoke test before sampling starts, evaluating variable shapes and comparing them against declared dimensions to catch errors early.


Key Changes

1. Added _validate_idata_conversion() function in mcmc.py

  • Uses PyTensor’s shape compilation to evaluate actual variable shapes
  • Compares declared dims against coordinate lengths
  • Provides clear, actionable error messages indicating exactly which variable has the problem and how to fix it
  • Minimal overhead: single shape evaluation, no sampling

2. Integrated validation into sample() function

  • Automatically runs when return_inferencedata=True (the default)
  • Executes after initial points are created but before any sampling begins
  • Gracefully handles errors without breaking existing workflows

3. Comprehensive test suite in test_issue_7891.py

  • Tests the exact scenario from the issue (deterministic with wrong dims)
  • Validates that correct models still work normally
  • Covers edge cases (scalar variables, multi-dimensional arrays)
  • Ensures helpful error messages are produced

Benefits

Saves computation time: Catches errors immediately instead of after hours of sampling
Better UX: Clear error messages tell users exactly what’s wrong and how to fix it
Minimal overhead: Fast shape evaluation without actual sampling
Backwards compatible: Only runs when creating InferenceData; doesn’t affect existing code
Comprehensive: Works with scalars, multi-dimensional variables, and deterministics


Example

Before this PR, the following model would sample for hours then fail:

with pm.Model(coords={"obs": range(10)}) as m:
    μ = pm.Normal("μ")
    y = pm.Deterministic("y", μ, dims="obs")  # shape mismatch!
    trace = pm.sample()

After this PR, it fails immediately with a helpful error:

ValueError: Variable 'y' has shape () but its dimension 'obs' implies length 10.
Verify that the variable's shape matches the coordinate length.

Related Issue

Closes #7891


Checklist

  • Checked that the pre-commit linting/style checks pass
  • Included tests that prove the fix is effective or that the new feature works
  • Added necessary documentation (docstrings in the validation function)
  • Each commit corresponds to a relevant logical change

Type of change

  • New feature / enhancement
  • Bug fix (prevents wasted computation from preventable errors)

Introduces a _validate_idata_conversion function to check for shape and dimension mismatches between model variables and their declared dims/coords before sampling begins, preventing wasted computation. Adds comprehensive tests to ensure mismatches are caught early and correct models proceed as expected (fixes pymc-devs#7891).
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR implements pre-sampling validation to catch shape mismatches between model coordinates and variable dimensions before sampling begins, addressing issue #7891. The goal is to prevent hours of wasted computation when inconsistencies exist between model coords and variable dims that would only be discovered during final InferenceData conversion.

Key changes:

  • Added _validate_idata_conversion() function that uses PyTensor shape compilation to evaluate variable shapes and compare them against declared dimensions
  • Integrated validation into the sample() function to run automatically when return_inferencedata=True
  • Added comprehensive test suite covering shape mismatches, valid models, scalar variables, and multi-dimensional arrays

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
pymc/sampling/mcmc.py Adds validation function _validate_idata_conversion() and integrates it into the sampling workflow to catch shape/dimension mismatches before sampling
tests/sampling/test_issue_7891.py Comprehensive test suite validating the pre-sampling shape mismatch detection, including edge cases for scalars and multi-dimensional variables

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +860 to +863
# Validate that InferenceData conversion will work before sampling
# This catches shape/dimension mismatches early to avoid wasting compute time
if return_inferencedata:
_validate_idata_conversion(model, initial_points[0], idata_kwargs or {})
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation is only applied to the PyMC NUTS sampler path, but not to external NUTS samplers (nutpie, numpyro, blackjax) which are handled in the _sample_external_nuts function (lines 801-817). Users with shape mismatches who use external samplers will still experience the issue this PR is trying to fix - hours of sampling followed by failure during InferenceData conversion. Consider adding the validation before the external sampler path as well, after obtaining initial points if needed.

Copilot uses AI. Check for mistakes.
Comment on lines +1192 to +1202
except Exception as e:
# If this is already our ValueError, re-raise it as-is
if isinstance(e, ValueError) and "Shape mismatch for variable" in str(e):
raise ValueError(
f"Pre-sampling validation failed: {e}\n\n"
f"See https://github.com/pymc-devs/pymc/issues/7891 for more information."
) from None

# For other exceptions, just skip validation - don't break existing code
# The user will get the error during idata conversion anyway
_log.debug(f"Pre-sampling validation skipped due to error: {type(e).__name__}: {e}")
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The broad exception handler (catching all Exception types) on line 1192 may inadvertently suppress validation errors that should be raised. Specifically, the check on line 1194 only looks for the substring "Shape mismatch for variable", but other ValueError instances raised in this function (like the dimension count mismatch on line 1170) contain different messages and won't be caught by this check. This means those validation errors will be silently suppressed and logged at debug level, defeating the purpose of pre-sampling validation. Consider checking for both error message patterns or restructuring the error handling to ensure all validation errors are properly propagated.

Copilot uses AI. Check for mistakes.
Comment on lines +109 to +125
def test_multiple_dims_validation(self):
"""Test validation with multiple dimensions."""
coords = {'dim_1': [1, 2], 'dim_2': ['a', 'b', 'c']}
with pm.Model(coords=coords):
# Create a 2D variable with correct dims
x = pm.Normal('x', mu=0, sigma=1, dims=['dim_1', 'dim_2'])

# Create a deterministic with wrong dims order
# x has shape (2, 3) with dims=['dim_1', 'dim_2']
# If we declare dims=['dim_2', 'dim_1'], we'd be saying
# shape is (3, 2)
pm.Deterministic('y', x, dims=['dim_2', 'dim_1'])

# This should catch the mismatch
match_text = "Pre-sampling validation failed"
with pytest.raises(ValueError, match=match_text):
pm.sample(draws=10, tune=10, chains=1, random_seed=42)
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no test coverage for the case where the number of dimensions doesn't match between the variable shape and the declared dims (validated at line 1170 in mcmc.py). Consider adding a test case where a scalar variable (shape ()) incorrectly declares dims=['some_dim'], or where a 1D variable incorrectly declares 2 dimensions. This would ensure the dimension count validation path is tested.

Copilot uses AI. Check for mistakes.
Comment on lines +1130 to +1133
if "coords" in idata_kwargs:
coords.update(idata_kwargs["coords"])
if "dims" in idata_kwargs:
dims.update(idata_kwargs["dims"])
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The coords.update() and dims.update() calls directly modify dictionaries returned by coords_and_dims_for_inferencedata(). If these dictionaries are shared or cached by the model, this could lead to unintended side effects. Consider creating copies before updating them, or ensure that coords_and_dims_for_inferencedata() returns new dictionaries each time.

Copilot uses AI. Check for mistakes.
@ricardoV94
Copy link
Member

As linked and discussed to in the original issue the alternative to just fail gracefully is being investigated, specifically #7912

The check proposed in this PR seems rather inefficient and will incur an overhead that would at least have to be benchmarked.

Also the LLM "summary" with 100 lines of code is not useful at all. Please summarize better if the LLM can't do it itself.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Check idata dims/coords for consistency before sampling begins

2 participants