diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index c18495899..f854d8a49 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -32,7 +32,11 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.ExperimentMetadata import ExperimentMetadata - from codeflash.models.models import AIServiceCodeRepairRequest, AIServiceRefinerRequest + from codeflash.models.models import ( + AIServiceAdaptiveOptimizeRequest, + AIServiceCodeRepairRequest, + AIServiceRefinerRequest, + ) from codeflash.result.explanation import Explanation @@ -249,6 +253,38 @@ def optimize_python_code_line_profiler( # noqa: D417 console.rule() return [] + def adaptive_optimize(self, request: AIServiceAdaptiveOptimizeRequest) -> OptimizedCandidate | None: + try: + payload = { + "trace_id": request.trace_id, + "original_source_code": request.original_source_code, + "candidates": request.candidates, + } + response = self.make_ai_service_request("/adaptive_optimize", payload=payload, timeout=120) + except (requests.exceptions.RequestException, TypeError) as e: + logger.exception(f"Error generating adaptive optimized candidates: {e}") + ph("cli-optimize-error-caught", {"error": str(e)}) + return None + + if response.status_code == 200: + fixed_optimization = response.json() + console.rule() + + valid_candidates = self._get_valid_candidates([fixed_optimization], OptimizedCandidateSource.ADAPTIVE) + if not valid_candidates: + logger.error("Adaptive optimization failed to generate a valid candidate.") + return None + + return valid_candidates[0] + + try: + error = response.json()["error"] + except Exception: + error = response.text + logger.error(f"Error generating optimized candidates: {response.status_code} - {error}") + ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) + return None + def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]: """Optimize the given python code for performance by making a request to the Django endpoint. diff --git a/codeflash/code_utils/config_consts.py b/codeflash/code_utils/config_consts.py index 6abefe3e7..ddb3890a7 100644 --- a/codeflash/code_utils/config_consts.py +++ b/codeflash/code_utils/config_consts.py @@ -35,6 +35,11 @@ REPAIR_UNMATCHED_PERCENTAGE_LIMIT = 0.4 # if the percentage of unmatched tests is greater than this, we won't fix it (lowering this value makes the repair more stricted) MAX_REPAIRS_PER_TRACE = 4 # maximum number of repairs we will do for each function +# Adaptive optimization +# TODO (ali): make this configurable with effort arg once the PR is merged +ADAPTIVE_OPTIMIZATION_THRESHOLD = 2 # Max adaptive optimizations per single candidate tree (for example : optimize -> refine -> adaptive -> another adaptive). +MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE = 4 # maximum number of adaptive optimizations we will do for each function (this can be 2 adaptive optimizations for 2 candidates for example) + MAX_N_CANDIDATES = 5 MAX_N_CANDIDATES_LP = 6 diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 822ecffab..1af946da4 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -49,6 +49,24 @@ class AIServiceRefinerRequest: call_sequence: int | None = None +# this should be possible to auto serialize +@dataclass(frozen=True) +class AdaptiveOptimizedCandidate: + optimization_id: str + source_code: str + # TODO: introduce repair explanation for code repair candidates to help the llm understand the full process + explanation: str + source: OptimizedCandidateSource + speedup: str + + +@dataclass(frozen=True) +class AIServiceAdaptiveOptimizeRequest: + trace_id: str + original_source_code: str + candidates: list[AdaptiveOptimizedCandidate] + + class TestDiffScope(str, Enum): RETURN_VALUE = "return_value" STDOUT = "stdout" @@ -442,6 +460,9 @@ def register_new_candidate( "diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat), } + def get_speedup_ratio(self, optimization_id: str) -> float | None: + return self.speedup_ratios.get(optimization_id) + @dataclass(frozen=True) class TestsInFile: @@ -456,6 +477,7 @@ class OptimizedCandidateSource(str, Enum): OPTIMIZE_LP = "OPTIMIZE_LP" REFINE = "REFINE" REPAIR = "REPAIR" + ADAPTIVE = "ADAPTIVE" @dataclass(frozen=True) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 96199be7e..547aefb39 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -9,7 +9,7 @@ import uuid from collections import defaultdict from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable import libcst as cst from rich.console import Group @@ -43,8 +43,10 @@ unified_diff_strings, ) from codeflash.code_utils.config_consts import ( + ADAPTIVE_OPTIMIZATION_THRESHOLD, COVERAGE_THRESHOLD, INDIVIDUAL_TESTCASE_TIMEOUT, + MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE, MAX_REPAIRS_PER_TRACE, N_TESTS_TO_GENERATE_EFFECTIVE, REFINE_ALL_THRESHOLD, @@ -74,6 +76,8 @@ from codeflash.lsp.lsp_message import LspCodeMessage, LspMarkdownMessage, LSPMessageId from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( + AdaptiveOptimizedCandidate, + AIServiceAdaptiveOptimizeRequest, AIServiceCodeRepairRequest, BestOptimization, CandidateEvaluationContext, @@ -126,19 +130,69 @@ from codeflash.verification.verification_utils import TestConfig +class CandidateNode: + __slots__ = ("candidate", "children", "parent") + + def __init__(self, candidate: OptimizedCandidate) -> None: + self.candidate = candidate + self.parent: CandidateNode | None = None + self.children: list[CandidateNode] = [] + + def is_leaf(self) -> bool: + return not self.children + + def path_to_root(self) -> list[OptimizedCandidate]: + path = [] + node: CandidateNode | None = self + while node: + path.append(node.candidate) + node = node.parent + return path[::-1] + + +class CandidateForest: + def __init__(self) -> None: + self.nodes: dict[str, CandidateNode] = {} + + def add(self, candidate: OptimizedCandidate) -> CandidateNode: + cid = candidate.optimization_id + pid = candidate.parent_id + + node = self.nodes.get(cid) + if node is None: + node = CandidateNode(candidate) + self.nodes[cid] = node + + if pid is not None: + parent = self.nodes.get(pid) + if parent is None: + parent = CandidateNode(candidate=None) # placeholder + self.nodes[pid] = parent + + node.parent = parent + parent.children.append(node) + + return node + + def get_node(self, cid: str) -> CandidateNode | None: + return self.nodes.get(cid) + + class CandidateProcessor: """Handles candidate processing using a queue-based approach.""" def __init__( self, - initial_candidates: list, + initial_candidates: list[OptimizedCandidate], future_line_profile_results: concurrent.futures.Future, all_refinements_data: list[AIServiceRefinerRequest], ai_service_client: AiServiceClient, executor: concurrent.futures.ThreadPoolExecutor, future_all_code_repair: list[concurrent.futures.Future], + future_adaptive_optimizations: list[concurrent.futures.Future], ) -> None: self.candidate_queue = queue.Queue() + self.forest = CandidateForest() self.line_profiler_done = False self.refinement_done = False self.candidate_len = len(initial_candidates) @@ -148,51 +202,90 @@ def __init__( # Initialize queue with initial candidates for candidate in initial_candidates: + self.forest.add(candidate) self.candidate_queue.put(candidate) self.future_line_profile_results = future_line_profile_results self.all_refinements_data = all_refinements_data self.future_all_code_repair = future_all_code_repair + self.future_adaptive_optimizations = future_adaptive_optimizations def get_total_llm_calls(self) -> int: return self.refinement_calls_count - def get_next_candidate(self) -> OptimizedCandidate | None: + def get_next_candidate(self) -> CandidateNode | None: """Get the next candidate from the queue, handling async results as needed.""" try: - return self.candidate_queue.get_nowait() + return self.forest.get_node(self.candidate_queue.get_nowait().optimization_id) except queue.Empty: return self._handle_empty_queue() - def _handle_empty_queue(self) -> OptimizedCandidate | None: + def _handle_empty_queue(self) -> CandidateNode | None: """Handle empty queue by checking for pending async results.""" if not self.line_profiler_done: - return self._process_line_profiler_results() + return self._process_candidates( + [self.future_line_profile_results], + "all candidates processed, await candidates from line profiler", + "Added results from line profiler to candidates, total candidates now: {1}", + lambda: setattr(self, "line_profiler_done", True), + ) if len(self.future_all_code_repair) > 0: - return self._process_code_repair() + return self._process_candidates( + self.future_all_code_repair, + "Repairing {0} candidates", + "Added {0} candidates from repair, total candidates now: {1}", + lambda: self.future_all_code_repair.clear(), + ) if self.line_profiler_done and not self.refinement_done: return self._process_refinement_results() + if len(self.future_adaptive_optimizations) > 0: + return self._process_candidates( + self.future_adaptive_optimizations, + "Applying adaptive optimizations to {0} candidates", + "Added {0} candidates from adaptive optimization, total candidates now: {1}", + lambda: self.future_adaptive_optimizations.clear(), + ) return None # All done - def _process_line_profiler_results(self) -> OptimizedCandidate | None: - """Process line profiler results and add to queue.""" - logger.debug("all candidates processed, await candidates from line profiler") - concurrent.futures.wait([self.future_line_profile_results]) - line_profile_results = self.future_line_profile_results.result() + def _process_candidates( + self, + future_candidates: list[concurrent.futures.Future], + loading_msg: str, + success_msg: str, + callback: Callable[[], None], + ) -> CandidateNode | None: + if len(future_candidates) == 0: + return None + with progress_bar( + loading_msg.format(len(future_candidates)), transient=True, revert_to_print=bool(get_pr_number()) + ): + concurrent.futures.wait(future_candidates) + candidates: list[OptimizedCandidate] = [] + for future_c in future_candidates: + candidate_result = future_c.result() + if not candidate_result: + continue - for candidate in line_profile_results: - self.candidate_queue.put(candidate) + if isinstance(candidate_result, list): + candidates.extend(candidate_result) + else: + candidates.append(candidate_result) - self.candidate_len += len(line_profile_results) - logger.info(f"Added results from line profiler to candidates, total candidates now: {self.candidate_len}") - self.line_profiler_done = True + for candidate in candidates: + self.forest.add(candidate) + self.candidate_queue.put(candidate) + self.candidate_len += 1 - return self.get_next_candidate() + if len(candidates) > 0: + logger.info(success_msg.format(len(candidates), self.candidate_len)) + + callback() + return self.get_next_candidate() def refine_optimizations(self, request: list[AIServiceRefinerRequest]) -> concurrent.futures.Future: return self.executor.submit(self.ai_service_client.optimize_python_code_refinement, request=request) - def _process_refinement_results(self) -> OptimizedCandidate | None: + def _process_refinement_results(self) -> CandidateNode | None: """Process refinement results and add to queue. We generate a weighted ranking based on the runtime and diff lines and select the best (round of 45%) of valid optimizations to be refined.""" future_refinements: list[concurrent.futures.Future] = [] refinement_call_index = 0 @@ -226,48 +319,12 @@ def _process_refinement_results(self) -> OptimizedCandidate | None: # Track total refinement calls made self.refinement_calls_count = refinement_call_index - if future_refinements: - logger.info("loading|Refining generated code for improved quality and performance...") - - concurrent.futures.wait(future_refinements) - refinement_response = [] - - for f in future_refinements: - possible_refinement = f.result() - if len(possible_refinement) > 0: - refinement_response.append(possible_refinement[0]) - - for candidate in refinement_response: - self.candidate_queue.put(candidate) - - self.candidate_len += len(refinement_response) - if len(refinement_response) > 0: - logger.info( - f"Added {len(refinement_response)} candidates from refinement, total candidates now: {self.candidate_len}" - ) - console.rule() - self.refinement_done = True - - return self.get_next_candidate() - - def _process_code_repair(self) -> OptimizedCandidate | None: - logger.info(f"loading|Repairing {len(self.future_all_code_repair)} candidates") - concurrent.futures.wait(self.future_all_code_repair) - candidates_added = 0 - for future_code_repair in self.future_all_code_repair: - possible_code_repair = future_code_repair.result() - if possible_code_repair: - self.candidate_queue.put(possible_code_repair) - self.candidate_len += 1 - candidates_added += 1 - - if candidates_added > 0: - logger.info( - f"Added {candidates_added} candidates from code repair, total candidates now: {self.candidate_len}" - ) - self.future_all_code_repair = [] - - return self.get_next_candidate() + return self._process_candidates( + future_refinements, + "Refining generated code for improved quality and performance...", + "Added {0} candidates from refinement, total candidates now: {1}", + lambda: setattr(self, "refinement_done", True), + ) def is_done(self) -> bool: """Check if processing is complete.""" @@ -275,6 +332,7 @@ def is_done(self) -> bool: self.line_profiler_done and self.refinement_done and len(self.future_all_code_repair) == 0 + and len(self.future_adaptive_optimizations) == 0 and self.candidate_queue.empty() ) @@ -327,7 +385,9 @@ def __init__( ) self.optimization_review = "" self.future_all_code_repair: list[concurrent.futures.Future] = [] + self.future_adaptive_optimizations: list[concurrent.futures.Future] = [] self.repair_counter = 0 # track how many repairs we did for each function + self.adaptive_optimization_counter = 0 # track how many adaptive optimizations we did for each function def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]: should_run_experiment = self.experiment_id is not None @@ -763,7 +823,7 @@ def log_evaluation_results( def process_single_candidate( self, - candidate: OptimizedCandidate, + candidate_node: CandidateNode, candidate_index: int, total_candidates: int, code_context: CodeOptimizationContext, @@ -785,6 +845,7 @@ def process_single_candidate( get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) logger.info(f"h3|Optimization candidate {candidate_index}/{total_candidates}:") + candidate = candidate_node.candidate code_print( candidate.source_code.flat, file_name=f"candidate_{candidate_index}.py", @@ -874,8 +935,22 @@ def process_single_candidate( ) eval_ctx.valid_optimizations.append(best_optimization) - # Queue refinement for non-refined candidates - if candidate.source != OptimizedCandidateSource.REFINE: + current_tree_candidates = candidate_node.path_to_root() + is_candidate_refined_before = any( + c.source == OptimizedCandidateSource.REFINE for c in current_tree_candidates + ) + + if is_candidate_refined_before: + future_adaptive_optimization = self.call_adaptive_optimize( + trace_id=self.get_trace_id(exp_type), + original_source_code=code_context.read_writable_code.markdown, + prev_candidates=current_tree_candidates, + eval_ctx=eval_ctx, + ai_service_client=self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client, + ) + if future_adaptive_optimization: + self.future_adaptive_optimizations.append(future_adaptive_optimization) + else: all_refinements_data.append( AIServiceRefinerRequest( optimization_id=best_optimization.candidate.optimization_id, @@ -925,7 +1000,9 @@ def determine_best_candidate( eval_ctx = CandidateEvaluationContext() all_refinements_data: list[AIServiceRefinerRequest] = [] self.future_all_code_repair.clear() + self.future_adaptive_optimizations.clear() self.repair_counter = 0 + self.adaptive_optimization_counter = 0 ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client assert ai_service_client is not None, "AI service client must be set for optimization" @@ -950,20 +1027,21 @@ def determine_best_candidate( self.aiservice_client, self.executor, self.future_all_code_repair, + self.future_adaptive_optimizations, ) candidate_index = 0 # Process candidates using queue-based approach while not processor.is_done(): - candidate = processor.get_next_candidate() - if candidate is None: + candidate_node = processor.get_next_candidate() + if candidate_node is None: logger.debug("everything done, exiting") break try: candidate_index += 1 self.process_single_candidate( - candidate=candidate, + candidate_node=candidate_node, candidate_index=candidate_index, total_candidates=processor.candidate_len, code_context=code_context, @@ -1004,6 +1082,47 @@ def determine_best_candidate( return best_optimization + def call_adaptive_optimize( + self, + trace_id: str, + original_source_code: str, + prev_candidates: list[OptimizedCandidate], + eval_ctx: CandidateEvaluationContext, + ai_service_client: AiServiceClient, + ) -> concurrent.futures.Future[OptimizedCandidate | None] | None: + if self.adaptive_optimization_counter >= MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE: + logger.debug( + f"Max adaptive optimizations reached for {self.function_to_optimize.qualified_name}: {self.adaptive_optimization_counter}" + ) + return None + + adaptive_count = sum(1 for c in prev_candidates if c.source == OptimizedCandidateSource.ADAPTIVE) + + if adaptive_count >= ADAPTIVE_OPTIMIZATION_THRESHOLD: + return None + + request_candidates = [] + + for c in prev_candidates: + speedup = eval_ctx.get_speedup_ratio(c.optimization_id) + request_candidates.append( + AdaptiveOptimizedCandidate( + optimization_id=c.optimization_id, + source_code=c.source_code.markdown, + explanation=c.explanation, + source=c.source, + speedup=f"Performance gain: {int(speedup * 100 + 0.5)}%" + if speedup + else "Candidate didn't match the behavior of the original code", + ) + ) + + request = AIServiceAdaptiveOptimizeRequest( + trace_id=trace_id, original_source_code=original_source_code, candidates=request_candidates + ) + self.adaptive_optimization_counter += 1 + return self.executor.submit(ai_service_client.adaptive_optimize, request=request) + def repair_optimization( self, original_source_code: str,