Skip to content
38 changes: 37 additions & 1 deletion codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@

from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.models.models import AIServiceCodeRepairRequest, AIServiceRefinerRequest
from codeflash.models.models import (
AIServiceAdaptiveOptimizeRequest,
AIServiceCodeRepairRequest,
AIServiceRefinerRequest,
)
from codeflash.result.explanation import Explanation


Expand Down Expand Up @@ -249,6 +253,38 @@ def optimize_python_code_line_profiler( # noqa: D417
console.rule()
return []

def adaptive_optimize(self, request: AIServiceAdaptiveOptimizeRequest) -> OptimizedCandidate | None:
try:
payload = {
"trace_id": request.trace_id,
"original_source_code": request.original_source_code,
"candidates": request.candidates,
}
response = self.make_ai_service_request("/adaptive_optimize", payload=payload, timeout=120)
except (requests.exceptions.RequestException, TypeError) as e:
logger.exception(f"Error generating adaptive optimized candidates: {e}")
ph("cli-optimize-error-caught", {"error": str(e)})
return None

if response.status_code == 200:
fixed_optimization = response.json()
console.rule()

valid_candidates = self._get_valid_candidates([fixed_optimization], OptimizedCandidateSource.ADAPTIVE)
if not valid_candidates:
logger.error("Adaptive optimization failed to generate a valid candidate.")
return None

return valid_candidates[0]

try:
error = response.json()["error"]
except Exception:
error = response.text
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
return None

def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]) -> list[OptimizedCandidate]:
"""Optimize the given python code for performance by making a request to the Django endpoint.

Expand Down
5 changes: 5 additions & 0 deletions codeflash/code_utils/config_consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
REPAIR_UNMATCHED_PERCENTAGE_LIMIT = 0.4 # if the percentage of unmatched tests is greater than this, we won't fix it (lowering this value makes the repair more stricted)
MAX_REPAIRS_PER_TRACE = 4 # maximum number of repairs we will do for each function

# Adaptive optimization
# TODO (ali): make this configurable with effort arg once the PR is merged
ADAPTIVE_OPTIMIZATION_THRESHOLD = 2 # Max adaptive optimizations per single candidate tree (for example : optimize -> refine -> adaptive -> another adaptive).
MAX_ADAPTIVE_OPTIMIZATIONS_PER_TRACE = 4 # maximum number of adaptive optimizations we will do for each function (this can be 2 adaptive optimizations for 2 candidates for example)

MAX_N_CANDIDATES = 5
MAX_N_CANDIDATES_LP = 6

Expand Down
22 changes: 22 additions & 0 deletions codeflash/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,24 @@ class AIServiceRefinerRequest:
call_sequence: int | None = None


# this should be possible to auto serialize
@dataclass(frozen=True)
class AdaptiveOptimizedCandidate:
optimization_id: str
source_code: str
# TODO: introduce repair explanation for code repair candidates to help the llm understand the full process
explanation: str
source: OptimizedCandidateSource
speedup: str


@dataclass(frozen=True)
class AIServiceAdaptiveOptimizeRequest:
trace_id: str
original_source_code: str
candidates: list[AdaptiveOptimizedCandidate]


class TestDiffScope(str, Enum):
RETURN_VALUE = "return_value"
STDOUT = "stdout"
Expand Down Expand Up @@ -442,6 +460,9 @@ def register_new_candidate(
"diff_len": diff_length(candidate.source_code.flat, code_context.read_writable_code.flat),
}

def get_speedup_ratio(self, optimization_id: str) -> float | None:
return self.speedup_ratios.get(optimization_id)


@dataclass(frozen=True)
class TestsInFile:
Expand All @@ -456,6 +477,7 @@ class OptimizedCandidateSource(str, Enum):
OPTIMIZE_LP = "OPTIMIZE_LP"
REFINE = "REFINE"
REPAIR = "REPAIR"
ADAPTIVE = "ADAPTIVE"


@dataclass(frozen=True)
Expand Down
Loading
Loading