Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
e4cbf95
exp
mohammedahmed18 Sep 8, 2025
01ad626
still experimenting
mohammedahmed18 Sep 9, 2025
2b01faf
Merge branch 'main' of github.com:codeflash-ai/codeflash into exp/con…
mohammedahmed18 Dec 9, 2025
ce89905
reset
mohammedahmed18 Dec 9, 2025
1f367be
dynamic tolerance
mohammedahmed18 Dec 9, 2025
0d819a8
Merge branch 'main' of github.com:codeflash-ai/codeflash into exp/con…
mohammedahmed18 Dec 11, 2025
5c4a6d9
get the duration from the pytest overriden methods
mohammedahmed18 Dec 12, 2025
ecd21d5
remove debug log
mohammedahmed18 Dec 12, 2025
30c89ce
respect the min loop count -just in case-
mohammedahmed18 Dec 12, 2025
ce2c05b
Merge branch 'main' of github.com:codeflash-ai/codeflash into exp/con…
mohammedahmed18 Dec 12, 2025
89fc939
more closer method
mohammedahmed18 Dec 15, 2025
cc94694
Merge branch 'main' of github.com:codeflash-ai/codeflash into exp/con…
mohammedahmed18 Dec 16, 2025
a67dad3
working version
mohammedahmed18 Dec 16, 2025
d52aae4
even better
mohammedahmed18 Dec 16, 2025
244f9ca
better stability algorithm
mohammedahmed18 Dec 17, 2025
a890d4f
should stop metrics
mohammedahmed18 Dec 19, 2025
3159eb6
Merge branch 'main' of github.com:codeflash-ai/codeflash into exp/con…
mohammedahmed18 Dec 22, 2025
95f22ee
better stability with sum the min of all prev loops
mohammedahmed18 Dec 22, 2025
9f311cd
Optimize should_stop
codeflash-ai[bot] Dec 22, 2025
83dff02
best summed runtime helper
mohammedahmed18 Dec 23, 2025
a8e93c7
Merge branch 'main' of github.com:codeflash-ai/codeflash into codefla…
mohammedahmed18 Dec 23, 2025
e49ba13
linting
mohammedahmed18 Dec 23, 2025
91cbc74
Merge pull request #984 from codeflash-ai/codeflash/optimize-pr967-20…
mohammedahmed18 Dec 23, 2025
0b3be3f
some enhancements from claude pr review
mohammedahmed18 Dec 23, 2025
b57fa1a
Merge branch 'exp/consistent-loop-break' of github.com:codeflash-ai/c…
mohammedahmed18 Dec 23, 2025
46701c7
window percentage
mohammedahmed18 Dec 26, 2025
74520f6
cleanup
mohammedahmed18 Dec 26, 2025
9ab06d3
revert comment
mohammedahmed18 Dec 26, 2025
f1058ea
for unit tests
mohammedahmed18 Dec 26, 2025
56cce15
Merge branch 'main' of github.com:codeflash-ai/codeflash into exp/con…
mohammedahmed18 Dec 26, 2025
8ea9231
toggle stability by arg and zero tolerance
mohammedahmed18 Dec 28, 2025
9079590
configs and cleaner impl
mohammedahmed18 Dec 29, 2025
70b7627
typo
mohammedahmed18 Dec 29, 2025
dd3707a
Merge branch 'main' of github.com:codeflash-ai/codeflash into exp/con…
mohammedahmed18 Dec 29, 2025
9cae2a1
refactor
mohammedahmed18 Dec 29, 2025
7f1818b
refactoring
mohammedahmed18 Dec 31, 2025
270af89
Merge branch 'main' into exp/consistent-loop-break
KRRT7 Dec 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions codeflash/code_utils/config_consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
DEFAULT_IMPORTANCE_THRESHOLD = 0.001
N_CANDIDATES_LP = 6

# pytest loop stability
# For now, we use strict thresholds (large windows and low tolerances), since this is still experimental.
STABILITY_WINDOW_SIZE = 0.35 # 35% of total window
STABILITY_CENTER_TOLERANCE = 0.0025 # ±0.25% around median
STABILITY_SPREAD_TOLERANCE = 0.0025 # 0.25% window spread

