From 0d84ab62dd57263dd6115ef3e5a15e8cf45a610d Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Thu, 1 Jan 2026 02:24:51 -0500 Subject: [PATCH 1/3] fix: track base class dependencies in unused definition remover Base classes must be preserved when a derived class inherits from them. --- .../context/unused_definition_remover.py | 12 ++++ tests/test_remove_unused_definitions.py | 59 +++++++++++++++++++ 2 files changed, 71 insertions(+) diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 8e6ea057c..823cb735b 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -223,6 +223,18 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None: self.current_class = class_name self.current_top_level_name = class_name + # Track base classes as dependencies + for base in node.bases: + if isinstance(base.value, cst.Name): + base_name = base.value.value + if base_name in self.definitions and class_name in self.definitions: + self.definitions[class_name].dependencies.add(base_name) + elif isinstance(base.value, cst.Attribute): + # Handle cases like module.ClassName + attr_name = base.value.attr.value + if attr_name in self.definitions and class_name in self.definitions: + self.definitions[class_name].dependencies.add(attr_name) + self.class_depth += 1 def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002 diff --git a/tests/test_remove_unused_definitions.py b/tests/test_remove_unused_definitions.py index 86a57bb6d..8d09a95e1 100644 --- a/tests/test_remove_unused_definitions.py +++ b/tests/test_remove_unused_definitions.py @@ -337,6 +337,65 @@ def unused_function(): result = remove_unused_definitions_by_function_names(code, qualified_functions) assert result.strip() == expected.strip() +def test_base_class_inheritance() -> None: + """Test that base classes used only for inheritance are preserved.""" + code = """ +class LayoutDumper: + def dump(self): + raise NotImplementedError + +class ObjectDetectionLayoutDumper(LayoutDumper): + def __init__(self, data): + self.data = data + def dump(self): + return self.data + +class ExtractedLayoutDumper(LayoutDumper): + def __init__(self, data): + self.data = data + def dump(self): + return self.data + +class UnusedClass: + pass + +def test_function(): + dumper = ObjectDetectionLayoutDumper({}) + return dumper.dump() +""" + + expected = """ +class LayoutDumper: + def dump(self): + raise NotImplementedError + +class ObjectDetectionLayoutDumper(LayoutDumper): + def __init__(self, data): + self.data = data + def dump(self): + return self.data + +class ExtractedLayoutDumper(LayoutDumper): + def __init__(self, data): + self.data = data + def dump(self): + return self.data + +class UnusedClass: + pass + +def test_function(): + dumper = ObjectDetectionLayoutDumper({}) + return dumper.dump() +""" + + qualified_functions = {"test_function"} + result = remove_unused_definitions_by_function_names(code, qualified_functions) + # LayoutDumper should be preserved because ObjectDetectionLayoutDumper inherits from it + assert "class LayoutDumper" in result + assert "class ObjectDetectionLayoutDumper" in result + + def test_conditional_and_loop_variables() -> None: """Test handling of variables defined in if-else and while loops.""" code = """ From 3048ece4da9ef9585684645a995ec92c8f36539b Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Thu, 1 Jan 2026 02:25:10 -0500 Subject: [PATCH 2/3] fix: track class __init__ as helper when class is instantiated Ensures LLM sees constructor signatures for proper test generation. --- codeflash/context/code_context_extractor.py | 57 ++++--- tests/test_code_context_extractor.py | 170 +++++++++++++++++++- 2 files changed, 206 insertions(+), 21 deletions(-) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 14d549633..a411bafac 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -446,31 +446,45 @@ def get_function_sources_from_jedi( definition_path = definition.module_path # The definition is part of this project and not defined within the original function - if ( + is_valid_definition = ( str(definition_path).startswith(str(project_root_path) + os.sep) and not path_belongs_to_site_packages(definition_path) and definition.full_name - and definition.type == "function" and not belongs_to_function_qualified(definition, qualified_function_name) and definition.full_name.startswith(definition.module_name) + ) + if is_valid_definition and definition.type == "function": + qualified_name = get_qualified_name(definition.module_name, definition.full_name) # Avoid nested functions or classes. Only class.function is allowed - and len( - (qualified_name := get_qualified_name(definition.module_name, definition.full_name)).split( - "." + if len(qualified_name.split(".")) <= 2: + function_source = FunctionSource( + file_path=definition_path, + qualified_name=qualified_name, + fully_qualified_name=definition.full_name, + only_function_name=definition.name, + source_code=definition.get_line_code(), + jedi_definition=definition, ) + file_path_to_function_source[definition_path].add(function_source) + function_source_list.append(function_source) + # When a class is instantiated (e.g., MyClass()), track its __init__ as a helper + # This ensures the class definition with constructor is included in testgen context + elif is_valid_definition and definition.type == "class": + init_qualified_name = get_qualified_name( + definition.module_name, f"{definition.full_name}.__init__" ) - <= 2 - ): - function_source = FunctionSource( - file_path=definition_path, - qualified_name=qualified_name, - fully_qualified_name=definition.full_name, - only_function_name=definition.name, - source_code=definition.get_line_code(), - jedi_definition=definition, - ) - file_path_to_function_source[definition_path].add(function_source) - function_source_list.append(function_source) + # Only include if it's a top-level class (not nested) + if len(init_qualified_name.split(".")) <= 2: + function_source = FunctionSource( + file_path=definition_path, + qualified_name=init_qualified_name, + fully_qualified_name=f"{definition.full_name}.__init__", + only_function_name="__init__", + source_code=definition.get_line_code(), + jedi_definition=definition, + ) + file_path_to_function_source[definition_path].add(function_source) + function_source_list.append(function_source) return file_path_to_function_source, function_source_list @@ -647,7 +661,10 @@ def prune_cst_for_code_hashing( # noqa: PLR0911 if isinstance(node, cst.FunctionDef): qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value - if qualified_name in target_functions: + # For hashing, exclude __init__ methods even if in target_functions + # because they don't affect the semantic behavior being hashed + # But include other dunder methods like __call__ which do affect behavior + if qualified_name in target_functions and node.name.value != "__init__": new_body = remove_docstring_from_body(node.body) if isinstance(node.body, cst.IndentedBlock) else node.body return node.with_changes(body=new_body), True return None, False @@ -666,7 +683,9 @@ def prune_cst_for_code_hashing( # noqa: PLR0911 for stmt in node.body.body: if isinstance(stmt, cst.FunctionDef): qualified_name = f"{class_prefix}.{stmt.name.value}" - if qualified_name in target_functions: + # For hashing, exclude __init__ methods even if in target_functions + # but include other methods like __call__ which affect behavior + if qualified_name in target_functions and stmt.name.value != "__init__": stmt_with_changes = stmt.with_changes( body=remove_docstring_from_body(cast("cst.IndentedBlock", stmt.body)) ) diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index aa4e2880f..b7cce0869 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -84,7 +84,8 @@ def test_code_replacement10() -> None: code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent) qualified_names = {func.qualified_name for func in code_ctx.helper_functions} - assert qualified_names == {"HelperClass.helper_method"} # Nested method should not be in here + # HelperClass.__init__ is now tracked because HelperClass(self.name) instantiates the class + assert qualified_names == {"HelperClass.helper_method", "HelperClass.__init__"} # Nested method should not be in here read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context @@ -570,6 +571,8 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): """Interface for cache backends used by the persistent cache decorator.""" + def __init__(self) -> None: ... + def hash_key( self, *, @@ -1296,6 +1299,8 @@ def __repr__(self) -> str: ``` ```python:{path_to_transform_utils.relative_to(project_root)} class DataTransformer: + def __init__(self): + self.data = None def transform(self, data): self.data = data @@ -1599,7 +1604,11 @@ def __repr__(self) -> str: \"\"\"Return a string representation of the DataProcessor.\"\"\" return f"DataProcessor(default_prefix={{self.default_prefix!r}})" ``` - +```python:{path_to_transform_utils.relative_to(project_root)} +class DataTransformer: + def __init__(self): + self.data = None +``` """ expected_hashing_context = f""" ```python:utils.py @@ -1705,6 +1714,7 @@ def test_direct_module_import() -> None: expected_read_only_context = """ ```python:utils.py +import math from transform_utils import DataTransformer class DataProcessor: @@ -1712,6 +1722,11 @@ class DataProcessor: number = 1 + def __init__(self, default_prefix: str = "PREFIX_"): + \"\"\"Initialize the DataProcessor with a default prefix.\"\"\" + self.default_prefix = default_prefix + self.number += math.log(self.number) + def __repr__(self) -> str: \"\"\"Return a string representation of the DataProcessor.\"\"\" return f"DataProcessor(default_prefix={self.default_prefix!r})" @@ -2727,3 +2742,154 @@ async def async_function(): # Verify correct order expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"] assert collector.assignment_order == expected_order + + +def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None: + """Test that when a class is instantiated, its __init__ method is tracked as a helper. + + This test verifies the fix for the bug where class constructors were not + included in the context when only the class instantiation was called + (not any other methods). This caused LLMs to not know the constructor + signatures when generating tests. + """ + code = ''' +class DataDumper: + """A class that dumps data.""" + + def __init__(self, data): + """Initialize with data.""" + self.data = data + + def dump(self): + """Dump the data.""" + return self.data + + +def target_function(): + # Only instantiates DataDumper, doesn't call any other methods + dumper = DataDumper({"key": "value"}) + return dumper +''' + file_path = tmp_path / "test_code.py" + file_path.write_text(code, encoding="utf-8") + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + function_to_optimize = FunctionToOptimize( + function_name="target_function", + file_path=file_path, + parents=[], + starting_line=None, + ending_line=None, + ) + + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + + # The __init__ method should be tracked as a helper since DataDumper() instantiates the class + qualified_names = {func.qualified_name for func in code_ctx.helper_functions} + assert "DataDumper.__init__" in qualified_names, ( + "DataDumper.__init__ should be tracked as a helper when the class is instantiated" + ) + + # The testgen context should contain the class with __init__ (critical for LLM to know constructor) + testgen_context = code_ctx.testgen_context.markdown + assert "class DataDumper:" in testgen_context, "DataDumper class should be in testgen context" + assert "def __init__(self, data):" in testgen_context, ( + "__init__ method should be included in testgen context" + ) + + # The hashing context should NOT contain __init__ (excluded for stability) + hashing_context = code_ctx.hashing_code_context + assert "__init__" not in hashing_context, ( + "__init__ should NOT be in hashing context (excluded for hash stability)" + ) + + +def test_class_instantiation_preserves_full_class_in_testgen(tmp_path: Path) -> None: + """Test that instantiated classes are fully preserved in testgen context. + + This is specifically for the unstructured LayoutDumper bug where helper classes + that were instantiated but had no other methods called were being excluded + from the testgen context. + """ + code = ''' +class LayoutDumper: + """Base class for layout dumpers.""" + layout_source: str = "unknown" + + def __init__(self, layout): + self._layout = layout + + def dump(self) -> dict: + raise NotImplementedError() + + +class ObjectDetectionLayoutDumper(LayoutDumper): + """Specific dumper for object detection layouts.""" + + def __init__(self, layout): + super().__init__(layout) + + def dump(self) -> dict: + return {"type": "object_detection", "layout": self._layout} + + +def dump_layout(layout_type, layout): + """Dump a layout based on its type.""" + if layout_type == "object_detection": + dumper = ObjectDetectionLayoutDumper(layout) + else: + dumper = LayoutDumper(layout) + return dumper.dump() +''' + file_path = tmp_path / "test_code.py" + file_path.write_text(code, encoding="utf-8") + opt = Optimizer( + Namespace( + project_root=file_path.parent.resolve(), + disable_telemetry=True, + tests_root="tests", + test_framework="pytest", + pytest_cmd="pytest", + experiment_id=None, + test_project_root=Path().resolve(), + ) + ) + function_to_optimize = FunctionToOptimize( + function_name="dump_layout", + file_path=file_path, + parents=[], + starting_line=None, + ending_line=None, + ) + + code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) + qualified_names = {func.qualified_name for func in code_ctx.helper_functions} + + # Both class __init__ methods should be tracked as helpers + assert "ObjectDetectionLayoutDumper.__init__" in qualified_names, ( + "ObjectDetectionLayoutDumper.__init__ should be tracked" + ) + assert "LayoutDumper.__init__" in qualified_names, ( + "LayoutDumper.__init__ should be tracked" + ) + + # The testgen context should include both classes with their __init__ methods + testgen_context = code_ctx.testgen_context.markdown + assert "class LayoutDumper:" in testgen_context, "LayoutDumper should be in testgen context" + assert "class ObjectDetectionLayoutDumper" in testgen_context, ( + "ObjectDetectionLayoutDumper should be in testgen context" + ) + + # Both __init__ methods should be in the testgen context (so LLM knows constructor signatures) + assert testgen_context.count("def __init__") >= 2, ( + "Both __init__ methods should be in testgen context" + ) From a0770e0392e049c68d67f38b8419c998cd400b29 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Thu, 1 Jan 2026 12:34:40 -0500 Subject: [PATCH 3/3] update expectations here --- tests/test_instrument_line_profiler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_instrument_line_profiler.py b/tests/test_instrument_line_profiler.py index 71d1005c0..675db5944 100644 --- a/tests/test_instrument_line_profiler.py +++ b/tests/test_instrument_line_profiler.py @@ -55,6 +55,7 @@ def hi(): class BubbleSortClass: + @codeflash_line_profile def __init__(self): pass @@ -117,7 +118,9 @@ def sort_classmethod(x): return y.sorter(x) """ assert code_path.read_text("utf-8") == expected_code_main - assert code_context.helper_functions.__len__() == 0 + # WrapperClass.__init__ is now detected as a helper since WrapperClass.BubbleSortClass() instantiates it + assert len(code_context.helper_functions) == 1 + assert code_context.helper_functions[0].qualified_name == "WrapperClass.__init__" finally: func_optimizer.write_code_and_helpers( func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path @@ -283,6 +286,7 @@ def sorter(arr): ans = helper(arr) return ans class helper: + @codeflash_line_profile def __init__(self, arr): return arr.sort() """