-
Notifications
You must be signed in to change notification settings - Fork 21
[Enhancement] Stop looping when runtime is stable (CF-934) #967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e4cbf95
01ad626
2b01faf
ce89905
1f367be
0d819a8
5c4a6d9
ecd21d5
30c89ce
ce2c05b
89fc939
cc94694
a67dad3
d52aae4
244f9ca
a890d4f
3159eb6
95f22ee
9f311cd
83dff02
a8e93c7
e49ba13
91cbc74
0b3be3f
b57fa1a
46701c7
74520f6
9ab06d3
f1058ea
56cce15
8ea9231
9079590
70b7627
dd3707a
9cae2a1
7f1818b
270af89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,13 +12,19 @@ | |
| import time as _time_module | ||
| import warnings | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING, Any, Callable | ||
| from typing import TYPE_CHECKING, Any, Callable, Optional | ||
| from unittest import TestCase | ||
|
|
||
| # PyTest Imports | ||
| import pytest | ||
| from pluggy import HookspecMarker | ||
|
|
||
| from codeflash.code_utils.config_consts import ( | ||
| STABILITY_CENTER_TOLERANCE, | ||
| STABILITY_SPREAD_TOLERANCE, | ||
| STABILITY_WINDOW_SIZE, | ||
| ) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from _pytest.config import Config, Parser | ||
| from _pytest.main import Session | ||
|
|
@@ -77,6 +83,7 @@ class UnexpectedError(Exception): | |
| # Store references to original functions before any patching | ||
| _ORIGINAL_TIME_TIME = _time_module.time | ||
| _ORIGINAL_PERF_COUNTER = _time_module.perf_counter | ||
| _ORIGINAL_PERF_COUNTER_NS = _time_module.perf_counter_ns | ||
| _ORIGINAL_TIME_SLEEP = _time_module.sleep | ||
|
|
||
|
|
||
|
|
@@ -249,6 +256,14 @@ def pytest_addoption(parser: Parser) -> None: | |
| choices=("function", "class", "module", "session"), | ||
| help="Scope for looping tests", | ||
| ) | ||
| pytest_loops.addoption( | ||
| "--codeflash_stability_check", | ||
| action="store", | ||
| default="false", | ||
| type=str, | ||
| choices=("true", "false"), | ||
| help="Enable stability checks for the loops", | ||
| ) | ||
|
|
||
|
|
||
| @pytest.hookimpl(trylast=True) | ||
|
|
@@ -260,6 +275,70 @@ def pytest_configure(config: Config) -> None: | |
| _apply_deterministic_patches() | ||
|
|
||
|
|
||
| def get_runtime_from_stdout(stdout: str) -> Optional[int]: | ||
| marker_start = "!######" | ||
| marker_end = "######!" | ||
|
|
||
| if not stdout: | ||
| return None | ||
|
|
||
| end = stdout.rfind(marker_end) | ||
| if end == -1: | ||
| return None | ||
|
|
||
| start = stdout.rfind(marker_start, 0, end) | ||
| if start == -1: | ||
| return None | ||
|
|
||
| payload = stdout[start + len(marker_start) : end] | ||
| last_colon = payload.rfind(":") | ||
| if last_colon == -1: | ||
| return None | ||
| try: | ||
| return int(payload[last_colon + 1 :]) | ||
| except ValueError: | ||
| return None | ||
|
|
||
|
|
||
| _NODEID_BRACKET_PATTERN = re.compile(r"\s*\[\s*\d+\s*\]\s*$") | ||
|
|
||
|
|
||
| def should_stop( | ||
| runtimes: list[int], | ||
| window: int, | ||
| min_window_size: int, | ||
| center_rel_tol: float = STABILITY_CENTER_TOLERANCE, | ||
| spread_rel_tol: float = STABILITY_SPREAD_TOLERANCE, | ||
| ) -> bool: | ||
| if len(runtimes) < window: | ||
| return False | ||
|
|
||
| if len(runtimes) < min_window_size: | ||
| return False | ||
|
|
||
| recent = runtimes[-window:] | ||
|
|
||
| # Use sorted array for faster median and min/max operations | ||
| recent_sorted = sorted(recent) | ||
| mid = window // 2 | ||
| m = recent_sorted[mid] if window % 2 else (recent_sorted[mid - 1] + recent_sorted[mid]) / 2 | ||
|
|
||
| # 1) All recent points close to the median | ||
| centered = True | ||
| for r in recent: | ||
| if abs(r - m) / m > center_rel_tol: | ||
| centered = False | ||
| break | ||
|
|
||
| # 2) Window spread is small | ||
| r_min, r_max = recent_sorted[0], recent_sorted[-1] | ||
| if r_min == 0: | ||
| return False | ||
| spread_ok = (r_max - r_min) / r_min <= spread_rel_tol | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. r_min could be 0? |
||
|
|
||
| return centered and spread_ok | ||
|
|
||
|
|
||
| class PytestLoops: | ||
| name: str = "pytest-loops" | ||
|
|
||
|
|
@@ -268,6 +347,20 @@ def __init__(self, config: Config) -> None: | |
| level = logging.DEBUG if config.option.verbose > 1 else logging.INFO | ||
| logging.basicConfig(level=level) | ||
| self.logger = logging.getLogger(self.name) | ||
| self.runtime_data_by_test_case: dict[str, list[int]] = {} | ||
| self.enable_stability_check: bool = ( | ||
| str(getattr(config.option, "codeflash_stability_check", "false")).lower() == "true" | ||
| ) | ||
|
|
||
| @pytest.hookimpl | ||
| def pytest_runtest_logreport(self, report: pytest.TestReport) -> None: | ||
| if not self.enable_stability_check: | ||
| return | ||
| if report.when == "call" and report.passed: | ||
| duration_ns = get_runtime_from_stdout(report.capstdout) | ||
| if duration_ns: | ||
| clean_id = _NODEID_BRACKET_PATTERN.sub("", report.nodeid) | ||
| self.runtime_data_by_test_case.setdefault(clean_id, []).append(duration_ns) | ||
|
|
||
| @hookspec(firstresult=True) | ||
| def pytest_runtestloop(self, session: Session) -> bool: | ||
|
|
@@ -283,11 +376,12 @@ def pytest_runtestloop(self, session: Session) -> bool: | |
| total_time: float = self._get_total_time(session) | ||
|
|
||
| count: int = 0 | ||
| runtimes = [] | ||
| elapsed_ns = 0 | ||
|
|
||
| while total_time >= SHORTEST_AMOUNT_OF_TIME: # need to run at least one for normal tests | ||
| count += 1 | ||
| total_time = self._get_total_time(session) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see if total_time changes inside this loop
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it doesn't, I checked |
||
|
|
||
| loop_start = _ORIGINAL_PERF_COUNTER_NS() | ||
| for index, item in enumerate(session.items): | ||
| item: pytest.Item = item # noqa: PLW0127, PLW2901 | ||
| item._report_sections.clear() # clear reports for new test # noqa: SLF001 | ||
|
|
@@ -304,8 +398,26 @@ def pytest_runtestloop(self, session: Session) -> bool: | |
| raise session.Failed(session.shouldfail) | ||
| if session.shouldstop: | ||
| raise session.Interrupted(session.shouldstop) | ||
|
|
||
| if self.enable_stability_check: | ||
| elapsed_ns += _ORIGINAL_PERF_COUNTER_NS() - loop_start | ||
| best_runtime_until_now = sum([min(data) for data in self.runtime_data_by_test_case.values()]) | ||
| if best_runtime_until_now > 0: | ||
| runtimes.append(best_runtime_until_now) | ||
|
|
||
| estimated_total_loops = 0 | ||
| if elapsed_ns > 0: | ||
| rate = count / elapsed_ns | ||
| total_time_ns = total_time * 1e9 | ||
| estimated_total_loops = int(rate * total_time_ns) | ||
|
|
||
| window_size = int(STABILITY_WINDOW_SIZE * estimated_total_loops + 0.5) | ||
| if should_stop(runtimes, window_size, session.config.option.codeflash_min_loops): | ||
| break | ||
|
|
||
| if self._timed_out(session, start_time, count): | ||
| break # exit loop | ||
| break | ||
|
|
||
| _ORIGINAL_TIME_SLEEP(self._get_delay_time(session)) | ||
| return True | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
related to this PR?