-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Pre-sampling validation for shape mismatches in PyMC #8020
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Introduces a _validate_idata_conversion function to check for shape and dimension mismatches between model variables and their declared dims/coords before sampling begins, preventing wasted computation. Adds comprehensive tests to ensure mismatches are caught early and correct models proceed as expected (fixes pymc-devs#7891).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR implements pre-sampling validation to catch shape mismatches between model coordinates and variable dimensions before sampling begins, addressing issue #7891. The goal is to prevent hours of wasted computation when inconsistencies exist between model coords and variable dims that would only be discovered during final InferenceData conversion.
Key changes:
- Added
_validate_idata_conversion()function that uses PyTensor shape compilation to evaluate variable shapes and compare them against declared dimensions - Integrated validation into the
sample()function to run automatically whenreturn_inferencedata=True - Added comprehensive test suite covering shape mismatches, valid models, scalar variables, and multi-dimensional arrays
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
pymc/sampling/mcmc.py |
Adds validation function _validate_idata_conversion() and integrates it into the sampling workflow to catch shape/dimension mismatches before sampling |
tests/sampling/test_issue_7891.py |
Comprehensive test suite validating the pre-sampling shape mismatch detection, including edge cases for scalars and multi-dimensional variables |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Validate that InferenceData conversion will work before sampling | ||
| # This catches shape/dimension mismatches early to avoid wasting compute time | ||
| if return_inferencedata: | ||
| _validate_idata_conversion(model, initial_points[0], idata_kwargs or {}) |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The validation is only applied to the PyMC NUTS sampler path, but not to external NUTS samplers (nutpie, numpyro, blackjax) which are handled in the _sample_external_nuts function (lines 801-817). Users with shape mismatches who use external samplers will still experience the issue this PR is trying to fix - hours of sampling followed by failure during InferenceData conversion. Consider adding the validation before the external sampler path as well, after obtaining initial points if needed.
| except Exception as e: | ||
| # If this is already our ValueError, re-raise it as-is | ||
| if isinstance(e, ValueError) and "Shape mismatch for variable" in str(e): | ||
| raise ValueError( | ||
| f"Pre-sampling validation failed: {e}\n\n" | ||
| f"See https://github.com/pymc-devs/pymc/issues/7891 for more information." | ||
| ) from None | ||
|
|
||
| # For other exceptions, just skip validation - don't break existing code | ||
| # The user will get the error during idata conversion anyway | ||
| _log.debug(f"Pre-sampling validation skipped due to error: {type(e).__name__}: {e}") |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The broad exception handler (catching all Exception types) on line 1192 may inadvertently suppress validation errors that should be raised. Specifically, the check on line 1194 only looks for the substring "Shape mismatch for variable", but other ValueError instances raised in this function (like the dimension count mismatch on line 1170) contain different messages and won't be caught by this check. This means those validation errors will be silently suppressed and logged at debug level, defeating the purpose of pre-sampling validation. Consider checking for both error message patterns or restructuring the error handling to ensure all validation errors are properly propagated.
| def test_multiple_dims_validation(self): | ||
| """Test validation with multiple dimensions.""" | ||
| coords = {'dim_1': [1, 2], 'dim_2': ['a', 'b', 'c']} | ||
| with pm.Model(coords=coords): | ||
| # Create a 2D variable with correct dims | ||
| x = pm.Normal('x', mu=0, sigma=1, dims=['dim_1', 'dim_2']) | ||
|
|
||
| # Create a deterministic with wrong dims order | ||
| # x has shape (2, 3) with dims=['dim_1', 'dim_2'] | ||
| # If we declare dims=['dim_2', 'dim_1'], we'd be saying | ||
| # shape is (3, 2) | ||
| pm.Deterministic('y', x, dims=['dim_2', 'dim_1']) | ||
|
|
||
| # This should catch the mismatch | ||
| match_text = "Pre-sampling validation failed" | ||
| with pytest.raises(ValueError, match=match_text): | ||
| pm.sample(draws=10, tune=10, chains=1, random_seed=42) |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no test coverage for the case where the number of dimensions doesn't match between the variable shape and the declared dims (validated at line 1170 in mcmc.py). Consider adding a test case where a scalar variable (shape ()) incorrectly declares dims=['some_dim'], or where a 1D variable incorrectly declares 2 dimensions. This would ensure the dimension count validation path is tested.
| if "coords" in idata_kwargs: | ||
| coords.update(idata_kwargs["coords"]) | ||
| if "dims" in idata_kwargs: | ||
| dims.update(idata_kwargs["dims"]) |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The coords.update() and dims.update() calls directly modify dictionaries returned by coords_and_dims_for_inferencedata(). If these dictionaries are shared or cached by the model, this could lead to unintended side effects. Consider creating copies before updating them, or ensure that coords_and_dims_for_inferencedata() returns new dictionaries each time.
|
As linked and discussed to in the original issue the alternative to just fail gracefully is being investigated, specifically #7912 The check proposed in this PR seems rather inefficient and will incur an overhead that would at least have to be benchmarked. Also the LLM "summary" with 100 lines of code is not useful at all. Please summarize better if the LLM can't do it itself. |
Pre-sampling validation for shape mismatches in PyMC
Description
This PR implements pre-sampling validation to catch shape mismatches between model coordinates and variable dimensions before sampling begins, preventing potentially hours of wasted computation time.
Previously, PyMC would happily sample for extended periods only to fail during the final
InferenceDataconversion step when shape inconsistencies existed between modelcoordsand variabledims.This PR adds a validation function that runs a smoke test before sampling starts, evaluating variable shapes and comparing them against declared dimensions to catch errors early.
Key Changes
1. Added
_validate_idata_conversion()function inmcmc.pydimsagainst coordinate lengths2. Integrated validation into
sample()functionreturn_inferencedata=True(the default)3. Comprehensive test suite in
test_issue_7891.pydims)Benefits
✅ Saves computation time: Catches errors immediately instead of after hours of sampling
✅ Better UX: Clear error messages tell users exactly what’s wrong and how to fix it
✅ Minimal overhead: Fast shape evaluation without actual sampling
✅ Backwards compatible: Only runs when creating
InferenceData; doesn’t affect existing code✅ Comprehensive: Works with scalars, multi-dimensional variables, and deterministics
Example
Before this PR, the following model would sample for hours then fail:
After this PR, it fails immediately with a helpful error:
Related Issue
Closes #7891
Checklist
Type of change