Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/google/adk/tools/base_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
"""
self.tool_filter = tool_filter
self.tool_name_prefix = tool_name_prefix
self._tools_cache: dict[Optional[int], list[BaseTool]] = {}

@abstractmethod
async def get_tools(
Expand All @@ -103,6 +104,7 @@ async def get_tools_with_prefix(
"""Return all tools with optional prefix applied to tool names.

This method calls get_tools() and applies prefixing if tool_name_prefix is provided.
Tools are cached per readonly_context to avoid redundant calls to get_tools().

Args:
readonly_context (ReadonlyContext, optional): Context used to filter tools
Expand All @@ -111,7 +113,16 @@ async def get_tools_with_prefix(
Returns:
list[BaseTool]: A list of tools with prefixed names if tool_name_prefix is provided.
"""
tools = await self.get_tools(readonly_context)
# Create a cache key based on the readonly_context identity
context_id = id(readonly_context) if readonly_context else None

# Check if we have cached tools for this context
if context_id in self._tools_cache:
tools = self._tools_cache[context_id]
else:
# Fetch tools and cache them
tools = await self.get_tools(readonly_context)
self._tools_cache[context_id] = tools

if not self.tool_name_prefix:
return tools
Expand Down
92 changes: 92 additions & 0 deletions tests/unittests/tools/test_base_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,95 @@ async def test_no_duplicate_prefixing():
# The prefixed tools should be different instances
assert prefixed_tools_1[0] is not prefixed_tools_2[0]
assert prefixed_tools_1[0] is not original_tools[0]


@pytest.mark.asyncio
async def test_get_tools_caching():
"""Test that get_tools_with_prefix caches results to avoid duplicate calls."""

class _CountingToolset(_TestingToolset):
"""Toolset that counts how many times get_tools is called."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.get_tools_call_count = 0

async def get_tools(
self, readonly_context: Optional[ReadonlyContext] = None
) -> list[BaseTool]:
self.get_tools_call_count += 1
return await super().get_tools(readonly_context)

tool = _TestingTool(name='test_tool', description='Test tool')
toolset = _CountingToolset(tools=[tool])

# First call should invoke get_tools
tools1 = await toolset.get_tools_with_prefix()
assert len(tools1) == 1
assert toolset.get_tools_call_count == 1

# Second call with same context (None) should use cache
tools2 = await toolset.get_tools_with_prefix()
assert len(tools2) == 1
assert toolset.get_tools_call_count == 1 # Still 1, not 2

# Third call with same context should still use cache
tools3 = await toolset.get_tools_with_prefix()
assert len(tools3) == 1
assert toolset.get_tools_call_count == 1 # Still 1, not 3


@pytest.mark.asyncio
async def test_get_tools_caching_with_different_contexts():
"""Test that get_tools_with_prefix caches separately for different contexts."""

class _CountingToolset(_TestingToolset):
"""Toolset that counts how many times get_tools is called."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.get_tools_call_count = 0

async def get_tools(
self, readonly_context: Optional[ReadonlyContext] = None
) -> list[BaseTool]:
self.get_tools_call_count += 1
return await super().get_tools(readonly_context)

tool = _TestingTool(name='test_tool', description='Test tool')
toolset = _CountingToolset(tools=[tool])

# Create mock contexts
session_service = InMemorySessionService()
session = await session_service.create_session(
app_name='test_app', user_id='test_user'
)
agent = SequentialAgent(name='test_agent')
invocation_context = InvocationContext(
invocation_id='test_id',
agent=agent,
session=session,
session_service=session_service,
)
context1 = ReadonlyContext(invocation_context)
context2 = ReadonlyContext(invocation_context)

# First call with context1
tools1 = await toolset.get_tools_with_prefix(context1)
assert len(tools1) == 1
assert toolset.get_tools_call_count == 1

# Second call with same context1 should use cache
tools2 = await toolset.get_tools_with_prefix(context1)
assert len(tools2) == 1
assert toolset.get_tools_call_count == 1 # Still 1

# Third call with different context2 should invoke get_tools again
tools3 = await toolset.get_tools_with_prefix(context2)
assert len(tools3) == 1
assert toolset.get_tools_call_count == 2 # Now 2

# Fourth call with context2 should use cache
tools4 = await toolset.get_tools_with_prefix(context2)
assert len(tools4) == 1
assert toolset.get_tools_call_count == 2 # Still 2
Loading