Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions codeflash/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,31 +446,45 @@ def get_function_sources_from_jedi(
definition_path = definition.module_path

# The definition is part of this project and not defined within the original function
if (
is_valid_definition = (
str(definition_path).startswith(str(project_root_path) + os.sep)
and not path_belongs_to_site_packages(definition_path)
and definition.full_name
and definition.type == "function"
and not belongs_to_function_qualified(definition, qualified_function_name)
and definition.full_name.startswith(definition.module_name)
)
if is_valid_definition and definition.type == "function":
qualified_name = get_qualified_name(definition.module_name, definition.full_name)
# Avoid nested functions or classes. Only class.function is allowed
and len(
(qualified_name := get_qualified_name(definition.module_name, definition.full_name)).split(
"."
if len(qualified_name.split(".")) <= 2:
function_source = FunctionSource(
file_path=definition_path,
qualified_name=qualified_name,
fully_qualified_name=definition.full_name,
only_function_name=definition.name,
source_code=definition.get_line_code(),
jedi_definition=definition,
)
file_path_to_function_source[definition_path].add(function_source)
function_source_list.append(function_source)
# When a class is instantiated (e.g., MyClass()), track its __init__ as a helper
# This ensures the class definition with constructor is included in testgen context
elif is_valid_definition and definition.type == "class":
init_qualified_name = get_qualified_name(
definition.module_name, f"{definition.full_name}.__init__"
)
<= 2
):
function_source = FunctionSource(
file_path=definition_path,
qualified_name=qualified_name,
fully_qualified_name=definition.full_name,
only_function_name=definition.name,
source_code=definition.get_line_code(),
jedi_definition=definition,
)
file_path_to_function_source[definition_path].add(function_source)
function_source_list.append(function_source)
# Only include if it's a top-level class (not nested)
if len(init_qualified_name.split(".")) <= 2:
function_source = FunctionSource(
file_path=definition_path,
qualified_name=init_qualified_name,
fully_qualified_name=f"{definition.full_name}.__init__",
only_function_name="__init__",
source_code=definition.get_line_code(),
jedi_definition=definition,
)
file_path_to_function_source[definition_path].add(function_source)
function_source_list.append(function_source)

return file_path_to_function_source, function_source_list

Expand Down Expand Up @@ -647,7 +661,10 @@ def prune_cst_for_code_hashing( # noqa: PLR0911

if isinstance(node, cst.FunctionDef):
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
if qualified_name in target_functions:
# For hashing, exclude __init__ methods even if in target_functions
# because they don't affect the semantic behavior being hashed
# But include other dunder methods like __call__ which do affect behavior
if qualified_name in target_functions and node.name.value != "__init__":
new_body = remove_docstring_from_body(node.body) if isinstance(node.body, cst.IndentedBlock) else node.body
return node.with_changes(body=new_body), True
return None, False
Expand All @@ -666,7 +683,9 @@ def prune_cst_for_code_hashing( # noqa: PLR0911
for stmt in node.body.body:
if isinstance(stmt, cst.FunctionDef):
qualified_name = f"{class_prefix}.{stmt.name.value}"
if qualified_name in target_functions:
# For hashing, exclude __init__ methods even if in target_functions
# but include other methods like __call__ which affect behavior
if qualified_name in target_functions and stmt.name.value != "__init__":
stmt_with_changes = stmt.with_changes(
body=remove_docstring_from_body(cast("cst.IndentedBlock", stmt.body))
)
Expand Down
12 changes: 12 additions & 0 deletions codeflash/context/unused_definition_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,18 @@ def visit_ClassDef(self, node: cst.ClassDef) -> None:
self.current_class = class_name
self.current_top_level_name = class_name

# Track base classes as dependencies
for base in node.bases:
if isinstance(base.value, cst.Name):
base_name = base.value.value
if base_name in self.definitions and class_name in self.definitions:
self.definitions[class_name].dependencies.add(base_name)
elif isinstance(base.value, cst.Attribute):
# Handle cases like module.ClassName
attr_name = base.value.attr.value
if attr_name in self.definitions and class_name in self.definitions:
self.definitions[class_name].dependencies.add(attr_name)

self.class_depth += 1

def leave_ClassDef(self, original_node: cst.ClassDef) -> None: # noqa: ARG002
Expand Down
170 changes: 168 additions & 2 deletions tests/test_code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def test_code_replacement10() -> None:

code_ctx = get_code_optimization_context(function_to_optimize=func_top_optimize, project_root_path=file_path.parent)
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
assert qualified_names == {"HelperClass.helper_method"} # Nested method should not be in here
# HelperClass.__init__ is now tracked because HelperClass(self.name) instantiates the class
assert qualified_names == {"HelperClass.helper_method", "HelperClass.__init__"} # Nested method should not be in here
read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code
hashing_context = code_ctx.hashing_code_context

Expand Down Expand Up @@ -570,6 +571,8 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R:
class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]):
"""Interface for cache backends used by the persistent cache decorator."""

def __init__(self) -> None: ...

def hash_key(
self,
*,
Expand Down Expand Up @@ -1296,6 +1299,8 @@ def __repr__(self) -> str:
```
```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def __init__(self):
self.data = None

def transform(self, data):
self.data = data
Expand Down Expand Up @@ -1599,7 +1604,11 @@ def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={{self.default_prefix!r}})"
```

```python:{path_to_transform_utils.relative_to(project_root)}
class DataTransformer:
def __init__(self):
self.data = None
```
"""
expected_hashing_context = f"""
```python:utils.py
Expand Down Expand Up @@ -1705,13 +1714,19 @@ def test_direct_module_import() -> None:

expected_read_only_context = """
```python:utils.py
import math
from transform_utils import DataTransformer

class DataProcessor:
\"\"\"A class for processing data.\"\"\"

number = 1

def __init__(self, default_prefix: str = "PREFIX_"):
\"\"\"Initialize the DataProcessor with a default prefix.\"\"\"
self.default_prefix = default_prefix
self.number += math.log(self.number)

def __repr__(self) -> str:
\"\"\"Return a string representation of the DataProcessor.\"\"\"
return f"DataProcessor(default_prefix={self.default_prefix!r})"
Expand Down Expand Up @@ -2727,3 +2742,154 @@ async def async_function():
# Verify correct order
expected_order = ["GLOBAL_CONSTANT", "ANOTHER_CONSTANT", "FINAL_ASSIGNMENT"]
assert collector.assignment_order == expected_order


def test_class_instantiation_includes_init_as_helper(tmp_path: Path) -> None:
"""Test that when a class is instantiated, its __init__ method is tracked as a helper.

This test verifies the fix for the bug where class constructors were not
included in the context when only the class instantiation was called
(not any other methods). This caused LLMs to not know the constructor
signatures when generating tests.
"""
code = '''
class DataDumper:
"""A class that dumps data."""

def __init__(self, data):
"""Initialize with data."""
self.data = data

def dump(self):
"""Dump the data."""
return self.data


def target_function():
# Only instantiates DataDumper, doesn't call any other methods
dumper = DataDumper({"key": "value"})
return dumper
'''
file_path = tmp_path / "test_code.py"
file_path.write_text(code, encoding="utf-8")
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="target_function",
file_path=file_path,
parents=[],
starting_line=None,
ending_line=None,
)

code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)

# The __init__ method should be tracked as a helper since DataDumper() instantiates the class
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}
assert "DataDumper.__init__" in qualified_names, (
"DataDumper.__init__ should be tracked as a helper when the class is instantiated"
)

# The testgen context should contain the class with __init__ (critical for LLM to know constructor)
testgen_context = code_ctx.testgen_context.markdown
assert "class DataDumper:" in testgen_context, "DataDumper class should be in testgen context"
assert "def __init__(self, data):" in testgen_context, (
"__init__ method should be included in testgen context"
)

# The hashing context should NOT contain __init__ (excluded for stability)
hashing_context = code_ctx.hashing_code_context
assert "__init__" not in hashing_context, (
"__init__ should NOT be in hashing context (excluded for hash stability)"
)


def test_class_instantiation_preserves_full_class_in_testgen(tmp_path: Path) -> None:
"""Test that instantiated classes are fully preserved in testgen context.

This is specifically for the unstructured LayoutDumper bug where helper classes
that were instantiated but had no other methods called were being excluded
from the testgen context.
"""
code = '''
class LayoutDumper:
"""Base class for layout dumpers."""
layout_source: str = "unknown"

def __init__(self, layout):
self._layout = layout

def dump(self) -> dict:
raise NotImplementedError()


class ObjectDetectionLayoutDumper(LayoutDumper):
"""Specific dumper for object detection layouts."""

def __init__(self, layout):
super().__init__(layout)

def dump(self) -> dict:
return {"type": "object_detection", "layout": self._layout}


def dump_layout(layout_type, layout):
"""Dump a layout based on its type."""
if layout_type == "object_detection":
dumper = ObjectDetectionLayoutDumper(layout)
else:
dumper = LayoutDumper(layout)
return dumper.dump()
'''
file_path = tmp_path / "test_code.py"
file_path.write_text(code, encoding="utf-8")
opt = Optimizer(
Namespace(
project_root=file_path.parent.resolve(),
disable_telemetry=True,
tests_root="tests",
test_framework="pytest",
pytest_cmd="pytest",
experiment_id=None,
test_project_root=Path().resolve(),
)
)
function_to_optimize = FunctionToOptimize(
function_name="dump_layout",
file_path=file_path,
parents=[],
starting_line=None,
ending_line=None,
)

code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root)
qualified_names = {func.qualified_name for func in code_ctx.helper_functions}

# Both class __init__ methods should be tracked as helpers
assert "ObjectDetectionLayoutDumper.__init__" in qualified_names, (
"ObjectDetectionLayoutDumper.__init__ should be tracked"
)
assert "LayoutDumper.__init__" in qualified_names, (
"LayoutDumper.__init__ should be tracked"
)

# The testgen context should include both classes with their __init__ methods
testgen_context = code_ctx.testgen_context.markdown
assert "class LayoutDumper:" in testgen_context, "LayoutDumper should be in testgen context"
assert "class ObjectDetectionLayoutDumper" in testgen_context, (
"ObjectDetectionLayoutDumper should be in testgen context"
)

# Both __init__ methods should be in the testgen context (so LLM knows constructor signatures)
assert testgen_context.count("def __init__") >= 2, (
"Both __init__ methods should be in testgen context"
)
6 changes: 5 additions & 1 deletion tests/test_instrument_line_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def hi():


class BubbleSortClass:
@codeflash_line_profile
def __init__(self):
pass

Expand Down Expand Up @@ -117,7 +118,9 @@ def sort_classmethod(x):
return y.sorter(x)
"""
assert code_path.read_text("utf-8") == expected_code_main
assert code_context.helper_functions.__len__() == 0
# WrapperClass.__init__ is now detected as a helper since WrapperClass.BubbleSortClass() instantiates it
assert len(code_context.helper_functions) == 1
assert code_context.helper_functions[0].qualified_name == "WrapperClass.__init__"
finally:
func_optimizer.write_code_and_helpers(
func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path
Expand Down Expand Up @@ -283,6 +286,7 @@ def sorter(arr):
ans = helper(arr)
return ans
class helper:
@codeflash_line_profile
def __init__(self, arr):
return arr.sort()
"""
Expand Down
Loading
Loading