@@ -7,6 +7,7 @@ use super::client::{ApiErrorResponse, ApiResponse, Client, Usage};
77
88use rig:: {
99 completion:: { self , CompletionError , CompletionRequest } ,
10+ message:: { ImageMediaType , MimeType } ,
1011 providers:: openai:: Message ,
1112 OneOrMany ,
1213} ;
@@ -117,6 +118,71 @@ pub struct CompletionModel {
117118 pub model : String ,
118119}
119120
121+ fn user_text_to_json ( content : rig:: message:: UserContent ) -> serde_json:: Value {
122+ match content {
123+ rig:: message:: UserContent :: Text ( text) => json ! ( {
124+ "role" : "user" ,
125+ "content" : text. text,
126+ } ) ,
127+ _ => unreachable ! ( ) ,
128+ }
129+ }
130+
131+ fn user_content_to_json (
132+ content : rig:: message:: UserContent ,
133+ ) -> Result < serde_json:: Value , CompletionError > {
134+ match content {
135+ rig:: message:: UserContent :: Text ( text) => Ok ( json ! ( {
136+ "type" : "text" ,
137+ "text" : text. text
138+ } ) ) ,
139+ rig:: message:: UserContent :: Image ( image) => Ok ( json ! ( {
140+ "type" : "image_url" ,
141+ "image_url" : {
142+ "url" : format!( "data:{};base64,{}" , image. media_type. unwrap_or( ImageMediaType :: PNG ) . to_mime_type( ) , image. data) ,
143+ }
144+ } ) ) ,
145+ rig:: message:: UserContent :: Audio ( _) => Err ( CompletionError :: RequestError (
146+ "Audio is not supported" . into ( ) ,
147+ ) ) ,
148+ rig:: message:: UserContent :: Document ( _) => Err ( CompletionError :: RequestError (
149+ "Document is not supported" . into ( ) ,
150+ ) ) ,
151+ rig:: message:: UserContent :: ToolResult ( _) => unreachable ! ( ) ,
152+ }
153+ }
154+
155+ fn tool_content_to_json (
156+ content : Vec < rig:: message:: UserContent > ,
157+ ) -> Result < serde_json:: Value , CompletionError > {
158+ let mut str_content = String :: new ( ) ;
159+ let mut tool_id = String :: new ( ) ;
160+
161+ for content in content. into_iter ( ) {
162+ match content {
163+ rig:: message:: UserContent :: ToolResult ( tool_result) => {
164+ tool_id = tool_result. id ;
165+ str_content = tool_result
166+ . content
167+ . iter ( )
168+ . map ( |c| match c {
169+ rig:: message:: ToolResultContent :: Text ( text) => text. text . clone ( ) ,
170+ // ignore image content
171+ _ => "" . to_string ( ) ,
172+ } )
173+ . collect :: < Vec < _ > > ( )
174+ . join ( "" ) ;
175+ }
176+ _ => unreachable ! ( ) ,
177+ }
178+ }
179+ Ok ( json ! ( {
180+ "role" : "tool" ,
181+ "content" : str_content,
182+ "tool_call_id" : tool_id,
183+ } ) )
184+ }
185+
120186impl CompletionModel {
121187 pub fn new ( client : Client , model : & str ) -> Self {
122188 Self {
@@ -130,64 +196,103 @@ impl CompletionModel {
130196 completion_request : CompletionRequest ,
131197 ) -> Result < Value , CompletionError > {
132198 // Add preamble to chat history (if available)
133- let mut full_history: Vec < Message > = match & completion_request. preamble {
134- Some ( preamble) => vec ! [ Message :: system( preamble) ] ,
199+ let mut full_history: Vec < serde_json:: Value > = match & completion_request. preamble {
200+ Some ( preamble) => vec ! [ json!( {
201+ "role" : "system" ,
202+ "content" : preamble,
203+ } ) ] ,
135204 None => vec ! [ ] ,
136205 } ;
137206
138207 // Convert existing chat history
139- let chat_history: Vec < Message > = completion_request
140- . chat_history
141- . into_iter ( )
142- . map ( |message| message. try_into ( ) )
143- . collect :: < Result < Vec < Vec < Message > > , _ > > ( ) ?
144- . into_iter ( )
145- . flatten ( )
146- . collect ( ) ;
147-
148- // Combine all messages into a single history
149- full_history. extend ( chat_history) ;
150- let messages: Vec < Value > = full_history
151- . into_iter ( )
152- . map ( |ref m| match m {
153- Message :: Assistant {
154- content,
155- refusal : _,
156- audio : _,
157- name : _,
158- tool_calls,
159- } => {
160- if !tool_calls. is_empty ( ) {
161- json ! ( {
162- "role" : "assistant" ,
163- "content" : null,
164- "tool_calls" : tool_calls,
165- } )
208+ for message in completion_request. chat_history . into_iter ( ) {
209+ match message {
210+ rig:: message:: Message :: User { content } => {
211+ if content. len ( ) == 1
212+ && matches ! ( content. first( ) , rig:: message:: UserContent :: Text ( _) )
213+ {
214+ full_history. push ( user_text_to_json ( content. first ( ) ) ) ;
215+ } else if content
216+ . iter ( )
217+ . any ( |c| matches ! ( c, rig:: message:: UserContent :: ToolResult ( _) ) )
218+ {
219+ let ( tool_content, user_content) =
220+ content. into_iter ( ) . partition :: < Vec < _ > , _ > ( |c| {
221+ matches ! ( c, rig:: message:: UserContent :: ToolResult ( _) )
222+ } ) ;
223+ full_history. push ( tool_content_to_json ( tool_content. clone ( ) ) ?) ;
224+ for tool_content in tool_content. into_iter ( ) {
225+ match tool_content {
226+ rig:: message:: UserContent :: ToolResult ( result) => {
227+ for tool_result_content in result. content . into_iter ( ) {
228+ match tool_result_content {
229+ rig:: message:: ToolResultContent :: Image ( image) => {
230+ full_history. push ( json ! ( {
231+ "role" : "user" ,
232+ "content" : [ {
233+ "type" : "image_url" ,
234+ "image_url" : {
235+ "url" : format!( "data:{};base64,{}" , image. media_type. unwrap_or( ImageMediaType :: PNG ) . to_mime_type( ) , image. data) ,
236+ }
237+ } ]
238+ } ) ) ;
239+ }
240+ _ => { }
241+ }
242+ }
243+ }
244+ _ => unreachable ! ( ) ,
245+ }
246+ }
247+ if !user_content. is_empty ( ) {
248+ if user_content. len ( ) == 1 {
249+ full_history
250+ . push ( user_text_to_json ( user_content. first ( ) . unwrap ( ) . clone ( ) ) ) ;
251+ } else {
252+ let user_content = user_content
253+ . into_iter ( )
254+ . map ( user_content_to_json)
255+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
256+ full_history
257+ . push ( json ! ( { "role" : "user" , "content" : user_content} ) ) ;
258+ }
259+ }
166260 } else {
167- json ! ( {
168- "role" : "assistant" ,
169- "content" : match content. first( ) . unwrap( ) {
170- AssistantContent :: Text { text } => text,
171- _ => "" ,
172- } ,
173- } )
261+ let content = content
262+ . into_iter ( )
263+ . map ( user_content_to_json)
264+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
265+ full_history. push ( json ! ( { "role" : "user" , "content" : content} ) ) ;
174266 }
175267 }
176- Message :: ToolResult {
177- tool_call_id,
178- content,
179- } => {
180- let content = json ! ( content. first( ) ) ;
181- let text = content. as_object ( ) . unwrap ( ) . get ( "text" ) . unwrap ( ) ;
182- json ! ( {
183- "role" : "tool" ,
184- "content" : text,
185- "tool_call_id" : tool_call_id,
186- } )
268+ rig:: message:: Message :: Assistant { content } => {
269+ for content in content {
270+ match content {
271+ rig:: message:: AssistantContent :: Text ( text) => {
272+ full_history. push ( json ! ( {
273+ "role" : "assistant" ,
274+ "content" : text. text
275+ } ) ) ;
276+ }
277+ rig:: message:: AssistantContent :: ToolCall ( tool_call) => {
278+ full_history. push ( json ! ( {
279+ "role" : "assistant" ,
280+ "content" : null,
281+ "tool_calls" : [ {
282+ "id" : tool_call. id,
283+ "type" : "function" ,
284+ "function" : {
285+ "name" : tool_call. function. name,
286+ "arguments" : tool_call. function. arguments. to_string( )
287+ }
288+ } ]
289+ } ) ) ;
290+ }
291+ }
292+ }
187293 }
188- _ => json ! ( m) ,
189- } )
190- . collect ( ) ;
294+ } ;
295+ }
191296
192297 let tools = completion_request
193298 . tools
@@ -201,7 +306,7 @@ impl CompletionModel {
201306 . collect :: < Vec < _ > > ( ) ;
202307 let request = json ! ( {
203308 "model" : self . model,
204- "messages" : messages ,
309+ "messages" : full_history ,
205310 "tools" : tools,
206311 "temperature" : completion_request. temperature,
207312 } ) ;
0 commit comments