diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index 88758455e..6abefe3e7 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -14,6 +14,12 @@ DEFAULT_IMPORTANCE_THRESHOLD = 0.001 N_CANDIDATES_LP = 6 +# pytest loop stability +# For now, we use strict thresholds (large windows and low tolerances), since this is still experimental. +STABILITY_WINDOW_SIZE = 0.35 # 35% of total window +STABILITY_CENTER_TOLERANCE = 0.0025 # ±0.25% around median +STABILITY_SPREAD_TOLERANCE = 0.0025 # 0.25% window spread + # Refinement REFINE_ALL_THRESHOLD = 2 # when valid optimizations count is 2 or less, refine all optimizations REFINED_CANDIDATE_RANKING_WEIGHTS = (2, 1) # (runtime, diff), runtime is more important than diff by a factor of 2 diff --git a/codeflash/code_utils/env_utils.py b/codeflash/code_utils/env_utils.py index 450c023b3..76621327c 100644 --- a/codeflash/code_utils/env_utils.py +++ b/codeflash/code_utils/env_utils.py @@ -19,7 +19,6 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa if not formatter_cmds or formatter_cmds[0] == "disabled": return True - first_cmd = formatter_cmds[0] cmd_tokens = shlex.split(first_cmd) if isinstance(first_cmd, str) else [first_cmd] diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 7b78dabda..c64b316fc 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -46,18 +46,13 @@ def apply_formatter_cmds( print_status: bool, # noqa exit_on_failure: bool = True, # noqa ) -> tuple[Path, str, bool]: - should_make_copy = False - file_path = path - - if test_dir_str: - should_make_copy = True - file_path = Path(test_dir_str) / "temp.py" - if not path.exists(): msg = f"File {path} does not exist. Cannot apply formatter commands." raise FileNotFoundError(msg) - if should_make_copy: + file_path = path + if test_dir_str: + file_path = Path(test_dir_str) / "temp.py" shutil.copy2(path, file_path) file_token = "$file" # noqa: S105 diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 7a0a5f510..96199be7e 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -1897,7 +1897,6 @@ def establish_original_code_baseline( benchmarking_results, self.function_to_optimize.function_name ) logger.debug(f"Original async function throughput: {async_throughput} calls/second") - console.rule() if self.args.benchmark: replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks( diff --git a/codeflash/verification/pytest_plugin.py b/codeflash/verification/pytest_plugin.py index 20ef8624a..ff3fcb4c5 100644 --- a/codeflash/verification/pytest_plugin.py +++ b/codeflash/verification/pytest_plugin.py @@ -12,13 +12,19 @@ import time as _time_module import warnings from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Optional from unittest import TestCase # PyTest Imports import pytest from pluggy import HookspecMarker +from codeflash.code_utils.config_consts import ( + STABILITY_CENTER_TOLERANCE, + STABILITY_SPREAD_TOLERANCE, + STABILITY_WINDOW_SIZE, +) + if TYPE_CHECKING: from _pytest.config import Config, Parser from _pytest.main import Session @@ -77,6 +83,7 @@ class UnexpectedError(Exception): # Store references to original functions before any patching _ORIGINAL_TIME_TIME = _time_module.time _ORIGINAL_PERF_COUNTER = _time_module.perf_counter +_ORIGINAL_PERF_COUNTER_NS = _time_module.perf_counter_ns _ORIGINAL_TIME_SLEEP = _time_module.sleep @@ -249,6 +256,14 @@ def pytest_addoption(parser: Parser) -> None: choices=("function", "class", "module", "session"), help="Scope for looping tests", ) + pytest_loops.addoption( + "--codeflash_stability_check", + action="store", + default="false", + type=str, + choices=("true", "false"), + help="Enable stability checks for the loops", + ) @pytest.hookimpl(trylast=True) @@ -260,6 +275,70 @@ def pytest_configure(config: Config) -> None: _apply_deterministic_patches() +def get_runtime_from_stdout(stdout: str) -> Optional[int]: + marker_start = "!######" + marker_end = "######!" + + if not stdout: + return None + + end = stdout.rfind(marker_end) + if end == -1: + return None + + start = stdout.rfind(marker_start, 0, end) + if start == -1: + return None + + payload = stdout[start + len(marker_start) : end] + last_colon = payload.rfind(":") + if last_colon == -1: + return None + try: + return int(payload[last_colon + 1 :]) + except ValueError: + return None + + +_NODEID_BRACKET_PATTERN = re.compile(r"\s*\[\s*\d+\s*\]\s*$") + + +def should_stop( + runtimes: list[int], + window: int, + min_window_size: int, + center_rel_tol: float = STABILITY_CENTER_TOLERANCE, + spread_rel_tol: float = STABILITY_SPREAD_TOLERANCE, +) -> bool: + if len(runtimes) < window: + return False + + if len(runtimes) < min_window_size: + return False + + recent = runtimes[-window:] + + # Use sorted array for faster median and min/max operations + recent_sorted = sorted(recent) + mid = window // 2 + m = recent_sorted[mid] if window % 2 else (recent_sorted[mid - 1] + recent_sorted[mid]) / 2 + + # 1) All recent points close to the median + centered = True + for r in recent: + if abs(r - m) / m > center_rel_tol: + centered = False + break + + # 2) Window spread is small + r_min, r_max = recent_sorted[0], recent_sorted[-1] + if r_min == 0: + return False + spread_ok = (r_max - r_min) / r_min <= spread_rel_tol + + return centered and spread_ok + + class PytestLoops: name: str = "pytest-loops" @@ -268,6 +347,20 @@ def __init__(self, config: Config) -> None: level = logging.DEBUG if config.option.verbose > 1 else logging.INFO logging.basicConfig(level=level) self.logger = logging.getLogger(self.name) + self.runtime_data_by_test_case: dict[str, list[int]] = {} + self.enable_stability_check: bool = ( + str(getattr(config.option, "codeflash_stability_check", "false")).lower() == "true" + ) + + @pytest.hookimpl + def pytest_runtest_logreport(self, report: pytest.TestReport) -> None: + if not self.enable_stability_check: + return + if report.when == "call" and report.passed: + duration_ns = get_runtime_from_stdout(report.capstdout) + if duration_ns: + clean_id = _NODEID_BRACKET_PATTERN.sub("", report.nodeid) + self.runtime_data_by_test_case.setdefault(clean_id, []).append(duration_ns) @hookspec(firstresult=True) def pytest_runtestloop(self, session: Session) -> bool: @@ -283,11 +376,12 @@ def pytest_runtestloop(self, session: Session) -> bool: total_time: float = self._get_total_time(session) count: int = 0 + runtimes = [] + elapsed_ns = 0 while total_time >= SHORTEST_AMOUNT_OF_TIME: # need to run at least one for normal tests count += 1 - total_time = self._get_total_time(session) - + loop_start = _ORIGINAL_PERF_COUNTER_NS() for index, item in enumerate(session.items): item: pytest.Item = item # noqa: PLW0127, PLW2901 item._report_sections.clear() # clear reports for new test # noqa: SLF001 @@ -304,8 +398,26 @@ def pytest_runtestloop(self, session: Session) -> bool: raise session.Failed(session.shouldfail) if session.shouldstop: raise session.Interrupted(session.shouldstop) + + if self.enable_stability_check: + elapsed_ns += _ORIGINAL_PERF_COUNTER_NS() - loop_start + best_runtime_until_now = sum([min(data) for data in self.runtime_data_by_test_case.values()]) + if best_runtime_until_now > 0: + runtimes.append(best_runtime_until_now) + + estimated_total_loops = 0 + if elapsed_ns > 0: + rate = count / elapsed_ns + total_time_ns = total_time * 1e9 + estimated_total_loops = int(rate * total_time_ns) + + window_size = int(STABILITY_WINDOW_SIZE * estimated_total_loops + 0.5) + if should_stop(runtimes, window_size, session.config.option.codeflash_min_loops): + break + if self._timed_out(session, start_time, count): - break # exit loop + break + _ORIGINAL_TIME_SLEEP(self._get_delay_time(session)) return True diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index 1860e0321..ba43577b1 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -212,6 +212,7 @@ def run_benchmarking_tests( f"--codeflash_min_loops={pytest_min_loops}", f"--codeflash_max_loops={pytest_max_loops}", f"--codeflash_seconds={pytest_target_runtime_seconds}", + "--codeflash_stability_check=true", ] if pytest_timeout is not None: pytest_args.append(f"--timeout={pytest_timeout}")