Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env.integration-test
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
AMRS_API_KEY=your_amrs_api_key_here
OPENAI_API_KEY=your_openai_api_key_here
FAKE_API_KEY=your_fake_api_key_here
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ lazy_static = "1.5.0"
rand = "0.9.2"
reqwest = "0.12.26"
serde = "1.0.228"
tokio = "1.48.0"
49 changes: 49 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,55 @@

The Adaptive Model Routing System (AMRS) is a framework designed to select the best-fit model for exploration and exploitation. (still under development)

Thanks to [async-openai](https://github.com/64bit/async-openai), AMRS builds on top of it to provide adaptive model routing capabilities.

## Features

- Flexible routing strategies, including:
- **Random**: Randomly selects a model from the available models.
- **WRR**: Weighted Round Robin selects models based on predefined weights.
- **UCB**: Upper Confidence Bound based model selection (coming soon).
- **Adaptive**: Dynamically selects models based on performance metrics (coming soon).


## How to use

Here's a simple example with random routing mode:


```rust
// Before running the code, make sure to set your OpenAI API key in the environment variable:
// export OPENAI_API_KEY="your_openai_api_key"

use arms::{Client, Config, ModelConfig, CreateResponseArgs, RoutingMode};

let config = Config::builder()
.provider("openai")
.routing_mode(RoutingMode::Random)
.model(
ModelConfig::builder()
.id("gpt-3.5-turbo")
.build()
.unwrap(),
)
.model(
ModelConfig::builder()
.id("gpt-4")
.build()
.unwrap(),
)
.build()
.unwrap();

let mut client = Client::new(config);
let request = CreateResponseArgs::default()
.input("give me a poem about nature")
.build()
.unwrap();

let response = client.create_response(request).await.unwrap();
```

## Contributing

🚀 All kinds of contributions are welcomed ! Please follow [Contributing](/CONTRIBUTING.md).
Expand Down
6 changes: 3 additions & 3 deletions src/client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ impl Client {
let providers = cfg
.models
.iter()
.map(|m| (m.id.clone(), provider::construct_provider(m)))
.map(|m| (m.id.clone(), provider::construct_provider(m.clone())))
.collect();

Self {
Expand All @@ -28,8 +28,8 @@ impl Client {

pub async fn create_response(
&mut self,
request: provider::ResponseRequest,
) -> Result<provider::ResponseResult, provider::APIError> {
request: provider::CreateResponseReq,
) -> Result<provider::CreateResponseRes, provider::APIError> {
let model_id = self.router.sample(&request);
let provider = self.providers.get(&model_id).unwrap();
provider.create_response(request).await
Expand Down
50 changes: 33 additions & 17 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ lazy_static! {
m.insert("OPENAI", "https://api.openai.com/v1");
m.insert("DEEPINFRA", "https://api.deepinfra.com/v1/openai");
m.insert("OPENROUTER", "https://openrouter.ai/api/v1");

m.insert("FAKE", "http://localhost:8080"); // test only
// TODO: support more providers here...
m
};
Expand All @@ -34,20 +36,34 @@ pub type ModelId = String;
pub struct ModelConfig {
// model-specific configs, will override global configs if provided
#[builder(default = "None")]
pub base_url: Option<String>,
#[builder(default = "None")]
pub provider: Option<ProviderName>,
pub(crate) base_url: Option<String>,
#[builder(default = "None", setter(custom))]
pub(crate) provider: Option<ProviderName>,
#[builder(default = "None")]
pub temperature: Option<f32>,
pub(crate) temperature: Option<f32>,
#[builder(default = "None")]
pub max_output_tokens: Option<usize>,
pub(crate) max_output_tokens: Option<usize>,

pub id: ModelId,
#[builder(setter(custom))]
pub(crate) id: ModelId,
#[builder(default=-1)]
pub weight: i32,
pub(crate) weight: i32,
}

impl ModelConfigBuilder {
pub fn id<S: AsRef<str>>(&mut self, name: S) -> &mut Self {
self.id = Some(name.as_ref().to_string());
self
}

pub fn provider<S>(&mut self, name: Option<S>) -> &mut Self
where
S: AsRef<str>,
{
self.provider = Some(name.map(|s| s.as_ref().to_string().to_uppercase()));
self
}

fn validate(&self) -> Result<(), String> {
if self.id.is_none() {
return Err("Model id must be provided.".to_string());
Expand All @@ -69,7 +85,7 @@ pub struct Config {
// global configs for models, will be overridden by model-specific configs
#[builder(default = "https://api.openai.com/v1".to_string())]
pub(crate) base_url: String,
#[builder(default = "ProviderName::from(OPENAI_PROVIDER)")]
#[builder(default = "ProviderName::from(OPENAI_PROVIDER)", setter(custom))]
pub(crate) provider: ProviderName,
#[builder(default = "0.8")]
pub(crate) temperature: f32,
Expand Down Expand Up @@ -124,6 +140,11 @@ impl ConfigBuilder {
self
}

pub fn provider<S: AsRef<str>>(&mut self, name: S) -> &mut Self {
self.provider = Some(name.as_ref().to_string().to_uppercase());
self
}

fn validate(&self) -> Result<(), String> {
if self.models.is_none() || self.models.as_ref().unwrap().is_empty() {
return Err("At least one model must be configured.".to_string());
Expand Down Expand Up @@ -258,7 +279,7 @@ mod tests {
.build()
.unwrap(),
)
.provider("unknown_provider".to_string())
.provider("unknown_provider")
.build();
assert!(invalid_cfg_with_no_api_key.is_err());

Expand All @@ -269,8 +290,8 @@ mod tests {
.max_output_tokens(2048)
.model(
ModelConfig::builder()
.id("custom-model".to_string())
.provider(Some("AMRS".to_string()))
.id("custom-model")
.provider(Some("AMRS"))
.build()
.unwrap(),
)
Expand Down Expand Up @@ -317,12 +338,7 @@ mod tests {
let mut valid_specified_cfg = Config::builder()
.provider("AMRS".to_string())
.base_url("http://custom-api.ai".to_string())
.model(
ModelConfig::builder()
.id("model-2".to_string())
.build()
.unwrap(),
)
.model(ModelConfig::builder().id("model-2").build().unwrap())
.build();
valid_specified_cfg.as_mut().unwrap().populate();

Expand Down
7 changes: 6 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@ mod router {
pub mod stats;
mod wrr;
}
mod config;
mod client {
pub mod client;
}
mod provider {
mod fake;
mod openai;
pub mod provider;
}

pub mod config;
pub use crate::client::client::Client;
pub use crate::config::{Config, ModelConfig, RoutingMode};
pub use crate::provider::provider::{
APIError, CreateResponseArgs, CreateResponseReq, CreateResponseRes,
};
82 changes: 82 additions & 0 deletions src/provider/fake.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use std::str::FromStr;

use async_openai::types::responses::{
AssistantRole, OutputItem, OutputMessage, OutputMessageContent, OutputStatus,
OutputTextContent, Status,
};
use async_openai::{Client, config::OpenAIConfig};
use async_trait::async_trait;
use reqwest::header::HeaderName;

use crate::config::{ModelConfig, ModelId};
use crate::provider::provider::{
APIError, CreateResponseReq, CreateResponseRes, Provider, validate_request,
};

pub struct FakeProvider {
model: ModelId,
}

impl FakeProvider {
pub fn new(config: ModelConfig) -> Self {
Self {
model: config.id.clone(),
}
}
}

#[async_trait]
impl Provider for FakeProvider {
fn name(&self) -> &'static str {
"FakeProvider"
}

async fn create_response(
&self,
request: CreateResponseReq,
) -> Result<CreateResponseRes, APIError> {
validate_request(&request)?;

Ok(CreateResponseRes {
id: "fake-response-id".to_string(),
object: "text_completion".to_string(),
model: self.model.clone(),
usage: None,
output: vec![OutputItem::Message(OutputMessage {
id: "fake-message-id".to_string(),
status: OutputStatus::Completed,
role: AssistantRole::Assistant,
content: vec![OutputMessageContent::OutputText(OutputTextContent {
annotations: vec![],
logprobs: None,
text: "This is a fake response.".to_string(),
})],
})],
created_at: 1_600_000_000,
background: None,
billing: None,
conversation: None,
error: None,
incomplete_details: None,
instructions: None,
max_output_tokens: None,
metadata: None,
prompt: None,
parallel_tool_calls: None,
previous_response_id: None,
prompt_cache_key: None,
prompt_cache_retention: None,
reasoning: None,
safety_identifier: None,
service_tier: None,
status: Status::Completed,
temperature: None,
text: None,
top_p: None,
tools: None,
tool_choice: None,
top_logprobs: None,
truncation: None,
})
}
}
48 changes: 24 additions & 24 deletions src/provider/openai.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
use std::str::FromStr;

use async_openai::{Client, config::OpenAIConfig};
use async_trait::async_trait;
use reqwest::header::HeaderName;
use derive_builder::Builder;

use crate::config::{ModelConfig, ModelId};
use crate::provider::provider::{APIError, Provider, ResponseRequest, ResponseResult};
use crate::provider::provider::{
APIError, CreateResponseReq, CreateResponseRes, Provider, validate_request,
};

#[derive(Debug, Clone, Builder)]
#[builder(pattern = "mutable", build_fn(skip))]
pub struct OpenAIProvider {
model: ModelId,
config: OpenAIConfig,
client: Option<Client<OpenAIConfig>>,
client: Client<OpenAIConfig>,
}

impl OpenAIProvider {
pub fn new(config: &ModelConfig) -> Self {
pub fn builder(config: ModelConfig) -> OpenAIProviderBuilder {
let api_key_var = format!(
"{}_API_KEY",
config.provider.as_ref().unwrap().to_uppercase()
Expand All @@ -25,26 +27,21 @@ impl OpenAIProvider {
.with_api_base(config.base_url.clone().unwrap())
.with_api_key(api_key);

OpenAIProvider {
model: config.id.clone(),
config: openai_config,
OpenAIProviderBuilder {
model: Some(config.id.clone()),
config: Some(openai_config),
client: None,
}
}
}

pub fn header(mut self, key: &str, value: &str) -> Result<Self, APIError> {
let name = HeaderName::from_str(key)
.map_err(|e| APIError::InvalidArgument(format!("Invalid header name: {}", e)))?;

self.config = self.config.with_header(name, value)?;
Ok(self)
}

pub fn build(mut self) -> Self {
if self.client.is_none() {
self.client = Some(Client::with_config(self.config.clone()));
impl OpenAIProviderBuilder {
pub fn build(&mut self) -> OpenAIProvider {
OpenAIProvider {
model: self.model.clone().unwrap(),
config: self.config.clone().unwrap(),
client: Client::with_config(self.config.as_ref().unwrap().clone()),
}
self
}
}

Expand All @@ -54,8 +51,11 @@ impl Provider for OpenAIProvider {
"OpenAIProvider"
}

async fn create_response(&self, request: ResponseRequest) -> Result<ResponseResult, APIError> {
let client = self.client.as_ref().unwrap();
client.responses().create(request).await
async fn create_response(
&self,
request: CreateResponseReq,
) -> Result<CreateResponseRes, APIError> {
validate_request(&request)?;
self.client.responses().create(request).await
}
}
Loading
Loading