diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py index 23860150be..806f5ab1bb 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py @@ -24,6 +24,7 @@ def __init__( self, state_schema: Any | None = None, predict_state_config: dict[str, dict[str, str]] | None = None, + use_service_thread: bool = False, require_confirmation: bool = True, ): """Initialize agent configuration. @@ -31,10 +32,12 @@ def __init__( Args: state_schema: Optional state schema for state management; accepts dict or Pydantic model/class predict_state_config: Configuration for predictive state updates + use_service_thread: Whether the agent thread is service-managed require_confirmation: Whether predictive updates require confirmation """ self.state_schema = self._normalize_state_schema(state_schema) self.predict_state_config = predict_state_config or {} + self.use_service_thread = use_service_thread self.require_confirmation = require_confirmation @staticmethod @@ -86,6 +89,7 @@ def __init__( predict_state_config: dict[str, dict[str, str]] | None = None, require_confirmation: bool = True, orchestrators: list[Orchestrator] | None = None, + use_service_thread: bool = False, confirmation_strategy: ConfirmationStrategy | None = None, ): """Initialize the AG-UI compatible agent wrapper. @@ -101,6 +105,7 @@ def __init__( Set to False for agentic generative UI that updates automatically. orchestrators: Custom orchestrators (auto-configured if None). Orchestrators are checked in order; first match handles the request. + use_service_thread: Whether the agent thread is service-managed. confirmation_strategy: Strategy for generating confirmation messages. Defaults to DefaultConfirmationStrategy if None. """ @@ -111,6 +116,7 @@ def __init__( self.config = AgentConfig( state_schema=state_schema, predict_state_config=predict_state_config, + use_service_thread=use_service_thread, require_confirmation=require_confirmation, ) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index 3067e3e4a7..689d61add7 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -6,7 +6,7 @@ import logging import uuid from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Sequence from typing import TYPE_CHECKING, Any from ag_ui.core import ( @@ -53,11 +53,18 @@ merge_tools, register_additional_client_tools, ) -from ._utils import convert_agui_tools_to_agent_framework, generate_event_id, get_role_value +from ._utils import ( + convert_agui_tools_to_agent_framework, + generate_event_id, + get_conversation_id_from_update, + get_role_value, +) if TYPE_CHECKING: from ._agent import AgentConfig from ._confirmation_strategies import ConfirmationStrategy + from ._events import AgentFrameworkEventBridge + from ._orchestration._state_manager import StateManager logger = logging.getLogger(__name__) @@ -92,6 +99,8 @@ def __init__( self._last_message = None self._run_id: str | None = None self._thread_id: str | None = None + self._supplied_run_id: str | None = None + self._supplied_thread_id: str | None = None @property def messages(self): @@ -125,26 +134,66 @@ def last_message(self): self._last_message = self.messages[-1] return self._last_message + @property + def supplied_run_id(self) -> str | None: + """Get the supplied run ID, if any.""" + if self._supplied_run_id is None: + self._supplied_run_id = self.input_data.get("run_id") or self.input_data.get("runId") + return self._supplied_run_id + @property def run_id(self) -> str: - """Get or generate run ID.""" + """Get supplied run ID or generate a new run ID.""" + if self._run_id: + return self._run_id + + if self.supplied_run_id: + self._run_id = self.supplied_run_id + if self._run_id is None: - self._run_id = self.input_data.get("run_id") or self.input_data.get("runId") or str(uuid.uuid4()) - # This should never be None after the if block above, but satisfy type checkers - if self._run_id is None: # pragma: no cover - raise RuntimeError("Failed to initialize run_id") + self._run_id = str(uuid.uuid4()) + return self._run_id + @property + def supplied_thread_id(self) -> str | None: + """Get the supplied thread ID, if any.""" + if self._supplied_thread_id is None: + self._supplied_thread_id = self.input_data.get("thread_id") or self.input_data.get("threadId") + return self._supplied_thread_id + @property def thread_id(self) -> str: - """Get or generate thread ID.""" + """Get supplied thread ID or generate a new thread ID.""" + if self._thread_id: + return self._thread_id + + if self.supplied_thread_id: + self._thread_id = self.supplied_thread_id + if self._thread_id is None: - self._thread_id = self.input_data.get("thread_id") or self.input_data.get("threadId") or str(uuid.uuid4()) - # This should never be None after the if block above, but satisfy type checkers - if self._thread_id is None: # pragma: no cover - raise RuntimeError("Failed to initialize thread_id") + self._thread_id = str(uuid.uuid4()) + return self._thread_id + def update_run_id(self, new_run_id: str) -> None: + """Update the run ID in the context. + + Args: + new_run_id: The new run ID to set + """ + self._supplied_run_id = new_run_id + self._run_id = new_run_id + + def update_thread_id(self, new_thread_id: str) -> None: + """Update the thread ID in the context. + + Args: + new_thread_id: The new thread ID to set + """ + self._supplied_thread_id = new_thread_id + self._thread_id = new_thread_id + class Orchestrator(ABC): """Base orchestrator for agent execution flows.""" @@ -297,6 +346,28 @@ def can_handle(self, context: ExecutionContext) -> bool: """ return True + def _create_initial_events( + self, event_bridge: "AgentFrameworkEventBridge", state_manager: "StateManager" + ) -> Sequence[BaseEvent]: + """Generate initial events for the run. + + Args: + event_bridge: Event bridge for creating events + Returns: + Initial AG-UI events + """ + events: list[BaseEvent] = [event_bridge.create_run_started_event()] + + predict_event = state_manager.predict_state_event() + if predict_event: + events.append(predict_event) + + snapshot_event = state_manager.initial_snapshot_event(event_bridge) + if snapshot_event: + events.append(snapshot_event) + + return events + async def run( self, context: ExecutionContext, @@ -342,17 +413,11 @@ async def run( approval_tool_name=approval_tool_name, ) - yield event_bridge.create_run_started_event() - - predict_event = state_manager.predict_state_event() - if predict_event: - yield predict_event - - snapshot_event = state_manager.initial_snapshot_event(event_bridge) - if snapshot_event: - yield snapshot_event + if context.config.use_service_thread: + thread = AgentThread(service_thread_id=context.supplied_thread_id) + else: + thread = AgentThread() - thread = AgentThread() thread.metadata = { # type: ignore[attr-defined] "ag_ui_thread_id": context.thread_id, "ag_ui_run_id": context.run_id, @@ -363,6 +428,8 @@ async def run( provider_messages = context.messages or [] snapshot_messages = context.snapshot_messages if not provider_messages: + for event in self._create_initial_events(event_bridge, state_manager): + yield event logger.warning("No messages provided in AG-UI input") yield event_bridge.create_run_finished_event() return @@ -554,13 +621,41 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap confirmation_message = strategy.on_state_rejected() message_id = generate_event_id() + for event in self._create_initial_events(event_bridge, state_manager): + yield event yield TextMessageStartEvent(message_id=message_id, role="assistant") yield TextMessageContentEvent(message_id=message_id, delta=confirmation_message) yield TextMessageEndEvent(message_id=message_id) yield event_bridge.create_run_finished_event() return + should_recreate_event_bridge = False async for update in context.agent.run_stream(messages_to_run, **run_kwargs): + conv_id = get_conversation_id_from_update(update) + if conv_id and conv_id != context.thread_id: + context.update_thread_id(conv_id) + should_recreate_event_bridge = True + + if hasattr(update, "response_id") and update.response_id and update.response_id != context.run_id: + context.update_run_id(update.response_id) + should_recreate_event_bridge = True + + if should_recreate_event_bridge: + event_bridge = AgentFrameworkEventBridge( + run_id=context.run_id, + thread_id=context.thread_id, + predict_state_config=context.config.predict_state_config, + current_state=current_state, + skip_text_content=skip_text_content, + require_confirmation=context.config.require_confirmation, + approval_tool_name=approval_tool_name, + ) + should_recreate_event_bridge = False + + if update_count == 0: + for event in self._create_initial_events(event_bridge, state_manager): + yield event + update_count += 1 logger.info(f"[STREAM] Received update #{update_count} from agent") if all_updates is not None: @@ -672,6 +767,11 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap yield TextMessageEndEvent(message_id=message_id) logger.info(f"Emitted conversational message with length={len(response_dict['message'])}") + if all_updates is not None and len(all_updates) == 0: + logger.info("No updates received from agent - emitting initial events") + for event in self._create_initial_events(event_bridge, state_manager): + yield event + logger.info(f"[FINALIZE] Checking for unclosed message. current_message_id={event_bridge.current_message_id}") if event_bridge.current_message_id: logger.info(f"[FINALIZE] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}") diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py index c0da986308..7de3b4f55f 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py @@ -10,7 +10,7 @@ from datetime import date, datetime from typing import Any -from agent_framework import AIFunction, Role, ToolProtocol +from agent_framework import AgentRunResponseUpdate, AIFunction, ChatResponseUpdate, Role, ToolProtocol # Role mapping constants AGUI_TO_FRAMEWORK_ROLE: dict[str, Role] = { @@ -259,3 +259,17 @@ def convert_tools_to_agui_format( continue return results if results else None + + +def get_conversation_id_from_update(update: AgentRunResponseUpdate) -> str | None: + """Extract conversation ID from AgentRunResponseUpdate metadata. + + Args: + update: AgentRunResponseUpdate instance + Returns: + Conversation ID if present, else None + + """ + if isinstance(update.raw_representation, ChatResponseUpdate): + return update.raw_representation.conversation_id + return None diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py index 281b81c968..17af8b956d 100644 --- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py @@ -632,6 +632,60 @@ async def stream_fn( assert "written" in full_text.lower() or "document" in full_text.lower() +async def test_agent_with_use_service_thread_is_false(): + """Test that when use_service_thread is False, the AgentThread used to run the agent is NOT set to the service thread ID.""" + from agent_framework.ag_ui import AgentFrameworkAgent + + request_service_thread_id: str | None = None + + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + nonlocal request_service_thread_id + thread = kwargs.get("thread") + request_service_thread_id = thread.service_thread_id if thread else None + yield ChatResponseUpdate( + contents=[TextContent(text="Response")], response_id="resp_67890", conversation_id="conv_12345" + ) + + agent = ChatAgent(chat_client=StreamingChatClientStub(stream_fn)) + wrapper = AgentFrameworkAgent(agent=agent, use_service_thread=False) + + input_data = {"messages": [{"role": "user", "content": "Hi"}], "thread_id": "conv_123456"} + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + assert request_service_thread_id is None # type: ignore[attr-defined] (service_thread_id should be set) + + +async def test_agent_with_use_service_thread_is_true(): + """Test that when use_service_thread is True, the AgentThread used to run the agent is set to the service thread ID.""" + from agent_framework.ag_ui import AgentFrameworkAgent + + request_service_thread_id: str | None = None + + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + nonlocal request_service_thread_id + thread = kwargs.get("thread") + request_service_thread_id = thread.service_thread_id if thread else None + yield ChatResponseUpdate( + contents=[TextContent(text="Response")], response_id="resp_67890", conversation_id="conv_12345" + ) + + agent = ChatAgent(chat_client=StreamingChatClientStub(stream_fn)) + wrapper = AgentFrameworkAgent(agent=agent, use_service_thread=True) + + input_data = {"messages": [{"role": "user", "content": "Hi"}], "thread_id": "conv_123456"} + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + assert request_service_thread_id == "conv_123456" # type: ignore[attr-defined] (service_thread_id should be set) + + async def test_function_approval_mode_executes_tool(): """Test that function approval with approval_mode='always_require' sends the correct messages.""" from agent_framework import FunctionResultContent, ai_function diff --git a/python/packages/ag-ui/tests/test_orchestrators.py b/python/packages/ag-ui/tests/test_orchestrators.py index 8c00602538..841a5493f7 100644 --- a/python/packages/ag-ui/tests/test_orchestrators.py +++ b/python/packages/ag-ui/tests/test_orchestrators.py @@ -6,7 +6,8 @@ from types import SimpleNamespace from typing import Any -from agent_framework import AgentRunResponseUpdate, TextContent, ai_function +from ag_ui.core import BaseEvent, RunFinishedEvent +from agent_framework import AgentRunResponseUpdate, AgentThread, ChatResponseUpdate, TextContent, ai_function from agent_framework._tools import FunctionInvocationConfiguration from agent_framework_ag_ui._agent import AgentConfig @@ -34,12 +35,21 @@ async def run_stream( self, messages: list[Any], *, - thread: Any, + thread: AgentThread, tools: list[Any] | None = None, **kwargs: Any, ) -> AsyncGenerator[AgentRunResponseUpdate, None]: self.seen_tools = tools - yield AgentRunResponseUpdate(contents=[TextContent(text="ok")], role="assistant") + yield AgentRunResponseUpdate( + contents=[TextContent(text="ok")], + role="assistant", + response_id=thread.metadata.get("ag_ui_run_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) + raw_representation=ChatResponseUpdate( + contents=[TextContent(text="ok")], + conversation_id=thread.metadata.get("ag_ui_thread_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) + response_id=thread.metadata.get("ag_ui_run_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) + ), + ) class RecordingAgent: @@ -137,6 +147,7 @@ async def test_default_orchestrator_with_camel_case_ids() -> None: events.append(event) # assert the last event has the expected run_id and thread_id + assert isinstance(events[-1], RunFinishedEvent) last_event = events[-1] assert last_event.run_id == "test-camelcase-runid" assert last_event.thread_id == "test-camelcase-threadid" @@ -166,11 +177,12 @@ async def test_default_orchestrator_with_snake_case_ids() -> None: config=AgentConfig(), ) - events = [] + events: list[BaseEvent] = [] async for event in orchestrator.run(context): events.append(event) # assert the last event has the expected run_id and thread_id + assert isinstance(events[-1], RunFinishedEvent) last_event = events[-1] assert last_event.run_id == "test-snakecase-runid" assert last_event.thread_id == "test-snakecase-threadid" diff --git a/python/packages/ag-ui/tests/test_service_thread_id.py b/python/packages/ag-ui/tests/test_service_thread_id.py new file mode 100644 index 0000000000..92d88963e7 --- /dev/null +++ b/python/packages/ag-ui/tests/test_service_thread_id.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for service-managed thread IDs, and service-generated response ids.""" + +import sys +from pathlib import Path +from typing import Any + +from ag_ui.core import RunFinishedEvent, RunStartedEvent +from agent_framework import TextContent +from agent_framework._types import AgentRunResponseUpdate, ChatResponseUpdate + +sys.path.insert(0, str(Path(__file__).parent)) +from test_helpers_ag_ui import StubAgent + + +async def test_service_thread_id_when_there_are_updates(): + """Test that service-managed thread IDs (conversation_id) are correctly set as the thread_id in events.""" + from agent_framework.ag_ui import AgentFrameworkAgent + + updates: list[AgentRunResponseUpdate] = [ + AgentRunResponseUpdate( + contents=[TextContent(text="Hello, user!")], + response_id="resp_67890", + raw_representation=ChatResponseUpdate( + contents=[TextContent(text="Hello, user!")], + conversation_id="conv_12345", + response_id="resp_67890", + ), + ) + ] + agent = StubAgent(updates=updates) + wrapper = AgentFrameworkAgent(agent=agent) + + input_data = { + "messages": [{"role": "user", "content": "Hi"}], + } + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + + assert isinstance(events[0], RunStartedEvent) + assert events[0].run_id == "resp_67890" + assert events[0].thread_id == "conv_12345" + assert isinstance(events[-1], RunFinishedEvent) + + +async def test_service_thread_id_when_no_user_message(): + """Test when user submits no messages, emitted events still have with a thread_id""" + from agent_framework.ag_ui import AgentFrameworkAgent + + updates: list[AgentRunResponseUpdate] = [] + agent = StubAgent(updates=updates) + wrapper = AgentFrameworkAgent(agent=agent) + + input_data: dict[str, list[dict[str, str]]] = { + "messages": [], + } + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + + assert len(events) == 2 + assert isinstance(events[0], RunStartedEvent) + assert events[0].thread_id + assert isinstance(events[-1], RunFinishedEvent) + + +async def test_service_thread_id_when_user_supplied_thread_id(): + """Test that user-supplied thread IDs are preserved in emitted events.""" + from agent_framework.ag_ui import AgentFrameworkAgent + + updates: list[AgentRunResponseUpdate] = [] + agent = StubAgent(updates=updates) + wrapper = AgentFrameworkAgent(agent=agent) + + input_data: dict[str, Any] = {"messages": [{"role": "user", "content": "Hi"}], "threadId": "conv_12345"} + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + + assert isinstance(events[0], RunStartedEvent) + assert events[0].thread_id == "conv_12345" + assert isinstance(events[-1], RunFinishedEvent)