From 2e0c5788caccfa897ced8bc58cbfc0fc90fa6d38 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Tue, 30 Dec 2025 09:44:29 -0800 Subject: [PATCH 1/8] need to test now --- .../code_utils/instrument_existing_tests.py | 315 +++++++++++++++++- 1 file changed, 313 insertions(+), 2 deletions(-) diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 6184782b3..43d195ece 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -671,6 +671,8 @@ def inject_profiling_into_existing_test( function_to_optimize: FunctionToOptimize, tests_project_root: Path, mode: TestingMode = TestingMode.BEHAVIOR, + *, + jit_warmup: bool = False, ) -> tuple[bool, str | None]: if function_to_optimize.is_async: return inject_async_profiling_into_existing_test( @@ -704,13 +706,277 @@ def inject_profiling_into_existing_test( ast.Import(names=[ast.alias(name="dill", asname="pickle")]), ] ) - additional_functions = [create_wrapper_function(mode)] + additional_functions = [create_wrapper_function(mode, jit_warmup=jit_warmup)] + if jit_warmup: + additional_functions.insert(0, create_jit_sync_helper()) tree.body = [*new_imports, *additional_functions, *tree.body] return True, sort_imports(ast.unparse(tree), float_to_top=True) -def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.FunctionDef: +def create_jit_sync_helper() -> ast.FunctionDef: + """Create a helper function that synchronizes JIT-compiled frameworks (PyTorch, TensorFlow, JAX, MLX). + + This function generates AST for: + def _codeflash_jit_sync(): + try: + import torch + if torch.cuda.is_available(): + torch.cuda.synchronize() + if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + torch.mps.synchronize() + except ImportError: + pass + try: + import jax + # Block until all JAX computations are complete + jax.effects_barrier() + except ImportError: + pass + try: + import mlx.core as mx + mx.synchronize() + except ImportError: + pass + # Note: TensorFlow in eager mode auto-syncs; Numba JIT is CPU-based and doesn't need sync + """ + lineno = 1 + + # PyTorch sync block + pytorch_sync = ast.Try( + body=[ + ast.Import(names=[ast.alias(name="torch")], lineno=lineno), + # if torch.cuda.is_available(): torch.cuda.synchronize() + ast.If( + test=ast.Call( + func=ast.Attribute( + value=ast.Attribute(value=ast.Name(id="torch", ctx=ast.Load()), attr="cuda", ctx=ast.Load()), + attr="is_available", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id="torch", ctx=ast.Load()), attr="cuda", ctx=ast.Load() + ), + attr="synchronize", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) + ) + ], + orelse=[], + lineno=lineno, + ), + # if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): torch.mps.synchronize() + ast.If( + test=ast.BoolOp( + op=ast.And(), + values=[ + ast.Call( + func=ast.Name(id="hasattr", ctx=ast.Load()), + args=[ + ast.Attribute( + value=ast.Name(id="torch", ctx=ast.Load()), attr="backends", ctx=ast.Load() + ), + ast.Constant(value="mps"), + ], + keywords=[], + ), + ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id="torch", ctx=ast.Load()), attr="backends", ctx=ast.Load() + ), + attr="mps", + ctx=ast.Load(), + ), + attr="is_available", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ), + ], + ), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id="torch", ctx=ast.Load()), attr="mps", ctx=ast.Load() + ), + attr="synchronize", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) + ) + ], + orelse=[], + lineno=lineno, + ), + ], + handlers=[ + ast.ExceptHandler( + type=ast.Name(id="ImportError", ctx=ast.Load()), + name=None, + body=[ast.Pass(lineno=lineno)], + lineno=lineno, + ) + ], + orelse=[], + finalbody=[], + lineno=lineno, + ) + + # JAX sync block - use effects_barrier() to wait for all computations + jax_sync = ast.Try( + body=[ + ast.Import(names=[ast.alias(name="jax")], lineno=lineno), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="jax", ctx=ast.Load()), attr="effects_barrier", ctx=ast.Load() + ), + args=[], + keywords=[], + ) + ), + ], + handlers=[ + ast.ExceptHandler( + type=ast.Name(id="ImportError", ctx=ast.Load()), + name=None, + body=[ast.Pass(lineno=lineno)], + lineno=lineno, + ) + ], + orelse=[], + finalbody=[], + lineno=lineno, + ) + + # MLX sync block + mlx_sync = ast.Try( + body=[ + ast.Import(names=[ast.alias(name="mlx.core", asname="mx")], lineno=lineno), + ast.Expr( + value=ast.Call( + func=ast.Attribute(value=ast.Name(id="mx", ctx=ast.Load()), attr="synchronize", ctx=ast.Load()), + args=[], + keywords=[], + ) + ), + ], + handlers=[ + ast.ExceptHandler( + type=ast.Name(id="ImportError", ctx=ast.Load()), + name=None, + body=[ast.Pass(lineno=lineno)], + lineno=lineno, + ) + ], + orelse=[], + finalbody=[], + lineno=lineno, + ) + + # TensorFlow sync block - sync XLA/TPU devices + tensorflow_sync = ast.Try( + body=[ + ast.Import(names=[ast.alias(name="tensorflow", asname="tf")], lineno=lineno), + # For TPU: tf.tpu.experimental.initialize_tpu_system if available + # For GPU: operations complete synchronously in eager mode but we can force sync + ast.If( + test=ast.Call( + func=ast.Name(id="hasattr", ctx=ast.Load()), + args=[ + ast.Attribute(value=ast.Name(id="tf", ctx=ast.Load()), attr="config", ctx=ast.Load()), + ast.Constant(value="experimental"), + ], + keywords=[], + ), + body=[ + # Get all physical devices and sync GPUs + ast.For( + target=ast.Name(id="_device", ctx=ast.Store()), + iter=ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id="tf", ctx=ast.Load()), attr="config", ctx=ast.Load() + ), + attr="list_physical_devices", + ctx=ast.Load(), + ), + args=[ast.Constant(value="GPU")], + keywords=[], + ), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Attribute( + value=ast.Attribute( + value=ast.Name(id="tf", ctx=ast.Load()), attr="test", ctx=ast.Load() + ), + attr="experimental", + ctx=ast.Load(), + ), + attr="sync_devices", + ctx=ast.Load(), + ), + args=[], + keywords=[], + ) + ) + ], + orelse=[], + lineno=lineno, + ) + ], + orelse=[], + lineno=lineno, + ), + ], + handlers=[ + ast.ExceptHandler( + type=ast.Tuple( + elts=[ast.Name(id="ImportError", ctx=ast.Load()), ast.Name(id="AttributeError", ctx=ast.Load())], + ctx=ast.Load(), + ), + name=None, + body=[ast.Pass(lineno=lineno)], + lineno=lineno, + ) + ], + orelse=[], + finalbody=[], + lineno=lineno, + ) + + return ast.FunctionDef( + name="_codeflash_jit_sync", + args=ast.arguments( + args=[], vararg=None, kwarg=None, posonlyargs=[], kwonlyargs=[], kw_defaults=[], defaults=[] + ), + body=[pytorch_sync, jax_sync, mlx_sync, tensorflow_sync], + decorator_list=[], + returns=None, + lineno=lineno, + ) + + +def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR, *, jit_warmup: bool = False) -> ast.FunctionDef: lineno = 1 wrapper_body: list[ast.stmt] = [ ast.Assign( @@ -871,6 +1137,25 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun ast.Assign( targets=[ast.Name(id="exception", ctx=ast.Store())], value=ast.Constant(value=None), lineno=lineno + 10 ), + # JIT warmup: call function once to trigger JIT compilation before timing + *( + [ + ast.Expr( + value=ast.Call( + func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()), + args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())], + keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))], + ), + lineno=lineno + 10, + ), + ast.Expr( + value=ast.Call(func=ast.Name(id="_codeflash_jit_sync", ctx=ast.Load()), args=[], keywords=[]), + lineno=lineno + 10, + ), + ] + if jit_warmup + else [] + ), ast.Expr( value=ast.Call( func=ast.Attribute(value=ast.Name(id="gc", ctx=ast.Load()), attr="disable", ctx=ast.Load()), @@ -881,6 +1166,19 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun ), ast.Try( body=[ + # Sync before starting timer (ensure previous operations are complete) + *( + [ + ast.Expr( + value=ast.Call( + func=ast.Name(id="_codeflash_jit_sync", ctx=ast.Load()), args=[], keywords=[] + ), + lineno=lineno + 11, + ) + ] + if jit_warmup + else [] + ), ast.Assign( targets=[ast.Name(id="counter", ctx=ast.Store())], value=ast.Call( @@ -901,6 +1199,19 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun ), lineno=lineno + 12, ), + # Sync after function call to ensure all GPU/async operations complete before stopping timer + *( + [ + ast.Expr( + value=ast.Call( + func=ast.Name(id="_codeflash_jit_sync", ctx=ast.Load()), args=[], keywords=[] + ), + lineno=lineno + 12, + ) + ] + if jit_warmup + else [] + ), ast.Assign( targets=[ast.Name(id="codeflash_duration", ctx=ast.Store())], value=ast.BinOp( From fed05952bb1ce06de5b480bbebc0d849a74b79ef Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Tue, 30 Dec 2025 16:20:26 -0800 Subject: [PATCH 2/8] wip --- .../code_utils/instrument_existing_tests.py | 19 ------------------- codeflash/optimization/function_optimizer.py | 2 ++ tests/test_instrument_tests.py | 2 +- 3 files changed, 3 insertions(+), 20 deletions(-) diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 43d195ece..d6c5f913b 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -1137,25 +1137,6 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR, *, jit_war ast.Assign( targets=[ast.Name(id="exception", ctx=ast.Store())], value=ast.Constant(value=None), lineno=lineno + 10 ), - # JIT warmup: call function once to trigger JIT compilation before timing - *( - [ - ast.Expr( - value=ast.Call( - func=ast.Name(id="codeflash_wrapped", ctx=ast.Load()), - args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())], - keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))], - ), - lineno=lineno + 10, - ), - ast.Expr( - value=ast.Call(func=ast.Name(id="_codeflash_jit_sync", ctx=ast.Load()), args=[], keywords=[]), - lineno=lineno + 10, - ), - ] - if jit_warmup - else [] - ), ast.Expr( value=ast.Call( func=ast.Attribute(value=ast.Name(id="gc", ctx=ast.Load()), attr="disable", ctx=ast.Load()), diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 416bdc8df..5cab36eda 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1230,6 +1230,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio call_positions=[test.position for test in tests_in_file_list], function_to_optimize=self.function_to_optimize, tests_project_root=self.test_cfg.tests_project_rootdir, + jit_warmup=True, ) if not success: continue @@ -1239,6 +1240,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio call_positions=[test.position for test in tests_in_file_list], function_to_optimize=self.function_to_optimize, tests_project_root=self.test_cfg.tests_project_rootdir, + jit_warmup=True, ) if not success: continue diff --git a/tests/test_instrument_tests.py b/tests/test_instrument_tests.py index a74f41533..d6439550d 100644 --- a/tests/test_instrument_tests.py +++ b/tests/test_instrument_tests.py @@ -193,7 +193,7 @@ def test_sort(self): Path(f.name), [CodePosition(9, 17), CodePosition(13, 17), CodePosition(17, 17)], func, - Path(f.name).parent, + Path(f.name).parent, jit_warmup=True ) os.chdir(original_cwd) assert success From e38df6282a6394f9d8a556b147c1ec30d75b5036 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Tue, 30 Dec 2025 16:25:31 -0800 Subject: [PATCH 3/8] wip --- codeflash/verification/comparator.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 7737900df..d95c45012 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -24,6 +24,7 @@ HAS_JAX = find_spec("jax") is not None HAS_XARRAY = find_spec("xarray") is not None HAS_TENSORFLOW = find_spec("tensorflow") is not None +HAS_MLX = find_spec("mlx") is not None def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911 @@ -138,6 +139,17 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 return False return comparator(orig.to_list(), new.to_list(), superset_obj) + if HAS_MLX: + import mlx.core as mx # type: ignore # noqa: PGH003 + + if isinstance(orig, mx.array): + if orig.dtype != new.dtype: + return False + if orig.shape != new.shape: + return False + # MLX allclose handles NaN comparison via equal_nan parameter + return bool(mx.allclose(orig, new, equal_nan=True)) + if HAS_SQLALCHEMY: import sqlalchemy # type: ignore # noqa: PGH003 From 17042a2040f0c65cbd3b59967b9dd35effd6e72e Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Tue, 30 Dec 2025 16:30:32 -0800 Subject: [PATCH 4/8] wip --- code_to_optimize/discrete_riccati.py | 170 +++++++++++ .../tests/pytest/test_gridmake2.py | 216 ++++++++++++++ .../tests/pytest/test_gridmake2_torch.py | 267 ++++++++++++++++++ codeflash/api/aiservice.py | 69 ++++- codeflash/optimization/function_optimizer.py | 20 ++ 5 files changed, 741 insertions(+), 1 deletion(-) create mode 100644 code_to_optimize/discrete_riccati.py create mode 100644 code_to_optimize/tests/pytest/test_gridmake2.py create mode 100644 code_to_optimize/tests/pytest/test_gridmake2_torch.py diff --git a/code_to_optimize/discrete_riccati.py b/code_to_optimize/discrete_riccati.py new file mode 100644 index 000000000..53fe30891 --- /dev/null +++ b/code_to_optimize/discrete_riccati.py @@ -0,0 +1,170 @@ +""" +Utility functions used in CompEcon + +Based routines found in the CompEcon toolbox by Miranda and Fackler. + +References +---------- +Miranda, Mario J, and Paul L Fackler. Applied Computational Economics +and Finance, MIT Press, 2002. + +""" +from functools import reduce +import numpy as np +import torch + + +def ckron(*arrays): + """ + Repeatedly applies the np.kron function to an arbitrary number of + input arrays + + Parameters + ---------- + *arrays : tuple/list of np.ndarray + + Returns + ------- + out : np.ndarray + The result of repeated kronecker products. + + Notes + ----- + Based of original function `ckron` in CompEcon toolbox by Miranda + and Fackler. + + References + ---------- + Miranda, Mario J, and Paul L Fackler. Applied Computational + Economics and Finance, MIT Press, 2002. + + """ + return reduce(np.kron, arrays) + + +def gridmake(*arrays): + """ + Expands one or more vectors (or matrices) into a matrix where rows span the + cartesian product of combinations of the input arrays. Each column of the + input arrays will correspond to one column of the output matrix. + + Parameters + ---------- + *arrays : tuple/list of np.ndarray + Tuple/list of vectors to be expanded. + + Returns + ------- + out : np.ndarray + The cartesian product of combinations of the input arrays. + + Notes + ----- + Based of original function ``gridmake`` in CompEcon toolbox by + Miranda and Fackler + + References + ---------- + Miranda, Mario J, and Paul L Fackler. Applied Computational Economics + and Finance, MIT Press, 2002. + + """ + if all([i.ndim == 1 for i in arrays]): + d = len(arrays) + if d == 2: + out = _gridmake2(*arrays) + else: + out = _gridmake2(arrays[0], arrays[1]) + for arr in arrays[2:]: + out = _gridmake2(out, arr) + + return out + else: + raise NotImplementedError("Come back here") + + +def _gridmake2(x1, x2): + """ + Expands two vectors (or matrices) into a matrix where rows span the + cartesian product of combinations of the input arrays. Each column of the + input arrays will correspond to one column of the output matrix. + + Parameters + ---------- + x1 : np.ndarray + First vector to be expanded. + + x2 : np.ndarray + Second vector to be expanded. + + Returns + ------- + out : np.ndarray + The cartesian product of combinations of the input arrays. + + Notes + ----- + Based of original function ``gridmake2`` in CompEcon toolbox by + Miranda and Fackler. + + References + ---------- + Miranda, Mario J, and Paul L Fackler. Applied Computational Economics + and Finance, MIT Press, 2002. + + """ + if x1.ndim == 1 and x2.ndim == 1: + return np.column_stack([np.tile(x1, x2.shape[0]), + np.repeat(x2, x1.shape[0])]) + elif x1.ndim > 1 and x2.ndim == 1: + first = np.tile(x1, (x2.shape[0], 1)) + second = np.repeat(x2, x1.shape[0]) + return np.column_stack([first, second]) + else: + raise NotImplementedError("Come back here") + + +def _gridmake2_torch(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """ + PyTorch version of _gridmake2. + + Expands two tensors into a matrix where rows span the cartesian product + of combinations of the input tensors. Each column of the input tensors + will correspond to one column of the output matrix. + + Parameters + ---------- + x1 : torch.Tensor + First tensor to be expanded. + + x2 : torch.Tensor + Second tensor to be expanded. + + Returns + ------- + out : torch.Tensor + The cartesian product of combinations of the input tensors. + + Notes + ----- + Based on original function ``gridmake2`` in CompEcon toolbox by + Miranda and Fackler. + + References + ---------- + Miranda, Mario J, and Paul L Fackler. Applied Computational Economics + and Finance, MIT Press, 2002. + + """ + if x1.dim() == 1 and x2.dim() == 1: + # tile x1 by x2.shape[0] times, repeat_interleave x2 by x1.shape[0] + first = x1.tile(x2.shape[0]) + second = x2.repeat_interleave(x1.shape[0]) + return torch.column_stack([first, second]) + elif x1.dim() > 1 and x2.dim() == 1: + # tile x1 along first dimension + first = x1.tile(x2.shape[0], 1) + second = x2.repeat_interleave(x1.shape[0]) + return torch.column_stack([first, second]) + else: + raise NotImplementedError("Come back here") diff --git a/code_to_optimize/tests/pytest/test_gridmake2.py b/code_to_optimize/tests/pytest/test_gridmake2.py new file mode 100644 index 000000000..60d7bfe56 --- /dev/null +++ b/code_to_optimize/tests/pytest/test_gridmake2.py @@ -0,0 +1,216 @@ +import numpy as np +import pytest +from numpy.testing import assert_array_equal + +from code_to_optimize.discrete_riccati import _gridmake2 + + +class TestGridmake2With1DArrays: + """Tests for _gridmake2 with two 1D arrays.""" + + def test_basic_two_element_arrays(self): + """Test basic cartesian product of two 2-element arrays.""" + x1 = np.array([1, 2]) + x2 = np.array([3, 4]) + result = _gridmake2(x1, x2) + + # Expected: x1 is tiled len(x2) times, x2 is repeated len(x1) times + expected = np.array([ + [1, 3], + [2, 3], + [1, 4], + [2, 4] + ]) + assert_array_equal(result, expected) + + def test_different_length_arrays(self): + """Test cartesian product with arrays of different lengths.""" + x1 = np.array([1, 2, 3]) + x2 = np.array([10, 20]) + result = _gridmake2(x1, x2) + + # Result should have len(x1) * len(x2) = 6 rows + expected = np.array([ + [1, 10], + [2, 10], + [3, 10], + [1, 20], + [2, 20], + [3, 20] + ]) + assert_array_equal(result, expected) + assert result.shape == (6, 2) + + def test_single_element_arrays(self): + """Test with single-element arrays.""" + x1 = np.array([5]) + x2 = np.array([7]) + result = _gridmake2(x1, x2) + + expected = np.array([[5, 7]]) + assert_array_equal(result, expected) + assert result.shape == (1, 2) + + def test_single_element_with_multi_element(self): + """Test single-element array with multi-element array.""" + x1 = np.array([1]) + x2 = np.array([10, 20, 30]) + result = _gridmake2(x1, x2) + + expected = np.array([ + [1, 10], + [1, 20], + [1, 30] + ]) + assert_array_equal(result, expected) + + def test_float_arrays(self): + """Test with float arrays.""" + x1 = np.array([1.5, 2.5]) + x2 = np.array([0.1, 0.2]) + result = _gridmake2(x1, x2) + + expected = np.array([ + [1.5, 0.1], + [2.5, 0.1], + [1.5, 0.2], + [2.5, 0.2] + ]) + assert_array_equal(result, expected) + + def test_negative_values(self): + """Test with negative values.""" + x1 = np.array([-1, 0, 1]) + x2 = np.array([-10, 10]) + result = _gridmake2(x1, x2) + + expected = np.array([ + [-1, -10], + [0, -10], + [1, -10], + [-1, 10], + [0, 10], + [1, 10] + ]) + assert_array_equal(result, expected) + + def test_result_shape(self): + """Test that result shape is (len(x1)*len(x2), 2).""" + x1 = np.array([1, 2, 3, 4]) + x2 = np.array([5, 6, 7]) + result = _gridmake2(x1, x2) + + assert result.shape == (12, 2) + + def test_larger_arrays(self): + """Test with larger arrays.""" + x1 = np.arange(10) + x2 = np.arange(5) + result = _gridmake2(x1, x2) + + assert result.shape == (50, 2) + # Verify first column is x1 tiled 5 times + assert_array_equal(result[:10, 0], x1) + assert_array_equal(result[10:20, 0], x1) + # Verify second column is x2 repeated 10 times each + assert all(result[:10, 1] == 0) + assert all(result[10:20, 1] == 1) + + +class TestGridmake2With2DFirst: + """Tests for _gridmake2 when x1 is 2D and x2 is 1D.""" + + def test_2d_first_1d_second(self): + """Test with 2D first array and 1D second array.""" + x1 = np.array([[1, 2], [3, 4]]) # 2 rows, 2 cols + x2 = np.array([10, 20]) + result = _gridmake2(x1, x2) + + # x1 is tiled len(x2) times vertically + # x2 is repeated len(x1) times (2 rows) + expected = np.array([ + [1, 2, 10], + [3, 4, 10], + [1, 2, 20], + [3, 4, 20] + ]) + assert_array_equal(result, expected) + + def test_2d_single_column(self): + """Test with 2D array having single column.""" + x1 = np.array([[1], [2], [3]]) # 3 rows, 1 col + x2 = np.array([10, 20]) + result = _gridmake2(x1, x2) + + expected = np.array([ + [1, 10], + [2, 10], + [3, 10], + [1, 20], + [2, 20], + [3, 20] + ]) + assert_array_equal(result, expected) + + def test_2d_multiple_columns(self): + """Test with 2D array having multiple columns.""" + x1 = np.array([[1, 2, 3], [4, 5, 6]]) # 2 rows, 3 cols + x2 = np.array([100]) + result = _gridmake2(x1, x2) + + expected = np.array([ + [1, 2, 3, 100], + [4, 5, 6, 100] + ]) + assert_array_equal(result, expected) + + +class TestGridmake2EdgeCases: + """Edge case tests for _gridmake2.""" + + def test_empty_arrays_raise_or_return_empty(self): + """Test behavior with empty arrays.""" + x1 = np.array([]) + x2 = np.array([1, 2]) + result = _gridmake2(x1, x2) + # Empty x1 should result in empty output + assert result.shape[0] == 0 + + def test_both_empty_arrays(self): + """Test with both empty arrays.""" + x1 = np.array([]) + x2 = np.array([]) + result = _gridmake2(x1, x2) + assert result.shape[0] == 0 + + def test_integer_dtype_preserved(self): + """Test that integer dtype is handled correctly.""" + x1 = np.array([1, 2], dtype=np.int64) + x2 = np.array([3, 4], dtype=np.int64) + result = _gridmake2(x1, x2) + assert result.dtype == np.int64 + + def test_float_dtype_preserved(self): + """Test that float dtype is handled correctly.""" + x1 = np.array([1.0, 2.0], dtype=np.float64) + x2 = np.array([3.0, 4.0], dtype=np.float64) + result = _gridmake2(x1, x2) + assert result.dtype == np.float64 + + +class TestGridmake2NotImplemented: + """Tests for NotImplementedError cases.""" + + def test_both_2d_raises(self): + """Test that two 2D arrays raises NotImplementedError.""" + x1 = np.array([[1, 2], [3, 4]]) + x2 = np.array([[5, 6], [7, 8]]) + with pytest.raises(NotImplementedError): + _gridmake2(x1, x2) + + def test_1d_first_2d_second_raises(self): + """Test that 1D first and 2D second raises NotImplementedError.""" + x1 = np.array([1, 2]) + x2 = np.array([[5, 6], [7, 8]]) + with pytest.raises(NotImplementedError): + _gridmake2(x1, x2) \ No newline at end of file diff --git a/code_to_optimize/tests/pytest/test_gridmake2_torch.py b/code_to_optimize/tests/pytest/test_gridmake2_torch.py new file mode 100644 index 000000000..f2ee737a2 --- /dev/null +++ b/code_to_optimize/tests/pytest/test_gridmake2_torch.py @@ -0,0 +1,267 @@ +import pytest +import torch + +from code_to_optimize.discrete_riccati import _gridmake2_torch + + +class TestGridmake2TorchCPU: + """Tests for _gridmake2_torch with CPU tensors.""" + + def test_both_1d_simple(self): + """Test with two simple 1D tensors.""" + x1 = torch.tensor([1, 2, 3]) + x2 = torch.tensor([10, 20]) + + result = _gridmake2_torch(x1, x2) + + # Expected: x1 tiled x2.shape[0] times, x2 repeat_interleaved x1.shape[0] + # x1 tiled: [1, 2, 3, 1, 2, 3] + # x2 repeated: [10, 10, 10, 20, 20, 20] + expected = torch.tensor([ + [1, 10], + [2, 10], + [3, 10], + [1, 20], + [2, 20], + [3, 20], + ]) + assert torch.equal(result, expected) + + def test_both_1d_single_element(self): + """Test with single element tensors.""" + x1 = torch.tensor([5]) + x2 = torch.tensor([10]) + + result = _gridmake2_torch(x1, x2) + + expected = torch.tensor([[5, 10]]) + assert torch.equal(result, expected) + + def test_both_1d_float_tensors(self): + """Test with float tensors.""" + x1 = torch.tensor([1.5, 2.5]) + x2 = torch.tensor([0.1, 0.2, 0.3]) + + result = _gridmake2_torch(x1, x2) + + assert result.shape == (6, 2) + assert result.dtype == torch.float32 + + def test_2d_and_1d_simple(self): + """Test with 2D x1 and 1D x2.""" + x1 = torch.tensor([[1, 2], [3, 4]]) + x2 = torch.tensor([10, 20]) + + result = _gridmake2_torch(x1, x2) + + # x1 tiled along first dim: [[1, 2], [3, 4], [1, 2], [3, 4]] + # x2 repeated: [10, 10, 20, 20] + # column_stack: [[1, 2, 10], [3, 4, 10], [1, 2, 20], [3, 4, 20]] + expected = torch.tensor([ + [1, 2, 10], + [3, 4, 10], + [1, 2, 20], + [3, 4, 20], + ]) + assert torch.equal(result, expected) + + def test_2d_and_1d_single_column(self): + """Test with 2D x1 having a single column and 1D x2.""" + x1 = torch.tensor([[1], [2], [3]]) + x2 = torch.tensor([10, 20]) + + result = _gridmake2_torch(x1, x2) + + expected = torch.tensor([ + [1, 10], + [2, 10], + [3, 10], + [1, 20], + [2, 20], + [3, 20], + ]) + assert torch.equal(result, expected) + + def test_output_shape_1d_1d(self): + """Test output shape for two 1D tensors.""" + x1 = torch.tensor([1, 2, 3, 4, 5]) + x2 = torch.tensor([10, 20, 30]) + + result = _gridmake2_torch(x1, x2) + + # Shape should be (len(x1) * len(x2), 2) + assert result.shape == (15, 2) + + def test_output_shape_2d_1d(self): + """Test output shape for 2D and 1D tensors.""" + x1 = torch.tensor([[1, 2, 3], [4, 5, 6]]) # Shape (2, 3) + x2 = torch.tensor([10, 20, 30, 40]) # Shape (4,) + + result = _gridmake2_torch(x1, x2) + + # Shape should be (2 * 4, 3 + 1) = (8, 4) + assert result.shape == (8, 4) + + def test_not_implemented_for_2d_2d(self): + """Test that NotImplementedError is raised for two 2D tensors.""" + x1 = torch.tensor([[1, 2], [3, 4]]) + x2 = torch.tensor([[10, 20], [30, 40]]) + + with pytest.raises(NotImplementedError, match="Come back here"): + _gridmake2_torch(x1, x2) + + def test_not_implemented_for_1d_2d(self): + """Test that NotImplementedError is raised for 1D and 2D tensors.""" + x1 = torch.tensor([1, 2, 3]) + x2 = torch.tensor([[10, 20], [30, 40]]) + + with pytest.raises(NotImplementedError, match="Come back here"): + _gridmake2_torch(x1, x2) + + def test_preserves_dtype_int(self): + """Test that integer dtype is preserved.""" + x1 = torch.tensor([1, 2, 3], dtype=torch.int32) + x2 = torch.tensor([10, 20], dtype=torch.int32) + + result = _gridmake2_torch(x1, x2) + + assert result.dtype == torch.int32 + + def test_preserves_dtype_float64(self): + """Test that float64 dtype is preserved.""" + x1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64) + x2 = torch.tensor([10.0, 20.0], dtype=torch.float64) + + result = _gridmake2_torch(x1, x2) + + assert result.dtype == torch.float64 + + def test_large_tensors(self): + """Test with larger tensors.""" + x1 = torch.arange(100) + x2 = torch.arange(50) + + result = _gridmake2_torch(x1, x2) + + assert result.shape == (5000, 2) + # Verify first and last elements + assert result[0, 0] == 0 and result[0, 1] == 0 + assert result[-1, 0] == 99 and result[-1, 1] == 49 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestGridmake2TorchCUDA: + """Tests for _gridmake2_torch with CUDA tensors.""" + + def test_both_1d_simple_cuda(self): + """Test with two simple 1D CUDA tensors.""" + x1 = torch.tensor([1, 2, 3], device="cuda") + x2 = torch.tensor([10, 20], device="cuda") + + result = _gridmake2_torch(x1, x2) + + expected = torch.tensor([ + [1, 10], + [2, 10], + [3, 10], + [1, 20], + [2, 20], + [3, 20], + ], device="cuda") + assert result.device.type == "cuda" + assert torch.equal(result, expected) + + def test_both_1d_matches_cpu(self): + """Test that CUDA version matches CPU version.""" + x1_cpu = torch.tensor([1.0, 2.0, 3.0, 4.0]) + x2_cpu = torch.tensor([10.0, 20.0, 30.0]) + + x1_cuda = x1_cpu.cuda() + x2_cuda = x2_cpu.cuda() + + result_cpu = _gridmake2_torch(x1_cpu, x2_cpu) + result_cuda = _gridmake2_torch(x1_cuda, x2_cuda) + + assert result_cuda.device.type == "cuda" + torch.testing.assert_close(result_cpu, result_cuda.cpu()) + + def test_2d_and_1d_cuda(self): + """Test with 2D x1 and 1D x2 on CUDA.""" + x1 = torch.tensor([[1, 2], [3, 4]], device="cuda") + x2 = torch.tensor([10, 20], device="cuda") + + result = _gridmake2_torch(x1, x2) + + expected = torch.tensor([ + [1, 2, 10], + [3, 4, 10], + [1, 2, 20], + [3, 4, 20], + ], device="cuda") + assert result.device.type == "cuda" + assert torch.equal(result, expected) + + def test_2d_and_1d_matches_cpu(self): + """Test that CUDA version matches CPU version for 2D, 1D inputs.""" + x1_cpu = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + x2_cpu = torch.tensor([10.0, 20.0]) + + x1_cuda = x1_cpu.cuda() + x2_cuda = x2_cpu.cuda() + + result_cpu = _gridmake2_torch(x1_cpu, x2_cpu) + result_cuda = _gridmake2_torch(x1_cuda, x2_cuda) + + assert result_cuda.device.type == "cuda" + torch.testing.assert_close(result_cpu, result_cuda.cpu()) + + def test_output_stays_on_cuda(self): + """Test that output tensor stays on CUDA device.""" + x1 = torch.tensor([1, 2, 3], device="cuda") + x2 = torch.tensor([10, 20], device="cuda") + + result = _gridmake2_torch(x1, x2) + + assert result.is_cuda + + def test_preserves_dtype_float32_cuda(self): + """Test that float32 dtype is preserved on CUDA.""" + x1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32, device="cuda") + x2 = torch.tensor([10.0, 20.0], dtype=torch.float32, device="cuda") + + result = _gridmake2_torch(x1, x2) + + assert result.dtype == torch.float32 + assert result.device.type == "cuda" + + def test_preserves_dtype_float64_cuda(self): + """Test that float64 dtype is preserved on CUDA.""" + x1 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float64, device="cuda") + x2 = torch.tensor([10.0, 20.0], dtype=torch.float64, device="cuda") + + result = _gridmake2_torch(x1, x2) + + assert result.dtype == torch.float64 + assert result.device.type == "cuda" + + def test_large_tensors_cuda(self): + """Test with larger tensors on CUDA.""" + x1 = torch.arange(100, device="cuda") + x2 = torch.arange(50, device="cuda") + + result = _gridmake2_torch(x1, x2) + + assert result.shape == (5000, 2) + assert result.device.type == "cuda" + # Verify first and last elements + assert result[0, 0].item() == 0 and result[0, 1].item() == 0 + assert result[-1, 0].item() == 99 and result[-1, 1].item() == 49 + + def test_not_implemented_for_2d_2d_cuda(self): + """Test that NotImplementedError is raised for two 2D CUDA tensors.""" + x1 = torch.tensor([[1, 2], [3, 4]], device="cuda") + x2 = torch.tensor([[10, 20], [30, 40]], device="cuda") + + with pytest.raises(NotImplementedError, match="Come back here"): + _gridmake2_torch(x1, x2) + diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 2eedb9fae..9718f1756 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -46,7 +46,7 @@ def get_aiservice_base_url(self) -> str: logger.info("Using local AI Service at http://localhost:8000") console.rule() return "http://localhost:8000" - return "https://app.codeflash.ai" + return "http://localhost:8000" def make_ai_service_request( self, @@ -177,6 +177,73 @@ def optimize_python_code( # noqa: D417 console.rule() return [] + def get_jit_rewritten_code( # noqa: D417 + self, + source_code: str, + dependency_code: str, + trace_id: str, + num_candidates: int = 1, + experiment_metadata: ExperimentMetadata | None = None, + *, + is_async: bool = False, + ) -> list[OptimizedCandidate]: + """Optimize the given python code for performance by making a request to the Django endpoint. + + Parameters + ---------- + - source_code (str): The python code to optimize. + - dependency_code (str): The dependency code used as read-only context for the optimization + - trace_id (str): Trace id of optimization run + - num_candidates (int): Number of optimization variants to generate. Default is 10. + - experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization + + Returns + ------- + - List[OptimizationCandidate]: A list of Optimization Candidates. + + """ + start_time = time.perf_counter() + git_repo_owner, git_repo_name = safe_get_repo_owner_and_name() + + payload = { + "source_code": source_code, + "dependency_code": dependency_code, + "num_variants": num_candidates, + "trace_id": trace_id, + "python_version": platform.python_version(), + "experiment_metadata": experiment_metadata, + "codeflash_version": codeflash_version, + "current_username": get_last_commit_author_if_pr_exists(None), + "repo_owner": git_repo_owner, + "repo_name": git_repo_name, + "n_candidates": N_CANDIDATES_EFFECTIVE, + "is_async": is_async, + } + + logger.info("!lsp|Generating optimized candidates…") + console.rule() + try: + response = self.make_ai_service_request("/rewrite_jit", payload=payload, timeout=60) + except requests.exceptions.RequestException as e: + logger.exception(f"Error generating optimized candidates: {e}") + ph("cli-optimize-error-caught", {"error": str(e)}) + return [] + + if response.status_code == 200: + optimizations_json = response.json()["optimizations"] + console.rule() + end_time = time.perf_counter() + logger.debug(f"!lsp|Generating possible optimizations took {end_time - start_time:.2f} seconds.") + return self._get_valid_candidates(optimizations_json, OptimizedCandidateSource.OPTIMIZE) + try: + error = response.json()["error"] + except Exception: + error = response.text + logger.error(f"Error generating optimized candidates: {response.status_code} - {error}") + ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) + console.rule() + return [] + def optimize_python_code_line_profiler( # noqa: D417 self, source_code: str, diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 5cab36eda..2bdc0b05c 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -453,6 +453,26 @@ def optimize_function(self) -> Result[BestOptimization, str]: revert_to_print=bool(get_pr_number()), ): console.rule() + # get new opt candidate + + jit_compiled_opt_candidate = self.aiservice_client.get_jit_rewritten_code( + code_context.read_writable_code.markdown, code_context.read_only_context_code, self.function_trace_id + ) + # write files + # Try to replace function with optimized code + self.replace_function_and_helpers_with_optimized_code( + code_context=code_context, + optimized_code=jit_compiled_opt_candidate[0].source_code, + original_helper_code=original_helper_code, + ) + # get codecontext + new_code_context = self.get_code_optimization_context().unwrap() + # unwrite files + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + # Generate tests and optimizations in parallel + future_tests = self.executor.submit(self.generate_and_instrument_tests, new_code_context) # Generate tests and optimizations in parallel future_tests = self.executor.submit(self.generate_and_instrument_tests, code_context) future_optimizations = self.executor.submit( From f945fefaecadfb9511b5a497d8016c105289e589 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Tue, 30 Dec 2025 17:28:19 -0800 Subject: [PATCH 5/8] bugfix --- codeflash/optimization/function_optimizer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 2bdc0b05c..037460624 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -473,8 +473,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: ) # Generate tests and optimizations in parallel future_tests = self.executor.submit(self.generate_and_instrument_tests, new_code_context) - # Generate tests and optimizations in parallel - future_tests = self.executor.submit(self.generate_and_instrument_tests, code_context) future_optimizations = self.executor.submit( self.generate_optimizations, read_writable_code=code_context.read_writable_code, From d555e8bd164ab2ecc5a4fdec6f0c1263aac4f9aa Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Tue, 30 Dec 2025 18:42:46 -0800 Subject: [PATCH 6/8] mlx is problematic --- .../code_utils/instrument_existing_tests.py | 51 ++++++++++--------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index d6c5f913b..e259edb8e 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -866,30 +866,30 @@ def _codeflash_jit_sync(): lineno=lineno, ) - # MLX sync block - mlx_sync = ast.Try( - body=[ - ast.Import(names=[ast.alias(name="mlx.core", asname="mx")], lineno=lineno), - ast.Expr( - value=ast.Call( - func=ast.Attribute(value=ast.Name(id="mx", ctx=ast.Load()), attr="synchronize", ctx=ast.Load()), - args=[], - keywords=[], - ) - ), - ], - handlers=[ - ast.ExceptHandler( - type=ast.Name(id="ImportError", ctx=ast.Load()), - name=None, - body=[ast.Pass(lineno=lineno)], - lineno=lineno, - ) - ], - orelse=[], - finalbody=[], - lineno=lineno, - ) + # # MLX sync block + # mlx_sync = ast.Try( + # body=[ + # ast.Import(names=[ast.alias(name="mlx.core", asname="mx")], lineno=lineno), + # ast.Expr( + # value=ast.Call( + # func=ast.Attribute(value=ast.Name(id="mx", ctx=ast.Load()), attr="synchronize", ctx=ast.Load()), + # args=[], + # keywords=[], + # ) + # ), + # ], + # handlers=[ + # ast.ExceptHandler( + # type=ast.Name(id="ImportError", ctx=ast.Load()), + # name=None, + # body=[ast.Pass(lineno=lineno)], + # lineno=lineno, + # ) + # ], + # orelse=[], + # finalbody=[], + # lineno=lineno, + # ) # TensorFlow sync block - sync XLA/TPU devices tensorflow_sync = ast.Try( @@ -969,7 +969,8 @@ def _codeflash_jit_sync(): args=ast.arguments( args=[], vararg=None, kwarg=None, posonlyargs=[], kwonlyargs=[], kw_defaults=[], defaults=[] ), - body=[pytorch_sync, jax_sync, mlx_sync, tensorflow_sync], + # body=[pytorch_sync, jax_sync, mlx_sync, tensorflow_sync], + body=[pytorch_sync, jax_sync, tensorflow_sync], decorator_list=[], returns=None, lineno=lineno, From 7bf6681ce13d372caff855ad33886d4f19f16aaa Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Wed, 31 Dec 2025 18:27:49 -0800 Subject: [PATCH 7/8] failsafe --- codeflash/optimization/function_optimizer.py | 29 +++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 037460624..51c0c43bf 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -458,19 +458,22 @@ def optimize_function(self) -> Result[BestOptimization, str]: jit_compiled_opt_candidate = self.aiservice_client.get_jit_rewritten_code( code_context.read_writable_code.markdown, code_context.read_only_context_code, self.function_trace_id ) - # write files - # Try to replace function with optimized code - self.replace_function_and_helpers_with_optimized_code( - code_context=code_context, - optimized_code=jit_compiled_opt_candidate[0].source_code, - original_helper_code=original_helper_code, - ) - # get codecontext - new_code_context = self.get_code_optimization_context().unwrap() - # unwrite files - self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path - ) + if len(jit_compiled_opt_candidate) > 0: + # write files + # Try to replace function with optimized code + self.replace_function_and_helpers_with_optimized_code( + code_context=code_context, + optimized_code=jit_compiled_opt_candidate[0].source_code, + original_helper_code=original_helper_code, + ) + # get codecontext + new_code_context = self.get_code_optimization_context().unwrap() + # unwrite files + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + else: + new_code_context = code_context # Generate tests and optimizations in parallel future_tests = self.executor.submit(self.generate_and_instrument_tests, new_code_context) future_optimizations = self.executor.submit( From f1e473576197a97b737d52b8df939f375acb2616 Mon Sep 17 00:00:00 2001 From: Codeflash Bot Date: Thu, 1 Jan 2026 14:25:34 -0800 Subject: [PATCH 8/8] comparator fix --- codeflash/verification/comparator.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index d95c45012..704d19b3c 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -24,7 +24,6 @@ HAS_JAX = find_spec("jax") is not None HAS_XARRAY = find_spec("xarray") is not None HAS_TENSORFLOW = find_spec("tensorflow") is not None -HAS_MLX = find_spec("mlx") is not None def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911 @@ -139,17 +138,6 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 return False return comparator(orig.to_list(), new.to_list(), superset_obj) - if HAS_MLX: - import mlx.core as mx # type: ignore # noqa: PGH003 - - if isinstance(orig, mx.array): - if orig.dtype != new.dtype: - return False - if orig.shape != new.shape: - return False - # MLX allclose handles NaN comparison via equal_nan parameter - return bool(mx.allclose(orig, new, equal_nan=True)) - if HAS_SQLALCHEMY: import sqlalchemy # type: ignore # noqa: PGH003 @@ -235,6 +223,10 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 return False return all(comparator(orig[field], new[field], superset_obj) for field in orig.dtype.fields) + # Handle np.dtype instances (including numpy.dtypes.* classes like Float64DType, Int64DType, etc.) + if isinstance(orig, np.dtype): + return orig == new + if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix): if orig.dtype != new.dtype: return False