diff --git a/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py b/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py index b0acc0feb8..7d7dbf8f06 100644 --- a/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +++ b/src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py @@ -23,7 +23,6 @@ from google.genai import types from typing_extensions import override -from ...utils.model_name_utils import is_gemini_2_or_above from ..tool_context import ToolContext from .base_retrieval_tool import BaseRetrievalTool @@ -62,26 +61,10 @@ async def process_llm_request( tool_context: ToolContext, llm_request: LlmRequest, ) -> None: - # Use Gemini built-in Vertex AI RAG tool for Gemini 2 models. - if is_gemini_2_or_above(llm_request.model): - llm_request.config = ( - types.GenerateContentConfig() - if not llm_request.config - else llm_request.config - ) - llm_request.config.tools = ( - [] if not llm_request.config.tools else llm_request.config.tools - ) - llm_request.config.tools.append( - types.Tool( - retrieval=types.Retrieval(vertex_rag_store=self.vertex_rag_store) - ) - ) - else: - # Add the function declaration to the tools - await super().process_llm_request( - tool_context=tool_context, llm_request=llm_request - ) + # Add the function declaration to the tools + await super().process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) @override async def run_async( diff --git a/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py b/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py index 132e6b7b10..333b267718 100644 --- a/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py +++ b/tests/unittests/tools/retrieval/test_vertex_ai_rag_retrieval.py @@ -133,15 +133,8 @@ def test_vertex_rag_retrieval_for_gemini_2_x(): ('user', 'test1'), ] assert len(mockModel.requests[0].config.tools) == 1 - assert mockModel.requests[0].config.tools == [ - types.Tool( - retrieval=types.Retrieval( - vertex_rag_store=types.VertexRagStore( - rag_corpora=[ - 'projects/123456789/locations/us-central1/ragCorpora/1234567890' - ] - ) - ) - ) - ] - assert 'rag_retrieval' not in mockModel.requests[0].tools_dict + assert ( + mockModel.requests[0].config.tools[0].function_declarations[0].name + == 'rag_retrieval' + ) + assert mockModel.requests[0].tools_dict['rag_retrieval'] is not None