Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions python/packages/ag-ui/agent_framework_ag_ui/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,20 @@ def __init__(
self,
state_schema: Any | None = None,
predict_state_config: dict[str, dict[str, str]] | None = None,
use_service_thread: bool = False,
require_confirmation: bool = True,
):
"""Initialize agent configuration.

Args:
state_schema: Optional state schema for state management; accepts dict or Pydantic model/class
predict_state_config: Configuration for predictive state updates
use_service_thread: Whether the agent thread is service-managed
require_confirmation: Whether predictive updates require confirmation
"""
self.state_schema = self._normalize_state_schema(state_schema)
self.predict_state_config = predict_state_config or {}
self.use_service_thread = use_service_thread
self.require_confirmation = require_confirmation

@staticmethod
Expand Down Expand Up @@ -86,6 +89,7 @@ def __init__(
predict_state_config: dict[str, dict[str, str]] | None = None,
require_confirmation: bool = True,
orchestrators: list[Orchestrator] | None = None,
use_service_thread: bool = False,
confirmation_strategy: ConfirmationStrategy | None = None,
):
"""Initialize the AG-UI compatible agent wrapper.
Expand All @@ -101,6 +105,7 @@ def __init__(
Set to False for agentic generative UI that updates automatically.
orchestrators: Custom orchestrators (auto-configured if None).
Orchestrators are checked in order; first match handles the request.
use_service_thread: Whether the agent thread is service-managed.
confirmation_strategy: Strategy for generating confirmation messages.
Defaults to DefaultConfirmationStrategy if None.
"""
Expand All @@ -111,6 +116,7 @@ def __init__(
self.config = AgentConfig(
state_schema=state_schema,
predict_state_config=predict_state_config,
use_service_thread=use_service_thread,
require_confirmation=require_confirmation,
)

Expand Down
144 changes: 122 additions & 22 deletions python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import uuid
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Sequence
from typing import TYPE_CHECKING, Any

from ag_ui.core import (
Expand Down Expand Up @@ -53,11 +53,18 @@
merge_tools,
register_additional_client_tools,
)
from ._utils import convert_agui_tools_to_agent_framework, generate_event_id, get_role_value
from ._utils import (
convert_agui_tools_to_agent_framework,
generate_event_id,
get_conversation_id_from_update,
get_role_value,
)

if TYPE_CHECKING:
from ._agent import AgentConfig
from ._confirmation_strategies import ConfirmationStrategy
from ._events import AgentFrameworkEventBridge
from ._orchestration._state_manager import StateManager


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -92,6 +99,8 @@ def __init__(
self._last_message = None
self._run_id: str | None = None
self._thread_id: str | None = None
self._supplied_run_id: str | None = None
self._supplied_thread_id: str | None = None

@property
def messages(self):
Expand Down Expand Up @@ -125,26 +134,66 @@ def last_message(self):
self._last_message = self.messages[-1]
return self._last_message

@property
def supplied_run_id(self) -> str | None:
"""Get the supplied run ID, if any."""
if self._supplied_run_id is None:
self._supplied_run_id = self.input_data.get("run_id") or self.input_data.get("runId")
return self._supplied_run_id

@property
def run_id(self) -> str:
"""Get or generate run ID."""
"""Get supplied run ID or generate a new run ID."""
if self._run_id:
return self._run_id

if self.supplied_run_id:
self._run_id = self.supplied_run_id

if self._run_id is None:
self._run_id = self.input_data.get("run_id") or self.input_data.get("runId") or str(uuid.uuid4())
# This should never be None after the if block above, but satisfy type checkers
if self._run_id is None: # pragma: no cover
raise RuntimeError("Failed to initialize run_id")
self._run_id = str(uuid.uuid4())

return self._run_id

@property
def supplied_thread_id(self) -> str | None:
"""Get the supplied thread ID, if any."""
if self._supplied_thread_id is None:
self._supplied_thread_id = self.input_data.get("thread_id") or self.input_data.get("threadId")
return self._supplied_thread_id

@property
def thread_id(self) -> str:
"""Get or generate thread ID."""
"""Get supplied thread ID or generate a new thread ID."""
if self._thread_id:
return self._thread_id

if self.supplied_thread_id:
self._thread_id = self.supplied_thread_id

if self._thread_id is None:
self._thread_id = self.input_data.get("thread_id") or self.input_data.get("threadId") or str(uuid.uuid4())
# This should never be None after the if block above, but satisfy type checkers
if self._thread_id is None: # pragma: no cover
raise RuntimeError("Failed to initialize thread_id")
self._thread_id = str(uuid.uuid4())

return self._thread_id

def update_run_id(self, new_run_id: str) -> None:
"""Update the run ID in the context.

Args:
new_run_id: The new run ID to set
"""
self._supplied_run_id = new_run_id
self._run_id = new_run_id

def update_thread_id(self, new_thread_id: str) -> None:
"""Update the thread ID in the context.

Args:
new_thread_id: The new thread ID to set
"""
self._supplied_thread_id = new_thread_id
self._thread_id = new_thread_id


class Orchestrator(ABC):
"""Base orchestrator for agent execution flows."""
Expand Down Expand Up @@ -297,6 +346,28 @@ def can_handle(self, context: ExecutionContext) -> bool:
"""
return True

def _create_initial_events(
self, event_bridge: "AgentFrameworkEventBridge", state_manager: "StateManager"
) -> Sequence[BaseEvent]:
"""Generate initial events for the run.

Args:
event_bridge: Event bridge for creating events
Returns:
Initial AG-UI events
"""
events: list[BaseEvent] = [event_bridge.create_run_started_event()]

predict_event = state_manager.predict_state_event()
if predict_event:
events.append(predict_event)

snapshot_event = state_manager.initial_snapshot_event(event_bridge)
if snapshot_event:
events.append(snapshot_event)

return events

async def run(
self,
context: ExecutionContext,
Expand Down Expand Up @@ -342,17 +413,11 @@ async def run(
approval_tool_name=approval_tool_name,
)

yield event_bridge.create_run_started_event()

predict_event = state_manager.predict_state_event()
if predict_event:
yield predict_event

snapshot_event = state_manager.initial_snapshot_event(event_bridge)
if snapshot_event:
yield snapshot_event
if context.config.use_service_thread:
thread = AgentThread(service_thread_id=context.supplied_thread_id)
else:
thread = AgentThread()

thread = AgentThread()
thread.metadata = { # type: ignore[attr-defined]
"ag_ui_thread_id": context.thread_id,
"ag_ui_run_id": context.run_id,
Expand All @@ -363,6 +428,8 @@ async def run(
provider_messages = context.messages or []
snapshot_messages = context.snapshot_messages
if not provider_messages:
for event in self._create_initial_events(event_bridge, state_manager):
yield event
logger.warning("No messages provided in AG-UI input")
yield event_bridge.create_run_finished_event()
return
Expand Down Expand Up @@ -554,13 +621,41 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap
confirmation_message = strategy.on_state_rejected()

message_id = generate_event_id()
for event in self._create_initial_events(event_bridge, state_manager):
yield event
yield TextMessageStartEvent(message_id=message_id, role="assistant")
yield TextMessageContentEvent(message_id=message_id, delta=confirmation_message)
yield TextMessageEndEvent(message_id=message_id)
yield event_bridge.create_run_finished_event()
return

should_recreate_event_bridge = False
async for update in context.agent.run_stream(messages_to_run, **run_kwargs):
conv_id = get_conversation_id_from_update(update)
if conv_id and conv_id != context.thread_id:
context.update_thread_id(conv_id)
should_recreate_event_bridge = True

if hasattr(update, "response_id") and update.response_id and update.response_id != context.run_id:
context.update_run_id(update.response_id)
should_recreate_event_bridge = True

if should_recreate_event_bridge:
event_bridge = AgentFrameworkEventBridge(
run_id=context.run_id,
thread_id=context.thread_id,
predict_state_config=context.config.predict_state_config,
current_state=current_state,
skip_text_content=skip_text_content,
require_confirmation=context.config.require_confirmation,
approval_tool_name=approval_tool_name,
)
should_recreate_event_bridge = False

if update_count == 0:
for event in self._create_initial_events(event_bridge, state_manager):
yield event

update_count += 1
logger.info(f"[STREAM] Received update #{update_count} from agent")
if all_updates is not None:
Expand Down Expand Up @@ -672,6 +767,11 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap
yield TextMessageEndEvent(message_id=message_id)
logger.info(f"Emitted conversational message with length={len(response_dict['message'])}")

if all_updates is not None and len(all_updates) == 0:
logger.info("No updates received from agent - emitting initial events")
for event in self._create_initial_events(event_bridge, state_manager):
yield event

logger.info(f"[FINALIZE] Checking for unclosed message. current_message_id={event_bridge.current_message_id}")
if event_bridge.current_message_id:
logger.info(f"[FINALIZE] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}")
Expand Down
16 changes: 15 additions & 1 deletion python/packages/ag-ui/agent_framework_ag_ui/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from datetime import date, datetime
from typing import Any

from agent_framework import AIFunction, Role, ToolProtocol
from agent_framework import AgentRunResponseUpdate, AIFunction, ChatResponseUpdate, Role, ToolProtocol

# Role mapping constants
AGUI_TO_FRAMEWORK_ROLE: dict[str, Role] = {
Expand Down Expand Up @@ -259,3 +259,17 @@ def convert_tools_to_agui_format(
continue

return results if results else None


def get_conversation_id_from_update(update: AgentRunResponseUpdate) -> str | None:
"""Extract conversation ID from AgentRunResponseUpdate metadata.

Args:
update: AgentRunResponseUpdate instance
Returns:
Conversation ID if present, else None

"""
if isinstance(update.raw_representation, ChatResponseUpdate):
return update.raw_representation.conversation_id
return None
54 changes: 54 additions & 0 deletions python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,60 @@ async def stream_fn(
assert "written" in full_text.lower() or "document" in full_text.lower()


async def test_agent_with_use_service_thread_is_false():
"""Test that when use_service_thread is False, the AgentThread used to run the agent is NOT set to the service thread ID."""
from agent_framework.ag_ui import AgentFrameworkAgent

request_service_thread_id: str | None = None

async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
nonlocal request_service_thread_id
thread = kwargs.get("thread")
request_service_thread_id = thread.service_thread_id if thread else None
yield ChatResponseUpdate(
contents=[TextContent(text="Response")], response_id="resp_67890", conversation_id="conv_12345"
)

agent = ChatAgent(chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent, use_service_thread=False)

input_data = {"messages": [{"role": "user", "content": "Hi"}], "thread_id": "conv_123456"}

events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
assert request_service_thread_id is None # type: ignore[attr-defined] (service_thread_id should be set)


async def test_agent_with_use_service_thread_is_true():
"""Test that when use_service_thread is True, the AgentThread used to run the agent is set to the service thread ID."""
from agent_framework.ag_ui import AgentFrameworkAgent

request_service_thread_id: str | None = None

async def stream_fn(
messages: MutableSequence[ChatMessage], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
nonlocal request_service_thread_id
thread = kwargs.get("thread")
request_service_thread_id = thread.service_thread_id if thread else None
yield ChatResponseUpdate(
contents=[TextContent(text="Response")], response_id="resp_67890", conversation_id="conv_12345"
)

agent = ChatAgent(chat_client=StreamingChatClientStub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent, use_service_thread=True)

input_data = {"messages": [{"role": "user", "content": "Hi"}], "thread_id": "conv_123456"}

events: list[Any] = []
async for event in wrapper.run_agent(input_data):
events.append(event)
assert request_service_thread_id == "conv_123456" # type: ignore[attr-defined] (service_thread_id should be set)


async def test_function_approval_mode_executes_tool():
"""Test that function approval with approval_mode='always_require' sends the correct messages."""
from agent_framework import FunctionResultContent, ai_function
Expand Down
Loading
Loading