Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
dd9a1db
refactor(autocomplete): improve word extraction and fuzzy matching logic
BrunoV21 Sep 18, 2025
5287d7e
fix(agent): correct handling of word extraction results in agent
BrunoV21 Sep 18, 2025
4c568ae
test(autocomplete): add comprehensive tests for AutoComplete class
BrunoV21 Sep 18, 2025
34c0dce
fix(agent): add debug prints and correct identifier extraction in age…
BrunoV21 Sep 18, 2025
d3feb7e
feat(streaming): add chunk logger and background flusher modules
BrunoV21 Sep 18, 2025
e44439a
feat(streaming,ui): add optimized streaming logger and update usage
BrunoV21 Sep 18, 2025
85d0837
refactor(agent): remove debug print statements from agent_loop
BrunoV21 Sep 21, 2025
60c8c58
feat(streaming,ui): add async generator cancellation utility and upda…
BrunoV21 Sep 21, 2025
6c38833
fix(streaming): prevent printing special tokens in logger
BrunoV21 Sep 21, 2025
e987be7
feat(autocomplete): enhance extract_words_from_text with substring an…
BrunoV21 Sep 21, 2025
042d1c9
test(autocomplete): add tests for substring and subpath matching
BrunoV21 Sep 21, 2025
e2e89a0
docs(ui): update chainlit.md to document all slash commands
BrunoV21 Sep 21, 2025
952729b
refactor(agent): add rolling context identifier window for history
BrunoV21 Sep 21, 2025
ee27983
build: add portalocker to agent requirements
BrunoV21 Sep 21, 2025
4fb34ae
refactor(ui): move docker and postgres utils to persistance.py
BrunoV21 Sep 22, 2025
0bcb1b1
refactor(hf_demo_space): update streaming imports and cleanup actions
BrunoV21 Sep 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 48 additions & 31 deletions codetide/agents/tide/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from codetide import CodeTide
from ...mcp.tools.patch_code import file_exists, open_file, process_patch, remove_file, write_file, parse_patch_blocks
from ...core.defaults import DEFAULT_ENCODING, DEFAULT_STORAGE_PATH
from ...core.defaults import DEFAULT_STORAGE_PATH
from ...parsers import SUPPORTED_LANGUAGES
from ...autocomplete import AutoComplete
from .models import Steps
Expand All @@ -13,7 +13,8 @@

