@@ -51,7 +51,6 @@ pub use event::AgentOutputEvent;
5151use tokio:: sync:: RwLock ;
5252
5353use self :: event:: AgentState ;
54- use self :: event:: AgentStatus ;
5554use self :: utils:: * ;
5655
5756pub struct Agent {
@@ -66,7 +65,8 @@ pub struct Agent {
6665 memory : Arc < RwLock < MemoryManager > > ,
6766 memory_index : Option < InMemoryVectorIndex < rig_fastembed:: EmbeddingModel , Entity > > ,
6867 process_registry : Arc < RwLock < ProcessRegistry > > ,
69- current_tokens : u32 ,
68+ current_input_tokens : u32 ,
69+ current_completion_tokens : u32 ,
7070 state : AgentState ,
7171}
7272
@@ -93,6 +93,10 @@ impl Display for AgentError {
9393 }
9494}
9595
96+ fn count_tokens ( system_prompt : & str ) -> u32 {
97+ system_prompt. len ( ) as u32 / 4
98+ }
99+
96100impl Agent {
97101 pub fn new (
98102 config : Config ,
@@ -108,7 +112,8 @@ impl Agent {
108112 messages,
109113 stream : None ,
110114 assistant_content : None ,
111- current_tokens : 0 ,
115+ current_input_tokens : 0 ,
116+ current_completion_tokens : 0 ,
112117 memory : Arc :: new ( RwLock :: new ( MemoryManager :: new ( false ) ) ) ,
113118 process_registry : Arc :: new ( RwLock :: new ( ProcessRegistry :: default ( ) ) ) ,
114119 memory_index : None ,
@@ -338,7 +343,7 @@ impl Agent {
338343 self . messages [ last_idx] = message;
339344 }
340345
341- async fn process_messages ( & mut self ) -> Result < ( ) , AgentError > {
346+ async fn process_messages ( & mut self , system_prompt_token_count : u32 ) -> Result < ( ) , AgentError > {
342347 if self . state . is_paused ( ) {
343348 return Ok ( ( ) ) ;
344349 }
@@ -494,7 +499,49 @@ impl Agent {
494499 if let Some ( raw_response) = response. raw_response {
495500 let usage = raw_response. usage ;
496501 tracing:: info!( "Usage: {:?}" , usage) ;
497- self . current_tokens = usage. total_tokens as u32 ;
502+ if usage. total_tokens > 0 {
503+ self . current_input_tokens = usage. prompt_tokens as u32 ;
504+ self . current_completion_tokens =
505+ ( usage. total_tokens - usage. prompt_tokens ) as u32 ;
506+ } else {
507+ // try to calculate aproximate tokens
508+ self . current_input_tokens = system_prompt_token_count
509+ + self
510+ . messages
511+ . iter ( )
512+ . map ( |m| match m {
513+ Message :: User { content } => content
514+ . iter ( )
515+ . map ( |c| match c {
516+ UserContent :: Text ( text) => count_tokens ( & text. text ) ,
517+ UserContent :: ToolResult ( tool_result) => tool_result
518+ . content
519+ . iter ( )
520+ . map ( |t| match t {
521+ ToolResultContent :: Text ( text) => {
522+ count_tokens ( & text. text )
523+ }
524+ _ => 0 ,
525+ } )
526+ . sum :: < u32 > ( ) ,
527+ _ => 0 ,
528+ } )
529+ . sum :: < u32 > ( ) ,
530+ Message :: Assistant { content } => content
531+ . iter ( )
532+ . map ( |c| match c {
533+ AssistantContent :: Text ( text) => {
534+ count_tokens ( & text. text )
535+ }
536+ AssistantContent :: ToolCall ( tool_call) => count_tokens (
537+ & serde_json:: to_string ( tool_call) . unwrap ( ) ,
538+ ) ,
539+ } )
540+ . sum :: < u32 > ( ) ,
541+ } )
542+ . sum :: < u32 > ( ) ;
543+ self . current_completion_tokens = 0 ;
544+ }
498545 }
499546 self . assistant_content = None ;
500547 if matches ! ( self . state, AgentState :: Completed ( false ) ) {
@@ -517,6 +564,7 @@ impl Agent {
517564 ) ;
518565 let system_prompt =
519566 prepare_system_prompt ( & self . config . workspace , & self . config . user_instructions ) . await ;
567+ let system_prompt_token_count = count_tokens ( & system_prompt) ;
520568 self . agent = Some (
521569 Self :: build_agent ( BuildAgentContext {
522570 config : & self . config ,
@@ -577,7 +625,7 @@ impl Agent {
577625 }
578626 }
579627 }
580- if let Err ( e) = self . process_messages ( ) . await {
628+ if let Err ( e) = self . process_messages ( system_prompt_token_count ) . await {
581629 tracing:: debug!( "persist_history" ) ;
582630 persist_history ( & self . messages ) ;
583631 tracing:: error!( "Error processing messages: {}" , e) ;
@@ -615,11 +663,11 @@ impl Agent {
615663 self . state = state;
616664 if !self . sender . is_closed ( ) {
617665 self . sender
618- . send ( AgentOutputEvent :: AgentStatus ( AgentStatus {
619- current_tokens : self . current_tokens ,
620- max_tokens : 1 ,
621- state : self . state . clone ( ) ,
622- } ) )
666+ . send ( AgentOutputEvent :: AgentStatus (
667+ self . current_input_tokens ,
668+ self . current_completion_tokens ,
669+ self . state . clone ( ) ,
670+ ) )
623671 . unwrap ( ) ;
624672 }
625673 }
0 commit comments