diff --git a/crates/llm/src/aicore_invoke.rs b/crates/llm/src/aicore/anthropic.rs similarity index 86% rename from crates/llm/src/aicore_invoke.rs rename to crates/llm/src/aicore/anthropic.rs index ef0baee2..5c8bb059 100644 --- a/crates/llm/src/aicore_invoke.rs +++ b/crates/llm/src/aicore/anthropic.rs @@ -1,5 +1,10 @@ +//! AI Core client for Anthropic (Claude) API +//! +//! This client wraps the Anthropic client with AI Core authentication. + use crate::{ anthropic::{AnthropicClient, AuthProvider, DefaultMessageConverter, RequestCustomizer}, + auth::TokenManager, types::*, LLMProvider, StreamingCallback, }; @@ -8,9 +13,7 @@ use async_trait::async_trait; use serde_json::Value; use std::sync::Arc; -use super::auth::TokenManager; - -/// AiCore authentication provider using TokenManager +/// AI Core authentication provider using TokenManager pub struct AiCoreAuthProvider { token_manager: Arc, } @@ -32,10 +35,10 @@ impl AuthProvider for AiCoreAuthProvider { } } -/// AiCore request customizer -pub struct AiCoreRequestCustomizer; +/// AI Core request customizer for Anthropic API +pub struct AiCoreAnthropicRequestCustomizer; -impl RequestCustomizer for AiCoreRequestCustomizer { +impl RequestCustomizer for AiCoreAnthropicRequestCustomizer { fn customize_request(&self, request: &mut serde_json::Value) -> Result<()> { if let Value::Object(ref mut map) = request { // Remove stream and model fields after URL routing is done @@ -70,18 +73,19 @@ impl RequestCustomizer for AiCoreRequestCustomizer { } } -pub struct AiCoreClient { +/// AI Core client for Anthropic (Claude) models +pub struct AiCoreAnthropicClient { anthropic_client: AnthropicClient, custom_config: Option, } -impl AiCoreClient { +impl AiCoreAnthropicClient { fn create_anthropic_client( token_manager: Arc, base_url: String, ) -> AnthropicClient { let auth_provider = Box::new(AiCoreAuthProvider::new(token_manager)); - let request_customizer = Box::new(AiCoreRequestCustomizer); + let request_customizer = Box::new(AiCoreAnthropicRequestCustomizer); let message_converter = Box::new(DefaultMessageConverter::new()); AnthropicClient::with_customization( @@ -127,7 +131,7 @@ impl AiCoreClient { } #[async_trait] -impl LLMProvider for AiCoreClient { +impl LLMProvider for AiCoreAnthropicClient { async fn send_message( &mut self, request: LLMRequest, diff --git a/crates/llm/src/aicore/mod.rs b/crates/llm/src/aicore/mod.rs new file mode 100644 index 00000000..9619b2aa --- /dev/null +++ b/crates/llm/src/aicore/mod.rs @@ -0,0 +1,69 @@ +//! AI Core provider module +//! +//! AI Core acts as a proxy service that can route to different backend vendors. +//! This module provides support for multiple vendor API types: +//! - Anthropic (Claude models via Bedrock-style API) +//! - OpenAI (Chat Completions API) +//! - Vertex (Google Gemini API) + +mod anthropic; +mod openai; +mod types; +mod vertex; + +pub use anthropic::AiCoreAnthropicClient; +pub use openai::AiCoreOpenAIClient; +pub use types::AiCoreApiType; +pub use vertex::AiCoreVertexClient; + +use crate::auth::TokenManager; +use crate::LLMProvider; +use std::path::Path; +use std::sync::Arc; + +/// Create an AI Core client based on the API type +pub fn create_aicore_client( + api_type: AiCoreApiType, + token_manager: Arc, + base_url: String, + model_id: String, +) -> Box { + match api_type { + AiCoreApiType::Anthropic => Box::new(AiCoreAnthropicClient::new(token_manager, base_url)), + AiCoreApiType::OpenAI => { + Box::new(AiCoreOpenAIClient::new(token_manager, base_url, model_id)) + } + AiCoreApiType::Vertex => { + Box::new(AiCoreVertexClient::new(token_manager, base_url, model_id)) + } + } +} + +/// Create an AI Core client with recording capability +pub fn create_aicore_client_with_recorder>( + api_type: AiCoreApiType, + token_manager: Arc, + base_url: String, + model_id: String, + recording_path: P, +) -> Box { + match api_type { + AiCoreApiType::Anthropic => Box::new(AiCoreAnthropicClient::new_with_recorder( + token_manager, + base_url, + recording_path, + )), + AiCoreApiType::OpenAI => Box::new(AiCoreOpenAIClient::new_with_recorder( + token_manager, + base_url, + model_id, + recording_path, + )), + AiCoreApiType::Vertex => Box::new(AiCoreVertexClient::new_with_recorder( + token_manager, + base_url, + model_id, + recording_path, + )), + } +} diff --git a/crates/llm/src/aicore/openai.rs b/crates/llm/src/aicore/openai.rs new file mode 100644 index 00000000..56721aed --- /dev/null +++ b/crates/llm/src/aicore/openai.rs @@ -0,0 +1,119 @@ +//! AI Core client for OpenAI Chat Completions API +//! +//! This client wraps the OpenAI client with AI Core authentication. + +use crate::{ + auth::TokenManager, + openai::{AuthProvider, OpenAIClient, RequestCustomizer}, + types::*, + LLMProvider, StreamingCallback, +}; +use anyhow::Result; +use async_trait::async_trait; +use std::sync::Arc; + +/// AI Core authentication provider for OpenAI-style API +struct AiCoreOpenAIAuthProvider { + token_manager: Arc, +} + +impl AiCoreOpenAIAuthProvider { + fn new(token_manager: Arc) -> Self { + Self { token_manager } + } +} + +#[async_trait] +impl AuthProvider for AiCoreOpenAIAuthProvider { + async fn get_auth_headers(&self) -> Result> { + let token = self.token_manager.get_valid_token().await?; + Ok(vec![( + "Authorization".to_string(), + format!("Bearer {token}"), + )]) + } +} + +/// AI Core request customizer for OpenAI Chat Completions API +struct AiCoreOpenAIRequestCustomizer; + +impl RequestCustomizer for AiCoreOpenAIRequestCustomizer { + fn customize_request(&self, _request: &mut serde_json::Value) -> Result<()> { + // No additional customization needed for OpenAI-style requests + Ok(()) + } + + fn get_additional_headers(&self) -> Vec<(String, String)> { + vec![ + ("AI-Resource-Group".to_string(), "default".to_string()), + ("Content-Type".to_string(), "application/json".to_string()), + ] + } + + fn customize_url(&self, base_url: &str, _streaming: bool) -> String { + // AI Core uses /chat/completions endpoint for OpenAI-compatible models + format!("{base_url}/chat/completions") + } +} + +/// AI Core client for OpenAI Chat Completions API +pub struct AiCoreOpenAIClient { + openai_client: OpenAIClient, + custom_config: Option, +} + +impl AiCoreOpenAIClient { + fn create_openai_client( + token_manager: Arc, + base_url: String, + model_id: String, + ) -> OpenAIClient { + let auth_provider = Box::new(AiCoreOpenAIAuthProvider::new(token_manager)); + let request_customizer = Box::new(AiCoreOpenAIRequestCustomizer); + + OpenAIClient::with_customization(model_id, base_url, auth_provider, request_customizer) + } + + pub fn new(token_manager: Arc, base_url: String, model_id: String) -> Self { + let openai_client = Self::create_openai_client(token_manager, base_url, model_id); + Self { + openai_client, + custom_config: None, + } + } + + /// Create a new client with recording capability + /// + /// Note: Recording is not yet implemented for OpenAI client. + /// This constructor exists for API consistency. + pub fn new_with_recorder>( + token_manager: Arc, + base_url: String, + model_id: String, + _recording_path: P, + ) -> Self { + // TODO: Add recording support to OpenAIClient + Self::new(token_manager, base_url, model_id) + } + + /// Set custom model configuration to be merged into API requests + pub fn with_custom_config(mut self, custom_config: serde_json::Value) -> Self { + self.openai_client = self.openai_client.with_custom_config(custom_config.clone()); + self.custom_config = Some(custom_config); + self + } +} + +#[async_trait] +impl LLMProvider for AiCoreOpenAIClient { + async fn send_message( + &mut self, + request: LLMRequest, + streaming_callback: Option<&StreamingCallback>, + ) -> Result { + // Delegate to the wrapped OpenAIClient + self.openai_client + .send_message(request, streaming_callback) + .await + } +} diff --git a/crates/llm/src/aicore/types.rs b/crates/llm/src/aicore/types.rs new file mode 100644 index 00000000..0a5a082c --- /dev/null +++ b/crates/llm/src/aicore/types.rs @@ -0,0 +1,42 @@ +//! Types for AI Core provider configuration + +use serde::{Deserialize, Serialize}; + +/// Specifies which vendor API type to use for an AI Core deployment +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum AiCoreApiType { + /// Anthropic Claude API (Bedrock-style invoke/converse endpoints) + #[default] + Anthropic, + /// OpenAI Chat Completions API + OpenAI, + /// Google Vertex AI / Gemini API + Vertex, +} + +impl std::fmt::Display for AiCoreApiType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AiCoreApiType::Anthropic => write!(f, "anthropic"), + AiCoreApiType::OpenAI => write!(f, "openai"), + AiCoreApiType::Vertex => write!(f, "vertex"), + } + } +} + +impl std::str::FromStr for AiCoreApiType { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "anthropic" => Ok(AiCoreApiType::Anthropic), + "openai" => Ok(AiCoreApiType::OpenAI), + "vertex" => Ok(AiCoreApiType::Vertex), + _ => Err(anyhow::anyhow!( + "Unknown AI Core API type: '{}'. Expected one of: anthropic, openai, vertex", + s + )), + } + } +} diff --git a/crates/llm/src/aicore/vertex.rs b/crates/llm/src/aicore/vertex.rs new file mode 100644 index 00000000..181dd6ec --- /dev/null +++ b/crates/llm/src/aicore/vertex.rs @@ -0,0 +1,135 @@ +//! AI Core client for Google Vertex AI / Gemini API +//! +//! This client wraps the VertexClient with AI Core authentication. + +use crate::{ + auth::TokenManager, + types::*, + vertex::{AuthProvider, RequestCustomizer, VertexAuth, VertexClient}, + LLMProvider, StreamingCallback, +}; +use anyhow::Result; +use async_trait::async_trait; +use std::sync::Arc; + +// ============================================================================ +// AI Core Authentication Provider for Vertex +// ============================================================================ + +/// AI Core authentication provider for Vertex API (uses Bearer token in headers) +struct AiCoreVertexAuthProvider { + token_manager: Arc, +} + +impl AiCoreVertexAuthProvider { + fn new(token_manager: Arc) -> Self { + Self { token_manager } + } +} + +#[async_trait] +impl AuthProvider for AiCoreVertexAuthProvider { + async fn get_auth(&self) -> Result { + let token = self.token_manager.get_valid_token().await?; + Ok(VertexAuth { + query_params: vec![], // AI Core doesn't use query params for auth + headers: vec![("Authorization".to_string(), format!("Bearer {token}"))], + }) + } +} + +// ============================================================================ +// AI Core Request Customizer for Vertex +// ============================================================================ + +/// AI Core request customizer for Vertex API +struct AiCoreVertexRequestCustomizer; + +impl RequestCustomizer for AiCoreVertexRequestCustomizer { + fn customize_request(&self, _request: &mut serde_json::Value) -> Result<()> { + Ok(()) + } + + fn get_additional_headers(&self) -> Vec<(String, String)> { + vec![ + ("AI-Resource-Group".to_string(), "default".to_string()), + ("Content-Type".to_string(), "application/json".to_string()), + ] + } + + fn customize_url(&self, base_url: &str, model: &str, streaming: bool) -> String { + // AI Core Vertex deployments use the same URL pattern + if streaming { + format!("{}/models/{}:streamGenerateContent", base_url, model) + } else { + format!("{}/models/{}:generateContent", base_url, model) + } + } +} + +// ============================================================================ +// AI Core Vertex Client +// ============================================================================ + +/// AI Core client for Google Vertex AI / Gemini models +pub struct AiCoreVertexClient { + vertex_client: VertexClient, + custom_config: Option, +} + +impl AiCoreVertexClient { + fn create_vertex_client( + token_manager: Arc, + base_url: String, + model: String, + ) -> VertexClient { + let auth_provider = Box::new(AiCoreVertexAuthProvider::new(token_manager)); + let request_customizer = Box::new(AiCoreVertexRequestCustomizer); + + VertexClient::with_customization(model, base_url, auth_provider, request_customizer) + } + + pub fn new(token_manager: Arc, base_url: String, model: String) -> Self { + let vertex_client = Self::create_vertex_client(token_manager, base_url, model); + Self { + vertex_client, + custom_config: None, + } + } + + /// Create a new client with recording capability + pub fn new_with_recorder>( + token_manager: Arc, + base_url: String, + model: String, + recording_path: P, + ) -> Self { + let vertex_client = Self::create_vertex_client(token_manager, base_url, model) + .with_recorder(recording_path); + Self { + vertex_client, + custom_config: None, + } + } + + /// Set custom model configuration to be merged into API requests + pub fn with_custom_config(mut self, custom_config: serde_json::Value) -> Self { + self.vertex_client = self.vertex_client.with_custom_config(custom_config.clone()); + self.custom_config = Some(custom_config); + self + } +} + +#[async_trait] +impl LLMProvider for AiCoreVertexClient { + async fn send_message( + &mut self, + request: LLMRequest, + streaming_callback: Option<&StreamingCallback>, + ) -> Result { + // Delegate to the wrapped VertexClient + self.vertex_client + .send_message(request, streaming_callback) + .await + } +} diff --git a/crates/llm/src/aicore_converse.rs b/crates/llm/src/aicore_converse.rs deleted file mode 100644 index 91b2dbf7..00000000 --- a/crates/llm/src/aicore_converse.rs +++ /dev/null @@ -1,812 +0,0 @@ -use crate::llm::{ - recording::APIRecorder, types::*, utils, ApiError, LLMProvider, RateLimitHandler, - StreamingCallback, StreamingChunk, -}; -use anyhow::Result; -use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use reqwest::{Client, Response}; -use serde::{Deserialize, Serialize}; -use std::str::{self}; -use std::sync::Arc; -use std::time::Duration; -use tracing::debug; - -use super::auth::TokenManager; - -/// Response structure for Anthropic error messages -#[derive(Debug, Serialize, serde::Deserialize)] -struct AnthropicErrorResponse { - #[serde(rename = "type")] - error_type: String, - error: AnthropicErrorPayload, -} - -#[derive(Debug, Serialize, serde::Deserialize)] -struct AnthropicErrorPayload { - #[serde(rename = "type")] - error_type: String, - message: String, -} - -/// Rate limit information extracted from response headers -#[derive(Debug)] -struct AnthropicRateLimitInfo { - requests_limit: Option, - requests_remaining: Option, - requests_reset: Option>, - tokens_limit: Option, - tokens_remaining: Option, - tokens_reset: Option>, - retry_after: Option, -} - -impl RateLimitHandler for AnthropicRateLimitInfo { - /// Extract rate limit information from response headers - fn from_response(response: &Response) -> Self { - let headers = response.headers(); - - fn parse_header( - headers: &reqwest::header::HeaderMap, - name: &str, - ) -> Option { - headers - .get(name) - .and_then(|h| h.to_str().ok()) - .and_then(|s| s.parse().ok()) - } - - fn parse_datetime( - headers: &reqwest::header::HeaderMap, - name: &str, - ) -> Option> { - headers - .get(name) - .and_then(|h| h.to_str().ok()) - .and_then(|s| DateTime::parse_from_rfc3339(s).ok()) - .map(|dt| dt.into()) - } - - Self { - requests_limit: parse_header(headers, "anthropic-ratelimit-requests-limit"), - requests_remaining: parse_header(headers, "anthropic-ratelimit-requests-remaining"), - requests_reset: parse_datetime(headers, "anthropic-ratelimit-requests-reset"), - tokens_limit: parse_header(headers, "anthropic-ratelimit-tokens-limit"), - tokens_remaining: parse_header(headers, "anthropic-ratelimit-tokens-remaining"), - tokens_reset: parse_datetime(headers, "anthropic-ratelimit-tokens-reset"), - retry_after: parse_header::(headers, "retry-after").map(Duration::from_secs), - } - } - - /// Calculate how long to wait before retrying based on rate limit information - fn get_retry_delay(&self) -> Duration { - // If we have a specific retry-after duration, use that - if let Some(retry_after) = self.retry_after { - return retry_after; - } - - // Otherwise, calculate based on reset times - let now = Utc::now(); - let mut shortest_wait = Duration::from_secs(60); // Default to 60 seconds if no information - - // Check requests reset time - if let Some(reset_time) = self.requests_reset { - if reset_time > now { - shortest_wait = shortest_wait.min(Duration::from_secs( - (reset_time - now).num_seconds().max(0) as u64, - )); - } - } - - // Check tokens reset time - if let Some(reset_time) = self.tokens_reset { - if reset_time > now { - shortest_wait = shortest_wait.min(Duration::from_secs( - (reset_time - now).num_seconds().max(0) as u64, - )); - } - } - - // Add a small buffer to avoid hitting the limit exactly at reset time - shortest_wait + Duration::from_secs(1) - } - - /// Log current rate limit status - fn log_status(&self) { - debug!( - "Rate limits - Requests: {}/{} (reset: {}), Tokens: {}/{} (reset: {})", - self.requests_remaining - .map_or("?".to_string(), |r| r.to_string()), - self.requests_limit - .map_or("?".to_string(), |l| l.to_string()), - self.requests_reset - .map_or("unknown".to_string(), |r| r.to_string()), - self.tokens_remaining - .map_or("?".to_string(), |r| r.to_string()), - self.tokens_limit.map_or("?".to_string(), |l| l.to_string()), - self.tokens_reset - .map_or("unknown".to_string(), |r| r.to_string()), - ); - } -} - -/// Cache control settings for Anthropic API request -#[derive(Debug, Serialize)] -struct CacheControl { - #[serde(rename = "type")] - cache_type: String, -} - -/// System content block with optional cache control -#[derive(Debug, Serialize)] -struct SystemBlock { - #[serde(rename = "type")] - block_type: String, - text: String, - #[serde(skip_serializing_if = "Option::is_none")] - cache_control: Option, -} - -#[derive(Debug, Serialize)] -struct ThinkingConfiguration { - #[serde(rename = "type")] - thinking_type: String, - budget_tokens: usize, -} - -/// AWS Bedrock Converse request structure for all models -#[derive(Debug, Serialize)] -struct ConverseRequest { - messages: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - system: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - #[serde(rename = "inferenceConfig")] - inference_config: Option, - #[serde( - skip_serializing_if = "Option::is_none", - rename = "additionalModelRequestFields" - )] - additional_model_request_fields: Option, - #[serde(skip_serializing_if = "Option::is_none", rename = "toolConfig")] - tool_config: Option, -} - -#[derive(Debug, Serialize)] -struct InferenceConfiguration { - max_tokens: usize, - temperature: f32, -} - -#[derive(Debug, Serialize)] -struct ToolConfiguration { - #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - tool_choice: Option, -} -/// Response structure for AWS Bedrock API responses -#[derive(Debug, Deserialize)] -struct ConverseResponse { - output: ConverseOutput, - #[serde(default)] - #[allow(dead_code)] - stop_reason: String, - usage: TokenUsage, -} - -#[derive(Debug, Deserialize)] -struct ConverseOutput { - message: Message, -} - -/// Usage information from AiCore Converse API -#[derive(Debug, Deserialize)] -struct TokenUsage { - #[serde(default, rename = "inputTokens")] - input_tokens: u32, - #[serde(default, rename = "outputTokens")] - output_tokens: u32, - #[serde(default, rename = "totalTokens")] - total_tokens: u32, - #[serde(default, rename = "cacheCreationInputTokens")] - cache_creation_input_tokens: u32, - #[serde(default, rename = "cacheReadInputTokens")] - cache_read_input_tokens: u32, -} - -#[derive(Debug, Deserialize)] -struct StreamEventCommon { - #[serde(rename = "contentBlockIndex")] - index: usize, -} - -// Structure to parse SSE events from the API -#[derive(Debug, Deserialize)] -struct StreamEvent { - #[serde(default)] - #[serde(rename = "messageStart")] - message_start: Option, - - #[serde(default)] - #[serde(rename = "contentBlockStart")] - content_block_start: Option, - - #[serde(default)] - #[serde(rename = "contentBlockDelta")] - content_block_delta: Option, - - #[serde(default)] - #[serde(rename = "contentBlockStop")] - content_block_stop: Option, - - #[serde(default)] - #[serde(rename = "messageStop")] - message_stop: Option, - - #[serde(default)] - metadata: Option, - - #[serde(default)] - ping: Option, -} - -#[derive(Debug, Deserialize)] -struct MessageStartEvent { - role: String, -} - -#[derive(Debug, Deserialize)] -struct ContentBlockStartEvent { - start: StreamContentBlockStart, - #[serde(rename = "contentBlockIndex")] - index: usize, -} - -#[derive(Debug, Deserialize)] -struct ContentBlockDeltaEvent { - delta: serde_json::Value, - #[serde(rename = "contentBlockIndex")] - index: usize, -} - -#[derive(Debug, Deserialize)] -struct ContentBlockStopEvent { - #[serde(rename = "contentBlockIndex")] - index: usize, -} - -#[derive(Debug, Deserialize)] -struct MessageStopEvent { - #[serde(rename = "stopReason")] - stop_reason: String, - #[serde(default)] - #[serde(rename = "additionalModelResponseFields")] - additional_model_response_fields: Option, -} - -#[derive(Debug, Deserialize)] -struct MetadataEvent { - usage: Option, - metrics: Option, - trace: Option, -} - -#[derive(Debug, Deserialize)] -struct ConverseMetrics { - #[serde(rename = "latencyMs")] - latency_ms: u64, -} - -#[derive(Debug, Deserialize)] -struct ConverseTrace { - guardrail: Option, -} - -#[derive(Debug, Deserialize)] -struct StreamContentBlockStart { - #[serde(rename = "type")] - block_type: String, - // Fields for text blocks - text: Option, - // Fields for thinking blocks - thinking: Option, - signature: Option, - // Fields for tool use blocks - id: Option, - name: Option, - input: Option, -} - -// We use serde_json::Value directly for content deltas since -// the response structure can vary and may be transformed by proxies - -pub struct AiCoreClient { - token_manager: Arc, - client: Client, - base_url: String, - recorder: Option, -} - -impl AiCoreClient { - pub fn new(token_manager: Arc, base_url: String) -> Self { - Self { - token_manager, - client: Client::new(), - base_url, - recorder: None, - } - } - - /// Create a new client with recording capability - pub fn new_with_recorder>( - token_manager: Arc, - base_url: String, - recording_path: P, - ) -> Self { - Self { - token_manager, - client: Client::new(), - base_url, - recorder: Some(APIRecorder::new(recording_path)), - } - } - - fn get_url(&self, streaming: bool) -> String { - if streaming { - format!("{}/converse-stream", self.base_url) - } else { - format!("{}/converse", self.base_url) - } - } - - async fn send_with_retry( - &mut self, - request: &ConverseRequest, - streaming_callback: Option<&StreamingCallback>, - max_retries: u32, - ) -> Result { - let mut attempts = 0; - - loop { - match self.try_send_request(request, streaming_callback).await { - Ok((response, rate_limits)) => { - // Log rate limit status on successful response - rate_limits.log_status(); - return Ok(response); - } - Err(e) => { - if utils::handle_retryable_error::( - &e, - attempts, - max_retries, - streaming_callback, - ) - .await - { - attempts += 1; - continue; - } - return Err(e); - } - } - } - } - - async fn try_send_request( - &mut self, - request: &ConverseRequest, - streaming_callback: Option<&StreamingCallback>, - ) -> Result<(LLMResponse, AnthropicRateLimitInfo)> { - let token = self.token_manager.get_valid_token().await?; - - // Start recording before HTTP request to capture real latency - if let Some(recorder) = &self.recorder { - let request_json = serde_json::to_value(request)?; - recorder.start_recording(request_json)?; - } - - let request_builder = self - .client - .post(&self.get_url(streaming_callback.is_some())) - .header("AI-Resource-Group", "default") - .header("Content-Type", "application/json") - .header("Authorization", format!("Bearer {}", token)); - - let response = request_builder - .json(&request) - .send() - .await - .map_err(|e| ApiError::NetworkError(e.to_string()))?; - - // Log raw headers for debugging - debug!("Response headers: {:?}", response.headers()); - - let mut response = utils::check_response_error::(response).await?; - let rate_limits = AnthropicRateLimitInfo::from_response(&response); - - // Log parsed rate limits - debug!("Parsed rate limits: {:?}", rate_limits); - - if let Some(callback) = streaming_callback { - let mut blocks: Vec = Vec::new(); - let mut current_content = String::new(); - let mut line_buffer = String::new(); - let mut usage = TokenUsage { - input_tokens: 0, - output_tokens: 0, - total_tokens: 0, - cache_creation_input_tokens: 0, - cache_read_input_tokens: 0, - }; - - fn process_chunk( - chunk: &[u8], - line_buffer: &mut String, - blocks: &mut Vec, - usage: &mut TokenUsage, - current_content: &mut String, - callback: &StreamingCallback, - recorder: &Option, - ) -> Result<()> { - let chunk_str = str::from_utf8(chunk)?; - - for c in chunk_str.chars() { - if c == '\n' { - if !line_buffer.is_empty() { - process_sse_line( - line_buffer, - blocks, - usage, - current_content, - callback, - recorder, - )?; - line_buffer.clear(); - } - } else { - line_buffer.push(c); - } - } - Ok(()) - } - - fn process_sse_line( - line: &str, - blocks: &mut Vec, - usage: &mut TokenUsage, - current_content: &mut String, - callback: &StreamingCallback, - recorder: &Option, - ) -> Result<()> { - if let Some(data) = line.strip_prefix("data: ") { - println!("SSE line: {}", data); - match serde_json::from_str::(data) { - Ok(event) => { - // Record the chunk if recorder is available - if let Some(recorder) = &recorder { - recorder.record_chunk(data)?; - } - - // Process based on which event type was received - if let Some(content_start) = &event.content_block_start { - // Check the index matches expected - let index = content_start.index; - if index != blocks.len() { - return Err(anyhow::anyhow!( - "Start index {} does not match expected block {}", - index, - blocks.len() - )); - } - - current_content.clear(); - let block = match content_start.start.block_type.as_str() { - "thinking" => { - if let Some(thinking) = &content_start.start.thinking { - current_content.push_str(thinking); - } - ContentBlock::Thinking { - thinking: current_content.clone(), - signature: content_start - .start - .signature - .clone() - .unwrap_or_default(), - } - } - "text" => { - if let Some(text) = &content_start.start.text { - current_content.push_str(text); - } - ContentBlock::Text { - text: current_content.clone(), - } - } - "tool_use" => { - if let Some(input) = &content_start.start.input { - current_content.push_str(input); - } - ContentBlock::ToolUse { - id: content_start.start.id.clone().unwrap_or_default(), - name: content_start - .start - .name - .clone() - .unwrap_or_default(), - input: serde_json::Value::Null, - thought_signature: None, - } - } - _ => ContentBlock::Text { - text: String::new(), - }, - }; - blocks.push(block); - } else if let Some(content_delta) = &event.content_block_delta { - let index = content_delta.index; - - // Check if we have any blocks at all - if blocks.is_empty() { - return Err(anyhow::anyhow!( - "Received Delta but no blocks exist" - )); - } - - if index != blocks.len() - 1 { - return Err(anyhow::anyhow!( - "Delta index {} does not match current block {}", - index, - blocks.len() - 1 - )); - } - - // Try to extract the text from the delta - if let Some(text_obj) = content_delta.delta.get("text") { - if let Some(text) = text_obj.as_str() { - callback(&StreamingChunk::Text(text.to_string()))?; - current_content.push_str(text); - } - } else if content_delta.delta.get("SDK_UNKNOWN_MEMBER").is_some() { - // This is the reasoningContent delta - if let Some(reasoning) = - content_delta.delta["SDK_UNKNOWN_MEMBER"].get("name") - { - if reasoning.as_str() == Some("reasoningContent") { - // Treat this as thinking content for now - // In a real implementation, we'd extract the text, but it's not available in the payload - callback(&StreamingChunk::Thinking( - "Thinking...".to_string(), - ))?; - } - } - } else if let Some(partial_json) = - content_delta.delta.get("partialJson") - { - if let Some(partial_json_str) = partial_json.as_str() { - current_content.push_str(partial_json_str); - } - } - } else if let Some(content_stop) = &event.content_block_stop { - let index = content_stop.index; - - // Check if we have any blocks at all - if blocks.is_empty() { - return Err(anyhow::anyhow!( - "Received Stop but no blocks exist" - )); - } - - if index != blocks.len() - 1 { - return Err(anyhow::anyhow!( - "Stop index {} does not match current block {}", - index, - blocks.len() - 1 - )); - } - - match blocks.last_mut().unwrap() { - ContentBlock::Text { text } => { - *text = current_content.clone(); - } - ContentBlock::ToolUse { input, .. } => { - if let Ok(json) = serde_json::from_str(¤t_content) { - *input = json; - } - } - _ => {} - } - } else if let Some(metadata) = &event.metadata { - if let Some(meta_usage) = &metadata.usage { - usage.input_tokens = meta_usage.input_tokens; - usage.output_tokens = meta_usage.output_tokens; - usage.cache_creation_input_tokens = - meta_usage.cache_creation_input_tokens; - usage.cache_read_input_tokens = - meta_usage.cache_read_input_tokens; - } - } - } - Err(e) => { - println!("[ERROR] Failed to parse event: {}", e); - } - } - } - Ok(()) - } - - while let Some(chunk) = response.chunk().await? { - process_chunk( - &chunk, - &mut line_buffer, - &mut blocks, - &mut usage, - &mut current_content, - callback, - &self.recorder, - )?; - } - - // Process any remaining data in the buffer - if !line_buffer.is_empty() { - process_sse_line( - &line_buffer, - &mut blocks, - &mut usage, - &mut current_content, - callback, - &self.recorder, - )?; - } - - // Send StreamingComplete to indicate streaming has finished - callback(&StreamingChunk::StreamingComplete)?; - - // End recording if a recorder is available - if let Some(recorder) = &self.recorder { - recorder.end_recording()?; - } - - Ok(( - LLMResponse { - content: blocks, - usage: Usage { - input_tokens: usage.input_tokens, - output_tokens: usage.output_tokens, - cache_creation_input_tokens: usage.cache_creation_input_tokens, - cache_read_input_tokens: usage.cache_read_input_tokens, - }, - rate_limit_info: None, - }, - rate_limits, - )) - } else { - let response_text = response - .text() - .await - .map_err(|e| ApiError::NetworkError(e.to_string()))?; - - let converse_response: ConverseResponse = serde_json::from_str(&response_text) - .map_err(|e| ApiError::Unknown(format!("Failed to parse response: {}", e)))?; - - let content = match converse_response.output.message.content { - MessageContent::Text(text) => vec![ContentBlock::Text { text }], - MessageContent::Structured(blocks) => blocks, - }; - - let llm_response = LLMResponse { - content, - usage: Usage { - input_tokens: converse_response.usage.input_tokens, - output_tokens: converse_response.usage.output_tokens, - cache_creation_input_tokens: converse_response - .usage - .cache_creation_input_tokens, - cache_read_input_tokens: converse_response.usage.cache_read_input_tokens, - }, - rate_limit_info: None, - }; - - Ok((llm_response, rate_limits)) - } - } -} - -#[async_trait] -impl LLMProvider for AiCoreClient { - async fn send_message( - &mut self, - request: LLMRequest, - streaming_callback: Option<&StreamingCallback>, - ) -> Result { - // Convert system prompt to system blocks with cache control - let system = Some(vec![SystemBlock { - block_type: "text".to_string(), - text: request.system_prompt, - // Add cache_control to the system prompt to utilize Anthropic's caching - cache_control: Some(CacheControl { - cache_type: "ephemeral".to_string(), - }), - }]); - - // Determine if we have tools and create tool_choice - let has_tools = request.tools.is_some(); - let tool_choice = if has_tools { - Some(serde_json::json!({ - "type": "any", - })) - } else { - None - }; - - // Create tools array with cache control on the last tool if present - let tools = request.tools.map(|tools| { - let mut tools_json = tools - .into_iter() - .map(|tool| { - serde_json::json!({ - "name": tool.name, - "description": tool.description, - "input_schema": tool.parameters - }) - }) - .collect::>(); - - // Add cache_control to the last tool if any exist - if let Some(last_tool) = tools_json.last_mut() { - if let Some(obj) = last_tool.as_object_mut() { - obj.insert( - "cache_control".to_string(), - serde_json::json!({"type": "ephemeral"}), - ); - } - } - - tools_json - }); - - // Always enable thinking mode and max tokens for large models - let thinking = Some(ThinkingConfiguration { - thinking_type: "enabled".to_string(), - budget_tokens: 4000, - }); - let max_tokens = 128000; - - // In der send_message-Methode - let tool_config = if has_tools { - Some(ToolConfiguration { tools, tool_choice }) - } else { - None - }; - - let inference_config = Some(InferenceConfiguration { - max_tokens, - temperature: 1.0, - }); - - let messages = request - .messages - .into_iter() - .map(|mut message| { - if let MessageContent::Text(text) = message.content { - message.content = MessageContent::Structured(vec![ContentBlock::Text { text }]); - } - message - }) - .collect(); - - let converse_request = ConverseRequest { - messages, - system, - inference_config, - tool_config, - additional_model_request_fields: Some(serde_json::json!({ - // "streaming": streaming_callback.is_some(), - "thinking": thinking - })), - }; - - self.send_with_retry(&converse_request, streaming_callback, 3) - .await - } -} diff --git a/crates/llm/src/factory.rs b/crates/llm/src/factory.rs index 8d379371..894ebf22 100644 --- a/crates/llm/src/factory.rs +++ b/crates/llm/src/factory.rs @@ -1,9 +1,10 @@ +use crate::aicore::{AiCoreAnthropicClient, AiCoreApiType, AiCoreOpenAIClient, AiCoreVertexClient}; use crate::auth::TokenManager; use crate::provider_config::{ConfigurationSystem, ModelConfig, ProviderConfig}; use crate::{ - recording::PlaybackState, AiCoreClient, AnthropicClient, CerebrasClient, GroqClient, - LLMProvider, MistralAiClient, OllamaClient, OpenAIClient, OpenAIResponsesClient, - OpenRouterClient, VertexClient, + recording::PlaybackState, AnthropicClient, CerebrasClient, GroqClient, LLMProvider, + MistralAiClient, OllamaClient, OpenAIClient, OpenAIResponsesClient, OpenRouterClient, + VertexClient, }; use anyhow::{Context, Result}; use clap::ValueEnum; @@ -106,7 +107,19 @@ impl WithCustomConfig for VertexClient { } } -impl WithCustomConfig for AiCoreClient { +impl WithCustomConfig for AiCoreAnthropicClient { + fn with_custom_config(self, custom_config: Value) -> Self { + self.with_custom_config(custom_config) + } +} + +impl WithCustomConfig for AiCoreOpenAIClient { + fn with_custom_config(self, custom_config: Value) -> Self { + self.with_custom_config(custom_config) + } +} + +impl WithCustomConfig for AiCoreVertexClient { fn with_custom_config(self, custom_config: Value) -> Self { self.with_custom_config(custom_config) } @@ -271,6 +284,86 @@ pub async fn create_llm_client_from_configs( } } +/// AI Core model deployment configuration +/// +/// The `models` field in the AI Core provider config can be specified in two formats: +/// +/// 1. Simple format (backwards compatible) - just the deployment UUID: +/// ```json +/// "models": { +/// "claude-3.5-sonnet": "deployment-uuid-here" +/// } +/// ``` +/// This defaults to Anthropic API type. +/// +/// 2. Extended format - object with `deployment` and `api_type`: +/// ```json +/// "models": { +/// "claude-3.5-sonnet": { +/// "deployment": "deployment-uuid-here", +/// "api_type": "anthropic" +/// }, +/// "gpt-4o": { +/// "deployment": "another-deployment-uuid", +/// "api_type": "openai" +/// }, +/// "gemini-pro": { +/// "deployment": "vertex-deployment-uuid", +/// "api_type": "vertex" +/// } +/// } +/// ``` +#[derive(Debug)] +struct AiCoreDeployment { + deployment_uuid: String, + api_type: AiCoreApiType, +} + +fn parse_aicore_deployment(model_id: &str, value: &Value) -> Result { + // Try simple string format first (backwards compatible) + if let Some(uuid) = value.as_str() { + return Ok(AiCoreDeployment { + deployment_uuid: uuid.to_string(), + api_type: AiCoreApiType::default(), // Defaults to Anthropic + }); + } + + // Try extended object format + if let Some(obj) = value.as_object() { + let deployment_uuid = obj + .get("deployment") + .and_then(|v| v.as_str()) + .ok_or_else(|| { + anyhow::anyhow!( + "Missing 'deployment' field in AI Core config for model '{}'", + model_id + ) + })? + .to_string(); + + let api_type = if let Some(api_type_str) = obj.get("api_type").and_then(|v| v.as_str()) { + api_type_str.parse().with_context(|| { + format!( + "Invalid api_type for model '{}'. Valid values are: anthropic, openai, vertex", + model_id + ) + })? + } else { + AiCoreApiType::default() + }; + + return Ok(AiCoreDeployment { + deployment_uuid, + api_type, + }); + } + + Err(anyhow::anyhow!( + "Invalid deployment configuration for model '{}'. Expected string (deployment UUID) or object with 'deployment' and optional 'api_type' fields", + model_id + )) +} + async fn create_ai_core_client( model_config: &ModelConfig, provider_config: &ProviderConfig, @@ -303,15 +396,14 @@ async fn create_ai_core_client( .and_then(|v| v.as_object()) .ok_or_else(|| anyhow::anyhow!("models not found in AI Core provider config"))?; - let deployment_uuid = models - .get(&model_config.id) - .and_then(|v| v.as_str()) - .ok_or_else(|| { - anyhow::anyhow!( - "No deployment found for model '{}' in AI Core config", - model_config.id - ) - })?; + let deployment_value = models.get(&model_config.id).ok_or_else(|| { + anyhow::anyhow!( + "No deployment found for model '{}' in AI Core config", + model_config.id + ) + })?; + + let deployment = parse_aicore_deployment(&model_config.id, deployment_value)?; let token_manager = TokenManager::new( client_id.to_string(), @@ -324,17 +416,49 @@ async fn create_ai_core_client( let api_url = format!( "{}/deployments/{}", api_base_url.trim_end_matches('/'), - deployment_uuid + deployment.deployment_uuid ); - let client = if let Some(path) = record_path { - AiCoreClient::new_with_recorder(token_manager, api_url, path) - } else { - AiCoreClient::new(token_manager, api_url) - }; - - let client = apply_custom_config(client, model_config); - Ok(Box::new(client)) + // Create the appropriate client based on API type + match deployment.api_type { + AiCoreApiType::Anthropic => { + let client = if let Some(path) = record_path { + AiCoreAnthropicClient::new_with_recorder(token_manager, api_url, path) + } else { + AiCoreAnthropicClient::new(token_manager, api_url) + }; + let client = apply_custom_config(client, model_config); + Ok(Box::new(client)) + } + AiCoreApiType::OpenAI => { + let client = if let Some(path) = record_path { + AiCoreOpenAIClient::new_with_recorder( + token_manager, + api_url, + model_config.id.clone(), + path, + ) + } else { + AiCoreOpenAIClient::new(token_manager, api_url, model_config.id.clone()) + }; + let client = apply_custom_config(client, model_config); + Ok(Box::new(client)) + } + AiCoreApiType::Vertex => { + let client = if let Some(path) = record_path { + AiCoreVertexClient::new_with_recorder( + token_manager, + api_url, + model_config.id.clone(), + path, + ) + } else { + AiCoreVertexClient::new(token_manager, api_url, model_config.id.clone()) + }; + let client = apply_custom_config(client, model_config); + Ok(Box::new(client)) + } + } } async fn create_anthropic_client( diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index e96315c6..f2964569 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -2,7 +2,7 @@ //! //! This module implements: //! - Common interface for LLM interactions via the LLMProvider trait -//! - Support for multiple providers (Anthropic, OpenAI, Ollama, Vertex) +//! - Support for multiple providers (Anthropic, OpenAI, Ollama, Vertex, AI Core) //! - Message streaming capabilities //! - Provider-specific implementations and optimizations //! - Shared types and utilities for LLM interactions @@ -13,8 +13,7 @@ mod tests; mod utils; -//pub mod aicore_converse; -pub mod aicore_invoke; +pub mod aicore; pub mod anthropic; pub mod auth; pub mod cerebras; @@ -33,7 +32,10 @@ pub mod streaming; pub mod types; pub mod vertex; -pub use aicore_invoke::AiCoreClient; +pub use aicore::{ + create_aicore_client, create_aicore_client_with_recorder, AiCoreAnthropicClient, AiCoreApiType, + AiCoreOpenAIClient, AiCoreVertexClient, +}; pub use anthropic::AnthropicClient; pub use cerebras::CerebrasClient; pub use groq::GroqClient; diff --git a/crates/llm/src/vertex.rs b/crates/llm/src/vertex.rs index 972095d1..0ab8c3b8 100644 --- a/crates/llm/src/vertex.rs +++ b/crates/llm/src/vertex.rs @@ -10,6 +10,86 @@ use serde_json::json; use std::time::{Duration, SystemTime}; use tracing::{debug, trace, warn}; +// ============================================================================ +// Customization Traits +// ============================================================================ + +/// Trait for providing authentication for Vertex API requests +#[async_trait] +pub trait AuthProvider: Send + Sync { + /// Get authentication to apply to the request. + /// Returns either query parameters or headers (or both). + async fn get_auth(&self) -> Result; +} + +/// Authentication configuration for Vertex API +pub struct VertexAuth { + /// Query parameters to add to the URL (e.g., `key=...`) + pub query_params: Vec<(String, String)>, + /// Headers to add to the request (e.g., `Authorization: Bearer ...`) + pub headers: Vec<(String, String)>, +} + +/// Trait for customizing Vertex API requests +pub trait RequestCustomizer: Send + Sync { + /// Customize the request JSON before sending + fn customize_request(&self, request: &mut serde_json::Value) -> Result<()>; + /// Get additional headers to include in requests + fn get_additional_headers(&self) -> Vec<(String, String)>; + /// Customize the URL for a request + fn customize_url(&self, base_url: &str, model: &str, streaming: bool) -> String; +} + +// ============================================================================ +// Default Implementations +// ============================================================================ + +/// Default API key authentication provider (uses query parameter) +pub struct ApiKeyAuth { + api_key: String, +} + +impl ApiKeyAuth { + pub fn new(api_key: String) -> Self { + Self { api_key } + } +} + +#[async_trait] +impl AuthProvider for ApiKeyAuth { + async fn get_auth(&self) -> Result { + Ok(VertexAuth { + query_params: vec![("key".to_string(), self.api_key.clone())], + headers: vec![], + }) + } +} + +/// Default request customizer for Google Generative Language API +pub struct DefaultRequestCustomizer; + +impl RequestCustomizer for DefaultRequestCustomizer { + fn customize_request(&self, _request: &mut serde_json::Value) -> Result<()> { + Ok(()) + } + + fn get_additional_headers(&self) -> Vec<(String, String)> { + vec![("Content-Type".to_string(), "application/json".to_string())] + } + + fn customize_url(&self, base_url: &str, model: &str, streaming: bool) -> String { + if streaming { + format!("{}/models/{}:streamGenerateContent", base_url, model) + } else { + format!("{}/models/{}:generateContent", base_url, model) + } + } +} + +// ============================================================================ +// Request/Response Types +// ============================================================================ + #[derive(Debug, Serialize)] struct VertexRequest { #[serde(skip_serializing_if = "Option::is_none")] @@ -84,14 +164,14 @@ struct VertexResponse { response_id: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Default)] struct VertexUsageMetadata { - #[serde(rename = "promptTokenCount")] + #[serde(rename = "promptTokenCount", default)] prompt_token_count: u32, - #[serde(rename = "candidatesTokenCount")] + #[serde(rename = "candidatesTokenCount", default)] candidates_token_count: u32, #[allow(dead_code)] - #[serde(rename = "totalTokenCount")] + #[serde(rename = "totalTokenCount", default)] total_token_count: u32, #[serde(rename = "cachedContentTokenCount")] cached_content_token_count: Option, @@ -165,27 +245,29 @@ impl RateLimitHandler for VertexRateLimitInfo { pub struct VertexClient { client: Client, - api_key: String, model: String, base_url: String, recorder: Option, custom_config: Option, + // Customization points + auth_provider: Box, + request_customizer: Box, } impl VertexClient { pub fn default_base_url() -> String { "https://generativelanguage.googleapis.com/v1beta".to_string() - //"https://aiplatform.googleapis.com/v1/publishers/google".to_string() } pub fn new(api_key: String, model: String, base_url: String) -> Self { Self { client: Client::new(), - api_key, model, base_url, recorder: None, custom_config: None, + auth_provider: Box::new(ApiKeyAuth::new(api_key)), + request_customizer: Box::new(DefaultRequestCustomizer), } } @@ -198,14 +280,39 @@ impl VertexClient { ) -> Self { Self { client: Client::new(), - api_key, model, base_url, recorder: Some(APIRecorder::new(recording_path)), custom_config: None, + auth_provider: Box::new(ApiKeyAuth::new(api_key)), + request_customizer: Box::new(DefaultRequestCustomizer), } } + /// Create a new client with custom authentication and request handling + pub fn with_customization( + model: String, + base_url: String, + auth_provider: Box, + request_customizer: Box, + ) -> Self { + Self { + client: Client::new(), + model, + base_url, + recorder: None, + custom_config: None, + auth_provider, + request_customizer, + } + } + + /// Add recording capability to an existing client + pub fn with_recorder>(mut self, recording_path: P) -> Self { + self.recorder = Some(APIRecorder::new(recording_path)); + self + } + /// Set custom model configuration to be merged into API requests pub fn with_custom_config(mut self, custom_config: serde_json::Value) -> Self { self.custom_config = Some(custom_config); @@ -213,14 +320,8 @@ impl VertexClient { } fn get_url(&self, streaming: bool) -> String { - if streaming { - format!( - "{}/models/{}:streamGenerateContent", - self.base_url, self.model - ) - } else { - format!("{}/models/{}:generateContent", self.base_url, self.model) - } + self.request_customizer + .customize_url(&self.base_url, &self.model, streaming) } fn convert_message(message: &Message) -> VertexMessage { @@ -372,17 +473,38 @@ impl VertexClient { request_json = crate::config_merge::merge_json(request_json, custom_config.clone()); } - trace!( + // Allow request customizer to modify the request + self.request_customizer + .customize_request(&mut request_json)?; + + debug!( "Sending Vertex request to {}:\n{}", self.model, serde_json::to_string_pretty(&request_json)? ); - let response = self - .client - .post(&url) - .query(&[("key", &self.api_key)]) - .header("Content-Type", "application/json") + // Get authentication + let auth = self.auth_provider.get_auth().await?; + + // Build request + let mut request_builder = self.client.post(&url); + + // Add query parameters from auth + if !auth.query_params.is_empty() { + request_builder = request_builder.query(&auth.query_params); + } + + // Add headers from auth + for (key, value) in auth.headers { + request_builder = request_builder.header(key, value); + } + + // Add additional headers from customizer + for (key, value) in self.request_customizer.get_additional_headers() { + request_builder = request_builder.header(key, value); + } + + let response = request_builder .json(&request_json) .send() .await @@ -490,16 +612,43 @@ impl VertexClient { request_json = crate::config_merge::merge_json(request_json, custom_config.clone()); } + // Allow request customizer to modify the request + self.request_customizer + .customize_request(&mut request_json)?; + + debug!( + "Sending Vertex streaming request to {}:\n{}", + self.model, + serde_json::to_string_pretty(&request_json)? + ); + // Start recording if a recorder is available if let Some(recorder) = &self.recorder { recorder.start_recording(request_json.clone())?; } - let response = self - .client - .post(self.get_url(true)) - .query(&[("key", &self.api_key), ("alt", &"sse".to_string())]) - .header("Content-Type", "application/json") + // Get authentication + let auth = self.auth_provider.get_auth().await?; + + // Build request - start with URL and add SSE alt parameter + let mut request_builder = self.client.post(self.get_url(true)); + + // Combine auth query params with alt=sse + let mut query_params = auth.query_params; + query_params.push(("alt".to_string(), "sse".to_string())); + request_builder = request_builder.query(&query_params); + + // Add headers from auth + for (key, value) in auth.headers { + request_builder = request_builder.header(key, value); + } + + // Add additional headers from customizer + for (key, value) in self.request_customizer.get_additional_headers() { + request_builder = request_builder.header(key, value); + } + + let response = request_builder .json(&request_json) .send() .await diff --git a/providers.example.json b/providers.example.json index 2d2c8510..fd082b23 100644 --- a/providers.example.json +++ b/providers.example.json @@ -103,8 +103,14 @@ "api_base_url": "https://api.ai.dev.your-region.aws.ml.hana.ondemand.com/v2/inference", "models": { "claude-sonnet-4": "your-claude-deployment-uuid", - "gpt-4": "your-gpt4-deployment-uuid", - "gemini-pro": "your-gemini-deployment-uuid" + "gpt-4": { + "deployment": "your-gpt4-deployment-uuid", + "api_type": "openai" + }, + "gemini-2.5-flash": { + "deployment": "your-gemini-deployment-uuid", + "api_type": "vertex" + } } } }, @@ -117,8 +123,14 @@ "token_url": "https://your-prod-instance.authentication.sap.hana.ondemand.com/oauth/token", "api_base_url": "https://api.ai.your-region.aws.ml.hana.ondemand.com/v2/inference", "models": { - "claude-sonnet-4": "your-prod-claude-deployment-uuid", - "gpt-4": "your-prod-gpt4-deployment-uuid" + "claude-sonnet-4": { + "deployment": "your-prod-claude-deployment-uuid", + "api_type": "anthropic" + }, + "gpt-4o": { + "deployment": "your-prod-gpt4o-deployment-uuid", + "api_type": "openai" + } } } }