# Refinement
REFINE_ALL_THRESHOLD = 2 # when valid optimizations count is 2 or less, refine all optimizations
REFINED_CANDIDATE_RANKING_WEIGHTS = (2, 1) # (runtime, diff), runtime is more important than diff by a factor of 2
Expand Down
1 change: 0 additions & 1 deletion codeflash/code_utils/env_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
if not formatter_cmds or formatter_cmds[0] == "disabled":
return True

first_cmd = formatter_cmds[0]
cmd_tokens = shlex.split(first_cmd) if isinstance(first_cmd, str) else [first_cmd]

Expand Down
11 changes: 3 additions & 8 deletions codeflash/code_utils/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,13 @@ def apply_formatter_cmds(
print_status: bool, # noqa
exit_on_failure: bool = True, # noqa
) -> tuple[Path, str, bool]:
should_make_copy = False
file_path = path

if test_dir_str:
should_make_copy = True
file_path = Path(test_dir_str) / "temp.py"

if not path.exists():
msg = f"File {path} does not exist. Cannot apply formatter commands."
raise FileNotFoundError(msg)

if should_make_copy:
file_path = path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

related to this PR?

if test_dir_str:
file_path = Path(test_dir_str) / "temp.py"
shutil.copy2(path, file_path)

file_token = "$file" # noqa: S105
Expand Down
1 change: 0 additions & 1 deletion codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1897,7 +1897,6 @@ def establish_original_code_baseline(
benchmarking_results, self.function_to_optimize.function_name
)
logger.debug(f"Original async function throughput: {async_throughput} calls/second")
console.rule()

if self.args.benchmark:
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(
Expand Down
120 changes: 116 additions & 4 deletions codeflash/verification/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,19 @@
import time as _time_module
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Optional
from unittest import TestCase

# PyTest Imports
import pytest
from pluggy import HookspecMarker

from codeflash.code_utils.config_consts import (
STABILITY_CENTER_TOLERANCE,
STABILITY_SPREAD_TOLERANCE,
STABILITY_WINDOW_SIZE,
)

if TYPE_CHECKING:
from _pytest.config import Config, Parser
from _pytest.main import Session
Expand Down Expand Up @@ -77,6 +83,7 @@ class UnexpectedError(Exception):
# Store references to original functions before any patching
_ORIGINAL_TIME_TIME = _time_module.time
_ORIGINAL_PERF_COUNTER = _time_module.perf_counter
_ORIGINAL_PERF_COUNTER_NS = _time_module.perf_counter_ns
_ORIGINAL_TIME_SLEEP = _time_module.sleep


Expand Down Expand Up @@ -249,6 +256,14 @@ def pytest_addoption(parser: Parser) -> None:
choices=("function", "class", "module", "session"),
help="Scope for looping tests",
)
pytest_loops.addoption(
"--codeflash_stability_check",
action="store",
default="false",
type=str,
choices=("true", "false"),
help="Enable stability checks for the loops",
)


@pytest.hookimpl(trylast=True)
Expand All @@ -260,6 +275,70 @@ def pytest_configure(config: Config) -> None:
_apply_deterministic_patches()


def get_runtime_from_stdout(stdout: str) -> Optional[int]:
marker_start = "!######"
marker_end = "######!"

if not stdout:
return None

end = stdout.rfind(marker_end)
if end == -1:
return None

start = stdout.rfind(marker_start, 0, end)
if start == -1:
return None

payload = stdout[start + len(marker_start) : end]
last_colon = payload.rfind(":")
if last_colon == -1:
return None
try:
return int(payload[last_colon + 1 :])
except ValueError:
return None


_NODEID_BRACKET_PATTERN = re.compile(r"\s*\[\s*\d+\s*\]\s*$")


def should_stop(
runtimes: list[int],
window: int,
min_window_size: int,
center_rel_tol: float = STABILITY_CENTER_TOLERANCE,
spread_rel_tol: float = STABILITY_SPREAD_TOLERANCE,
) -> bool:
if len(runtimes) < window:
return False

