diff --git a/src/google/adk/tools/base_toolset.py b/src/google/adk/tools/base_toolset.py index 201eec9087..07f8bd8867 100644 --- a/src/google/adk/tools/base_toolset.py +++ b/src/google/adk/tools/base_toolset.py @@ -79,6 +79,7 @@ def __init__( """ self.tool_filter = tool_filter self.tool_name_prefix = tool_name_prefix + self._tools_cache: dict[Optional[int], list[BaseTool]] = {} @abstractmethod async def get_tools( @@ -103,6 +104,7 @@ async def get_tools_with_prefix( """Return all tools with optional prefix applied to tool names. This method calls get_tools() and applies prefixing if tool_name_prefix is provided. + Tools are cached per readonly_context to avoid redundant calls to get_tools(). Args: readonly_context (ReadonlyContext, optional): Context used to filter tools @@ -111,7 +113,16 @@ async def get_tools_with_prefix( Returns: list[BaseTool]: A list of tools with prefixed names if tool_name_prefix is provided. """ - tools = await self.get_tools(readonly_context) + # Create a cache key based on the readonly_context identity + context_id = id(readonly_context) if readonly_context else None + + # Check if we have cached tools for this context + if context_id in self._tools_cache: + tools = self._tools_cache[context_id] + else: + # Fetch tools and cache them + tools = await self.get_tools(readonly_context) + self._tools_cache[context_id] = tools if not self.tool_name_prefix: return tools diff --git a/tests/unittests/tools/test_base_toolset.py b/tests/unittests/tools/test_base_toolset.py index 20d7f9d825..3c94f60709 100644 --- a/tests/unittests/tools/test_base_toolset.py +++ b/tests/unittests/tools/test_base_toolset.py @@ -386,3 +386,95 @@ async def test_no_duplicate_prefixing(): # The prefixed tools should be different instances assert prefixed_tools_1[0] is not prefixed_tools_2[0] assert prefixed_tools_1[0] is not original_tools[0] + + +@pytest.mark.asyncio +async def test_get_tools_caching(): + """Test that get_tools_with_prefix caches results to avoid duplicate calls.""" + + class _CountingToolset(_TestingToolset): + """Toolset that counts how many times get_tools is called.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.get_tools_call_count = 0 + + async def get_tools( + self, readonly_context: Optional[ReadonlyContext] = None + ) -> list[BaseTool]: + self.get_tools_call_count += 1 + return await super().get_tools(readonly_context) + + tool = _TestingTool(name='test_tool', description='Test tool') + toolset = _CountingToolset(tools=[tool]) + + # First call should invoke get_tools + tools1 = await toolset.get_tools_with_prefix() + assert len(tools1) == 1 + assert toolset.get_tools_call_count == 1 + + # Second call with same context (None) should use cache + tools2 = await toolset.get_tools_with_prefix() + assert len(tools2) == 1 + assert toolset.get_tools_call_count == 1 # Still 1, not 2 + + # Third call with same context should still use cache + tools3 = await toolset.get_tools_with_prefix() + assert len(tools3) == 1 + assert toolset.get_tools_call_count == 1 # Still 1, not 3 + + +@pytest.mark.asyncio +async def test_get_tools_caching_with_different_contexts(): + """Test that get_tools_with_prefix caches separately for different contexts.""" + + class _CountingToolset(_TestingToolset): + """Toolset that counts how many times get_tools is called.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.get_tools_call_count = 0 + + async def get_tools( + self, readonly_context: Optional[ReadonlyContext] = None + ) -> list[BaseTool]: + self.get_tools_call_count += 1 + return await super().get_tools(readonly_context) + + tool = _TestingTool(name='test_tool', description='Test tool') + toolset = _CountingToolset(tools=[tool]) + + # Create mock contexts + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + agent = SequentialAgent(name='test_agent') + invocation_context = InvocationContext( + invocation_id='test_id', + agent=agent, + session=session, + session_service=session_service, + ) + context1 = ReadonlyContext(invocation_context) + context2 = ReadonlyContext(invocation_context) + + # First call with context1 + tools1 = await toolset.get_tools_with_prefix(context1) + assert len(tools1) == 1 + assert toolset.get_tools_call_count == 1 + + # Second call with same context1 should use cache + tools2 = await toolset.get_tools_with_prefix(context1) + assert len(tools2) == 1 + assert toolset.get_tools_call_count == 1 # Still 1 + + # Third call with different context2 should invoke get_tools again + tools3 = await toolset.get_tools_with_prefix(context2) + assert len(tools3) == 1 + assert toolset.get_tools_call_count == 2 # Now 2 + + # Fourth call with context2 should use cache + tools4 = await toolset.get_tools_with_prefix(context2) + assert len(tools4) == 1 + assert toolset.get_tools_call_count == 2 # Still 2