From 0a432e563ad39f0ef4792df25b185babc061576c Mon Sep 17 00:00:00 2001 From: Hao Luo Date: Wed, 7 Jan 2026 17:54:47 -1000 Subject: [PATCH 1/8] added service thread support --- .../ag-ui/agent_framework_ag_ui/_agent.py | 6 + .../agent_framework_ag_ui/_orchestrators.py | 131 +++++++++++++++--- .../ag-ui/agent_framework_ag_ui/_utils.py | 18 ++- .../ag-ui/tests/test_orchestrators.py | 18 ++- .../ag-ui/tests/test_service_thread_id.py | 102 ++++++++++++++ 5 files changed, 248 insertions(+), 27 deletions(-) create mode 100644 python/packages/ag-ui/tests/test_service_thread_id.py diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py index 23860150be..806f5ab1bb 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py @@ -24,6 +24,7 @@ def __init__( self, state_schema: Any | None = None, predict_state_config: dict[str, dict[str, str]] | None = None, + use_service_thread: bool = False, require_confirmation: bool = True, ): """Initialize agent configuration. @@ -31,10 +32,12 @@ def __init__( Args: state_schema: Optional state schema for state management; accepts dict or Pydantic model/class predict_state_config: Configuration for predictive state updates + use_service_thread: Whether the agent thread is service-managed require_confirmation: Whether predictive updates require confirmation """ self.state_schema = self._normalize_state_schema(state_schema) self.predict_state_config = predict_state_config or {} + self.use_service_thread = use_service_thread self.require_confirmation = require_confirmation @staticmethod @@ -86,6 +89,7 @@ def __init__( predict_state_config: dict[str, dict[str, str]] | None = None, require_confirmation: bool = True, orchestrators: list[Orchestrator] | None = None, + use_service_thread: bool = False, confirmation_strategy: ConfirmationStrategy | None = None, ): """Initialize the AG-UI compatible agent wrapper. @@ -101,6 +105,7 @@ def __init__( Set to False for agentic generative UI that updates automatically. orchestrators: Custom orchestrators (auto-configured if None). Orchestrators are checked in order; first match handles the request. + use_service_thread: Whether the agent thread is service-managed. confirmation_strategy: Strategy for generating confirmation messages. Defaults to DefaultConfirmationStrategy if None. """ @@ -111,6 +116,7 @@ def __init__( self.config = AgentConfig( state_schema=state_schema, predict_state_config=predict_state_config, + use_service_thread=use_service_thread, require_confirmation=require_confirmation, ) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index 6bdff552b6..a2624f5709 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -6,7 +6,7 @@ import logging import uuid from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Sequence from typing import TYPE_CHECKING, Any from ag_ui.core import ( @@ -26,11 +26,13 @@ TextContent, ) -from ._utils import convert_agui_tools_to_agent_framework, generate_event_id +from ._utils import convert_agui_tools_to_agent_framework, generate_event_id, get_conversation_id_from_update if TYPE_CHECKING: from ._agent import AgentConfig from ._confirmation_strategies import ConfirmationStrategy + from ._events import AgentFrameworkEventBridge + from ._orchestration._state_manager import StateManager logger = logging.getLogger(__name__) @@ -64,6 +66,8 @@ def __init__( self._last_message = None self._run_id: str | None = None self._thread_id: str | None = None + self._supplied_run_id: str | None = None + self._supplied_thread_id: str | None = None @property def messages(self): @@ -82,26 +86,58 @@ def last_message(self): self._last_message = self.messages[-1] return self._last_message + @property + def supplied_run_id(self) -> str | None: + """Get the supplied run ID, if any.""" + if self._supplied_run_id is None: + self._supplied_run_id = self.input_data.get("run_id") or self.input_data.get("runId") + return self._supplied_run_id + @property def run_id(self) -> str: - """Get or generate run ID.""" + """Get supplied run ID or generate a new run ID.""" + if self.supplied_run_id is not None: + self._run_id = self.supplied_run_id + if self._run_id is None: - self._run_id = self.input_data.get("run_id") or self.input_data.get("runId") or str(uuid.uuid4()) - # This should never be None after the if block above, but satisfy type checkers - if self._run_id is None: # pragma: no cover - raise RuntimeError("Failed to initialize run_id") + self._run_id = str(uuid.uuid4()) + return self._run_id + @property + def supplied_thread_id(self) -> str | None: + """Get the supplied thread ID, if any.""" + if self._supplied_thread_id is None: + self._supplied_thread_id = self.input_data.get("thread_id") or self.input_data.get("threadId") + return self._supplied_thread_id + @property def thread_id(self) -> str: - """Get or generate thread ID.""" + """Get supplied thread ID or generate a new thread ID.""" + if self.supplied_thread_id is not None: + self._thread_id = self.supplied_thread_id + if self._thread_id is None: - self._thread_id = self.input_data.get("thread_id") or self.input_data.get("threadId") or str(uuid.uuid4()) - # This should never be None after the if block above, but satisfy type checkers - if self._thread_id is None: # pragma: no cover - raise RuntimeError("Failed to initialize thread_id") + self._thread_id = str(uuid.uuid4()) + return self._thread_id + def update_run_id(self, new_run_id: str) -> None: + """Update the run ID in the context. + + Args: + new_run_id: The new run ID to set + """ + self._supplied_run_id = new_run_id + + def update_thread_id(self, new_thread_id: str) -> None: + """Update the thread ID in the context. + + Args: + new_thread_id: The new thread ID to set + """ + self._supplied_thread_id = new_thread_id + class Orchestrator(ABC): """Base orchestrator for agent execution flows.""" @@ -254,6 +290,28 @@ def can_handle(self, context: ExecutionContext) -> bool: """ return True + def _create_initial_events( + self, event_bridge: "AgentFrameworkEventBridge", state_manager: "StateManager" + ) -> Sequence[BaseEvent]: + """Generate initial events for the run. + + Args: + event_bridge: Event bridge for creating events + Returns: + Initial AG-UI events + """ + events: list[BaseEvent] = [event_bridge.create_run_started_event()] + + predict_event = state_manager.predict_state_event() + if predict_event: + events.append(predict_event) + + snapshot_event = state_manager.initial_snapshot_event(event_bridge) + if snapshot_event: + events.append(snapshot_event) + + return events + async def run( self, context: ExecutionContext, @@ -303,17 +361,11 @@ async def run( require_confirmation=context.config.require_confirmation, ) - yield event_bridge.create_run_started_event() - - predict_event = state_manager.predict_state_event() - if predict_event: - yield predict_event - - snapshot_event = state_manager.initial_snapshot_event(event_bridge) - if snapshot_event: - yield snapshot_event + if context.config.use_service_thread: + thread = AgentThread(service_thread_id=context.thread_id) + else: + thread = AgentThread() - thread = AgentThread() thread.metadata = { # type: ignore[attr-defined] "ag_ui_thread_id": context.thread_id, "ag_ui_run_id": context.run_id, @@ -323,6 +375,8 @@ async def run( raw_messages = context.messages or [] if not raw_messages: + for event in self._create_initial_events(event_bridge, state_manager): + yield event logger.warning("No messages provided in AG-UI input") yield event_bridge.create_run_finished_event() return @@ -358,6 +412,8 @@ async def run( provider_messages = deduplicate_messages(sanitized_messages) if not provider_messages: + for event in self._create_initial_events(event_bridge, state_manager): + yield event logger.info("No provider-eligible messages after filtering; finishing run without invoking agent.") yield event_bridge.create_run_finished_event() return @@ -441,7 +497,33 @@ async def run( if safe_metadata: run_kwargs["store"] = True + should_recreate_event_bridge = False async for update in context.agent.run_stream(messages_to_run, **run_kwargs): + conv_id = get_conversation_id_from_update(update) + if conv_id and conv_id != context.thread_id: + context.update_thread_id(conv_id) + should_recreate_event_bridge = True + + if hasattr(update, "response_id") and update.response_id and update.response_id != context.run_id: + context.update_run_id(update.response_id) + should_recreate_event_bridge = True + + if should_recreate_event_bridge: + event_bridge = AgentFrameworkEventBridge( + run_id=context.run_id, + thread_id=context.thread_id, + predict_state_config=context.config.predict_state_config, + current_state=current_state, + skip_text_content=skip_text_content, + input_messages=context.input_data.get("messages", []), + require_confirmation=context.config.require_confirmation, + ) + should_recreate_event_bridge = False + + if update_count == 0: + for events in self._create_initial_events(event_bridge, state_manager): + yield events + update_count += 1 logger.info(f"[STREAM] Received update #{update_count} from agent") all_updates.append(update) @@ -503,6 +585,11 @@ async def run( yield TextMessageEndEvent(message_id=message_id) logger.info(f"Emitted conversational message with length={len(response_dict['message'])}") + if len(all_updates) == 0: + logger.info("No updates received from agent - emitting initial events") + for event in self._create_initial_events(event_bridge, state_manager): + yield event + logger.info(f"[FINALIZE] Checking for unclosed message. current_message_id={event_bridge.current_message_id}") if event_bridge.current_message_id: logger.info(f"[FINALIZE] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}") diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py index 8b271988dc..6b473fc836 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py @@ -9,7 +9,7 @@ from datetime import date, datetime from typing import Any -from agent_framework import AIFunction, ToolProtocol +from agent_framework import AgentRunResponseUpdate, AIFunction, ToolProtocol def generate_event_id() -> str: @@ -164,3 +164,19 @@ def convert_tools_to_agui_format( continue return results if results else None + + +def get_conversation_id_from_update(update: AgentRunResponseUpdate) -> str | None: + """Extract conversation ID from AgentRunResponseUpdate metadata. + + Args: + update: AgentRunResponseUpdate instance + Returns: + Conversation ID if present, else None + + """ + # if update.conversation_id: + # return update.conversation_id + if update.additional_properties: + return update.additional_properties.get("conversation_id") + return None diff --git a/python/packages/ag-ui/tests/test_orchestrators.py b/python/packages/ag-ui/tests/test_orchestrators.py index af90ea2e88..3da8fe06f8 100644 --- a/python/packages/ag-ui/tests/test_orchestrators.py +++ b/python/packages/ag-ui/tests/test_orchestrators.py @@ -6,7 +6,8 @@ from types import SimpleNamespace from typing import Any -from agent_framework import AgentRunResponseUpdate, TextContent, ai_function +from ag_ui.core import BaseEvent, RunFinishedEvent +from agent_framework import AgentRunResponseUpdate, AgentThread, TextContent, ai_function from agent_framework._tools import FunctionInvocationConfiguration from agent_framework_ag_ui._agent import AgentConfig @@ -34,12 +35,19 @@ async def run_stream( self, messages: list[Any], *, - thread: Any, + thread: AgentThread, tools: list[Any] | None = None, **kwargs: Any, ) -> AsyncGenerator[AgentRunResponseUpdate, None]: self.seen_tools = tools - yield AgentRunResponseUpdate(contents=[TextContent(text="ok")], role="assistant") + yield AgentRunResponseUpdate( + contents=[TextContent(text="ok")], + role="assistant", + response_id=thread.metadata.get("ag_ui_run_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) + additional_properties={ + "conversation_id": thread.metadata.get("ag_ui_thread_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) + }, + ) async def test_default_orchestrator_merges_client_tools() -> None: @@ -114,6 +122,7 @@ async def test_default_orchestrator_with_camel_case_ids() -> None: events.append(event) # assert the last event has the expected run_id and thread_id + assert isinstance(events[-1], RunFinishedEvent) last_event = events[-1] assert last_event.run_id == "test-camelcase-runid" assert last_event.thread_id == "test-camelcase-threadid" @@ -143,11 +152,12 @@ async def test_default_orchestrator_with_snake_case_ids() -> None: config=AgentConfig(), ) - events = [] + events: list[BaseEvent] = [] async for event in orchestrator.run(context): events.append(event) # assert the last event has the expected run_id and thread_id + assert isinstance(events[-1], RunFinishedEvent) last_event = events[-1] assert last_event.run_id == "test-snakecase-runid" assert last_event.thread_id == "test-snakecase-threadid" diff --git a/python/packages/ag-ui/tests/test_service_thread_id.py b/python/packages/ag-ui/tests/test_service_thread_id.py new file mode 100644 index 0000000000..f717576cbd --- /dev/null +++ b/python/packages/ag-ui/tests/test_service_thread_id.py @@ -0,0 +1,102 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for service-managed thread IDs, and service-generated response ids.""" + +import sys +from pathlib import Path +from typing import Any + +from ag_ui.core import RunFinishedEvent, RunStartedEvent +from agent_framework import TextContent +from agent_framework._types import AgentRunResponseUpdate + +sys.path.insert(0, str(Path(__file__).parent)) +from .test_helpers_ag_ui import StubAgent + + +async def test_service_thread_id_when_there_are_updates(): + """Test that service-managed thread IDs (conversation_id) are correctly set as the thread_id in events.""" + from agent_framework.ag_ui import AgentFrameworkAgent + + updates: list[AgentRunResponseUpdate] = [ + AgentRunResponseUpdate( + contents=[TextContent(text="Hello, user!")], + response_id="resp_67890", + additional_properties={"conversation_id": "conv_12345"}, + ) + ] + agent = StubAgent(updates=updates) + wrapper = AgentFrameworkAgent(agent=agent) + + input_data = { + "messages": [{"role": "user", "content": "Hi"}], + } + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + + assert events[0].run_id == "resp_67890" + assert events[0].thread_id == "conv_12345" + + +async def test_service_thread_id_when_no_updates(): + """Test when no response updates are returned, events still have with a thread_id""" + from agent_framework.ag_ui import AgentFrameworkAgent + + updates: list[AgentRunResponseUpdate] = [] + agent = StubAgent(updates=updates) + wrapper = AgentFrameworkAgent(agent=agent) + + input_data = { + "messages": [{"role": "user", "content": "Hi"}], + } + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + + assert isinstance(events[0], RunStartedEvent) + assert events[0].thread_id + assert isinstance(events[-1], RunFinishedEvent) + + +async def test_service_thread_id_when_no_user_message(): + """Test when user submits no messages, emitted events still have with a thread_id""" + from agent_framework.ag_ui import AgentFrameworkAgent + + updates: list[AgentRunResponseUpdate] = [] + agent = StubAgent(updates=updates) + wrapper = AgentFrameworkAgent(agent=agent) + + input_data: dict[str, list[dict[str, str]]] = { + "messages": [], + } + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + + assert len(events) == 2 + assert isinstance(events[0], RunStartedEvent) + assert events[0].thread_id + assert isinstance(events[-1], RunFinishedEvent) + + +async def test_service_thread_id_when_user_supplied_thread_id(): + """Test when user submits no messages, emitted events still have with a thread_id""" + from agent_framework.ag_ui import AgentFrameworkAgent + + updates: list[AgentRunResponseUpdate] = [] + agent = StubAgent(updates=updates) + wrapper = AgentFrameworkAgent(agent=agent) + + input_data: dict[str, Any] = {"messages": [{"role": "user", "content": "Hi"}], "threadId": "conv_12345"} + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + + assert isinstance(events[0], RunStartedEvent) + assert events[0].thread_id == "conv_12345" + assert isinstance(events[-1], RunFinishedEvent) From abc5a11711de999ae38e25b37d0807bd7bb21b43 Mon Sep 17 00:00:00 2001 From: Hao Luo Date: Wed, 7 Jan 2026 21:19:11 -1000 Subject: [PATCH 2/8] set service_thread_id to only supplied_thread_id --- .../agent_framework_ag_ui/_orchestrators.py | 2 +- .../ag-ui/tests/test_service_thread_id.py | 21 +------------------ 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index a2624f5709..a24a8c67b2 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -362,7 +362,7 @@ async def run( ) if context.config.use_service_thread: - thread = AgentThread(service_thread_id=context.thread_id) + thread = AgentThread(service_thread_id=context.supplied_thread_id) else: thread = AgentThread() diff --git a/python/packages/ag-ui/tests/test_service_thread_id.py b/python/packages/ag-ui/tests/test_service_thread_id.py index f717576cbd..f407147ba1 100644 --- a/python/packages/ag-ui/tests/test_service_thread_id.py +++ b/python/packages/ag-ui/tests/test_service_thread_id.py @@ -36,28 +36,9 @@ async def test_service_thread_id_when_there_are_updates(): async for event in wrapper.run_agent(input_data): events.append(event) + assert isinstance(events[0], RunStartedEvent) assert events[0].run_id == "resp_67890" assert events[0].thread_id == "conv_12345" - - -async def test_service_thread_id_when_no_updates(): - """Test when no response updates are returned, events still have with a thread_id""" - from agent_framework.ag_ui import AgentFrameworkAgent - - updates: list[AgentRunResponseUpdate] = [] - agent = StubAgent(updates=updates) - wrapper = AgentFrameworkAgent(agent=agent) - - input_data = { - "messages": [{"role": "user", "content": "Hi"}], - } - - events: list[Any] = [] - async for event in wrapper.run_agent(input_data): - events.append(event) - - assert isinstance(events[0], RunStartedEvent) - assert events[0].thread_id assert isinstance(events[-1], RunFinishedEvent) From e99bf5f5ef517fb623c20c4cd19473b8771cdd07 Mon Sep 17 00:00:00 2001 From: Hao Luo Date: Wed, 7 Jan 2026 22:28:47 -1000 Subject: [PATCH 3/8] uses raw_representation to extract the conversation_id --- python/.vscode/settings.json | 3 ++- python/packages/ag-ui/agent_framework_ag_ui/_utils.py | 8 +++----- python/packages/ag-ui/tests/test_service_thread_id.py | 8 ++++++-- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/python/.vscode/settings.json b/python/.vscode/settings.json index 47da1de9e4..6804d9a87d 100644 --- a/python/.vscode/settings.json +++ b/python/.vscode/settings.json @@ -38,5 +38,6 @@ "name": "azure", "depth": 2 } - ] + ], + "editor.lineNumbers": "on" } diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py index 6b473fc836..2399d326bf 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py @@ -9,7 +9,7 @@ from datetime import date, datetime from typing import Any -from agent_framework import AgentRunResponseUpdate, AIFunction, ToolProtocol +from agent_framework import AgentRunResponseUpdate, AIFunction, ChatResponseUpdate, ToolProtocol def generate_event_id() -> str: @@ -175,8 +175,6 @@ def get_conversation_id_from_update(update: AgentRunResponseUpdate) -> str | Non Conversation ID if present, else None """ - # if update.conversation_id: - # return update.conversation_id - if update.additional_properties: - return update.additional_properties.get("conversation_id") + if isinstance(update.raw_representation, ChatResponseUpdate): + return update.raw_representation.conversation_id return None diff --git a/python/packages/ag-ui/tests/test_service_thread_id.py b/python/packages/ag-ui/tests/test_service_thread_id.py index f407147ba1..240d66fd15 100644 --- a/python/packages/ag-ui/tests/test_service_thread_id.py +++ b/python/packages/ag-ui/tests/test_service_thread_id.py @@ -8,7 +8,7 @@ from ag_ui.core import RunFinishedEvent, RunStartedEvent from agent_framework import TextContent -from agent_framework._types import AgentRunResponseUpdate +from agent_framework._types import AgentRunResponseUpdate, ChatResponseUpdate sys.path.insert(0, str(Path(__file__).parent)) from .test_helpers_ag_ui import StubAgent @@ -22,7 +22,11 @@ async def test_service_thread_id_when_there_are_updates(): AgentRunResponseUpdate( contents=[TextContent(text="Hello, user!")], response_id="resp_67890", - additional_properties={"conversation_id": "conv_12345"}, + raw_representation=ChatResponseUpdate( + contents=[TextContent(text="Hello, user!")], + conversation_id="conv_12345", + response_id="resp_67890", + ), ) ] agent = StubAgent(updates=updates) From fb8543c0cb048965d5c88fa9ee728d10b16d7e06 Mon Sep 17 00:00:00 2001 From: Hao Luo Date: Wed, 7 Jan 2026 22:30:48 -1000 Subject: [PATCH 4/8] removed accidental edit --- python/.vscode/settings.json | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/.vscode/settings.json b/python/.vscode/settings.json index 6804d9a87d..47da1de9e4 100644 --- a/python/.vscode/settings.json +++ b/python/.vscode/settings.json @@ -38,6 +38,5 @@ "name": "azure", "depth": 2 } - ], - "editor.lineNumbers": "on" + ] } From 48066fd6f3ab2a506537bf1b432cce771a5a7173 Mon Sep 17 00:00:00 2001 From: Hao Luo Date: Wed, 7 Jan 2026 22:43:06 -1000 Subject: [PATCH 5/8] updated test to use raw_representation --- python/packages/ag-ui/tests/test_orchestrators.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/packages/ag-ui/tests/test_orchestrators.py b/python/packages/ag-ui/tests/test_orchestrators.py index 3da8fe06f8..f2b9004c77 100644 --- a/python/packages/ag-ui/tests/test_orchestrators.py +++ b/python/packages/ag-ui/tests/test_orchestrators.py @@ -7,7 +7,7 @@ from typing import Any from ag_ui.core import BaseEvent, RunFinishedEvent -from agent_framework import AgentRunResponseUpdate, AgentThread, TextContent, ai_function +from agent_framework import AgentRunResponseUpdate, AgentThread, ChatResponseUpdate, TextContent, ai_function from agent_framework._tools import FunctionInvocationConfiguration from agent_framework_ag_ui._agent import AgentConfig @@ -44,9 +44,11 @@ async def run_stream( contents=[TextContent(text="ok")], role="assistant", response_id=thread.metadata.get("ag_ui_run_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) - additional_properties={ - "conversation_id": thread.metadata.get("ag_ui_thread_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) - }, + raw_representation=ChatResponseUpdate( + contents=[TextContent(text="ok")], + conversation_id=thread.metadata.get("ag_ui_thread_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) + response_id=thread.metadata.get("ag_ui_run_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) + ), ) From 1fecf5e8a9c49b9a6df085b8a9a55678937231df Mon Sep 17 00:00:00 2001 From: Hao Luo Date: Thu, 8 Jan 2026 00:34:49 -1000 Subject: [PATCH 6/8] resolves copilot review feedback --- .../agent_framework_ag_ui/_orchestrators.py | 16 +++-- .../tests/test_agent_wrapper_comprehensive.py | 58 ++++++++++++++++++- .../ag-ui/tests/test_helpers_ag_ui.py | 25 ++++++-- .../ag-ui/tests/test_service_thread_id.py | 2 +- 4 files changed, 90 insertions(+), 11 deletions(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py index a24a8c67b2..b9ba25fb2d 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py @@ -96,7 +96,10 @@ def supplied_run_id(self) -> str | None: @property def run_id(self) -> str: """Get supplied run ID or generate a new run ID.""" - if self.supplied_run_id is not None: + if self._run_id: + return self._run_id + + if self.supplied_run_id: self._run_id = self.supplied_run_id if self._run_id is None: @@ -114,7 +117,10 @@ def supplied_thread_id(self) -> str | None: @property def thread_id(self) -> str: """Get supplied thread ID or generate a new thread ID.""" - if self.supplied_thread_id is not None: + if self._thread_id: + return self._thread_id + + if self.supplied_thread_id: self._thread_id = self.supplied_thread_id if self._thread_id is None: @@ -129,6 +135,7 @@ def update_run_id(self, new_run_id: str) -> None: new_run_id: The new run ID to set """ self._supplied_run_id = new_run_id + self._run_id = new_run_id def update_thread_id(self, new_thread_id: str) -> None: """Update the thread ID in the context. @@ -137,6 +144,7 @@ def update_thread_id(self, new_thread_id: str) -> None: new_thread_id: The new thread ID to set """ self._supplied_thread_id = new_thread_id + self._thread_id = new_thread_id class Orchestrator(ABC): @@ -521,8 +529,8 @@ async def run( should_recreate_event_bridge = False if update_count == 0: - for events in self._create_initial_events(event_bridge, state_manager): - yield events + for event in self._create_initial_events(event_bridge, state_manager): + yield event update_count += 1 logger.info(f"[STREAM] Received update #{update_count} from agent") diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py index beb6f8af2c..b210edbb0e 100644 --- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py @@ -14,7 +14,7 @@ from pydantic import BaseModel sys.path.insert(0, str(Path(__file__).parent)) -from test_helpers_ag_ui import StreamingChatClientStub +from .test_helpers_ag_ui import StreamingChatClientStub async def test_agent_initialization_basic(): @@ -630,3 +630,59 @@ async def stream_fn( # Should contain some reference to the document full_text = "".join(e.delta for e in text_events) assert "written" in full_text.lower() or "document" in full_text.lower() + + +async def test_agent_with_use_service_thread_is_false(): + """Test that when use_service_thread is False, the AgentThread used to run the agent is NOT set to the service thread ID.""" + from agent_framework.ag_ui import AgentFrameworkAgent + + request_service_thread_id: str | None = None + + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + nonlocal request_service_thread_id + thread = kwargs.get("thread") + request_service_thread_id = thread.service_thread_id if thread else None + yield ChatResponseUpdate( + contents=[TextContent(text="Response")], response_id="resp_67890", conversation_id="conv_12345" + ) + + agent = ChatAgent(chat_client=StreamingChatClientStub(stream_fn)) + wrapper = AgentFrameworkAgent(agent=agent, use_service_thread=False) + + input_data = {"messages": [{"role": "user", "content": "Hi"}], "thread_id": "conv_123456"} + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + + assert request_service_thread_id is None # type: ignore[attr-defined] (service_thread_id should be set) + + +async def test_agent_with_use_service_thread_is_true(): + """Test that when use_service_thread is True, the AgentThread used to run the agent is set to the service thread ID.""" + from agent_framework.ag_ui import AgentFrameworkAgent + + request_service_thread_id: str | None = None + + async def stream_fn( + messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + nonlocal request_service_thread_id + thread = kwargs.get("thread") + request_service_thread_id = thread.service_thread_id if thread else None + yield ChatResponseUpdate( + contents=[TextContent(text="Response")], response_id="resp_67890", conversation_id="conv_12345" + ) + + agent = ChatAgent(chat_client=StreamingChatClientStub(stream_fn)) + wrapper = AgentFrameworkAgent(agent=agent, use_service_thread=True) + + input_data = {"messages": [{"role": "user", "content": "Hi"}], "thread_id": "conv_123456"} + + events: list[Any] = [] + async for event in wrapper.run_agent(input_data): + events.append(event) + + assert request_service_thread_id == "conv_123456" # type: ignore[attr-defined] (service_thread_id should be set) diff --git a/python/packages/ag-ui/tests/test_helpers_ag_ui.py b/python/packages/ag-ui/tests/test_helpers_ag_ui.py index bfb528511e..47d449859b 100644 --- a/python/packages/ag-ui/tests/test_helpers_ag_ui.py +++ b/python/packages/ag-ui/tests/test_helpers_ag_ui.py @@ -75,14 +75,14 @@ def __init__( *, agent_id: str = "stub-agent", agent_name: str | None = "stub-agent", - chat_options: Any | None = None, - chat_client: Any | None = None, + chat_options: ChatOptions | None = None, + chat_client: BaseChatClient | None = None, ) -> None: self._id = agent_id self._name = agent_name self._description = "stub agent" self.updates = updates or [AgentRunResponseUpdate(contents=[TextContent(text="response")], role="assistant")] - self.chat_options = chat_options or SimpleNamespace(tools=None, response_format=None) + self.chat_options = chat_options or ChatOptions(tools=None, response_format=None) self.chat_client = chat_client or SimpleNamespace(function_invocation_configuration=None) self.messages_received: list[Any] = [] self.tools_received: list[Any] | None = None @@ -122,8 +122,23 @@ def run_stream( async def _stream() -> AsyncIterator[AgentRunResponseUpdate]: self.messages_received = [] if messages is None else list(messages) # type: ignore[arg-type] self.tools_received = kwargs.get("tools") - for update in self.updates: - yield update + if hasattr(self.chat_client, "get_streaming_response"): + # Simulate streaming from chat client if available + async for update in self.chat_client.get_streaming_response( + messages=self.messages_received, + chat_options=self.chat_options, + **kwargs, + ): + yield AgentRunResponseUpdate( + contents=update.contents, + role=update.role, + response_id=update.response_id, + raw_representation=update, + ) + return + else: + for update in self.updates: + yield update return _stream() diff --git a/python/packages/ag-ui/tests/test_service_thread_id.py b/python/packages/ag-ui/tests/test_service_thread_id.py index 240d66fd15..c0b7e74ba5 100644 --- a/python/packages/ag-ui/tests/test_service_thread_id.py +++ b/python/packages/ag-ui/tests/test_service_thread_id.py @@ -69,7 +69,7 @@ async def test_service_thread_id_when_no_user_message(): async def test_service_thread_id_when_user_supplied_thread_id(): - """Test when user submits no messages, emitted events still have with a thread_id""" + """Test that user-supplied thread IDs are preserved in emitted events.""" from agent_framework.ag_ui import AgentFrameworkAgent updates: list[AgentRunResponseUpdate] = [] From 11761fe12c4c5505a903dacd9ae95a39775e19b6 Mon Sep 17 00:00:00 2001 From: Hao Luo Date: Thu, 8 Jan 2026 00:39:11 -1000 Subject: [PATCH 7/8] revert back StubAgent, since not used --- .../ag-ui/tests/test_helpers_ag_ui.py | 25 ++++--------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/python/packages/ag-ui/tests/test_helpers_ag_ui.py b/python/packages/ag-ui/tests/test_helpers_ag_ui.py index 47d449859b..bfb528511e 100644 --- a/python/packages/ag-ui/tests/test_helpers_ag_ui.py +++ b/python/packages/ag-ui/tests/test_helpers_ag_ui.py @@ -75,14 +75,14 @@ def __init__( *, agent_id: str = "stub-agent", agent_name: str | None = "stub-agent", - chat_options: ChatOptions | None = None, - chat_client: BaseChatClient | None = None, + chat_options: Any | None = None, + chat_client: Any | None = None, ) -> None: self._id = agent_id self._name = agent_name self._description = "stub agent" self.updates = updates or [AgentRunResponseUpdate(contents=[TextContent(text="response")], role="assistant")] - self.chat_options = chat_options or ChatOptions(tools=None, response_format=None) + self.chat_options = chat_options or SimpleNamespace(tools=None, response_format=None) self.chat_client = chat_client or SimpleNamespace(function_invocation_configuration=None) self.messages_received: list[Any] = [] self.tools_received: list[Any] | None = None @@ -122,23 +122,8 @@ def run_stream( async def _stream() -> AsyncIterator[AgentRunResponseUpdate]: self.messages_received = [] if messages is None else list(messages) # type: ignore[arg-type] self.tools_received = kwargs.get("tools") - if hasattr(self.chat_client, "get_streaming_response"): - # Simulate streaming from chat client if available - async for update in self.chat_client.get_streaming_response( - messages=self.messages_received, - chat_options=self.chat_options, - **kwargs, - ): - yield AgentRunResponseUpdate( - contents=update.contents, - role=update.role, - response_id=update.response_id, - raw_representation=update, - ) - return - else: - for update in self.updates: - yield update + for update in self.updates: + yield update return _stream() From c9ed539040c51cba609321ce08b8f9f50b01192a Mon Sep 17 00:00:00 2001 From: Hao Luo Date: Thu, 8 Jan 2026 21:19:00 -1000 Subject: [PATCH 8/8] removed relative module import --- python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py | 2 +- python/packages/ag-ui/tests/test_service_thread_id.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py index b210edbb0e..4a8dd6db9f 100644 --- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py @@ -14,7 +14,7 @@ from pydantic import BaseModel sys.path.insert(0, str(Path(__file__).parent)) -from .test_helpers_ag_ui import StreamingChatClientStub +from test_helpers_ag_ui import StreamingChatClientStub async def test_agent_initialization_basic(): diff --git a/python/packages/ag-ui/tests/test_service_thread_id.py b/python/packages/ag-ui/tests/test_service_thread_id.py index c0b7e74ba5..92d88963e7 100644 --- a/python/packages/ag-ui/tests/test_service_thread_id.py +++ b/python/packages/ag-ui/tests/test_service_thread_id.py @@ -11,7 +11,7 @@ from agent_framework._types import AgentRunResponseUpdate, ChatResponseUpdate sys.path.insert(0, str(Path(__file__).parent)) -from .test_helpers_ag_ui import StubAgent +from test_helpers_ag_ui import StubAgent async def test_service_thread_id_when_there_are_updates():