From e9cc9da31716f3d5914ff8a6b7fea9f387680377 Mon Sep 17 00:00:00 2001 From: Daniela Petruzalek Date: Sat, 13 Dec 2025 01:59:51 +0000 Subject: [PATCH] fix(tools): Use function declaration for VertexAiRagRetrieval in Gemini 2.x+ Standardize VertexAiRagRetrieval to always use function declarations, even for Gemini 2.x+ models. This prevents unnecessary direct retrieval configurations and resolves premature API calls and 429 errors when the tool is not explicitly invoked. This ensures the RAG tool is only called when the model explicitly invokes its declared function. Fixes #3315 --- .../retrieval/vertex_ai_rag_retrieval.py | 25 +++---------------- .../retrieval/test_vertex_ai_rag_retrieval.py | 17 ++++--------- 2 files changed, 9 insertions(+), 33 deletions(-) diff --git a/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py b/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py index b0acc0feb8..7d7dbf8f06 100644 --- a/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +++ b/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py @@ -23,7 +23,6 @@ from google.genai import types from typing_extensions import override -from ...utils.model_name_utils import is_gemini_2_or_above from ..tool_context import ToolContext from .base_retrieval_tool import BaseRetrievalTool @@ -62,26 +61,10 @@ async def process_llm_request( tool_context: ToolContext, llm_request: LlmRequest, ) -> None: - # Use Gemini built-in Vertex AI RAG tool for Gemini 2 models. - if is_gemini_2_or_above(llm_request.model): - llm_request.config = ( - types.GenerateContentConfig() - if not llm_request.config - else llm_request.config - ) - llm_request.config.tools = ( - [] if not llm_request.config.tools else llm_request.config.tools - ) - llm_request.config.tools.append( - types.Tool( - retrieval=types.Retrieval(vertex_rag_store=self.vertex_rag_store) - ) - ) - else: - # Add the function declaration to the tools - await super().process_llm_request( - tool_context=tool_context, llm_request=llm_request - ) + # Add the function declaration to the tools + await super().process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) @override async def run_async( diff --git a/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py b/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py index 132e6b7b10..333b267718 100644 --- a/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py +++ b/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py @@ -133,15 +133,8 @@ def test_vertex_rag_retrieval_for_gemini_2_x(): ('user', 'test1'), ] assert len(mockModel.requests[0].config.tools) == 1 - assert mockModel.requests[0].config.tools == [ - types.Tool( - retrieval=types.Retrieval( - vertex_rag_store=types.VertexRagStore( - rag_corpora=[ - 'projects/123456789/locations/us-central1/ragCorpora/1234567890' - ] - ) - ) - ) - ] - assert 'rag_retrieval' not in mockModel.requests[0].tools_dict + assert ( + mockModel.requests[0].config.tools[0].function_declarations[0].name + == 'rag_retrieval' + ) + assert mockModel.requests[0].tools_dict['rag_retrieval'] is not None