Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions codeflash/verification/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,19 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
if isinstance(orig, np.dtype):
return orig == new

# Handle numpy random generators
if isinstance(orig, np.random.Generator):
# Compare the underlying BitGenerator state
orig_state = orig.bit_generator.state
new_state = new.bit_generator.state
return comparator(orig_state, new_state, superset_obj)

if isinstance(orig, np.random.RandomState):
# Compare the internal state
orig_state = orig.get_state(legacy=False)
new_state = new.get_state(legacy=False)
return comparator(orig_state, new_state, superset_obj)

if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix):
if orig.dtype != new.dtype:
return False
Expand Down
75 changes: 75 additions & 0 deletions tests/test_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,81 @@ def test_numpy():
assert not comparator(a_void, c_void)


def test_numpy_random_generator():
try:
import numpy as np
except ImportError:
pytest.skip()

# Test numpy.random.Generator (modern API)
# Same seed should produce equal generators
rng1 = np.random.default_rng(seed=42)
rng2 = np.random.default_rng(seed=42)
assert comparator(rng1, rng2)

# Different seeds should produce non-equal generators
rng3 = np.random.default_rng(seed=123)
assert not comparator(rng1, rng3)

# After generating numbers, state changes
rng4 = np.random.default_rng(seed=42)
rng5 = np.random.default_rng(seed=42)
rng4.random() # Advance state
assert not comparator(rng4, rng5)

# Both advanced by same amount should be equal
rng5.random()
assert comparator(rng4, rng5)

# Test with different bit generators
from numpy.random import PCG64, MT19937
rng_pcg1 = np.random.Generator(PCG64(seed=42))
rng_pcg2 = np.random.Generator(PCG64(seed=42))
assert comparator(rng_pcg1, rng_pcg2)

rng_mt1 = np.random.Generator(MT19937(seed=42))
rng_mt2 = np.random.Generator(MT19937(seed=42))
assert comparator(rng_mt1, rng_mt2)

# Different bit generator types should not be equal
assert not comparator(rng_pcg1, rng_mt1)


def test_numpy_random_state():
try:
import numpy as np
except ImportError:
pytest.skip()

# Test numpy.random.RandomState (legacy API)
# Same seed should produce equal states
rs1 = np.random.RandomState(seed=42)
rs2 = np.random.RandomState(seed=42)
assert comparator(rs1, rs2)

# Different seeds should produce non-equal states
rs3 = np.random.RandomState(seed=123)
assert not comparator(rs1, rs3)

# After generating numbers, state changes
rs4 = np.random.RandomState(seed=42)
rs5 = np.random.RandomState(seed=42)
rs4.random() # Advance state
assert not comparator(rs4, rs5)

# Both advanced by same amount should be equal
rs5.random()
assert comparator(rs4, rs5)

# Test state restoration
rs6 = np.random.RandomState(seed=42)
state = rs6.get_state()
rs6.random() # Advance state
rs7 = np.random.RandomState(seed=42)
rs7.set_state(state)
# rs6 advanced, rs7 restored to original state
assert not comparator(rs6, rs7)


def test_scipy():
try:
Expand Down
Loading