Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 49 additions & 9 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ def main():
action="store_true",
default=False,
)
parser.add_argument(
"--openai-base-url",
type=str,
default=None,
help="(Optional) Base URL for OpenAI API. Used when ANTHROPIC_API_KEY is not set.",
)
parser.add_argument(
"--openai-model",
type=str,
default="gpt-4-turbo",
help="(Optional) OpenAI model to use when ANTHROPIC_API_KEY is not set.",
)

args = parser.parse_args()

Expand All @@ -86,11 +98,31 @@ def main():
else:
logger_for_agent_logs.propagate = False

# Set OpenAI base URL if provided via command line
if args.openai_base_url:
os.environ["OPENAI_BASE_URL"] = args.openai_base_url

# Check if ANTHROPIC_API_KEY is set
if "ANTHROPIC_API_KEY" not in os.environ:
print("Error: ANTHROPIC_API_KEY environment variable is not set.")
print("Please set it to your Anthropic API key.")
sys.exit(1)
openai_base_url = os.getenv("OPENAI_BASE_URL", "")
base_url_info = f" with base URL {openai_base_url}" if openai_base_url else ""

if not args.minimize_stdout_logs:
console = Console()
console.print(
Panel(
f"[bold yellow]Warning: ANTHROPIC_API_KEY environment variable is not set.[/bold yellow]\n"
+ f"Using OpenAI{base_url_info} with model {args.openai_model} instead.",
title="[bold yellow]API Key Warning[/bold yellow]",
border_style="yellow",
padding=(1, 2),
)
)
else:
logger_for_agent_logs.info(
f"Warning: ANTHROPIC_API_KEY environment variable is not set. "
f"Using OpenAI{base_url_info} with model {args.openai_model} instead."
)

# Initialize console
console = Console()
Expand All @@ -113,11 +145,19 @@ def main():
)

# Initialize LLM client
client = get_client(
"anthropic-direct",
model_name="claude-3-7-sonnet-20250219",
use_caching=True,
)
if "ANTHROPIC_API_KEY" in os.environ:
client = get_client(
"anthropic-direct",
model_name="claude-3-7-sonnet-20250219",
use_caching=True,
)
else:
# Use OpenAI client when Anthropic API key is not available
client = get_client(
"openai-direct",
model_name=args.openai_model,
cot_model=False,
)

# Initialize workspace manager
workspace_path = Path(args.workspace).resolve()
Expand Down Expand Up @@ -189,4 +229,4 @@ def main():


if __name__ == "__main__":
main()
main()
22 changes: 16 additions & 6 deletions utils/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,10 +400,14 @@ class OpenAIDirectClient(LLMClient):
def __init__(self, model_name: str, max_retries=2, cot_model: bool = True):
"""Initialize the OpenAI first party client."""
api_key = os.getenv("OPENAI_API_KEY")
self.client = openai.OpenAI(
api_key=api_key,
max_retries=1,
)
base_url = os.getenv("OPENAI_BASE_URL")

# Initialize the client with optional base_url if provided
client_kwargs = {"api_key": api_key, "max_retries": 1}
if base_url:
client_kwargs["base_url"] = base_url

self.client = openai.OpenAI(**client_kwargs)
self.model_name = model_name
self.max_retries = max_retries
self.cot_model = cot_model
Expand Down Expand Up @@ -454,12 +458,18 @@ def generate(
openai_message = {"role": "assistant", "content": [message_content]}
elif str(type(augment_message)) == str(ToolCall):
augment_message = cast(ToolCall, augment_message)
# Ensure arguments are always a JSON string, not an object
if isinstance(augment_message.tool_input, (dict, list)):
tool_input_str = json.dumps(augment_message.tool_input)
else:
tool_input_str = str(augment_message.tool_input)

tool_call = {
"type": "function",
"id": augment_message.tool_call_id,
"function": {
"name": augment_message.tool_name,
"arguments": augment_message.tool_input,
"arguments": tool_input_str,
},
}
openai_message = {
Expand Down Expand Up @@ -601,4 +611,4 @@ def get_client(client_name: str, **kwargs) -> LLMClient:
elif client_name == "openai-direct":
return OpenAIDirectClient(**kwargs)
else:
raise ValueError(f"Unknown client name: {client_name}")
raise ValueError(f"Unknown client name: {client_name}")