try:
from aicore.llm import Llm
from aicore.logger import _logger, SPECIAL_TOKENS
from aicore.logger import _logger
from .streaming.service import custom_logger_fn
except ImportError as e:
raise ImportError(
"The 'codetide.agents' module requires the 'aicore' package. "
Expand All @@ -29,18 +30,10 @@
from datetime import date
from pathlib import Path
from ulid import ulid
import aiofiles
import asyncio
import pygit2
import os

async def custom_logger_fn(message :str, session_id :str, filepath :str):
if message not in SPECIAL_TOKENS:
async with aiofiles.open(filepath, 'a', encoding=DEFAULT_ENCODING) as f:
await f.write(message)

await _logger.log_chunk_to_queue(message, session_id)

class AgentTide(BaseModel):
llm :Llm
tide :CodeTide
Expand All @@ -60,6 +53,11 @@ class AgentTide(BaseModel):
_has_patch :bool=False
_direct_mode :bool=False

# Number of previous interactions to remember for context identifiers
CONTEXT_WINDOW_SIZE: int = 3
# Rolling window of identifier sets from previous N interactions
_context_identifier_window: Optional[list] = None

model_config = ConfigDict(arbitrary_types_allowed=True)

@model_validator(mode="after")
Expand Down Expand Up @@ -134,23 +132,43 @@ async def agent_loop(self, codeIdentifiers :Optional[List[str]]=None):
await self.tide.check_for_updates(serialize=True, include_cached_ids=True)
self._clean_history()

# Initialize the context identifier window if not present
if self._context_identifier_window is None:
self._context_identifier_window = []

codeContext = None
if self._skip_context_retrieval:
...
else:
autocomplete = AutoComplete(self.tide.cached_ids)
if self._direct_mode:
self.contextIdentifiers = None
exact_matches = autocomplete.extract_words_from_text(self.history[-1], max_matches_per_word=1)["all_found_words"]
# Only extract matches from the last message
last_message = self.history[-1] if self.history else ""
exact_matches = autocomplete.extract_words_from_text(last_message, max_matches_per_word=1)["all_found_words"]
self.modifyIdentifiers = self.tide._as_file_paths(exact_matches)
codeIdentifiers = self.modifyIdentifiers
self._direct_mode = False

# Update the context identifier window
self._context_identifier_window.append(set(exact_matches))
if len(self._context_identifier_window) > self.CONTEXT_WINDOW_SIZE:
self._context_identifier_window.pop(0)
else:
matches = autocomplete.extract_words_from_text("\n\n".join(self.history), max_matches_per_word=1)

# --- Begin Unified Identifier Retrieval ---
identifiers_accum = set(matches["all_found_words"]) if codeIdentifiers is None else set(codeIdentifiers + matches["all_found_words"])
# Only extract matches from the last message
last_message = self.history[-1] if self.history else ""
matches = autocomplete.extract_words_from_text(last_message, max_matches_per_word=1)["all_found_words"]
print(f"{matches=}")
# Update the context identifier window
self._context_identifier_window.append(set(matches))
if len(self._context_identifier_window) > self.CONTEXT_WINDOW_SIZE:
self._context_identifier_window.pop(0)
# Combine identifiers from the last N interactions
window_identifiers = set()
for s in self._context_identifier_window:
window_identifiers.update(s)
# If codeIdentifiers is passed, include them as well
identifiers_accum = set(codeIdentifiers) if codeIdentifiers else set()
identifiers_accum.update(window_identifiers)
modify_accum = set()
reasoning_accum = []
repo_tree = None
Expand All @@ -166,57 +184,55 @@ async def agent_loop(self, codeIdentifiers :Optional[List[str]]=None):
repo_history = self.history
if previous_reason:
repo_history += [previous_reason]

repo_tree = await self.get_repo_tree_from_user_prompt(self.history, include_modules=bool(smart_search_attempts), expand_paths=expand_paths)

# 2. Single LLM call with unified prompt
# Pass accumulated identifiers for context if this isn't the first iteration
accumulated_context = "\n".join(
sorted((identifiers_accum or set()) | (modify_accum or set()))
) if (identifiers_accum or modify_accum) else ""

unified_response = await self.llm.acomplete(
self.history,
system_prompt=[GET_CODE_IDENTIFIERS_UNIFIED_PROMPT.format(
DATE=TODAY,
DATE=TODAY,
SUPPORTED_LANGUAGES=SUPPORTED_LANGUAGES,
IDENTIFIERS=accumulated_context
)],
prefix_prompt=repo_tree,
stream=False
)
print(f"{unified_response=}")

# Parse the unified response
contextIdentifiers = parse_blocks(unified_response, block_word="Context Identifiers", multiple=False)
modifyIdentifiers = parse_blocks(unified_response, block_word="Modify Identifiers", multiple=False)
expandPaths = parse_blocks(unified_response, block_word="Expand Paths", multiple=False)
expandPaths = parse_blocks(unified_response, block_word="Expand Paths", multiple=False)

# Extract reasoning (everything before the first "*** Begin")
reasoning_parts = unified_response.split("*** Begin")
if reasoning_parts:
reasoning_accum.append(reasoning_parts[0].strip())
previous_reason = reasoning_accum[-1]

# Accumulate identifiers
if contextIdentifiers:
if smart_search_attempts == 0:
### clean wrongly mismtatched idenitifers
identifiers_accum = set()
for ident in contextIdentifiers.splitlines():
if ident := self.get_valid_identifier(autocomplete, ident.strip()):
if ident := self.get_valid_identifier(autocomplete, ident.strip()):
identifiers_accum.add(ident)

if modifyIdentifiers:
for ident in modifyIdentifiers.splitlines():
if ident := self.get_valid_identifier(autocomplete, ident.strip()):
modify_accum.add(ident.strip())

if expandPaths:
expand_paths = [
path for ident in expandPaths if (path := self.get_valid_identifier(autocomplete, ident.strip()))
]

# Check if we have enough identifiers (unified prompt includes this decision)
if "ENOUGH_IDENTIFIERS: TRUE" in unified_response.upper():
done = True
Expand All @@ -235,7 +251,7 @@ async def agent_loop(self, codeIdentifiers :Optional[List[str]]=None):
self.modifyIdentifiers = self.tide._as_file_paths(self.modifyIdentifiers)
codeIdentifiers.extend(self.modifyIdentifiers)
# TODO preserve passed identifiers by the user
codeIdentifiers += matches["all_found_words"]
codeIdentifiers += matches

# --- End Unified Identifier Retrieval ---
if codeIdentifiers:
Expand All @@ -244,7 +260,8 @@ async def agent_loop(self, codeIdentifiers :Optional[List[str]]=None):

if not codeContext:
codeContext = REPO_TREE_CONTEXT_PROMPT.format(REPO_TREE=self.tide.codebase.get_tree_view())
readmeFile = self.tide.get(["README.md"] + matches["all_found_words"] , as_string_list=True)
# Use matches from the last message for README context
readmeFile = self.tide.get(["README.md"] + (matches if 'matches' in locals() else []), as_string_list=True)
if readmeFile:
codeContext = "\n".join([codeContext, README_CONTEXT_PROMPT.format(README=readmeFile)])

Expand Down
7 changes: 7 additions & 0 deletions codetide/agents/tide/streaming/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .background_flusher import BackgroundFlusher
from .chunk_logger import ChunkLogger

__all__ = [
"BackgroundFlusher",
"ChunkLogger"
]
61 changes: 61 additions & 0 deletions codetide/agents/tide/streaming/background_flusher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from .chunk_logger import ChunkLogger
from typing import Optional
import asyncio

class BackgroundFlusher:
"""
# For very high throughput, you can use the background flusher:
background_flusher = BackgroundFlusher(_optimized_logger, flush_interval=0.05)
await background_flusher.start()

# ... your application code ...

# Clean shutdown
await background_flusher.stop()
await _optimized_logger.shutdown()
"""
def __init__(self, logger: ChunkLogger, flush_interval: float = 0.1):
self.logger = logger
self.flush_interval = flush_interval
self._task: Optional[asyncio.Task] = None
self._running = False

async def start(self):
"""Start background flushing task"""
if self._task and not self._task.done():
return

self._running = True
self._task = asyncio.create_task(self._flush_loop())
self.logger._background_tasks.add(self._task)

async def stop(self):
"""Stop background flushing"""
self._running = False
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass

async def _flush_loop(self):
"""Background flush loop"""
try:
while self._running:
await asyncio.sleep(self.flush_interval)
if not self._running:
break

# Flush all file buffers
flush_tasks = []
for filepath in list(self.logger._file_buffers.keys()):
if self.logger._file_buffers[filepath]:
flush_tasks.append(self.logger._flush_file_buffer(filepath))

if flush_tasks:
await asyncio.gather(*flush_tasks, return_exceptions=True)
except asyncio.CancelledError:
raise
except Exception:
pass # Ignore errors in background task
132 changes: 132 additions & 0 deletions codetide/agents/tide/streaming/chunk_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from ....core.defaults import DEFAULT_ENCODING
from aicore.logger import SPECIAL_TOKENS

from typing import List, Dict, AsyncGenerator
from collections import defaultdict, deque
from pathlib import Path
import portalocker
import asyncio
import time

class ChunkLogger:
def __init__(self, buffer_size: int = 1024, flush_interval: float = 0.1):
self.buffer_size = buffer_size
self.flush_interval = flush_interval
self._session_buffers: Dict[str, deque] = defaultdict(deque)
self._session_subscribers: Dict[str, List] = defaultdict(list)
self._file_buffers: Dict[str, List[str]] = defaultdict(list)
self._last_flush_time: Dict[str, float] = defaultdict(float)
self._background_tasks: set = set()
self._shutdown = False

async def log_chunk(self, message: str, session_id: str, filepath: str):
"""Optimized chunk logging with batched file writes and direct streaming"""
if message not in SPECIAL_TOKENS:
# Add to file buffer for batched writing
self._file_buffers[filepath].append(message)
current_time = time.time()

# Check if we should flush based on buffer size or time
should_flush = (
len(self._file_buffers[filepath]) >= self.buffer_size or
current_time - self._last_flush_time[filepath] >= self.flush_interval
)

if should_flush:
await self._flush_file_buffer(filepath)
self._last_flush_time[filepath] = current_time

# Directly notify subscribers without queue overhead
await self._notify_subscribers(session_id, message)

async def _flush_file_buffer(self, filepath: str):
"""Flush buffer to file with file locking"""
if not self._file_buffers[filepath]:
return

messages_to_write = self._file_buffers[filepath].copy()
self._file_buffers[filepath].clear()

# Create directory if it doesn't exist
Path(filepath).parent.mkdir(parents=True, exist_ok=True)

try:
# Use portalocker for safe concurrent file access
with open(filepath, 'a', encoding=DEFAULT_ENCODING) as f:
portalocker.lock(f, portalocker.LOCK_EX)
try:
f.writelines(messages_to_write)
f.flush() # Ensure data is written to disk
finally:
portalocker.unlock(f)
except Exception as e:
# Re-add messages to buffer if write failed
self._file_buffers[filepath].extendleft(reversed(messages_to_write))
raise e

async def _notify_subscribers(self, session_id: str, message: str):
"""Directly notify subscribers without queue overhead"""
if session_id in self._session_subscribers:
# Use a list copy to avoid modification during iteration
subscribers = list(self._session_subscribers[session_id])
for queue in subscribers:
try:
queue.put_nowait(message)
except asyncio.QueueFull:
# Remove full queues (slow consumers)
self._session_subscribers[session_id].remove(queue)
except Exception:
# Remove invalid queues
if queue in self._session_subscribers[session_id]:
self._session_subscribers[session_id].remove(queue)

async def get_session_logs(self, session_id: str) -> AsyncGenerator[str, None]:
"""Get streaming logs for a session without separate distributor task"""
# Create a queue for this subscriber
queue = asyncio.Queue(maxsize=1000) # Prevent memory issues

# Add to subscribers
self._session_subscribers[session_id].append(queue)

try:
while not self._shutdown:
try:
# Use a timeout to allow for cleanup checks
chunk = await asyncio.wait_for(queue.get(), timeout=1.0)
yield chunk
except asyncio.TimeoutError:
# Check if we should continue or if there are no more publishers
continue
except asyncio.CancelledError:
break
finally:
# Cleanup subscriber
if queue in self._session_subscribers[session_id]:
self._session_subscribers[session_id].remove(queue)

# Clean up empty session entries
if not self._session_subscribers[session_id]:
del self._session_subscribers[session_id]

async def ensure_all_flushed(self):
"""Ensure all buffers are flushed - call before shutdown"""
flush_tasks = []
for filepath in list(self._file_buffers.keys()):
if self._file_buffers[filepath]:
flush_tasks.append(self._flush_file_buffer(filepath))

if flush_tasks:
await asyncio.gather(*flush_tasks, return_exceptions=True)

async def shutdown(self):
"""Graceful shutdown"""
self._shutdown = True
await self.ensure_all_flushed()

# Cancel any background tasks
for task in self._background_tasks:
if not task.done():
task.cancel()

if self._background_tasks:
await asyncio.gather(*self._background_tasks, return_exceptions=True)
Loading