diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000..a91e47a
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,28 @@
+name: CI - pre-commit & run tests
+
+on:
+ pull_request:
+ branches: [ main ]
+
+jobs:
+ lint-and-test:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v3
+
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: '3.11.7' # or your preferred version
+
+ - name: Install dependencies
+ run: |
+ ./setup.sh
+
+ - name: Run pre-commit
+ run: uv run pre-commit run --all-files
+
+ - name: Run pytest
+ run: uv run pytest
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..ee3e3bf
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,42 @@
+repos:
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.5.0 # Use the latest stable version
+ hooks:
+ - id: trailing-whitespace
+ - id: end-of-file-fixer
+ - id: check-yaml
+ - id: check-symlinks
+ - id: check-added-large-files
+ - id: check-case-conflict
+ - id: check-json
+
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ # Ruff version.
+ rev: v0.11.2
+ hooks:
+ # Run the linter.
+ - id: ruff
+ args: [ --fix ]
+ # Run the formatter.
+ - id: ruff-format
+ # manual stages to auto-correct
+ - id: ruff
+ args: [ --fix ]
+ stages: [manual]
+ - id: ruff-format
+ stages: [manual]
+
+ - repo: https://github.com/Yelp/detect-secrets
+ rev: v1.5.0
+ hooks:
+ - id: detect-secrets
+ args: ['--baseline', '.secrets.baseline', 'audit']
+
+ - repo: local
+ hooks:
+ - id: pyright
+ name: pyright
+ entry: pyright
+ language: system
+ types: [python]
+ args: [--stats, -p, pyrightconfig.ci.json]
diff --git a/.secrets.baseline b/.secrets.baseline
new file mode 100644
index 0000000..73ee1a1
--- /dev/null
+++ b/.secrets.baseline
@@ -0,0 +1,137 @@
+{
+ "version": "1.5.0",
+ "plugins_used": [
+ {
+ "name": "ArtifactoryDetector"
+ },
+ {
+ "name": "AWSKeyDetector"
+ },
+ {
+ "name": "AzureStorageKeyDetector"
+ },
+ {
+ "name": "Base64HighEntropyString",
+ "limit": 4.5
+ },
+ {
+ "name": "BasicAuthDetector"
+ },
+ {
+ "name": "CloudantDetector"
+ },
+ {
+ "name": "DiscordBotTokenDetector"
+ },
+ {
+ "name": "GitHubTokenDetector"
+ },
+ {
+ "name": "GitLabTokenDetector"
+ },
+ {
+ "name": "HexHighEntropyString",
+ "limit": 3.0
+ },
+ {
+ "name": "IbmCloudIamDetector"
+ },
+ {
+ "name": "IbmCosHmacDetector"
+ },
+ {
+ "name": "IPPublicDetector"
+ },
+ {
+ "name": "JwtTokenDetector"
+ },
+ {
+ "name": "KeywordDetector",
+ "keyword_exclude": ""
+ },
+ {
+ "name": "MailchimpDetector"
+ },
+ {
+ "name": "NpmDetector"
+ },
+ {
+ "name": "OpenAIDetector"
+ },
+ {
+ "name": "PrivateKeyDetector"
+ },
+ {
+ "name": "PypiTokenDetector"
+ },
+ {
+ "name": "SendGridDetector"
+ },
+ {
+ "name": "SlackDetector"
+ },
+ {
+ "name": "SoftlayerDetector"
+ },
+ {
+ "name": "SquareOAuthDetector"
+ },
+ {
+ "name": "StripeDetector"
+ },
+ {
+ "name": "TelegramBotTokenDetector"
+ },
+ {
+ "name": "TwilioKeyDetector"
+ }
+ ],
+ "filters_used": [
+ {
+ "path": "detect_secrets.filters.allowlist.is_line_allowlisted"
+ },
+ {
+ "path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies",
+ "min_level": 2
+ },
+ {
+ "path": "detect_secrets.filters.heuristic.is_indirect_reference"
+ },
+ {
+ "path": "detect_secrets.filters.heuristic.is_likely_id_string"
+ },
+ {
+ "path": "detect_secrets.filters.heuristic.is_lock_file"
+ },
+ {
+ "path": "detect_secrets.filters.heuristic.is_not_alphanumeric_string"
+ },
+ {
+ "path": "detect_secrets.filters.heuristic.is_potential_uuid"
+ },
+ {
+ "path": "detect_secrets.filters.heuristic.is_prefixed_with_dollar_sign"
+ },
+ {
+ "path": "detect_secrets.filters.heuristic.is_sequential_string"
+ },
+ {
+ "path": "detect_secrets.filters.heuristic.is_swagger_file"
+ },
+ {
+ "path": "detect_secrets.filters.heuristic.is_templated_secret"
+ }
+ ],
+ "results": {
+ "setup.sh": [
+ {
+ "type": "Hex High Entropy String",
+ "filename": "setup.sh",
+ "hashed_secret": "7431e16af96558909e41438950a6ffe7ee811465",
+ "is_verified": false,
+ "line_number": 31
+ }
+ ]
+ },
+ "generated_at": "2025-03-31T20:57:53Z"
+}
diff --git a/LICENSE b/LICENSE
index 76caa86..677d5fd 100644
--- a/LICENSE
+++ b/LICENSE
@@ -30,4 +30,4 @@ licensed under the MIT License. The original code has been modified.
Original copyright:
© 2024 Anthropic, PBC
-Original license: https://github.com/modelcontextprotocol/servers/blob/main/LICENSE
\ No newline at end of file
+Original license: https://github.com/modelcontextprotocol/servers/blob/main/LICENSE
diff --git a/README.md b/README.md
index a240183..7f20051 100644
--- a/README.md
+++ b/README.md
@@ -120,7 +120,7 @@ You can increase `--num-examples` and `--num-candidate-solutions` to run on more
### Running on more examples.
-There are 500 examples total in SWE-bench Verified. Note that this can take awhile, so there are a few levels of parallelism this repository supports.
+There are 500 examples total in SWE-bench Verified. Note that this can take awhile, so there are a few levels of parallelism this repository supports.
- Firstly, we suggest running 8 processes. This is the `--num-processes` flag. Beyond this, Docker hits issues.
- Secondly, we support a notion of breaking up the dataset into shards. This is the `--shard-ct` and `--shard-id` flags. This makes it relatively easy to split up the work across multiple machines, which circumnvents the issues with scaling Docker byeond 8 processes.
@@ -166,7 +166,7 @@ python majority_vote_ensembler.py example_ensembler_data.jsonl --output_path exa
#### Input Format
-The input JSONL file should contain a list of problem objects, each with the following structure:
+The input JSONL file should contain a list of problem objects, each with the following structure. The `diffs` are the candidate solutions generated by the agent. The `eval_outcomes` are the results of running the eval harness on each candidate solution, where the index corresponds to the index in the `diffs` array.
```json
{
diff --git a/cli.py b/cli.py
index 147a977..36872bf 100644
--- a/cli.py
+++ b/cli.py
@@ -16,7 +16,6 @@
from rich.panel import Panel
from prompt_toolkit import prompt
from prompt_toolkit.history import InMemoryHistory
-from termcolor import colored
from tools.agent import Agent
from utils.workspace_manager import WorkspaceManager
@@ -26,6 +25,7 @@
MAX_OUTPUT_TOKENS_PER_TURN = 32768
MAX_TURNS = 200
+
def main():
"""Main entry point for the CLI."""
# Parse command-line arguments
@@ -84,7 +84,7 @@ def main():
if not args.minimize_stdout_logs:
logger_for_agent_logs.addHandler(logging.StreamHandler())
else:
- logger_for_agent_logs.propagate = False
+ logger_for_agent_logs.propagate = False
# Check if ANTHROPIC_API_KEY is set
if "ANTHROPIC_API_KEY" not in os.environ:
@@ -108,7 +108,9 @@ def main():
)
)
else:
- logger_for_agent_logs.info("Agent CLI started. Waiting for user input. Press Ctrl+C to exit. Type 'exit' or 'quit' to end the session.")
+ logger_for_agent_logs.info(
+ "Agent CLI started. Waiting for user input. Press Ctrl+C to exit. Type 'exit' or 'quit' to end the session."
+ )
# Initialize LLM client
client = get_client(
@@ -119,7 +121,9 @@ def main():
# Initialize workspace manager
workspace_path = Path(args.workspace).resolve()
- workspace_manager = WorkspaceManager(root=workspace_path, container_workspace=args.use_container_workspace)
+ workspace_manager = WorkspaceManager(
+ root=workspace_path, container_workspace=args.use_container_workspace
+ )
# Initialize agent
agent = Agent(
@@ -135,8 +139,12 @@ def main():
if args.problem_statement is not None:
instruction = INSTRUCTION_PROMPT.format(
- location=workspace_path if args.use_container_workspace is None else args.use_container_workspace,
- pr_description=args.problem_statement
+ location=(
+ workspace_path
+ if args.use_container_workspace is None
+ else args.use_container_workspace
+ ),
+ pr_description=args.problem_statement,
)
else:
instruction = None
@@ -157,7 +165,9 @@ def main():
break
else:
user_input = instruction
- logger_for_agent_logs.info(f"User instruction:\n{user_input}\n-------------")
+ logger_for_agent_logs.info(
+ f"User instruction:\n{user_input}\n-------------"
+ )
# Run the agent with the user input
logger_for_agent_logs.info("\nAgent is thinking...")
@@ -167,7 +177,7 @@ def main():
except Exception as e:
logger_for_agent_logs.info(f"Error: {str(e)}")
- logger_for_agent_logs.info("\n" + '-' * 40 + "\n")
+ logger_for_agent_logs.info("\n" + "-" * 40 + "\n")
if instruction is not None:
break
@@ -179,4 +189,4 @@ def main():
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/example_ensembler_results.json b/example_ensembler_results.json
index c866ce0..3ab0e37 100644
--- a/example_ensembler_results.json
+++ b/example_ensembler_results.json
@@ -23,4 +23,4 @@
"selected_diff": "@@ -45,3 +45,12 @@ def is_palindrome(text):\n cleaned_text = ''.join(c.lower() for c in text if c.isalnum())\n return cleaned_text == cleaned_text[::-1]\n \n+def is_valid_email(email):\n+ \"\"\"\n+ Check if a string is a valid email address.\n+ \"\"\"\n+ import re\n+ \n+ # Simple regex pattern for email validation\n+ pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$'\n+ return bool(re.match(pattern, email))\n",
"is_eval_success": true
}
-]
\ No newline at end of file
+]
diff --git a/majority_vote_ensembler.py b/majority_vote_ensembler.py
index 898e999..146c3dd 100755
--- a/majority_vote_ensembler.py
+++ b/majority_vote_ensembler.py
@@ -5,7 +5,7 @@
candidate diffs. It then uses the ensembler prompt to generate prompts and submits them
to the specified LLM (Claude or OpenAI) for results.
-To see example input, see `example_ensembler_data.jsonl`. To see example output,
+To see example input, see `example_ensembler_data.jsonl`. To see example output,
run `python majority_vote_ensembler.py example_ensembler_data.jsonl --output_path example_ensembler_results.json`.
"""
@@ -15,15 +15,16 @@
import os
import re
import sys
-from typing import Dict, List, Any, Optional, Tuple
+from typing import Dict, List, Any, Optional
from tqdm import tqdm
from prompts.ensembler_prompt import build_ensembler_prompt
-from utils.llm_client import get_client, TextPrompt, LLMMessages
+from utils.llm_client import get_client, TextPrompt
MAX_TOKENS = 16384
TEMPERATURE = 0.0
+
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Majority Vote Ensembler CLI Tool")
@@ -68,42 +69,46 @@ def extract_solution_index(response_text: str) -> Optional[int]:
return None
-def process_problem(problem: Dict[str, Any], problem_index: int, total_problems: int) -> Dict[str, Any]:
+def process_problem(
+ problem: Dict[str, Any], problem_index: int, total_problems: int
+) -> Dict[str, Any]:
"""Process a single problem using the LLM.
-
+
Args:
problem: The problem to process
problem_index: The index of the problem (for logging)
total_problems: The total number of problems (for logging)
-
+
Returns:
A dictionary containing the result for the problem
"""
# Create a client for this thread
client = get_client("openai-direct", model_name="o1-2024-12-17", cot_model=True)
-
- print(f"Processing problem {problem_index+1}/{total_problems}: {problem.get('id', f'Problem {problem_index+1}')}")
-
+
+ print(
+ f"Processing problem {problem_index + 1}/{total_problems}: {problem.get('id', f'Problem {problem_index + 1}')}"
+ )
+
instruction = problem.get("instruction", "")
diffs = problem.get("diffs", [])
eval_outcomes = problem.get("eval_outcomes", {})
-
+
if not diffs:
- print(f" Warning: No diffs found for problem {problem_index+1}, skipping")
+ print(f" Warning: No diffs found for problem {problem_index + 1}, skipping")
return {
- "id": problem.get("id", f"Problem {problem_index+1}"),
+ "id": problem.get("id", f"Problem {problem_index + 1}"),
"instruction": instruction,
"error": "No diffs provided",
"selected_diff_index": None,
- "selected_diff": None
+ "selected_diff": None,
}
-
+
# Build the ensembler prompt
prompt = build_ensembler_prompt(instruction, diffs)
-
+
# Prepare the message for the LLM
messages = [[TextPrompt(text=prompt)]]
-
+
# Submit to the LLM
try:
response, metadata = client.generate(
@@ -111,103 +116,115 @@ def process_problem(problem: Dict[str, Any], problem_index: int, total_problems:
max_tokens=MAX_TOKENS,
temperature=TEMPERATURE,
)
-
+
# Extract the response text
- response_text = response[0].text if hasattr(response[0], "text") else str(response[0])
-
+ response_text = (
+ response[0].text if hasattr(response[0], "text") else str(response[0]) # pyright: ignore[reportAttributeAccessIssue]
+ )
+
# Extract the solution index
solution_index = extract_solution_index(response_text)
-
+
if solution_index is not None and 0 <= solution_index < len(diffs):
selected_diff = diffs[solution_index] # Convert to 0-indexed
else:
selected_diff = None
- print(f" Warning: Invalid solution index {solution_index} for problem {problem_index+1}")
-
+ print(
+ f" Warning: Invalid solution index {solution_index} for problem {problem_index + 1}"
+ )
+
result = {
- "id": problem.get("id", f"Problem {problem_index+1}"),
+ "id": problem.get("id", f"Problem {problem_index + 1}"),
"instruction": instruction,
"response": response_text,
"selected_diff_index": solution_index,
"selected_diff": selected_diff,
- "is_eval_success": eval_outcomes[solution_index]['is_success']
+ "is_eval_success": eval_outcomes[solution_index]["is_success"],
}
-
+
print(f" Selected solution index: {solution_index}")
return result
-
+
except Exception as e:
- print(f" Error processing problem {problem_index+1}: {e}")
+ print(f" Error processing problem {problem_index + 1}: {e}")
return {
- "id": problem.get("id", f"Problem {problem_index+1}"),
+ "id": problem.get("id", f"Problem {problem_index + 1}"),
"instruction": instruction,
"error": str(e),
"selected_diff_index": None,
"selected_diff": None,
- "is_eval_success": False
+ "is_eval_success": False,
}
def ensemble_problems(
- problems: List[Dict[str, Any]],
- num_workers: int = 8
+ problems: List[Dict[str, Any]], num_workers: int = 8
) -> List[Dict[str, Any]]:
"""Ensemble problems using a thread pool for parallel processing.
-
+
Args:
problems: List of problems to process
num_workers: Number of worker threads to use
-
+
Returns:
List of results for each problem
"""
# Adjust number of workers based on the number of problems
effective_workers = min(num_workers, len(problems))
-
- print(f"Processing {len(problems)} problems using {effective_workers} worker threads")
+
+ print(
+ f"Processing {len(problems)} problems using {effective_workers} worker threads"
+ )
# Create a thread pool and process problems in parallel
- with concurrent.futures.ThreadPoolExecutor(max_workers=effective_workers) as executor:
+ with concurrent.futures.ThreadPoolExecutor(
+ max_workers=effective_workers
+ ) as executor:
# Create a list of (problem, index, total) tuples to pass to the worker function
- problem_data = [(problem, i, len(problems)) for i, problem in enumerate(problems)]
-
+ problem_data = [
+ (problem, i, len(problems)) for i, problem in enumerate(problems)
+ ]
+
# Map the worker function over the problems with tqdm progress bar
# This preserves the order of the results
- results = list(tqdm(executor.map(
- lambda x: process_problem(*x),
- problem_data
- ), total=len(problems), desc="Processing problems"))
-
+ results = list(
+ tqdm(
+ executor.map(lambda x: process_problem(*x), problem_data),
+ total=len(problems),
+ desc="Processing problems",
+ )
+ )
+
return results
def main():
"""Main function."""
args = parse_args()
-
+
if not os.environ.get("OPENAI_API_KEY"):
print("Error: OPENAI_API_KEY environment variable is not set")
sys.exit(1)
-
+
# Load problems from JSON file
problems = load_problems(args.input_jsonl_path)
-
+
# Determine output path
output_path = args.output_path or "ensembler_results.json"
-
+
# Ensemble problems using thread pool
results = ensemble_problems(problems, num_workers=args.workers)
# get success rate
success_rate = sum([result["is_eval_success"] for result in results]) / len(results)
print(f"Success rate: {success_rate:.2f}")
-
+
# Save results to output file in JSONL format
with open(output_path, "w") as f:
json.dump(results, f, indent=2)
-
+
print(f"Results saved to {output_path}")
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/merge_shards.py b/merge_shards.py
index e512905..f6d4d80 100755
--- a/merge_shards.py
+++ b/merge_shards.py
@@ -15,7 +15,7 @@
def merge_jsonl_files(input_files, output_file):
"""
Merge multiple JSONL files into a single JSONL file.
-
+
Args:
input_files (list): List of paths to input JSONL files
output_file (str): Path to the output JSONL file
@@ -23,43 +23,54 @@ def merge_jsonl_files(input_files, output_file):
# Create output directory if it doesn't exist
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
-
+
# Count total lines for reporting
total_lines = 0
-
- with open(output_file, 'w') as outfile:
+
+ with open(output_file, "w") as outfile:
for input_file in input_files:
try:
- with open(input_file, 'r') as infile:
+ with open(input_file, "r") as infile:
for line in infile:
line = line.strip()
if not line: # Skip empty lines
continue
-
+
# Verify it's valid JSON before writing
try:
json.loads(line)
- outfile.write(line + '\n') # Ensure each JSON object is on its own line
+ outfile.write(
+ line + "\n"
+ ) # Ensure each JSON object is on its own line
total_lines += 1
except json.JSONDecodeError:
- print(f"Warning: Skipping invalid JSON line in {input_file}", file=sys.stderr)
+ print(
+ f"Warning: Skipping invalid JSON line in {input_file}",
+ file=sys.stderr,
+ )
print(f"Processed: {input_file}")
except FileNotFoundError:
print(f"Error: File not found: {input_file}", file=sys.stderr)
continue
-
- print(f"Merged {len(input_files)} files with {total_lines} total records into {output_file}")
+
+ print(
+ f"Merged {len(input_files)} files with {total_lines} total records into {output_file}"
+ )
def main():
- parser = argparse.ArgumentParser(description="Merge multiple JSONL files into a single JSONL file")
- parser.add_argument("--input", "-i", nargs="+", required=True, help="Input JSONL files to merge")
+ parser = argparse.ArgumentParser(
+ description="Merge multiple JSONL files into a single JSONL file"
+ )
+ parser.add_argument(
+ "--input", "-i", nargs="+", required=True, help="Input JSONL files to merge"
+ )
parser.add_argument("--output", "-o", required=True, help="Output JSONL file path")
-
+
args = parser.parse_args()
-
+
merge_jsonl_files(args.input, args.output)
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/prompts/ensembler_prompt.py b/prompts/ensembler_prompt.py
index 843f6f0..53b4835 100644
--- a/prompts/ensembler_prompt.py
+++ b/prompts/ensembler_prompt.py
@@ -1,8 +1,7 @@
"""Majority Vote Ensembler Prompt"""
-def build_ensembler_prompt(
- instruction: str, diffs: list[str]
-) -> str:
+
+def build_ensembler_prompt(instruction: str, diffs: list[str]) -> str:
prompt = f"""\
I am a software engineer. I am working on a task in my codebase. Here is the task:
@@ -19,9 +18,9 @@ def build_ensembler_prompt(
for i, diff in enumerate(diffs):
prompt += f"""\
-
+
{diff}
-
+
"""
prompt += """\
@@ -33,4 +32,4 @@ def build_ensembler_prompt(
"""
- return prompt
\ No newline at end of file
+ return prompt
diff --git a/prompts/system_prompt.py b/prompts/system_prompt.py
index c4da476..cbcba49 100644
--- a/prompts/system_prompt.py
+++ b/prompts/system_prompt.py
@@ -15,4 +15,4 @@
- You should run relevant tests to verify that your changes work.
Make sure to call the complete tool when you are done with the task, or when you have an answer to the question.
-"""
\ No newline at end of file
+"""
diff --git a/pyproject.toml b/pyproject.toml
index 21e1cae..08f2f0b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -18,8 +18,15 @@ dependencies = [
"numpy>=2.2.4",
"openai==1.59.*",
"pexpect>=4.9.0",
+ "pre-commit>=4.2.0",
"prompt-toolkit>=3.0.50",
+ "pyright>=1.1.398",
"pytest==7.4.3",
"rich>=13.9.4",
"termcolor>=2.5.0",
]
+
+[dependency-groups]
+dev = [
+ "ruff>=0.11.2",
+]
diff --git a/pyrightconfig.ci.json b/pyrightconfig.ci.json
new file mode 100644
index 0000000..6a13fd6
--- /dev/null
+++ b/pyrightconfig.ci.json
@@ -0,0 +1,17 @@
+{
+ "include": [
+ "."
+ ],
+ "strict": [],
+ "ignore": [],
+ "executionEnvironments": [
+ {
+ "root": "."
+ }
+ ],
+ "typeCheckingMode": "standard",
+ "reportIncompatibleVariableOverride": "none",
+ "reportIncompatibleMethodOverride": "none",
+ "reportFunctionMemberAccess": "none",
+ "pythonVersion": "3.11"
+}
diff --git a/run_agent_on_swebench_problem.py b/run_agent_on_swebench_problem.py
index 9cf2023..f622632 100755
--- a/run_agent_on_swebench_problem.py
+++ b/run_agent_on_swebench_problem.py
@@ -14,7 +14,6 @@
import json
import argparse
from pathlib import Path
-from typing import Dict, Optional, Tuple, Any, List
from multiprocessing import Pool, Manager
import time
import numpy as np
@@ -27,20 +26,16 @@
from utils.common import generate_patch
from cli import main as cli_main
import uuid
-from prompts.instruction import INSTRUCTION_PROMPT
from utils.swebench_eval_utils import get_dataset_name, run_evaluation
-def run_eval_on_single_problem(
- problem_id: str,
- workspace_path: Path,
- console: Console
-):
+
+def run_eval_on_single_problem(problem_id: str, workspace_path: Path, console: Console):
eval_file = None
eval_outcomes = {
"is_success": False,
}
-
+
try:
run_evaluation(
predictions_file=workspace_path / "predictions.json",
@@ -48,8 +43,10 @@ def run_eval_on_single_problem(
"full"
), # Always use the full dataset for evaluation.
run_id=problem_id,
- swebench_venv_path=Path(f"{os.environ['HOME']}/swebench_eval_tools_env/bin/python"),
- console=console
+ swebench_venv_path=Path(
+ f"{os.environ['HOME']}/swebench_eval_tools_env/bin/python"
+ ),
+ console=console,
)
eval_file = workspace_path / f"augment-agent.{problem_id}.json"
eval_dict = json.loads(eval_file.read_text())
@@ -60,23 +57,24 @@ def run_eval_on_single_problem(
console.print(exc)
return eval_outcomes
+
def run_agent_on_single_problem(
problem_id: str,
problem_statement: str,
rollout_idx: int,
workspace_base_path: Path,
lock: threading.Lock,
- semaphore: threading.Semaphore
+ semaphore: threading.Semaphore,
) -> tuple[str, float, dict]:
"""
Run the agent on a single SWE-bench problem.
-
+
Args:
problem_id: The ID of the problem
problem_statement: The problem statement
lock: Threading lock for Docker operations
semaphore: Threading semaphore for Docker operations
-
+
Returns:
dict: The diff data generated by the agent
float: The time taken to generate the diff
@@ -87,13 +85,13 @@ def run_agent_on_single_problem(
workspace_path = workspace_base_path / problem_id / f"rollout_{rollout_idx}"
output_file = workspace_path / "agent_logs.txt"
-
+
# Ensure workspace directory exists
workspace_path.mkdir(parents=True, exist_ok=True)
-
+
# Start the Docker container
container_id = None
-
+
try:
env, container_id = setup_workspace(workspace_path, problem_id, lock, semaphore)
console.print(f"{logs_prefix} Docker container started with ID: {container_id}")
@@ -108,89 +106,101 @@ def run_agent_on_single_problem(
# Create new sys.argv for cli.py
cli_args = [
"cli.py",
- "--workspace", str(workspace_path / problem_id),
- "--problem-statement", problem_statement,
- "--docker-container-id", container_id,
- "--use-container-workspace", "/testbed",
- "--minimize-stdout-logs"
+ "--workspace",
+ str(workspace_path / problem_id),
+ "--problem-statement",
+ problem_statement,
+ "--docker-container-id",
+ container_id,
+ "--use-container-workspace",
+ "/testbed",
+ "--minimize-stdout-logs",
]
-
+
# Set logs path if output_file is specified
if output_file:
cli_args.extend(["--logs-path", str(output_file)])
-
+
# Replace sys.argv with our custom arguments
sys.argv = cli_args
-
+
# Run the agent via cli.py
console.print(f"{logs_prefix} Starting agent run...")
start_time = time.time()
cli_main()
agent_duration = time.time() - start_time
console.print(f"{logs_prefix} Agent run completed in {agent_duration:.2f}s.")
-
+
# Restore original sys.argv
sys.argv = original_argv
# Generate patch after the agent has completed its work
repo_path = str(workspace_path / problem_id)
- diff = generate_patch(repo_path)
+ diff = generate_patch(repo_path)
with (workspace_path / "predictions.json").open("w") as f:
- json.dump([{"instance_id": problem_id, "model_name_or_path": "augment-agent", "model_patch": diff}], f, indent=2)
+ json.dump(
+ [
+ {
+ "instance_id": problem_id,
+ "model_name_or_path": "augment-agent",
+ "model_patch": diff,
+ }
+ ],
+ f,
+ indent=2,
+ )
finally:
# Stop and clean up the Docker container
if container_id is not None:
console.print(f"{logs_prefix} Stopping Docker container...")
stop_container(container_id)
console.print(f"{logs_prefix} Docker container stopped")
-
+
# Evaluate the generated diff
console.print(f"{logs_prefix} Evaluating the generated diff...")
start_time = time.time()
eval_outcomes = run_eval_on_single_problem(problem_id, workspace_path, console)
eval_duration = time.time() - start_time
console.print(f"{logs_prefix} Evaluation completed in {eval_duration:.2f}s.")
-
+
assert diff is not None
return diff, agent_duration, eval_outcomes
+
def main():
"""Main entry point for the script."""
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Run the agent on SWE-bench problems")
parser.add_argument(
- "--num-examples",
- type=int,
+ "--num-examples",
+ type=int,
default=None,
- help="Optionally, specify the number of examples to run on"
+ help="Optionally, specify the number of examples to run on",
)
parser.add_argument(
- "--shard-ct",
- type=int,
- default=1,
- help="Number of shards to split the work into"
+ "--shard-ct",
+ type=int,
+ default=1,
+ help="Number of shards to split the work into",
)
parser.add_argument(
- "--shard-id",
- type=int,
- default=0,
- help="Shard ID to run (0-indexed)"
+ "--shard-id", type=int, default=0, help="Shard ID to run (0-indexed)"
)
parser.add_argument(
- "--num-processes",
- type=int,
- default=8,
- help="Number of processes to use for each example"
+ "--num-processes",
+ type=int,
+ default=8,
+ help="Number of processes to use for each example",
)
parser.add_argument(
- "--num-candidate-solutions",
- type=int,
- default=8,
- help="Number of candidate solutions to generate for each example"
+ "--num-candidate-solutions",
+ type=int,
+ default=8,
+ help="Number of candidate solutions to generate for each example",
)
args = parser.parse_args()
-
+
# Set up logging
logging.basicConfig(
level=logging.INFO,
@@ -208,20 +218,34 @@ def main():
# Load the SWE-bench dataset
console.print("Loading SWE-bench dataset...")
- swebench_dataset = load_dataset("princeton-nlp/SWE-bench_Verified")["test"].to_pandas() # type: ignore
+ swebench_dataset = load_dataset("princeton-nlp/SWE-bench_Verified")[ # pyright: ignore[reportIndexIssue]
+ "test"
+ ].to_pandas() # pyright: ignore
# Sharding
- num_examples_per_shard = len(swebench_dataset) // args.shard_ct
- examples = swebench_dataset.iloc[args.shard_id * num_examples_per_shard : (args.shard_id + 1) * num_examples_per_shard]
-
+ num_examples_per_shard = len(swebench_dataset) // args.shard_ct # pyright: ignore[reportArgumentType]
+ examples = swebench_dataset.iloc[ # pyright: ignore[reportAttributeAccessIssue]
+ args.shard_id * num_examples_per_shard : (args.shard_id + 1)
+ * num_examples_per_shard
+ ]
+
# Get the number of examples to run
- assert args.num_examples is None or args.num_examples <= len(examples), f"num_examples ({args.num_examples}) is greater than the number of examples in the shard ({len(examples)}). Either decrease num_examples or decrease the number of shards."
+ assert args.num_examples is None or args.num_examples <= len(examples), (
+ f"num_examples ({args.num_examples}) is greater than the number of examples in the shard ({len(examples)}). Either decrease num_examples or decrease the number of shards."
+ )
num_examples = args.num_examples if args.num_examples is not None else len(examples)
- console.print(f"Running on {num_examples} examples from shard {args.shard_id} out of {args.shard_ct} shards.")
- console.print(f"We will generate {args.num_candidate_solutions} candidate solutions for each example with parallelism of {args.num_processes}.")
-
+ console.print(
+ f"Running on {num_examples} examples from shard {args.shard_id} out of {args.shard_ct} shards."
+ )
+ console.print(
+ f"We will generate {args.num_candidate_solutions} candidate solutions for each example with parallelism of {args.num_processes}."
+ )
+
# print out all example ids we'll be processing
- console.print("Selected examples:", '\n - ' + '\n - '.join(examples.iloc[:num_examples]["instance_id"].tolist()))
+ console.print(
+ "Selected examples:",
+ "\n - " + "\n - ".join(examples.iloc[:num_examples]["instance_id"].tolist()),
+ )
# List to store all diff data
all_diff_data = []
@@ -239,8 +263,8 @@ def main():
problem_id = problem["instance_id"]
problem_statement = problem["problem_statement"]
- console.print(f"\nProcessing example {i+1}/{num_examples}")
-
+ console.print(f"\nProcessing example {i + 1}/{num_examples}")
+
# Run the agent on the selected problem
with Manager() as manager:
lock = manager.Lock()
@@ -251,9 +275,12 @@ def main():
run_agent_on_single_problem,
lock=lock,
semaphore=semaphore,
- workspace_base_path=workspace_base_path
+ workspace_base_path=workspace_base_path,
),
- [(problem_id, problem_statement, rollout_idx) for rollout_idx in range(args.num_candidate_solutions)]
+ [
+ (problem_id, problem_statement, rollout_idx)
+ for rollout_idx in range(args.num_candidate_solutions)
+ ],
)
diffs, agent_durations, eval_outcomes = zip(*diffs)
median_duration = np.median(agent_durations)
@@ -263,18 +290,18 @@ def main():
"diffs": diffs,
"agent_durations": agent_durations,
"median_duration": median_duration,
- "eval_outcomes": eval_outcomes
+ "eval_outcomes": eval_outcomes,
}
all_diff_data.append(diff_data)
-
+
# Save the results after each example in case of failures
with open(output_path, "w") as f:
for diff_data in all_diff_data:
f.write(json.dumps(diff_data) + "\n")
-
- console.print(f"Completed example {i+1}/{num_examples}")
+
+ console.print(f"Completed example {i + 1}/{num_examples}")
except Exception as e:
- console.print(f"Error processing example {i+1}: {str(e)}")
+ console.print(f"Error processing example {i + 1}: {str(e)}")
continue
all_durations = [d["median_duration"] for d in all_diff_data]
@@ -299,7 +326,7 @@ def main():
- augment-agent..json: The eval results
- logs/run_evaluation//augment-agent//test_output.txt: The raw output from running tests during eval step
- logs/run_evaluation//augment-agent//report.json: Structured enumeration of what tests failed vs passed duringe eval step.
- - FAIL_TO_PASS tests are testing new functionality.
+ - FAIL_TO_PASS tests are testing new functionality.
- PASS_TO_PASS tests are testing existing functionality to make sure the diff didn't break any existing features.
The user has two next steps:
@@ -312,7 +339,7 @@ def main():
(Path for this shard's output: {output_path})
Run `python merge_shards.py --input --output pre-ensemble_result_all_shards.jsonl`
-
+
Next step for user is to run the ensembling step with command. Make sure to set OPENAI_API_KEY environment variable before running the command!
--------
python majority_vote_ensembler.py pre-ensemble_result_all_shards.jsonl --output_path ensembler_results_all_shards.json
@@ -324,5 +351,6 @@ def main():
)
console.print(ensemble_instruction)
+
if __name__ == "__main__":
main()
diff --git a/setup.sh b/setup.sh
index 7a71d71..7ef1745 100755
--- a/setup.sh
+++ b/setup.sh
@@ -42,6 +42,10 @@ echo "Done installing SWE-bench evaluation tools, including a separate virtual e
# navigate to cwd
cd ${current_dir}
+# Install pre-commit hooks
+echo "Installing pre-commit hooks..."
+uv run pre-commit install-hooks
+
echo "Setup complete! Activate the virtual environment for augment-swebench-agent with:"
echo "source .venv/bin/activate"
diff --git a/swebench_patch.diff b/swebench_patch.diff
index 4a5c10d..3734298 100644
--- a/swebench_patch.diff
+++ b/swebench_patch.diff
@@ -10,15 +10,15 @@ index ccb9f8d..de906e9 100644
+import uuid
from dataclasses import dataclass
from typing import Any, Union, cast
-
+
@@ -106,8 +106,8 @@ class TestSpec:
-
+
def get_instance_container_name(self, run_id=None):
if not run_id:
- return f"sweb.eval.{self.instance_id}"
- return f"sweb.eval.{self.instance_id.lower()}.{run_id}"
+ return f"sweb.eval.{self.instance_id}.{uuid.uuid4().hex[:8]}"
+ return f"sweb.eval.{self.instance_id.lower()}.{run_id}.{uuid.uuid4().hex[:8]}"
-
+
@property
def base_dockerfile(self):
diff --git a/tools/agent.py b/tools/agent.py
index 00e5339..d28b790 100644
--- a/tools/agent.py
+++ b/tools/agent.py
@@ -1,5 +1,3 @@
-
-from pathlib import Path
from copy import deepcopy
from typing import Any, Optional
from tools.bash_tool import create_bash_tool, create_docker_bash_tool
@@ -7,7 +5,6 @@
DialogMessages,
LLMTool,
ToolImplOutput,
- ToolCallParameters,
)
from utils.llm_client import LLMClient, TextResult
from utils.workspace_manager import WorkspaceManager
@@ -20,7 +17,6 @@
import logging
-
class Agent(LLMTool):
name = "general_agent"
description = """\
@@ -79,7 +75,10 @@ def __init__(
self.max_turns = max_turns
self.workspace_manager = workspace_manager
self.interrupted = False
- self.dialog = DialogMessages(logger_for_agent_logs=logger_for_agent_logs,use_prompt_budgeting=use_prompt_budgeting)
+ self.dialog = DialogMessages(
+ logger_for_agent_logs=logger_for_agent_logs,
+ use_prompt_budgeting=use_prompt_budgeting,
+ )
# Create and store the complete tool
self.complete_tool = CompleteTool()
@@ -119,7 +118,7 @@ def run_impl(
user_input_delimiter = "-" * 45 + " USER INPUT " + "-" * 45 + "\n" + instruction
self.logger_for_agent_logs.info(f"\n{user_input_delimiter}\n")
-
+
# print("Agent starting with instruction:", instruction)
# Add instruction to dialog before getting mode
@@ -169,14 +168,15 @@ def run_impl(
tool_result_message="Task completed",
)
-
if len(pending_tool_calls) > 1:
raise ValueError("Only one tool call per turn is supported")
assert len(pending_tool_calls) == 1
tool_call = pending_tool_calls[0]
- text_results = [item for item in model_response if isinstance(item, TextResult)]
+ text_results = [
+ item for item in model_response if isinstance(item, TextResult)
+ ]
if len(text_results) > 0:
text_result = text_results[0]
self.logger_for_agent_logs.info(
@@ -193,7 +193,9 @@ def run_impl(
try:
result = tool.run(tool_call.tool_input, deepcopy(self.dialog))
- tool_input_str = '\n'.join([f' - {k}: {v}' for k, v in tool_call.tool_input.items()])
+ tool_input_str = "\n".join(
+ [f" - {k}: {v}" for k, v in tool_call.tool_input.items()]
+ )
log_message = f"Calling tool {tool_call.tool_name} with input:\n{tool_input_str}"
log_message += f"\nTool output: \n{result}\n\n"
self.logger_for_agent_logs.info(log_message)
@@ -279,7 +281,6 @@ def run_agent(
self.dialog.clear()
self.interrupted = False
-
tool_input = {
"instruction": instruction,
}
diff --git a/tools/bash_tool.py b/tools/bash_tool.py
index cb95fb1..c7ec2b5 100644
--- a/tools/bash_tool.py
+++ b/tools/bash_tool.py
@@ -8,11 +8,10 @@
It also supports command filters for transforming commands before execution.
"""
-import subprocess
from pathlib import Path
-from typing import Any, Dict, List, Optional, Protocol, Callable
+from typing import Any, Dict, List, Optional
-from utils.common import (
+from utils.common import (
DialogMessages,
LLMTool,
ToolImplOutput,
diff --git a/tools/complete_tool.py b/tools/complete_tool.py
index 713821a..1d396de 100644
--- a/tools/complete_tool.py
+++ b/tools/complete_tool.py
@@ -1,10 +1,13 @@
"""Tool for indicating task completion."""
+
from typing import Any, Optional
from utils.common import (
DialogMessages,
LLMTool,
ToolImplOutput,
)
+
+
class CompleteTool(LLMTool):
name = "complete"
"""The model should call this tool when it is done with the task."""
@@ -42,4 +45,4 @@ def run_impl(
return ToolImplOutput("Task completed", "Task completed")
def get_tool_start_message(self, tool_input: dict[str, Any]) -> str:
- return ""
\ No newline at end of file
+ return ""
diff --git a/tools/sequential_thinking_tool.py b/tools/sequential_thinking_tool.py
index 2b2aaa5..505361d 100644
--- a/tools/sequential_thinking_tool.py
+++ b/tools/sequential_thinking_tool.py
@@ -10,7 +10,6 @@
import json
import logging
-import re
from typing import Any, Dict, List, Optional, TypedDict
from utils.common import (
@@ -197,9 +196,9 @@ def _format_thought(self, thought_data: ThoughtData) -> str:
Returns:
Formatted thought string
"""
- thought_number = thought_data["thoughtNumber"]
- total_thoughts = thought_data["totalThoughts"]
- thought = thought_data["thought"]
+ thought_number = thought_data["thoughtNumber"] # pyright: ignore[reportTypedDictNotRequiredAccess]
+ total_thoughts = thought_data["totalThoughts"] # pyright: ignore[reportTypedDictNotRequiredAccess]
+ thought = thought_data["thought"] # pyright: ignore[reportTypedDictNotRequiredAccess]
is_revision = thought_data.get("isRevision", False)
revises_thought = thought_data.get("revisesThought")
branch_from_thought = thought_data.get("branchFromThought")
@@ -247,8 +246,8 @@ def run_impl(
validated_input = self._validate_thought_data(tool_input)
# Adjust total thoughts if needed
- if validated_input["thoughtNumber"] > validated_input["totalThoughts"]:
- validated_input["totalThoughts"] = validated_input["thoughtNumber"]
+ if validated_input["thoughtNumber"] > validated_input["totalThoughts"]: # pyright: ignore[reportTypedDictNotRequiredAccess]
+ validated_input["totalThoughts"] = validated_input["thoughtNumber"] # pyright: ignore[reportTypedDictNotRequiredAccess]
# Add to thought history
self.thought_history.append(validated_input)
@@ -257,11 +256,10 @@ def run_impl(
if validated_input.get("branchFromThought") and validated_input.get(
"branchId"
):
- branch_id = validated_input["branchId"]
+ branch_id = validated_input["branchId"] # pyright: ignore[reportTypedDictNotRequiredAccess]
if branch_id not in self.branches:
- self.branches[branch_id] = []
- self.branches[branch_id].append(validated_input)
-
+ self.branches[branch_id] = [] # pyright: ignore[reportArgumentType]
+ self.branches[branch_id].append(validated_input) # pyright: ignore[reportArgumentType]
# Format and log the thought
formatted_thought = self._format_thought(validated_input)
@@ -270,16 +268,16 @@ def run_impl(
# Prepare response
response = {
- "thoughtNumber": validated_input["thoughtNumber"],
- "totalThoughts": validated_input["totalThoughts"],
- "nextThoughtNeeded": validated_input["nextThoughtNeeded"],
+ "thoughtNumber": validated_input["thoughtNumber"], # pyright: ignore[reportTypedDictNotRequiredAccess]
+ "totalThoughts": validated_input["totalThoughts"], # pyright: ignore[reportTypedDictNotRequiredAccess]
+ "nextThoughtNeeded": validated_input["nextThoughtNeeded"], # pyright: ignore[reportTypedDictNotRequiredAccess]
"branches": list(self.branches.keys()),
"thoughtHistoryLength": len(self.thought_history),
}
return ToolImplOutput(
tool_output=json.dumps(response, indent=2),
- tool_result_message=f"Processed thought {validated_input['thoughtNumber']}/{validated_input['totalThoughts']}",
+ tool_result_message=f"Processed thought {validated_input['thoughtNumber']}/{validated_input['totalThoughts']}", # pyright: ignore[reportTypedDictNotRequiredAccess]
auxiliary_data={"thought_data": validated_input},
)
except Exception as e:
diff --git a/tools/str_replace_tool.py b/tools/str_replace_tool.py
index c966b45..7026108 100644
--- a/tools/str_replace_tool.py
+++ b/tools/str_replace_tool.py
@@ -32,6 +32,7 @@
"undo_edit",
]
+
def is_path_in_directory(directory: Path, path: Path) -> bool:
directory = directory.resolve()
path = path.resolve()
@@ -269,12 +270,12 @@ def run_impl(
elif command == "undo_edit":
return self.undo_edit(_ws_path)
raise ToolError(
- f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}'
+ f"Unrecognized command {command}. The allowed commands for the {self.name} tool are: {', '.join(get_args(Command))}"
)
except Exception as e:
return ExtendedToolImplOutput(
- e.message,
- e.message,
+ e.message, # pyright: ignore[reportAttributeAccessIssue]
+ e.message, # pyright: ignore[reportAttributeAccessIssue]
{"success": False},
)
@@ -376,6 +377,7 @@ def _str_replace_ignore_indent(self, path: Path, old_str: str, new_str: str | No
new_str = new_str.expandtabs()
new_str = match_indent(new_str, content)
+ assert new_str is not None, "new_str should not be None after match_indent"
# Split into lines for processing
content_lines = content.splitlines()
@@ -422,6 +424,7 @@ def _str_replace_ignore_indent(self, path: Path, old_str: str, new_str: str | No
indented_new_str = match_indent_by_first_line(
new_str, original_matched_lines[0]
)
+ assert indented_new_str is not None, "indented_new_str should not be None"
# Create new content by replacing the matched lines
new_content = [
diff --git a/tools/test_bash_tool.py b/tools/test_bash_tool.py
index 337f837..3a346e9 100644
--- a/tools/test_bash_tool.py
+++ b/tools/test_bash_tool.py
@@ -19,6 +19,7 @@
create_bash_tool,
)
+
def bash_tool():
return BashTool(
workspace_root=Path("/tmp"),
@@ -32,9 +33,7 @@ def test_successful_command():
workspace_root=Path("/tmp"),
require_confirmation=False,
)
- with patch(
- "tools.bash_tool.run_command"
- ) as mock_run_command:
+ with patch("tools.bash_tool.run_command") as mock_run_command:
# Mock a successful command execution
mock_run_command.return_value = "Command output"
@@ -55,9 +54,7 @@ def test_failed_command():
workspace_root=Path("/tmp"),
require_confirmation=False,
)
- with patch(
- "tools.bash_tool.run_command"
- ) as mock_run_command:
+ with patch("tools.bash_tool.run_command") as mock_run_command:
# Mock a failed command execution that raises an exception
mock_run_command.side_effect = Exception("Command failed")
@@ -82,9 +79,7 @@ def test_command_with_exception():
workspace_root=Path("/tmp"),
require_confirmation=False,
)
- with patch(
- "tools.bash_tool.run_command"
- ) as mock_run_command:
+ with patch("tools.bash_tool.run_command") as mock_run_command:
# Mock an exception during command execution
mock_run_command.side_effect = Exception("Test exception")
@@ -351,9 +346,7 @@ def test_get_tool_start_message(self):
def test_create_bash_tool(self):
"""Test the create_bash_tool factory function."""
- with patch(
- "tools.bash_tool.BashTool"
- ) as mock_bash_tool:
+ with patch("tools.bash_tool.BashTool") as mock_bash_tool:
_ = create_bash_tool(
ask_user_permission=True,
cwd=self.workspace_root,
diff --git a/tools/test_sequential_thinking_tool.py b/tools/test_sequential_thinking_tool.py
index 87cbf27..e8f0a09 100644
--- a/tools/test_sequential_thinking_tool.py
+++ b/tools/test_sequential_thinking_tool.py
@@ -33,10 +33,10 @@ def test_validate_thought_data_valid(self):
"nextThoughtNeeded": True,
}
result = self.tool._validate_thought_data(valid_input)
- self.assertEqual(result["thought"], "This is a test thought")
- self.assertEqual(result["thoughtNumber"], 1)
- self.assertEqual(result["totalThoughts"], 5)
- self.assertTrue(result["nextThoughtNeeded"])
+ self.assertEqual(result["thought"], "This is a test thought") # pyright: ignore[reportTypedDictNotRequiredAccess]
+ self.assertEqual(result["thoughtNumber"], 1) # pyright: ignore[reportTypedDictNotRequiredAccess]
+ self.assertEqual(result["totalThoughts"], 5) # pyright: ignore[reportTypedDictNotRequiredAccess]
+ self.assertTrue(result["nextThoughtNeeded"]) # pyright: ignore[reportTypedDictNotRequiredAccess]
def test_validate_thought_data_invalid(self):
"""Test validation with invalid input."""
@@ -67,7 +67,7 @@ def test_format_thought_regular(self):
"totalThoughts": 5,
"nextThoughtNeeded": True,
}
- formatted = self.tool._format_thought(thought_data)
+ formatted = self.tool._format_thought(thought_data) # pyright: ignore[reportArgumentType]
self.assertIn("💭 Thought 1/5", formatted)
self.assertIn("This is a regular thought", formatted)
@@ -81,7 +81,7 @@ def test_format_thought_revision(self):
"isRevision": True,
"revisesThought": 1,
}
- formatted = self.tool._format_thought(thought_data)
+ formatted = self.tool._format_thought(thought_data) # pyright: ignore[reportArgumentType]
self.assertIn("🔄 Revision 2/5", formatted)
self.assertIn("(revising thought 1)", formatted)
self.assertIn("This is a revision", formatted)
@@ -96,7 +96,7 @@ def test_format_thought_branch(self):
"branchFromThought": 2,
"branchId": "branch-1",
}
- formatted = self.tool._format_thought(thought_data)
+ formatted = self.tool._format_thought(thought_data) # pyright: ignore[reportArgumentType]
self.assertIn("🌿 Branch 3/5", formatted)
self.assertIn("(from thought 2, ID: branch-1)", formatted)
self.assertIn("This is a branch", formatted)
@@ -130,7 +130,8 @@ def test_run_impl_success(self):
# Verify thought was added to history
self.assertEqual(len(self.tool.thought_history), 1)
self.assertEqual(
- self.tool.thought_history[0]["thought"], "This is a test thought"
+ self.tool.thought_history[0]["thought"], # pyright: ignore[reportTypedDictNotRequiredAccess]
+ "This is a test thought",
)
def test_run_impl_with_branch(self):
@@ -197,7 +198,7 @@ def test_adjust_total_thoughts(self):
# Verify totalThoughts was adjusted
self.assertEqual(output_data["totalThoughts"], 10)
- self.assertEqual(self.tool.thought_history[0]["totalThoughts"], 10)
+ self.assertEqual(self.tool.thought_history[0]["totalThoughts"], 10) # pyright: ignore[reportTypedDictNotRequiredAccess]
def test_get_tool_start_message(self):
"""Test the get_tool_start_message method."""
diff --git a/tools/test_str_replace_tool.py b/tools/test_str_replace_tool.py
index 33f1240..a3d9e9e 100644
--- a/tools/test_str_replace_tool.py
+++ b/tools/test_str_replace_tool.py
@@ -1,11 +1,5 @@
-import pytest
-from pathlib import Path
from unittest.mock import MagicMock, patch
-
-from tools.str_replace_tool import (
- StrReplaceEditorTool,
- ToolError,
-)
+from tools.str_replace_tool import StrReplaceEditorTool
def build_ws_manager(root):
@@ -458,13 +452,15 @@ def test_str_replace_with_indentation(tmp_path):
test_file = tmp_path / "test_indentation.py"
# Create a file with indented code
- test_file.write_text("""def main():
+ test_file.write_text(
+ """def main():
if True:
print("Hello")
if True:
print("World")
print("End")
-""")
+"""
+ )
tool = StrReplaceEditorTool(
workspace_manager=workspace_manager,
@@ -494,7 +490,8 @@ def test_str_replace_with_different_indentation_levels(tmp_path):
test_file = tmp_path / "test_multi_indent.py"
# Create a file with multiple indentation levels
- test_file.write_text("""def function():
+ test_file.write_text(
+ """def function():
# Level 1
if condition_1:
# Level 2
@@ -505,7 +502,8 @@ def test_str_replace_with_different_indentation_levels(tmp_path):
process(item)
else:
skip(item)
-""")
+"""
+ )
tool = StrReplaceEditorTool(
workspace_manager=workspace_manager,
@@ -542,7 +540,8 @@ def test_str_replace_with_mixed_indentation(tmp_path):
test_file = tmp_path / "test_mixed_indent.py"
# Create a file with mixed tabs and spaces
- test_file.write_text("""def mixed_indentation():
+ test_file.write_text(
+ """def mixed_indentation():
# 4 spaces
if condition:
# 8 spaces
@@ -552,7 +551,8 @@ def test_str_replace_with_mixed_indentation(tmp_path):
if value > threshold:
# 12 spaces
return value
-""")
+"""
+ )
tool = StrReplaceEditorTool(
workspace_manager=workspace_manager,
@@ -584,7 +584,8 @@ def test_str_replace_indentation_edge_cases(tmp_path):
test_file = tmp_path / "test_indent_edge.py"
# Create a file with some edge cases
- test_file.write_text("""def edge_cases():
+ test_file.write_text(
+ """def edge_cases():
# Empty lines between code
if condition:
@@ -596,7 +597,8 @@ def test_str_replace_indentation_edge_cases(tmp_path):
print("Two spaces")
print("Four spaces")
print("Six spaces")
-""")
+"""
+ )
tool = StrReplaceEditorTool(
workspace_manager=workspace_manager,
@@ -644,11 +646,13 @@ def test_str_replace_no_match_after_indentation_attempts(tmp_path):
workspace_manager = build_ws_manager(tmp_path)
test_file = tmp_path / "test_no_match.py"
- test_file.write_text("""def no_match():
+ test_file.write_text(
+ """def no_match():
print("This is some code")
if condition:
print("More code")
-""")
+"""
+ )
tool = StrReplaceEditorTool(
workspace_manager=workspace_manager,
diff --git a/utils/common.py b/utils/common.py
index fa09014..76f677b 100644
--- a/utils/common.py
+++ b/utils/common.py
@@ -9,12 +9,11 @@
import subprocess
import jsonschema
-import numpy as np
from anthropic import BadRequestError
from termcolor import colored
from typing_extensions import final
-from token_counter import (
+from utils.token_counter import (
ClaudeTokenCounter,
)
from utils.llm_client import (
@@ -72,7 +71,11 @@ class DialogMessages:
An assistant turn consists of a model answer and tool calls.
"""
- def __init__(self, logger_for_agent_logs: logging.Logger, use_prompt_budgeting: bool = False, ):
+ def __init__(
+ self,
+ logger_for_agent_logs: logging.Logger,
+ use_prompt_budgeting: bool = False,
+ ):
self.logger_for_agent_logs = logger_for_agent_logs
self._message_lists: list[list[GeneralContentBlock]] = []
self.token_counter = ClaudeTokenCounter()
@@ -335,9 +338,9 @@ def _assert_user_turn(self):
assert self.is_user_turn(), "Can only add user prompts on user's turn"
def _assert_assistant_turn(self):
- assert (
- self.is_assistant_turn()
- ), "Can only get/replace last user prompt on assistant's turn"
+ assert self.is_assistant_turn(), (
+ "Can only get/replace last user prompt on assistant's turn"
+ )
class Tool:
@@ -455,6 +458,7 @@ def call_tools(
return tool_outputs
+
def generate_patch(git_repo, reverse=False):
"""Generate the patch for the prediction."""
logging.info(f"Generating patch in {git_repo}")
@@ -488,4 +492,4 @@ def generate_patch(git_repo, reverse=False):
logging.error(
f"Failed to decode git diff output after {max_retries} attempts."
)
- raise
\ No newline at end of file
+ raise
diff --git a/utils/console.py b/utils/console.py
deleted file mode 100644
index 4ebe337..0000000
--- a/utils/console.py
+++ /dev/null
@@ -1,151 +0,0 @@
-"""
-Console Utilities Module
-
-This module provides utilities for console interaction in the setup script generator.
-"""
-
-import os
-from pathlib import Path
-from typing import Optional, List
-
-from rich.console import Console
-from rich.panel import Panel
-from rich.syntax import Syntax
-from rich.live import Live
-from prompt_toolkit import prompt
-
-# Initialize console
-console = Console()
-
-
-class Confirm:
- """Utility class for confirmation prompts."""
-
- @staticmethod
- def ask(question):
- """Ask a yes/no question and return the answer.
-
- Args:
- question: The question to ask
-
- Returns:
- True if the answer is yes, False otherwise
- """
- response = prompt(f"{question} [y/n] ").lower().strip()
- return response.startswith("y")
-
-
-def print_welcome_message():
- """Print the welcome message for the setup script generator and wait for user to press Enter."""
- console.print(
- Panel(
- "[bold]Automated environment configuration powered by Augment[/bold]\n\n"
- + "[yellow]Workflow:[/yellow]\n"
- + "[dim]1.[/dim] [cyan]Select Test Command[/cyan] → Identifies appropriate validation command for the project\n"
- + "[dim]2.[/dim] [cyan]Generate Setup Script[/cyan] → Creates setup script and validates it by running the test command\n\n"
- + "[italic green]Press Enter to start...[/italic green]",
- title="[bold blue]AI Setup Script Generator[/bold blue]",
- border_style="blue",
- padding=(1, 2),
- )
- )
- # Wait for user to press Enter
- input()
-
-
-def print_file_saved_message(file_path: Path):
- """Print a message indicating that a file was saved.
-
- Args:
- file_path: Path to the saved file
- """
- console.print(f"\n[bold green]Setup script saved to: {file_path}[/bold green]")
- console.print("You can now use this script to set up your project.")
-
-
-def display_script(script_content: str, file_path: Optional[Path] = None):
- """Display a script with syntax highlighting.
-
- Args:
- script_content: The content of the script
- file_path: Optional path to the script file
- """
- title = f"Setup Script ({file_path})" if file_path else "Setup Script"
- syntax = Syntax(script_content, "bash", theme="monokai", line_numbers=True)
- console.print(Panel(syntax, title=title))
-
-
-def create_live_output_display(title: str = "Output", height: int = 20):
- """Create a live output display with fixed height and syntax highlighting.
-
- Args:
- title: Title for the panel
- height: Height of the output panel in lines
-
- Returns:
- A tuple of (Live, function to update content)
- """
-
- # Create a renderable that will be updated
- class LogDisplay:
- def __init__(self):
- self.log_lines = []
-
- def __rich__(self):
- # Create a syntax object with the current log lines
- # Only show the last 'height' lines if we have more
- display_lines = (
- self.log_lines[-height:]
- if len(self.log_lines) > height
- else self.log_lines
- )
- text = "\n".join(display_lines)
- return Panel(
- Syntax(
- text,
- "bash", # Using bash lexer which works well for logs with commands
- theme="monokai",
- line_numbers=False,
- word_wrap=True,
- ),
- title=title,
- border_style="green",
- )
-
- log_display = LogDisplay()
-
- def update_content(lines: List[str]):
- """Update the display with new content.
-
- Args:
- lines: List of log lines to display
- """
- log_display.log_lines = lines
-
- # Create the live display with a higher refresh rate for smoother updates
- live = Live(log_display, refresh_per_second=10)
-
- return live, update_content
-
-
-def get_script_path(default_path: str = "setup.sh") -> Path:
- """Get the path to save the script.
-
- Args:
- default_path: Default path for the script
-
- Returns:
- Path object for the script
- """
- script_path = prompt(
- f"Where would you like to save the setup script? [default: {default_path}]: "
- )
-
- if not script_path:
- script_path = default_path
-
- # Convert to absolute path if not already
- if not os.path.isabs(script_path):
- script_path = os.path.abspath(script_path)
-
- return Path(script_path)
diff --git a/utils/docker_utils.py b/utils/docker_utils.py
index 25b9742..7a45a6f 100644
--- a/utils/docker_utils.py
+++ b/utils/docker_utils.py
@@ -1,6 +1,4 @@
-
import docker
-import json
import logging
import os
import subprocess
@@ -58,6 +56,7 @@ def set_volume_permissions(container_id, volume_path: Path):
logging.warning(f"Failed to chown {volume_path}: {e}")
raise
+
def start_container(workspace: Path, problem_id: str, semaphore: Any) -> str:
"""Start a docker container for the issue."""
stop_container(f"sweb.augment.{problem_id}")
@@ -106,6 +105,7 @@ def start_container(workspace: Path, problem_id: str, semaphore: Any) -> str:
time.sleep(10)
return container_id
+
def remove_container_image(image_name: str) -> None:
"""Remove a docker image."""
try:
@@ -188,4 +188,4 @@ def setup_workspace(
# Copy the python wrapper into the workspace
container_id = start_container(workspace, problem_id, semaphore)
- return env, container_id
\ No newline at end of file
+ return env, container_id
diff --git a/utils/indent_utils.py b/utils/indent_utils.py
index 353452a..f336ba8 100644
--- a/utils/indent_utils.py
+++ b/utils/indent_utils.py
@@ -73,7 +73,7 @@ def detect_line_indent(line: str) -> Tuple[int, int]:
return (num_tabs, num_spaces)
-def detect_indent_type(code: str) -> IndentType | None:
+def detect_indent_type(code: str | None) -> IndentType | None:
"""Detect the indentation type (spaces or tabs) and size used in the code.
If the code contains mixed indentation, it will return MIXED.
@@ -156,7 +156,7 @@ def force_normalize_indent(code: str) -> str:
return "\n".join(normalized_lines)
-def normalize_indent(code: str, indent_type: IndentType) -> str:
+def normalize_indent(code: str | None, indent_type: IndentType) -> str | None:
"""Normalize indentation in code to use 4 spaces.
Args:
@@ -207,8 +207,10 @@ def normalize_indent(code: str, indent_type: IndentType) -> str:
def apply_indent_type(
- code: str, indent_type: IndentType, original_indent_type: IndentType | None = None
-) -> str:
+ code: str | None,
+ indent_type: IndentType,
+ original_indent_type: IndentType | None = None,
+) -> str | None:
"""Apply the specified indentation type to code.
Args:
@@ -268,7 +270,7 @@ def apply_indent_type(
return "\n".join(modified_lines)
-def match_indent_by_first_line(code: str, line: str) -> str:
+def match_indent_by_first_line(code: str | None, line: str) -> str | None:
"""Match the indentation of the first line in code to the given line.
All subsequent lines will be adjusted to maintain their relative indentation.
@@ -307,7 +309,10 @@ def match_indent_by_first_line(code: str, line: str) -> str:
return "\n".join(modified_lines)
-def match_indent(code: str, code_to_match: str) -> str:
+def match_indent(code: str | None, code_to_match: str) -> str | None:
+ if not code or not isinstance(code, str):
+ return code
+
indent_type = detect_indent_type(code_to_match)
if indent_type is not None and indent_type.is_mixed:
indent_type = indent_type.most_used
diff --git a/utils/llm_client.py b/utils/llm_client.py
index 77ef537..8ba5c81 100644
--- a/utils/llm_client.py
+++ b/utils/llm_client.py
@@ -1,14 +1,11 @@
"""LLM client for Anthropic models."""
-
import json
import os
import random
-import io
-import sys
import time
from dataclasses import dataclass
-from typing import Any, Tuple, cast, Generator
+from typing import Any, Tuple, cast
from dataclasses_json import DataClassJsonMixin
import anthropic
import openai
@@ -25,7 +22,7 @@
RateLimitError as AnthropicRateLimitError,
)
from anthropic._exceptions import (
- OverloadedError as AnthropicOverloadedError,
+ OverloadedError as AnthropicOverloadedError, # pyright: ignore[reportPrivateImportUsage]
)
from anthropic.types import (
TextBlock as AnthropicTextBlock,
@@ -54,11 +51,15 @@
from openai import (
RateLimitError as OpenAI_RateLimitError,
)
-from openai._types import NOT_GIVEN as OpenAI_NOT_GIVEN
+from openai._types import (
+ NOT_GIVEN as OpenAI_NOT_GIVEN, # pyright: ignore[reportPrivateImportUsage]
+)
import logging
+
logging.getLogger("httpx").setLevel(logging.WARNING)
+
@dataclass
class ToolParam(DataClassJsonMixin):
"""Internal representation of LLM tool."""
@@ -116,6 +117,7 @@ class TextResult(DataClassJsonMixin):
GeneralContentBlock = UserContentBlock | AssistantContentBlock
LLMMessages = list[list[GeneralContentBlock]]
+
class LLMClient:
"""A client for LLM APIs for the use in agents."""
@@ -144,6 +146,7 @@ def generate(
"""
raise NotImplementedError
+
def recursively_remove_invoke_tag(obj):
"""Recursively remove the tag from a dictionary or list."""
result_obj = {}
@@ -184,7 +187,7 @@ def __init__(
self.use_caching = use_caching
self.prompt_caching_headers = {"anthropic-beta": "prompt-caching-2024-07-31"}
self.thinking_tokens = thinking_tokens
-
+
def generate(
self,
messages: LLMMessages,
@@ -312,9 +315,9 @@ def generate(
"thinking": {"type": "enabled", "budget_tokens": thinking_tokens}
}
temperature = 1
- assert (
- max_tokens >= 32_000 and thinking_tokens <= 8192
- ), f"As a heuristic, max tokens {max_tokens} must be >= 32k and thinking tokens {thinking_tokens} must be < 8k"
+ assert max_tokens >= 32_000 and thinking_tokens <= 8192, (
+ f"As a heuristic, max tokens {max_tokens} must be >= 32k and thinking tokens {thinking_tokens} must be < 8k"
+ )
else:
extra_body = None
@@ -390,6 +393,7 @@ def generate(
return augment_messages, message_metadata
+
class OpenAIDirectClient(LLMClient):
"""Use OpenAI models via first party API."""
diff --git a/utils/swebench_eval_utils.py b/utils/swebench_eval_utils.py
index 4d1fd28..2c5c6b3 100644
--- a/utils/swebench_eval_utils.py
+++ b/utils/swebench_eval_utils.py
@@ -11,6 +11,7 @@
from utils.docker_utils import stop_container
+
def get_dataset_name(dataset: str) -> str:
"""Get the dataset name for the specified dataset."""
return {
@@ -19,6 +20,7 @@ def get_dataset_name(dataset: str) -> str:
"lite": "princeton-nlp/SWE-bench_Lite",
}[dataset]
+
def run_evaluation(
predictions_file: Path,
dataset: str,
@@ -70,7 +72,9 @@ def run_evaluation(
capture_output=capture_output,
)
except huggingface_hub.errors.HfHubHTTPError as e:
- logging.warning(f"Failed to run evaluation for {instance_id}: {e}. Retrying.")
+ logging.warning(
+ f"Failed to run evaluation for {instance_id}: {e}. Retrying."
+ )
continue
tries += 1
want_retry = False
@@ -113,4 +117,4 @@ def run_evaluation(
logging.warning(
f"Evaluations completed for {predictions_file.name} in {datetime.now() - et} ({(datetime.now() - et).total_seconds():.2f}s) rc {evaluation.returncode}"
)
- return evaluation
\ No newline at end of file
+ return evaluation
diff --git a/utils/test_indent_utils.py b/utils/test_indent_utils.py
index f0286eb..d2a6f01 100644
--- a/utils/test_indent_utils.py
+++ b/utils/test_indent_utils.py
@@ -214,9 +214,9 @@ def test_detect_indent_type_edge_cases(self):
# Code with tabs and spaces on same line but tabs first
code = "def test():\n\t print('mixed')"
indent_type = detect_indent_type(code)
- assert indent_type.is_mixed
+ assert indent_type.is_mixed # pyright: ignore[reportOptionalMemberAccess]
assert (
- indent_type.most_used == IndentType.tab()
+ indent_type.most_used == IndentType.tab() # pyright: ignore[reportOptionalMemberAccess]
) # Tab should be most used since it's the primary indent
# Code with only one indented line but multiple levels
@@ -607,9 +607,9 @@ def test_end_to_end_complex(self):
result.append(i)
\treturn result"""
mixed_type = detect_indent_type(mixed_code)
- assert mixed_type.is_mixed
+ assert mixed_type.is_mixed # pyright: ignore[reportOptionalMemberAccess]
assert (
- mixed_type.most_used == IndentType.tab()
+ mixed_type.most_used == IndentType.tab() # pyright: ignore[reportOptionalMemberAccess]
) # More tab-indented lines (3 vs 2)
# Applying new indent to mixed code should return original
@@ -706,13 +706,13 @@ def test_indenttype_properties(self):
mixed_space = IndentType.mixed(most_used=IndentType.space(4))
assert mixed_space.is_mixed
assert mixed_space.most_used == IndentType.space(4)
- assert mixed_space.most_used.size == 4
+ assert mixed_space.most_used.size == 4 # pyright: ignore[reportOptionalMemberAccess]
# Test mixed indentation with tab as most_used
mixed_tab = IndentType.mixed(most_used=IndentType.tab())
assert mixed_tab.is_mixed
assert mixed_tab.most_used == IndentType.tab()
- assert mixed_tab.most_used.size == 1
+ assert mixed_tab.most_used.size == 1 # pyright: ignore[reportOptionalMemberAccess]
# Test equality
assert IndentType.space(4) == IndentType.space(4)
diff --git a/token_counter.py b/utils/token_counter.py
similarity index 100%
rename from token_counter.py
rename to utils/token_counter.py
diff --git a/utils/workspace_manager.py b/utils/workspace_manager.py
index 202b046..41a6ad2 100644
--- a/utils/workspace_manager.py
+++ b/utils/workspace_manager.py
@@ -30,4 +30,4 @@ def container_path(self, path: Path | str) -> Path:
return self.root / path
if self.container_workspace and path.is_relative_to(self.root):
return self.container_workspace / path.relative_to(self.root)
- return path
\ No newline at end of file
+ return path
diff --git a/uv.lock b/uv.lock
index 9aa93a9..b113241 100644
--- a/uv.lock
+++ b/uv.lock
@@ -151,6 +151,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/38/fc/bce832fd4fd99766c04d1ee0eead6b0ec6486fb100ae5e74c1d91292b982/certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe", size = 166393 },
]
+[[package]]
+name = "cfgv"
+version = "3.4.0"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/11/74/539e56497d9bd1d484fd863dd69cbbfa653cd2aa27abfe35653494d85e94/cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560", size = 7114 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/c5/55/51844dd50c4fc7a33b653bfaba4c2456f06955289ca770a5dbd5fd267374/cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9", size = 7249 },
+]
+
[[package]]
name = "charset-normalizer"
version = "3.4.1"
@@ -255,6 +264,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7", size = 116252 },
]
+[[package]]
+name = "distlib"
+version = "0.3.9"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/0d/dd/1bec4c5ddb504ca60fc29472f3d27e8d4da1257a854e1d96742f15c1d02d/distlib-0.3.9.tar.gz", hash = "sha256:a60f20dea646b8a33f3e7772f74dc0b2d0772d2837ee1342a00645c81edf9403", size = 613923 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/91/a1/cf2472db20f7ce4a6be1253a81cfdf85ad9c7885ffbed7047fb72c24cf87/distlib-0.3.9-py2.py3-none-any.whl", hash = "sha256:47f8c22fd27c27e25a65601af709b38e4f0a45ea4fc2e710f65755fa8caaaf87", size = 468973 },
+]
+
[[package]]
name = "distro"
version = "1.9.0"
@@ -422,6 +440,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/40/0c/37d380846a2e5c9a3c6a73d26ffbcfdcad5fc3eacf42fdf7cff56f2af634/huggingface_hub-0.29.3-py3-none-any.whl", hash = "sha256:0b25710932ac649c08cdbefa6c6ccb8e88eef82927cacdb048efb726429453aa", size = 468997 },
]
+[[package]]
+name = "identify"
+version = "2.6.9"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/9b/98/a71ab060daec766acc30fb47dfca219d03de34a70d616a79a38c6066c5bf/identify-2.6.9.tar.gz", hash = "sha256:d40dfe3142a1421d8518e3d3985ef5ac42890683e32306ad614a29490abeb6bf", size = 99249 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/07/ce/0845144ed1f0e25db5e7a79c2354c1da4b5ce392b8966449d5db8dca18f1/identify-2.6.9-py2.py3-none-any.whl", hash = "sha256:c98b4322da415a8e5a70ff6e51fbc2d2932c015532d77e9f8537b4ba7813b150", size = 99101 },
+]
+
[[package]]
name = "idna"
version = "3.10"
@@ -641,6 +668,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695 },
]
+[[package]]
+name = "nodeenv"
+version = "1.9.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314 },
+]
+
[[package]]
name = "numpy"
version = "2.2.4"
@@ -770,6 +806,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772 },
]
+[[package]]
+name = "platformdirs"
+version = "4.3.7"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/b6/2d/7d512a3913d60623e7eb945c6d1b4f0bddf1d0b7ada5225274c87e5b53d1/platformdirs-4.3.7.tar.gz", hash = "sha256:eb437d586b6a0986388f0d6f74aa0cde27b48d0e3d66843640bfb6bdcdb6e351", size = 21291 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/6d/45/59578566b3275b8fd9157885918fcd0c4d74162928a5310926887b856a51/platformdirs-4.3.7-py3-none-any.whl", hash = "sha256:a03875334331946f13c549dbd8f4bac7a13a50a895a0eb1e8c6a8ace80d40a94", size = 18499 },
+]
+
[[package]]
name = "pluggy"
version = "1.5.0"
@@ -779,6 +824,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 },
]
+[[package]]
+name = "pre-commit"
+version = "4.2.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "cfgv" },
+ { name = "identify" },
+ { name = "nodeenv" },
+ { name = "pyyaml" },
+ { name = "virtualenv" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/08/39/679ca9b26c7bb2999ff122d50faa301e49af82ca9c066ec061cfbc0c6784/pre_commit-4.2.0.tar.gz", hash = "sha256:601283b9757afd87d40c4c4a9b2b5de9637a8ea02eaff7adc2d0fb4e04841146", size = 193424 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/88/74/a88bf1b1efeae488a0c0b7bdf71429c313722d1fc0f377537fbe554e6180/pre_commit-4.2.0-py2.py3-none-any.whl", hash = "sha256:a009ca7205f1eb497d10b845e52c838a98b6cdd2102a6c8e4540e94ee75c58bd", size = 220707 },
+]
+
[[package]]
name = "prompt-toolkit"
version = "3.0.50"
@@ -997,6 +1058,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/8a/0b/9fcc47d19c48b59121088dd6da2488a49d5f72dacf8262e2790a1d2c7d15/pygments-2.19.1-py3-none-any.whl", hash = "sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c", size = 1225293 },
]
+[[package]]
+name = "pyright"
+version = "1.1.398"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "nodeenv" },
+ { name = "typing-extensions" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/24/d6/48740f1d029e9fc4194880d1ad03dcf0ba3a8f802e0e166b8f63350b3584/pyright-1.1.398.tar.gz", hash = "sha256:357a13edd9be8082dc73be51190913e475fa41a6efb6ec0d4b7aab3bc11638d8", size = 3892675 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/58/e0/5283593f61b3c525d6d7e94cfb6b3ded20b3df66e953acaf7bb4f23b3f6e/pyright-1.1.398-py3-none-any.whl", hash = "sha256:0a70bfd007d9ea7de1cf9740e1ad1a40a122592cfe22a3f6791b06162ad08753", size = 5780235 },
+]
+
[[package]]
name = "pytest"
version = "7.4.3"
@@ -1197,6 +1271,31 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b3/14/c492b9c7d5dd133e13f211ddea6bb9870f99e4f73932f11aa00bc09a9be9/rpds_py-0.24.0-pp311-pypy311_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:6a727fd083009bc83eb83d6950f0c32b3c94c8b80a9b667c87f4bd1274ca30ba", size = 560885 },
]
+[[package]]
+name = "ruff"
+version = "0.11.2"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/90/61/fb87430f040e4e577e784e325351186976516faef17d6fcd921fe28edfd7/ruff-0.11.2.tar.gz", hash = "sha256:ec47591497d5a1050175bdf4e1a4e6272cddff7da88a2ad595e1e326041d8d94", size = 3857511 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/62/99/102578506f0f5fa29fd7e0df0a273864f79af044757aef73d1cae0afe6ad/ruff-0.11.2-py3-none-linux_armv6l.whl", hash = "sha256:c69e20ea49e973f3afec2c06376eb56045709f0212615c1adb0eda35e8a4e477", size = 10113146 },
+ { url = "https://files.pythonhosted.org/packages/74/ad/5cd4ba58ab602a579997a8494b96f10f316e874d7c435bcc1a92e6da1b12/ruff-0.11.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:2c5424cc1c4eb1d8ecabe6d4f1b70470b4f24a0c0171356290b1953ad8f0e272", size = 10867092 },
+ { url = "https://files.pythonhosted.org/packages/fc/3e/d3f13619e1d152c7b600a38c1a035e833e794c6625c9a6cea6f63dbf3af4/ruff-0.11.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:ecf20854cc73f42171eedb66f006a43d0a21bfb98a2523a809931cda569552d9", size = 10224082 },
+ { url = "https://files.pythonhosted.org/packages/90/06/f77b3d790d24a93f38e3806216f263974909888fd1e826717c3ec956bbcd/ruff-0.11.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0c543bf65d5d27240321604cee0633a70c6c25c9a2f2492efa9f6d4b8e4199bb", size = 10394818 },
+ { url = "https://files.pythonhosted.org/packages/99/7f/78aa431d3ddebfc2418cd95b786642557ba8b3cb578c075239da9ce97ff9/ruff-0.11.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:20967168cc21195db5830b9224be0e964cc9c8ecf3b5a9e3ce19876e8d3a96e3", size = 9952251 },
+ { url = "https://files.pythonhosted.org/packages/30/3e/f11186d1ddfaca438c3bbff73c6a2fdb5b60e6450cc466129c694b0ab7a2/ruff-0.11.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:955a9ce63483999d9f0b8f0b4a3ad669e53484232853054cc8b9d51ab4c5de74", size = 11563566 },
+ { url = "https://files.pythonhosted.org/packages/22/6c/6ca91befbc0a6539ee133d9a9ce60b1a354db12c3c5d11cfdbf77140f851/ruff-0.11.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:86b3a27c38b8fce73bcd262b0de32e9a6801b76d52cdb3ae4c914515f0cef608", size = 12208721 },
+ { url = "https://files.pythonhosted.org/packages/19/b0/24516a3b850d55b17c03fc399b681c6a549d06ce665915721dc5d6458a5c/ruff-0.11.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a3b66a03b248c9fcd9d64d445bafdf1589326bee6fc5c8e92d7562e58883e30f", size = 11662274 },
+ { url = "https://files.pythonhosted.org/packages/d7/65/76be06d28ecb7c6070280cef2bcb20c98fbf99ff60b1c57d2fb9b8771348/ruff-0.11.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0397c2672db015be5aa3d4dac54c69aa012429097ff219392c018e21f5085147", size = 13792284 },
+ { url = "https://files.pythonhosted.org/packages/ce/d2/4ceed7147e05852876f3b5f3fdc23f878ce2b7e0b90dd6e698bda3d20787/ruff-0.11.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:869bcf3f9abf6457fbe39b5a37333aa4eecc52a3b99c98827ccc371a8e5b6f1b", size = 11327861 },
+ { url = "https://files.pythonhosted.org/packages/c4/78/4935ecba13706fd60ebe0e3dc50371f2bdc3d9bc80e68adc32ff93914534/ruff-0.11.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:2a2b50ca35457ba785cd8c93ebbe529467594087b527a08d487cf0ee7b3087e9", size = 10276560 },
+ { url = "https://files.pythonhosted.org/packages/81/7f/1b2435c3f5245d410bb5dc80f13ec796454c21fbda12b77d7588d5cf4e29/ruff-0.11.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:7c69c74bf53ddcfbc22e6eb2f31211df7f65054bfc1f72288fc71e5f82db3eab", size = 9945091 },
+ { url = "https://files.pythonhosted.org/packages/39/c4/692284c07e6bf2b31d82bb8c32f8840f9d0627d92983edaac991a2b66c0a/ruff-0.11.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6e8fb75e14560f7cf53b15bbc55baf5ecbe373dd5f3aab96ff7aa7777edd7630", size = 10977133 },
+ { url = "https://files.pythonhosted.org/packages/94/cf/8ab81cb7dd7a3b0a3960c2769825038f3adcd75faf46dd6376086df8b128/ruff-0.11.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:842a472d7b4d6f5924e9297aa38149e5dcb1e628773b70e6387ae2c97a63c58f", size = 11378514 },
+ { url = "https://files.pythonhosted.org/packages/d9/3a/a647fa4f316482dacf2fd68e8a386327a33d6eabd8eb2f9a0c3d291ec549/ruff-0.11.2-py3-none-win32.whl", hash = "sha256:aca01ccd0eb5eb7156b324cfaa088586f06a86d9e5314b0eb330cb48415097cc", size = 10319835 },
+ { url = "https://files.pythonhosted.org/packages/86/54/3c12d3af58012a5e2cd7ebdbe9983f4834af3f8cbea0e8a8c74fa1e23b2b/ruff-0.11.2-py3-none-win_amd64.whl", hash = "sha256:3170150172a8f994136c0c66f494edf199a0bbea7a409f649e4bc8f4d7084080", size = 11373713 },
+ { url = "https://files.pythonhosted.org/packages/d6/d4/dd813703af8a1e2ac33bf3feb27e8a5ad514c9f219df80c64d69807e7f71/ruff-0.11.2-py3-none-win_arm64.whl", hash = "sha256:52933095158ff328f4c77af3d74f0379e34fd52f175144cefc1b192e7ccd32b4", size = 10441990 },
+]
+
[[package]]
name = "six"
version = "1.17.0"
@@ -1229,12 +1328,19 @@ dependencies = [
{ name = "numpy" },
{ name = "openai" },
{ name = "pexpect" },
+ { name = "pre-commit" },
{ name = "prompt-toolkit" },
+ { name = "pyright" },
{ name = "pytest" },
{ name = "rich" },
{ name = "termcolor" },
]
+[package.dev-dependencies]
+dev = [
+ { name = "ruff" },
+]
+
[package.metadata]
requires-dist = [
{ name = "anthropic", specifier = "==0.47.0" },
@@ -1246,12 +1352,17 @@ requires-dist = [
{ name = "numpy", specifier = ">=2.2.4" },
{ name = "openai", specifier = "==1.59.*" },
{ name = "pexpect", specifier = ">=4.9.0" },
+ { name = "pre-commit", specifier = ">=4.2.0" },
{ name = "prompt-toolkit", specifier = ">=3.0.50" },
+ { name = "pyright", specifier = ">=1.1.398" },
{ name = "pytest", specifier = "==7.4.3" },
{ name = "rich", specifier = ">=13.9.4" },
{ name = "termcolor", specifier = ">=2.5.0" },
]
+[package.metadata.requires-dev]
+dev = [{ name = "ruff", specifier = ">=0.11.2" }]
+
[[package]]
name = "termcolor"
version = "2.5.0"
@@ -1325,6 +1436,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 },
]
+[[package]]
+name = "virtualenv"
+version = "20.30.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "distlib" },
+ { name = "filelock" },
+ { name = "platformdirs" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/38/e0/633e369b91bbc664df47dcb5454b6c7cf441e8f5b9d0c250ce9f0546401e/virtualenv-20.30.0.tar.gz", hash = "sha256:800863162bcaa5450a6e4d721049730e7f2dae07720e0902b0e4040bd6f9ada8", size = 4346945 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/4c/ed/3cfeb48175f0671ec430ede81f628f9fb2b1084c9064ca67ebe8c0ed6a05/virtualenv-20.30.0-py3-none-any.whl", hash = "sha256:e34302959180fca3af42d1800df014b35019490b119eba981af27f2fa486e5d6", size = 4329461 },
+]
+
[[package]]
name = "wcwidth"
version = "0.2.13"