From 74f7c9dcae25691d2abe40cb208eb14134999dec Mon Sep 17 00:00:00 2001 From: kerthcet Date: Wed, 24 Dec 2025 11:26:05 +0800 Subject: [PATCH 1/2] add client Signed-off-by: kerthcet --- .env.integration-test | 3 ++ Cargo.lock | 1 + Cargo.toml | 1 + README.md | 49 ++++++++++++++++++++++++ src/client/client.rs | 6 +-- src/config.rs | 50 +++++++++++++++--------- src/lib.rs | 7 +++- src/provider/fake.rs | 82 ++++++++++++++++++++++++++++++++++++++++ src/provider/openai.rs | 48 +++++++++++------------ src/provider/provider.rs | 37 +++++++++++++----- src/router/random.rs | 6 +-- src/router/router.rs | 4 +- src/router/wrr.rs | 6 +-- tests/client.rs | 72 +++++++++++++++++++++++++++++++++++ 14 files changed, 309 insertions(+), 63 deletions(-) create mode 100644 .env.integration-test create mode 100644 src/provider/fake.rs create mode 100644 tests/client.rs diff --git a/.env.integration-test b/.env.integration-test new file mode 100644 index 0000000..099acf2 --- /dev/null +++ b/.env.integration-test @@ -0,0 +1,3 @@ +AMRS_API_KEY=your_amrs_api_key_here +OPENAI_API_KEY=your_openai_api_key_here +FAKE_API_KEY=your_fake_api_key_here diff --git a/Cargo.lock b/Cargo.lock index b4cca6b..6bb4ba0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14,6 +14,7 @@ dependencies = [ "rand 0.9.2", "reqwest", "serde", + "tokio", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 66ecc70..563dd9f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,3 +12,4 @@ lazy_static = "1.5.0" rand = "0.9.2" reqwest = "0.12.26" serde = "1.0.228" +tokio = "1.48.0" diff --git a/README.md b/README.md index 173ba32..0181cec 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,55 @@ The Adaptive Model Routing System (AMRS) is a framework designed to select the best-fit model for exploration and exploitation. (still under development) +Thanks to [async-openai](https://github.com/64bit/async-openai), AMRS builds on top of it to provide adaptive model routing capabilities. + +## Features + +- Flexible routing strategies, including: + - **Random**: Randomly selects a model from the available models. + - **WRR**: Weighted Round Robin selects models based on predefined weights. + - **UCB**: Upper Confidence Bound based model selection (coming soon). + - **Adaptive**: Dynamically selects models based on performance metrics (coming soon). + + +## How to use + +Here's a simple example with random routing mode: + + +```rust +// Before running the code, make sure to set your OpenAI API key in the environment variable: +// export OPENAI_API_KEY="your_openai_api_key" + +use arms::{Client, Config, ModelConfig, CreateResponseArgs, RoutingMode}; + +let config = Config::builder() + .provider("openai") + .routing_mode(RoutingMode::Random) + .model( + ModelConfig::builder() + .id("gpt-3.5-turbo") + .build() + .unwrap(), + ) + .model( + ModelConfig::builder() + .id("gpt-4") + .build() + .unwrap(), + ) + .build() + .unwrap(); + +let mut client = Client::new(config); +let request = CreateResponseArgs::default() + .input("give me a poem about nature") + .build() + .unwrap(); + +let response = client.create_response(request).await.unwrap(); +``` + ## Contributing 🚀 All kinds of contributions are welcomed ! Please follow [Contributing](/CONTRIBUTING.md). diff --git a/src/client/client.rs b/src/client/client.rs index 36a72b3..7b9dc2b 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -17,7 +17,7 @@ impl Client { let providers = cfg .models .iter() - .map(|m| (m.id.clone(), provider::construct_provider(m))) + .map(|m| (m.id.clone(), provider::construct_provider(m.clone()))) .collect(); Self { @@ -28,8 +28,8 @@ impl Client { pub async fn create_response( &mut self, - request: provider::ResponseRequest, - ) -> Result { + request: provider::CreateResponseInput, + ) -> Result { let model_id = self.router.sample(&request); let provider = self.providers.get(&model_id).unwrap(); provider.create_response(request).await diff --git a/src/config.rs b/src/config.rs index 876144d..a056407 100644 --- a/src/config.rs +++ b/src/config.rs @@ -14,6 +14,8 @@ lazy_static! { m.insert("OPENAI", "https://api.openai.com/v1"); m.insert("DEEPINFRA", "https://api.deepinfra.com/v1/openai"); m.insert("OPENROUTER", "https://openrouter.ai/api/v1"); + + m.insert("FAKE", "http://localhost:8080"); // test only // TODO: support more providers here... m }; @@ -34,20 +36,34 @@ pub type ModelId = String; pub struct ModelConfig { // model-specific configs, will override global configs if provided #[builder(default = "None")] - pub base_url: Option, - #[builder(default = "None")] - pub provider: Option, + pub(crate) base_url: Option, + #[builder(default = "None", setter(custom))] + pub(crate) provider: Option, #[builder(default = "None")] - pub temperature: Option, + pub(crate) temperature: Option, #[builder(default = "None")] - pub max_output_tokens: Option, + pub(crate) max_output_tokens: Option, - pub id: ModelId, + #[builder(setter(custom))] + pub(crate) id: ModelId, #[builder(default=-1)] - pub weight: i32, + pub(crate) weight: i32, } impl ModelConfigBuilder { + pub fn id>(&mut self, name: S) -> &mut Self { + self.id = Some(name.as_ref().to_string()); + self + } + + pub fn provider(&mut self, name: Option) -> &mut Self + where + S: AsRef, + { + self.provider = Some(name.map(|s| s.as_ref().to_string().to_uppercase())); + self + } + fn validate(&self) -> Result<(), String> { if self.id.is_none() { return Err("Model id must be provided.".to_string()); @@ -69,7 +85,7 @@ pub struct Config { // global configs for models, will be overridden by model-specific configs #[builder(default = "https://api.openai.com/v1".to_string())] pub(crate) base_url: String, - #[builder(default = "ProviderName::from(OPENAI_PROVIDER)")] + #[builder(default = "ProviderName::from(OPENAI_PROVIDER)", setter(custom))] pub(crate) provider: ProviderName, #[builder(default = "0.8")] pub(crate) temperature: f32, @@ -124,6 +140,11 @@ impl ConfigBuilder { self } + pub fn provider>(&mut self, name: S) -> &mut Self { + self.provider = Some(name.as_ref().to_string().to_uppercase()); + self + } + fn validate(&self) -> Result<(), String> { if self.models.is_none() || self.models.as_ref().unwrap().is_empty() { return Err("At least one model must be configured.".to_string()); @@ -258,7 +279,7 @@ mod tests { .build() .unwrap(), ) - .provider("unknown_provider".to_string()) + .provider("unknown_provider") .build(); assert!(invalid_cfg_with_no_api_key.is_err()); @@ -269,8 +290,8 @@ mod tests { .max_output_tokens(2048) .model( ModelConfig::builder() - .id("custom-model".to_string()) - .provider(Some("AMRS".to_string())) + .id("custom-model") + .provider(Some("AMRS")) .build() .unwrap(), ) @@ -317,12 +338,7 @@ mod tests { let mut valid_specified_cfg = Config::builder() .provider("AMRS".to_string()) .base_url("http://custom-api.ai".to_string()) - .model( - ModelConfig::builder() - .id("model-2".to_string()) - .build() - .unwrap(), - ) + .model(ModelConfig::builder().id("model-2").build().unwrap()) .build(); valid_specified_cfg.as_mut().unwrap().populate(); diff --git a/src/lib.rs b/src/lib.rs index 3f0f656..81e7b45 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,13 +4,18 @@ mod router { pub mod stats; mod wrr; } +mod config; mod client { pub mod client; } mod provider { + mod fake; mod openai; pub mod provider; } -pub mod config; pub use crate::client::client::Client; +pub use crate::config::{Config, ModelConfig, RoutingMode}; +pub use crate::provider::provider::{ + APIError, CreateResponseArgs, CreateResponseInput, CreateResponseOutput, +}; diff --git a/src/provider/fake.rs b/src/provider/fake.rs new file mode 100644 index 0000000..dc710d0 --- /dev/null +++ b/src/provider/fake.rs @@ -0,0 +1,82 @@ +use std::str::FromStr; + +use async_openai::types::responses::{ + AssistantRole, OutputItem, OutputMessage, OutputMessageContent, OutputStatus, + OutputTextContent, Status, +}; +use async_openai::{Client, config::OpenAIConfig}; +use async_trait::async_trait; +use reqwest::header::HeaderName; + +use crate::config::{ModelConfig, ModelId}; +use crate::provider::provider::{ + APIError, CreateResponseInput, CreateResponseOutput, Provider, validate_request, +}; + +pub struct FakeProvider { + model: ModelId, +} + +impl FakeProvider { + pub fn new(config: ModelConfig) -> Self { + Self { + model: config.id.clone(), + } + } +} + +#[async_trait] +impl Provider for FakeProvider { + fn name(&self) -> &'static str { + "FakeProvider" + } + + async fn create_response( + &self, + request: CreateResponseInput, + ) -> Result { + validate_request(&request)?; + + Ok(CreateResponseOutput { + id: "fake-response-id".to_string(), + object: "text_completion".to_string(), + model: self.model.clone(), + usage: None, + output: vec![OutputItem::Message(OutputMessage { + id: "fake-message-id".to_string(), + status: OutputStatus::Completed, + role: AssistantRole::Assistant, + content: vec![OutputMessageContent::OutputText(OutputTextContent { + annotations: vec![], + logprobs: None, + text: "This is a fake response.".to_string(), + })], + })], + created_at: 1_600_000_000, + background: None, + billing: None, + conversation: None, + error: None, + incomplete_details: None, + instructions: None, + max_output_tokens: None, + metadata: None, + prompt: None, + parallel_tool_calls: None, + previous_response_id: None, + prompt_cache_key: None, + prompt_cache_retention: None, + reasoning: None, + safety_identifier: None, + service_tier: None, + status: Status::Completed, + temperature: None, + text: None, + top_p: None, + tools: None, + tool_choice: None, + top_logprobs: None, + truncation: None, + }) + } +} diff --git a/src/provider/openai.rs b/src/provider/openai.rs index 601013a..13349a2 100644 --- a/src/provider/openai.rs +++ b/src/provider/openai.rs @@ -1,20 +1,22 @@ -use std::str::FromStr; - use async_openai::{Client, config::OpenAIConfig}; use async_trait::async_trait; -use reqwest::header::HeaderName; +use derive_builder::Builder; use crate::config::{ModelConfig, ModelId}; -use crate::provider::provider::{APIError, Provider, ResponseRequest, ResponseResult}; +use crate::provider::provider::{ + APIError, CreateResponseInput, CreateResponseOutput, Provider, validate_request, +}; +#[derive(Debug, Clone, Builder)] +#[builder(pattern = "mutable", build_fn(skip))] pub struct OpenAIProvider { model: ModelId, config: OpenAIConfig, - client: Option>, + client: Client, } impl OpenAIProvider { - pub fn new(config: &ModelConfig) -> Self { + pub fn builder(config: ModelConfig) -> OpenAIProviderBuilder { let api_key_var = format!( "{}_API_KEY", config.provider.as_ref().unwrap().to_uppercase() @@ -25,26 +27,21 @@ impl OpenAIProvider { .with_api_base(config.base_url.clone().unwrap()) .with_api_key(api_key); - OpenAIProvider { - model: config.id.clone(), - config: openai_config, + OpenAIProviderBuilder { + model: Some(config.id.clone()), + config: Some(openai_config), client: None, } } +} - pub fn header(mut self, key: &str, value: &str) -> Result { - let name = HeaderName::from_str(key) - .map_err(|e| APIError::InvalidArgument(format!("Invalid header name: {}", e)))?; - - self.config = self.config.with_header(name, value)?; - Ok(self) - } - - pub fn build(mut self) -> Self { - if self.client.is_none() { - self.client = Some(Client::with_config(self.config.clone())); +impl OpenAIProviderBuilder { + pub fn build(&mut self) -> OpenAIProvider { + OpenAIProvider { + model: self.model.clone().unwrap(), + config: self.config.clone().unwrap(), + client: Client::with_config(self.config.as_ref().unwrap().clone()), } - self } } @@ -54,8 +51,11 @@ impl Provider for OpenAIProvider { "OpenAIProvider" } - async fn create_response(&self, request: ResponseRequest) -> Result { - let client = self.client.as_ref().unwrap(); - client.responses().create(request).await + async fn create_response( + &self, + request: CreateResponseInput, + ) -> Result { + validate_request(&request)?; + self.client.responses().create(request).await } } diff --git a/src/provider/provider.rs b/src/provider/provider.rs index 0f4eb44..5792a05 100644 --- a/src/provider/provider.rs +++ b/src/provider/provider.rs @@ -1,18 +1,23 @@ -use async_openai::error::OpenAIError; -use async_openai::types::responses::{CreateResponse as OpenAIRequest, Response as OpenAIResponse}; +use async_openai::error::OpenAIError as OpenAI_Error; +use async_openai::types::responses::{ + CreateResponse, CreateResponseArgs as OpenAICreateResponseArgs, Response, +}; use async_trait::async_trait; use crate::config::ModelConfig; +use crate::provider::fake::FakeProvider; use crate::provider::openai::OpenAIProvider; -pub type ResponseRequest = OpenAIRequest; -pub type ResponseResult = OpenAIResponse; -pub type APIError = OpenAIError; +pub type CreateResponseInput = CreateResponse; +pub type CreateResponseArgs = OpenAICreateResponseArgs; +pub type CreateResponseOutput = Response; +pub type APIError = OpenAI_Error; -pub fn construct_provider(config: &ModelConfig) -> Box { +pub fn construct_provider(config: ModelConfig) -> Box { let provider = config.provider.as_ref().unwrap(); match provider.to_uppercase().as_ref() { - "OPENAI" => Box::new(OpenAIProvider::new(config).build()), + "FAKE" => Box::new(FakeProvider::new(config)), + "OPENAI" => Box::new(OpenAIProvider::builder(config).build()), _ => panic!("Unsupported provider: {}", provider), } } @@ -20,7 +25,19 @@ pub fn construct_provider(config: &ModelConfig) -> Box { #[async_trait] pub trait Provider: Send + Sync { fn name(&self) -> &'static str; - async fn create_response(&self, request: ResponseRequest) -> Result; + async fn create_response( + &self, + request: CreateResponseInput, + ) -> Result; +} + +pub fn validate_request(request: &CreateResponseInput) -> Result<(), APIError> { + if request.model.is_some() { + return Err(APIError::InvalidArgument( + "Model ID must be specified in the config".to_string(), + )); + } + Ok(()) } #[cfg(test)] @@ -61,7 +78,7 @@ mod tests { for case in cases { if case.expect_provider_type.is_empty() { let result = std::panic::catch_unwind(|| { - construct_provider(&case.config); + construct_provider(case.config); }); assert!( result.is_err(), @@ -69,7 +86,7 @@ mod tests { case.name ); } else { - let provider = construct_provider(&case.config); + let provider = construct_provider(case.config); assert!( provider.name() == case.expect_provider_type, "Test case '{}': expected provider type '{}', got '{}'", diff --git a/src/router/random.rs b/src/router/random.rs index c774d50..e3933b1 100644 --- a/src/router/random.rs +++ b/src/router/random.rs @@ -1,7 +1,7 @@ use rand::Rng; use crate::config::ModelId; -use crate::provider::provider::ResponseRequest; +use crate::provider::provider::CreateResponseInput; use crate::router::router::{ModelInfo, Router}; pub struct RandomRouter { @@ -19,7 +19,7 @@ impl Router for RandomRouter { "RandomRouter" } - fn sample(&mut self, _input: &ResponseRequest) -> ModelId { + fn sample(&mut self, _input: &CreateResponseInput) -> ModelId { let mut rng = rand::rng(); let idx = rng.random_range(0..self.model_infos.len()); self.model_infos[idx].id.clone() @@ -50,7 +50,7 @@ mod tests { let mut counts = std::collections::HashMap::new(); for _ in 0..1000 { - let sampled_id = router.sample(&ResponseRequest::default()); + let sampled_id = router.sample(&CreateResponseInput::default()); *counts.entry(sampled_id.clone()).or_insert(0) += 1; } assert!(counts.len() == model_infos.len()); diff --git a/src/router/router.rs b/src/router/router.rs index ecec90b..4919ba1 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::sync::atomic::AtomicUsize; use crate::config::{ModelConfig, ModelId, RoutingMode}; -use crate::provider::provider::ResponseRequest; +use crate::provider::provider::CreateResponseInput; use crate::router::random::RandomRouter; use crate::router::wrr::WeightedRoundRobinRouter; @@ -28,7 +28,7 @@ pub fn construct_router(mode: RoutingMode, models: Vec) -> Box &'static str; - fn sample(&mut self, input: &ResponseRequest) -> ModelId; + fn sample(&mut self, input: &CreateResponseInput) -> ModelId; } #[cfg(test)] diff --git a/src/router/wrr.rs b/src/router/wrr.rs index b1ea9c7..00a307b 100644 --- a/src/router/wrr.rs +++ b/src/router/wrr.rs @@ -1,5 +1,5 @@ use crate::router::router::{ModelInfo, Router}; -use crate::{config::ModelId, provider::provider::ResponseRequest}; +use crate::{config::ModelId, provider::provider::CreateResponseInput}; pub struct WeightedRoundRobinRouter { total_weight: i32, @@ -27,7 +27,7 @@ impl Router for WeightedRoundRobinRouter { } // Use Smooth Weighted Round Robin Algorithm. - fn sample(&mut self, _input: &ResponseRequest) -> ModelId { + fn sample(&mut self, _input: &CreateResponseInput) -> ModelId { // return early if only one model. if self.model_infos.len() == 1 { return self.model_infos[0].id.clone(); @@ -76,7 +76,7 @@ mod tests { let mut wrr = WeightedRoundRobinRouter::new(model_infos.clone()); let mut counts = HashMap::new(); for _ in 0..1000 { - let sampled_id = wrr.sample(&ResponseRequest::default()); + let sampled_id = wrr.sample(&CreateResponseInput::default()); *counts.entry(sampled_id.clone()).or_insert(0) += 1; } assert!(counts.len() == model_infos.len()); diff --git a/tests/client.rs b/tests/client.rs new file mode 100644 index 0000000..ecdea8c --- /dev/null +++ b/tests/client.rs @@ -0,0 +1,72 @@ +use dotenvy::from_filename; + +use arms::{Client, Config, CreateResponseArgs, ModelConfig, RoutingMode}; + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_response() { + from_filename(".env.integration-test").ok(); + + // case 1: one model. + let config = Config::builder() + .provider("fake") + .model(ModelConfig::builder().id("fake-model").build().unwrap()) + .build() + .unwrap(); + + let mut client = Client::new(config); + let request = CreateResponseArgs::default() + .input("tell me the weather today") + .build() + .unwrap(); + + let response = client.create_response(request).await.unwrap(); + assert!(response.id.starts_with("fake-response-id")); + assert!(response.model == "fake-model"); + + // case 2: specify model in request. + let config = Config::builder() + .provider("openai") + .model(ModelConfig::builder().id("gpt-3.5-turbo").build().unwrap()) + .build() + .unwrap(); + let mut client = Client::new(config); + let request = CreateResponseArgs::default() + .model("gpt-3.5-turbo") + .input("tell me a joke") + .build() + .unwrap(); + let response = client.create_response(request).await; + assert!(response.is_err()); + + // case 3: multiple models with router. + let config = Config::builder() + .provider("fake") + .routing_mode(RoutingMode::WRR) + .model( + ModelConfig::builder() + .id("gpt-3.5-turbo") + .weight(1) + .build() + .unwrap(), + ) + .model( + ModelConfig::builder() + .id("gpt-4") + .weight(1) + .build() + .unwrap(), + ) + .build() + .unwrap(); + let mut client = Client::new(config); + let request = CreateResponseArgs::default() + .input("give me a poem about nature") + .build() + .unwrap(); + let _ = client.create_response(request).await.unwrap(); + } +} From cabaf2464c5b3aeb439d9c3548a4cfe5cbd6f6c7 Mon Sep 17 00:00:00 2001 From: kerthcet Date: Wed, 24 Dec 2025 11:28:42 +0800 Subject: [PATCH 2/2] rename Signed-off-by: kerthcet --- src/client/client.rs | 4 ++-- src/lib.rs | 2 +- src/provider/fake.rs | 8 ++++---- src/provider/openai.rs | 6 +++--- src/provider/provider.rs | 10 +++++----- src/router/random.rs | 6 +++--- src/router/router.rs | 4 ++-- src/router/wrr.rs | 6 +++--- 8 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/client/client.rs b/src/client/client.rs index 7b9dc2b..dca327c 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -28,8 +28,8 @@ impl Client { pub async fn create_response( &mut self, - request: provider::CreateResponseInput, - ) -> Result { + request: provider::CreateResponseReq, + ) -> Result { let model_id = self.router.sample(&request); let provider = self.providers.get(&model_id).unwrap(); provider.create_response(request).await diff --git a/src/lib.rs b/src/lib.rs index 81e7b45..7027782 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,5 +17,5 @@ mod provider { pub use crate::client::client::Client; pub use crate::config::{Config, ModelConfig, RoutingMode}; pub use crate::provider::provider::{ - APIError, CreateResponseArgs, CreateResponseInput, CreateResponseOutput, + APIError, CreateResponseArgs, CreateResponseReq, CreateResponseRes, }; diff --git a/src/provider/fake.rs b/src/provider/fake.rs index dc710d0..a38b7b7 100644 --- a/src/provider/fake.rs +++ b/src/provider/fake.rs @@ -10,7 +10,7 @@ use reqwest::header::HeaderName; use crate::config::{ModelConfig, ModelId}; use crate::provider::provider::{ - APIError, CreateResponseInput, CreateResponseOutput, Provider, validate_request, + APIError, CreateResponseReq, CreateResponseRes, Provider, validate_request, }; pub struct FakeProvider { @@ -33,11 +33,11 @@ impl Provider for FakeProvider { async fn create_response( &self, - request: CreateResponseInput, - ) -> Result { + request: CreateResponseReq, + ) -> Result { validate_request(&request)?; - Ok(CreateResponseOutput { + Ok(CreateResponseRes { id: "fake-response-id".to_string(), object: "text_completion".to_string(), model: self.model.clone(), diff --git a/src/provider/openai.rs b/src/provider/openai.rs index 13349a2..753e591 100644 --- a/src/provider/openai.rs +++ b/src/provider/openai.rs @@ -4,7 +4,7 @@ use derive_builder::Builder; use crate::config::{ModelConfig, ModelId}; use crate::provider::provider::{ - APIError, CreateResponseInput, CreateResponseOutput, Provider, validate_request, + APIError, CreateResponseReq, CreateResponseRes, Provider, validate_request, }; #[derive(Debug, Clone, Builder)] @@ -53,8 +53,8 @@ impl Provider for OpenAIProvider { async fn create_response( &self, - request: CreateResponseInput, - ) -> Result { + request: CreateResponseReq, + ) -> Result { validate_request(&request)?; self.client.responses().create(request).await } diff --git a/src/provider/provider.rs b/src/provider/provider.rs index 5792a05..6ffe305 100644 --- a/src/provider/provider.rs +++ b/src/provider/provider.rs @@ -8,9 +8,9 @@ use crate::config::ModelConfig; use crate::provider::fake::FakeProvider; use crate::provider::openai::OpenAIProvider; -pub type CreateResponseInput = CreateResponse; +pub type CreateResponseReq = CreateResponse; pub type CreateResponseArgs = OpenAICreateResponseArgs; -pub type CreateResponseOutput = Response; +pub type CreateResponseRes = Response; pub type APIError = OpenAI_Error; pub fn construct_provider(config: ModelConfig) -> Box { @@ -27,11 +27,11 @@ pub trait Provider: Send + Sync { fn name(&self) -> &'static str; async fn create_response( &self, - request: CreateResponseInput, - ) -> Result; + request: CreateResponseReq, + ) -> Result; } -pub fn validate_request(request: &CreateResponseInput) -> Result<(), APIError> { +pub fn validate_request(request: &CreateResponseReq) -> Result<(), APIError> { if request.model.is_some() { return Err(APIError::InvalidArgument( "Model ID must be specified in the config".to_string(), diff --git a/src/router/random.rs b/src/router/random.rs index e3933b1..1a5591b 100644 --- a/src/router/random.rs +++ b/src/router/random.rs @@ -1,7 +1,7 @@ use rand::Rng; use crate::config::ModelId; -use crate::provider::provider::CreateResponseInput; +use crate::provider::provider::CreateResponseReq; use crate::router::router::{ModelInfo, Router}; pub struct RandomRouter { @@ -19,7 +19,7 @@ impl Router for RandomRouter { "RandomRouter" } - fn sample(&mut self, _input: &CreateResponseInput) -> ModelId { + fn sample(&mut self, _input: &CreateResponseReq) -> ModelId { let mut rng = rand::rng(); let idx = rng.random_range(0..self.model_infos.len()); self.model_infos[idx].id.clone() @@ -50,7 +50,7 @@ mod tests { let mut counts = std::collections::HashMap::new(); for _ in 0..1000 { - let sampled_id = router.sample(&CreateResponseInput::default()); + let sampled_id = router.sample(&CreateResponseReq::default()); *counts.entry(sampled_id.clone()).or_insert(0) += 1; } assert!(counts.len() == model_infos.len()); diff --git a/src/router/router.rs b/src/router/router.rs index 4919ba1..0647cec 100644 --- a/src/router/router.rs +++ b/src/router/router.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::sync::atomic::AtomicUsize; use crate::config::{ModelConfig, ModelId, RoutingMode}; -use crate::provider::provider::CreateResponseInput; +use crate::provider::provider::CreateResponseReq; use crate::router::random::RandomRouter; use crate::router::wrr::WeightedRoundRobinRouter; @@ -28,7 +28,7 @@ pub fn construct_router(mode: RoutingMode, models: Vec) -> Box &'static str; - fn sample(&mut self, input: &CreateResponseInput) -> ModelId; + fn sample(&mut self, input: &CreateResponseReq) -> ModelId; } #[cfg(test)] diff --git a/src/router/wrr.rs b/src/router/wrr.rs index 00a307b..fa5e481 100644 --- a/src/router/wrr.rs +++ b/src/router/wrr.rs @@ -1,5 +1,5 @@ use crate::router::router::{ModelInfo, Router}; -use crate::{config::ModelId, provider::provider::CreateResponseInput}; +use crate::{config::ModelId, provider::provider::CreateResponseReq}; pub struct WeightedRoundRobinRouter { total_weight: i32, @@ -27,7 +27,7 @@ impl Router for WeightedRoundRobinRouter { } // Use Smooth Weighted Round Robin Algorithm. - fn sample(&mut self, _input: &CreateResponseInput) -> ModelId { + fn sample(&mut self, _input: &CreateResponseReq) -> ModelId { // return early if only one model. if self.model_infos.len() == 1 { return self.model_infos[0].id.clone(); @@ -76,7 +76,7 @@ mod tests { let mut wrr = WeightedRoundRobinRouter::new(model_infos.clone()); let mut counts = HashMap::new(); for _ in 0..1000 { - let sampled_id = wrr.sample(&CreateResponseInput::default()); + let sampled_id = wrr.sample(&CreateResponseReq::default()); *counts.entry(sampled_id.clone()).or_insert(0) += 1; } assert!(counts.len() == model_infos.len());