From 0876ba07ba247bad177f0ea56badc095e8751d36 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Fri, 2 Jan 2026 11:07:45 -0800 Subject: [PATCH] comparator fix --- codeflash/verification/comparator.py | 13 +++++ tests/test_comparator.py | 75 ++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+) diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 704d19b3c..20f28292d 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -227,6 +227,19 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 if isinstance(orig, np.dtype): return orig == new + # Handle numpy random generators + if isinstance(orig, np.random.Generator): + # Compare the underlying BitGenerator state + orig_state = orig.bit_generator.state + new_state = new.bit_generator.state + return comparator(orig_state, new_state, superset_obj) + + if isinstance(orig, np.random.RandomState): + # Compare the internal state + orig_state = orig.get_state(legacy=False) + new_state = new.get_state(legacy=False) + return comparator(orig_state, new_state, superset_obj) + if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix): if orig.dtype != new.dtype: return False diff --git a/tests/test_comparator.py b/tests/test_comparator.py index aa556db32..9bd81dfce 100644 --- a/tests/test_comparator.py +++ b/tests/test_comparator.py @@ -360,6 +360,81 @@ def test_numpy(): assert not comparator(a_void, c_void) +def test_numpy_random_generator(): + try: + import numpy as np + except ImportError: + pytest.skip() + + # Test numpy.random.Generator (modern API) + # Same seed should produce equal generators + rng1 = np.random.default_rng(seed=42) + rng2 = np.random.default_rng(seed=42) + assert comparator(rng1, rng2) + + # Different seeds should produce non-equal generators + rng3 = np.random.default_rng(seed=123) + assert not comparator(rng1, rng3) + + # After generating numbers, state changes + rng4 = np.random.default_rng(seed=42) + rng5 = np.random.default_rng(seed=42) + rng4.random() # Advance state + assert not comparator(rng4, rng5) + + # Both advanced by same amount should be equal + rng5.random() + assert comparator(rng4, rng5) + + # Test with different bit generators + from numpy.random import PCG64, MT19937 + rng_pcg1 = np.random.Generator(PCG64(seed=42)) + rng_pcg2 = np.random.Generator(PCG64(seed=42)) + assert comparator(rng_pcg1, rng_pcg2) + + rng_mt1 = np.random.Generator(MT19937(seed=42)) + rng_mt2 = np.random.Generator(MT19937(seed=42)) + assert comparator(rng_mt1, rng_mt2) + + # Different bit generator types should not be equal + assert not comparator(rng_pcg1, rng_mt1) + + +def test_numpy_random_state(): + try: + import numpy as np + except ImportError: + pytest.skip() + + # Test numpy.random.RandomState (legacy API) + # Same seed should produce equal states + rs1 = np.random.RandomState(seed=42) + rs2 = np.random.RandomState(seed=42) + assert comparator(rs1, rs2) + + # Different seeds should produce non-equal states + rs3 = np.random.RandomState(seed=123) + assert not comparator(rs1, rs3) + + # After generating numbers, state changes + rs4 = np.random.RandomState(seed=42) + rs5 = np.random.RandomState(seed=42) + rs4.random() # Advance state + assert not comparator(rs4, rs5) + + # Both advanced by same amount should be equal + rs5.random() + assert comparator(rs4, rs5) + + # Test state restoration + rs6 = np.random.RandomState(seed=42) + state = rs6.get_state() + rs6.random() # Advance state + rs7 = np.random.RandomState(seed=42) + rs7.set_state(state) + # rs6 advanced, rs7 restored to original state + assert not comparator(rs6, rs7) + def test_scipy(): try: