Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/google/adk/tools/agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ async def run_async(
app_name=child_app_name,
user_id=tool_context._invocation_context.user_id,
state=state_dict,
session_id=tool_context._invocation_context.session.id,
)

last_content = None
Expand Down
258 changes: 160 additions & 98 deletions tests/unittests/tools/test_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from google.genai import types
from google.genai.types import Part
from pydantic import BaseModel
from pytest import fixture
from pytest import mark

from .. import testing_utils
Expand All @@ -59,112 +60,147 @@ def change_state_callback(callback_context: CallbackContext):
print('change_state_callback: ', callback_context.state)


@mark.asyncio
async def test_agent_tool_inherits_parent_app_name(monkeypatch):
parent_app_name = 'parent_app'
captured: dict[str, str] = {}

class RecordingSessionService(InMemorySessionService):

async def create_session(
self,
*,
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
):
captured['session_app_name'] = app_name
return await super().create_session(
app_name=app_name,
user_id=user_id,
state=state,
session_id=session_id,
)

monkeypatch.setattr(
'google.adk.sessions.in_memory_session_service.InMemorySessionService',
RecordingSessionService,
)

@fixture
def agent_tool_setup_factory(monkeypatch):
async def _empty_async_generator():
if False:
yield None

class StubRunner:
async def _create_setup(
*,
parent_app_name: str,
parent_session_id: Optional[str] | None = None,
capture_runner_app_name: bool = False,
capture_session_app_name: bool = False,
capture_child_session_id: bool = False,
):
captured: dict[str, Any] = {}

class RecordingSessionService(InMemorySessionService):

async def create_session(
self,
*,
app_name: str,
user_id: str,
state: Optional[dict[str, Any]] = None,
session_id: Optional[str] = None,
):
if capture_session_app_name:
captured['session_app_name'] = app_name
if capture_child_session_id:
captured['child_session_id'] = session_id
return await super().create_session(
app_name=app_name,
user_id=user_id,
state=state,
session_id=session_id,
)

def __init__(
self,
*,
app_name: str,
agent: Agent,
artifact_service,
session_service,
memory_service,
credential_service,
plugins,
):
del artifact_service, memory_service, credential_service
captured['runner_app_name'] = app_name
self.agent = agent
self.session_service = session_service
self.plugin_manager = PluginManager(plugins=plugins)
self.app_name = app_name

def run_async(
self,
*,
user_id: str,
session_id: str,
invocation_id: Optional[str] = None,
new_message: Optional[types.Content] = None,
state_delta: Optional[dict[str, Any]] = None,
run_config: Optional[RunConfig] = None,
):
del (
user_id,
session_id,
invocation_id,
new_message,
state_delta,
run_config,
)
return _empty_async_generator()

async def close(self):
"""Mock close method."""
pass

monkeypatch.setattr('google.adk.runners.Runner', StubRunner)
monkeypatch.setattr(
'google.adk.sessions.in_memory_session_service.InMemorySessionService',
RecordingSessionService,
)

tool_agent = Agent(
name='tool_agent',
model='test-model',
)
agent_tool = AgentTool(agent=tool_agent)
root_agent = Agent(
name='root_agent',
model='test-model',
tools=[agent_tool],
)
class StubRunner:

def __init__(
self,
*,
app_name: str,
agent: Agent,
artifact_service,
session_service,
memory_service,
credential_service,
plugins,
):
del artifact_service, memory_service, credential_service
if capture_runner_app_name:
captured['runner_app_name'] = app_name
self.agent = agent
self.session_service = session_service
self.plugin_manager = PluginManager(plugins=plugins)
self.app_name = app_name

def run_async(
self,
*,
user_id: str,
session_id: str,
invocation_id: Optional[str] = None,
new_message: Optional[types.Content] = None,
state_delta: Optional[dict[str, Any]] = None,
run_config: Optional[RunConfig] = None,
):
del (
user_id,
session_id,
invocation_id,
new_message,
state_delta,
run_config,
)
return _empty_async_generator()

artifact_service = InMemoryArtifactService()
parent_session_service = InMemorySessionService()
parent_session = await parent_session_service.create_session(
app_name=parent_app_name,
user_id='user',
)
invocation_context = InvocationContext(
artifact_service=artifact_service,
session_service=parent_session_service,
memory_service=InMemoryMemoryService(),
plugin_manager=PluginManager(),
invocation_id='invocation-id',
agent=root_agent,
session=parent_session,
run_config=RunConfig(),
async def close(self):
"""Mock close method."""
pass

monkeypatch.setattr('google.adk.runners.Runner', StubRunner)

tool_agent = Agent(
name='tool_agent',
model='test-model',
)
agent_tool = AgentTool(agent=tool_agent)
root_agent = Agent(
name='root_agent',
model='test-model',
tools=[agent_tool],
)

artifact_service = InMemoryArtifactService()
parent_session_service = InMemorySessionService()
parent_session = await parent_session_service.create_session(
app_name=parent_app_name,
user_id='user',
session_id=parent_session_id,
)
invocation_context = InvocationContext(
artifact_service=artifact_service,
session_service=parent_session_service,
memory_service=InMemoryMemoryService(),
plugin_manager=PluginManager(),
invocation_id='invocation-id',
agent=root_agent,
session=parent_session,
run_config=RunConfig(),
)
tool_context = ToolContext(invocation_context)

return {
'agent_tool': agent_tool,
'tool_context': tool_context,
'captured': captured,
}

return _create_setup


@mark.asyncio
async def test_agent_tool_inherits_parent_app_name(agent_tool_setup_factory):
parent_app_name = 'parent_app'

setup = await agent_tool_setup_factory(
parent_app_name=parent_app_name,
capture_runner_app_name=True,
capture_session_app_name=True,
)
tool_context = ToolContext(invocation_context)

agent_tool = setup['agent_tool']
tool_context = setup['tool_context']
captured = setup['captured']

assert tool_context._invocation_context.app_name == parent_app_name

Expand All @@ -177,6 +213,32 @@ async def close(self):
assert captured['session_app_name'] == parent_app_name


@mark.asyncio
async def test_agent_tool_passes_parent_session_id(agent_tool_setup_factory):
"""Test that the parent session ID is passed to the child session."""
parent_app_name = 'parent_app'
parent_session_id = 'parent-session-123'
setup = await agent_tool_setup_factory(
parent_app_name=parent_app_name,
parent_session_id=parent_session_id,
capture_child_session_id=True,
)

agent_tool = setup['agent_tool']
tool_context = setup['tool_context']
captured = setup['captured']

assert tool_context._invocation_context.session.id == parent_session_id

await agent_tool.run_async(
args={'request': 'hello'},
tool_context=tool_context,
)

# Verify that the parent session ID was passed to the child session
assert captured['child_session_id'] == parent_session_id


def test_no_schema():
mock_model = testing_utils.MockModel.create(
responses=[
Expand Down