Skip to content

Commit f14a992

Browse files
committed
feat: update file system work to support huge workspaces
1 parent 1cc1df8 commit f14a992

File tree

6 files changed

+75
-32
lines changed

6 files changed

+75
-32
lines changed

src/agent/mod.rs

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use crate::tools::write_to_file::WriteToFileTool;
2525
use crate::Config;
2626
use anyhow::Result;
2727
use futures::StreamExt;
28+
use itertools::Itertools;
2829
use mcp_core::types::ProtocolVersion;
2930
use rig::agent::AgentBuilder;
3031
use rig::completion::CompletionError;
@@ -93,8 +94,8 @@ impl Display for AgentError {
9394
}
9495
}
9596

96-
fn count_tokens(system_prompt: &str) -> u32 {
97-
system_prompt.len() as u32 / 4
97+
fn count_tokens(text: &str) -> u32 {
98+
text.len() as u32 / 4
9899
}
99100

100101
impl Agent {
@@ -246,7 +247,8 @@ impl Agent {
246247
async fn configure_agent<M>(
247248
mut agent_builder: AgentBuilder<M>,
248249
context: BuildAgentContext<'_>,
249-
) -> Result<AgentBuilder<M>>
250+
tools_tokens: &mut u32,
251+
) -> Result<rig::agent::Agent<M>>
250252
where
251253
M: CompletionModel,
252254
{
@@ -256,10 +258,24 @@ impl Agent {
256258
let mcp_config = context.config.mcp.as_ref();
257259
agent_builder = Self::add_static_tools(agent_builder, context);
258260
agent_builder = Self::add_mcp_tools(agent_builder, mcp_config).await?;
259-
Ok(agent_builder)
261+
let agent = agent_builder.build();
262+
*tools_tokens = count_tokens(
263+
&agent
264+
.tools
265+
.documents()
266+
.await
267+
.unwrap()
268+
.iter()
269+
.map(|d| &d.text)
270+
.join("\n"),
271+
);
272+
Ok(agent)
260273
}
261274

262-
async fn build_agent(context: BuildAgentContext<'_>) -> Result<Box<dyn HulyAgent>> {
275+
async fn build_agent(
276+
context: BuildAgentContext<'_>,
277+
tools_tokens: &mut u32,
278+
) -> Result<Box<dyn HulyAgent>> {
263279
match context.config.provider {
264280
ProviderKind::OpenAI => {
265281
let agent_builder = rig::providers::openai::Client::new(
@@ -271,7 +287,7 @@ impl Agent {
271287
)
272288
.agent(&context.config.model);
273289
Ok(Box::new(
274-
Self::configure_agent(agent_builder, context).await?.build(),
290+
Self::configure_agent(agent_builder, context, tools_tokens).await?,
275291
))
276292
}
277293
ProviderKind::Anthropic => {
@@ -286,7 +302,7 @@ impl Agent {
286302
.agent(&context.config.model)
287303
.max_tokens(20000);
288304
Ok(Box::new(
289-
Self::configure_agent(agent_builder, context).await?.build(),
305+
Self::configure_agent(agent_builder, context, tools_tokens).await?,
290306
))
291307
}
292308
ProviderKind::OpenRouter => {
@@ -299,7 +315,7 @@ impl Agent {
299315
)
300316
.agent(&context.config.model);
301317
Ok(Box::new(
302-
Self::configure_agent(agent_builder, context).await?.build(),
318+
Self::configure_agent(agent_builder, context, tools_tokens).await?,
303319
))
304320
}
305321
ProviderKind::LMStudio => {
@@ -313,7 +329,7 @@ impl Agent {
313329
)
314330
.agent(&context.config.model);
315331
Ok(Box::new(
316-
Self::configure_agent(agent_builder, context).await?.build(),
332+
Self::configure_agent(agent_builder, context, tools_tokens).await?,
317333
))
318334
}
319335
}
@@ -332,6 +348,16 @@ impl Agent {
332348
self.sender
333349
.send(AgentOutputEvent::AddMessage(message.clone()))
334350
.unwrap();
351+
if let Message::User { .. } = &message {
352+
// clear previous messages from env details
353+
self.messages.iter_mut().for_each(|m| {
354+
if let Message::User { content, .. } = m {
355+
if content.len() > 1 {
356+
*content = OneOrMany::one(content.first());
357+
}
358+
}
359+
});
360+
}
335361
self.messages.push(message);
336362
}
337363

@@ -477,7 +503,7 @@ impl Agent {
477503
} else {
478504
add_env_message(
479505
&mut result_message,
480-
None,
506+
self.memory_index.as_ref(),
481507
&self.config.workspace,
482508
self.process_registry.clone(),
483509
)
@@ -565,17 +591,23 @@ impl Agent {
565591
let system_prompt =
566592
prepare_system_prompt(&self.config.workspace, &self.config.user_instructions).await;
567593
let system_prompt_token_count = count_tokens(&system_prompt);
594+
let mut tools_tokens = 0;
568595
self.agent = Some(
569-
Self::build_agent(BuildAgentContext {
570-
config: &self.config,
571-
system_prompt,
572-
memory: self.memory.clone(),
573-
process_registry: self.process_registry.clone(),
574-
sender: self.sender.clone(),
575-
})
596+
Self::build_agent(
597+
BuildAgentContext {
598+
config: &self.config,
599+
system_prompt,
600+
memory: self.memory.clone(),
601+
process_registry: self.process_registry.clone(),
602+
sender: self.sender.clone(),
603+
},
604+
&mut tools_tokens,
605+
)
576606
.await
577607
.unwrap(),
578608
);
609+
// This is workaround to calculate tokens from system prompt and tools for providers like LMStudio
610+
let system_prompt_token_count = system_prompt_token_count + tools_tokens / 2;
579611
// restore state from messages
580612
self.set_state(if self.messages.is_empty() {
581613
AgentState::WaitingUserPrompt

src/agent/utils.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ pub async fn add_env_message<'a>(
5656

5757
for entry in ignore::WalkBuilder::new(&workspace)
5858
.filter_entry(|e| e.file_name() != "node_modules")
59+
.max_depth(Some(2))
5960
.build()
6061
.filter_map(|e| e.ok())
6162
.take(MAX_FILES)
@@ -81,9 +82,16 @@ pub async fn add_env_message<'a>(
8182
let text = content.first();
8283
let mut memory_entries = String::new();
8384
if let Some(memory_index) = memory_index {
84-
if let UserContent::Text(text) = text {
85-
let res: Vec<(f64, String, Entity)> =
86-
memory_index.top_n(&text.text, 10).await.unwrap();
85+
let txt = match text {
86+
UserContent::Text(text) => &text.text.to_string(),
87+
UserContent::ToolResult(tool_result) => match tool_result.content.first() {
88+
rig::message::ToolResultContent::Text(text) => &text.text.to_string(),
89+
rig::message::ToolResultContent::Image(_) => "",
90+
},
91+
_ => "",
92+
};
93+
if !txt.is_empty() {
94+
let res: Vec<(f64, String, Entity)> = memory_index.top_n(txt, 10).await.unwrap();
8795
let result: Vec<_> = res.into_iter().map(|(_, _, entity)| entity).collect();
8896
memory_entries = serde_yaml::to_string(&result).unwrap();
8997
}

src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ fn init_logger() {
5757
.with_target("ort", tracing::Level::WARN)
5858
.with_target("tokenizers", tracing::Level::WARN)
5959
.with_target("process_wrap", tracing::Level::INFO)
60-
.with_default(tracing::Level::TRACE),
60+
.with_default(tracing::Level::DEBUG),
6161
),
6262
)
6363
.init()

src/templates/env_details.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ ${MEMORY_ENTRIES}
1010
|------------|-------------------------|---------|
1111
${COMMANDS}
1212

13-
# Current Working Directory (${WORKING_DIR}) Files
13+
# Current Working Directory (${WORKING_DIR}) Files (max depth 2)
1414
${FILES}
1515
</environment_details>

src/tools/list_files.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use super::{normalize_path, AgentToolError};
1414
#[derive(Debug, Clone, Serialize, Deserialize)]
1515
pub struct ListFilesToolArgs {
1616
pub path: String,
17-
pub recursive: Option<bool>,
17+
pub max_depth: Option<usize>,
1818
}
1919

2020
pub struct ListFilesTool {
@@ -38,10 +38,10 @@ impl Tool for ListFilesTool {
3838
ToolDefinition {
3939
name: self.name(),
4040
description: formatdoc! {"\
41-
Request to list files and directories within the specified directory. If recursive is true, it will list \
42-
all files and directories recursively. If recursive is false or not provided, it will only list the top-level contents. \
43-
Do not use this tool to confirm the existence of files you may have created, as the user will let you know \
44-
if the files were created successfully or not."}.to_string(),
41+
Request to list files and directories within the specified directory. If max_depth equals 1 or not provided, \
42+
it will only list the top-level contents. If max_depth is greater than 1, it will list the contents of the directory \
43+
and its subdirectories up to the specified depth. Do not use this tool to confirm the existence of files you may have created,\
44+
as the user will let you know if the files were created successfully or not."}.to_string(),
4545
parameters: json!({
4646
"type": "object",
4747
"properties": {
@@ -50,9 +50,9 @@ impl Tool for ListFilesTool {
5050
"description": formatdoc!{"The path of the directory to list contents for (relative to the current \
5151
working directory {})", workspace_to_string(&self.workspace)},
5252
},
53-
"recursive": {
54-
"type": "boolean",
55-
"description": "Whether to list files recursively. Use true for recursive listing, false or omit for top-level only."
53+
"max_depth": {
54+
"type": "number",
55+
"description": "Max depth to list files (default: 1)",
5656
}
5757
},
5858
"required": ["path"]
@@ -63,10 +63,10 @@ impl Tool for ListFilesTool {
6363

6464
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
6565
let path = normalize_path(&self.workspace, &args.path);
66-
let recursive = args.recursive.unwrap_or(false);
66+
let max_depth = args.max_depth.unwrap_or(1);
6767
let mut files: Vec<String> = Vec::default();
6868
for entry in ignore::WalkBuilder::new(path.clone())
69-
.max_depth(if recursive { None } else { Some(1) })
69+
.max_depth(Some(max_depth))
7070
.filter_entry(|e| e.file_name() != "node_modules")
7171
.build()
7272
.filter_map(|e| e.ok())

src/tools/memory/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,9 @@ impl MemoryManager {
244244
self.knowledge_graph
245245
.entities
246246
.retain(|entity| entity.name != entity_name);
247+
self.knowledge_graph.relations.retain(|relation| {
248+
relation.from != entity_name || relation.to != entity_name
249+
});
247250
}
248251
self.save();
249252
Ok("Entities deleted successfully".to_string())

0 commit comments

Comments
 (0)