if len(runtimes) < min_window_size:
return False

recent = runtimes[-window:]

# Use sorted array for faster median and min/max operations
recent_sorted = sorted(recent)
mid = window // 2
m = recent_sorted[mid] if window % 2 else (recent_sorted[mid - 1] + recent_sorted[mid]) / 2

# 1) All recent points close to the median
centered = True
for r in recent:
if abs(r - m) / m > center_rel_tol:
centered = False
break

# 2) Window spread is small
r_min, r_max = recent_sorted[0], recent_sorted[-1]
if r_min == 0:
return False
spread_ok = (r_max - r_min) / r_min <= spread_rel_tol
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

r_min could be 0?


return centered and spread_ok


class PytestLoops:
name: str = "pytest-loops"

Expand All @@ -268,6 +347,20 @@ def __init__(self, config: Config) -> None:
level = logging.DEBUG if config.option.verbose > 1 else logging.INFO
logging.basicConfig(level=level)
self.logger = logging.getLogger(self.name)
self.runtime_data_by_test_case: dict[str, list[int]] = {}
self.enable_stability_check: bool = (
str(getattr(config.option, "codeflash_stability_check", "false")).lower() == "true"
)

@pytest.hookimpl
def pytest_runtest_logreport(self, report: pytest.TestReport) -> None:
if not self.enable_stability_check:
return
if report.when == "call" and report.passed:
duration_ns = get_runtime_from_stdout(report.capstdout)
if duration_ns:
clean_id = _NODEID_BRACKET_PATTERN.sub("", report.nodeid)
self.runtime_data_by_test_case.setdefault(clean_id, []).append(duration_ns)

@hookspec(firstresult=True)
def pytest_runtestloop(self, session: Session) -> bool:
Expand All @@ -283,11 +376,12 @@ def pytest_runtestloop(self, session: Session) -> bool:
total_time: float = self._get_total_time(session)

count: int = 0
runtimes = []
elapsed_ns = 0

while total_time >= SHORTEST_AMOUNT_OF_TIME: # need to run at least one for normal tests
count += 1
total_time = self._get_total_time(session)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see if total_time changes inside this loop

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't, I checked


loop_start = _ORIGINAL_PERF_COUNTER_NS()
for index, item in enumerate(session.items):
item: pytest.Item = item # noqa: PLW0127, PLW2901
item._report_sections.clear() # clear reports for new test # noqa: SLF001
Expand All @@ -304,8 +398,26 @@ def pytest_runtestloop(self, session: Session) -> bool:
raise session.Failed(session.shouldfail)
if session.shouldstop:
raise session.Interrupted(session.shouldstop)

if self.enable_stability_check:
elapsed_ns += _ORIGINAL_PERF_COUNTER_NS() - loop_start
best_runtime_until_now = sum([min(data) for data in self.runtime_data_by_test_case.values()])
if best_runtime_until_now > 0:
runtimes.append(best_runtime_until_now)

estimated_total_loops = 0
if elapsed_ns > 0:
rate = count / elapsed_ns
total_time_ns = total_time * 1e9
estimated_total_loops = int(rate * total_time_ns)

window_size = int(STABILITY_WINDOW_SIZE * estimated_total_loops + 0.5)
if should_stop(runtimes, window_size, session.config.option.codeflash_min_loops):
break

if self._timed_out(session, start_time, count):
break # exit loop
break

_ORIGINAL_TIME_SLEEP(self._get_delay_time(session))
return True

Expand Down
1 change: 1 addition & 0 deletions codeflash/verification/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def run_benchmarking_tests(
f"--codeflash_min_loops={pytest_min_loops}",
f"--codeflash_max_loops={pytest_max_loops}",
f"--codeflash_seconds={pytest_target_runtime_seconds}",
"--codeflash_stability_check=true",
]
if pytest_timeout is not None:
pytest_args.append(f"--timeout={pytest_timeout}")
Expand Down
Loading