diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 24fdce9d59..51940ee2d3 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -206,6 +206,15 @@ class InvocationContext(BaseModel): canonical_tools_cache: Optional[list[BaseTool]] = None """The cache of canonical tools for this invocation.""" + metadata: Optional[dict[str, Any]] = None + """Per-request metadata passed from Runner.run_async(). + + This field allows passing arbitrary metadata that can be accessed during + the invocation lifecycle, particularly in callbacks like before_model_callback. + Common use cases include passing user_id, trace_id, memory context keys, or + other request-specific context that needs to be available during processing. + """ + _invocation_cost_manager: _InvocationCostManager = PrivateAttr( default_factory=_InvocationCostManager ) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 824cd26be1..6ccd9adb43 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -89,7 +89,7 @@ async def run_live( invocation_context: InvocationContext, ) -> AsyncGenerator[Event, None]: """Runs the flow using live api.""" - llm_request = LlmRequest() + llm_request = LlmRequest(metadata=invocation_context.metadata) event_id = Event.new_id() # Preprocess before calling the LLM. @@ -380,7 +380,7 @@ async def _run_one_step_async( invocation_context: InvocationContext, ) -> AsyncGenerator[Event, None]: """One step means one LLM call.""" - llm_request = LlmRequest() + llm_request = LlmRequest(metadata=invocation_context.metadata) # Preprocess before calling the LLM. async with Aclosing( diff --git a/src/google/adk/models/llm_request.py b/src/google/adk/models/llm_request.py index 287da34240..6bf81e8c8b 100644 --- a/src/google/adk/models/llm_request.py +++ b/src/google/adk/models/llm_request.py @@ -15,6 +15,7 @@ from __future__ import annotations import logging +from typing import Any from typing import Optional from typing import Union @@ -99,6 +100,15 @@ class LlmRequest(BaseModel): the full history. """ + metadata: Optional[dict[str, Any]] = None + """Per-request metadata for callbacks and custom processing. + + This field allows passing arbitrary metadata from the Runner.run_async() + call to callbacks like before_model_callback. This is useful for passing + request-specific context such as user_id, trace_id, or memory context keys + that need to be available during model invocation. + """ + def append_instructions( self, instructions: Union[list[str], types.Content] ) -> list[types.Content]: diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 1773729719..ae790a5c88 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -400,6 +400,7 @@ async def run_async( new_message: Optional[types.Content] = None, state_delta: Optional[dict[str, Any]] = None, run_config: Optional[RunConfig] = None, + metadata: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[Event, None]: """Main entry method to run the agent in this runner. @@ -417,6 +418,13 @@ async def run_async( new_message: A new message to append to the session. state_delta: Optional state changes to apply to the session. run_config: The run config for the agent. + metadata: Optional per-request metadata that will be passed to callbacks. + This allows passing request-specific context such as user_id, trace_id, + or memory context keys to before_model_callback and other callbacks. + Note: A shallow copy is made of this dictionary, so top-level changes + within callbacks won't affect the original. However, modifications to + nested mutable objects (e.g., nested dicts or lists) will affect the + original. Yields: The events generated by the agent. @@ -426,6 +434,8 @@ async def run_async( new_message are None. """ run_config = run_config or RunConfig() + # Create a shallow copy to isolate from caller's modifications + metadata = metadata.copy() if metadata is not None else None if new_message and not new_message.role: new_message.role = 'user' @@ -433,6 +443,7 @@ async def run_async( async def _run_with_trace( new_message: Optional[types.Content] = None, invocation_id: Optional[str] = None, + metadata: Optional[dict[str, Any]] = None, ) -> AsyncGenerator[Event, None]: with tracer.start_as_current_span('invocation'): session = await self.session_service.get_session( @@ -463,6 +474,7 @@ async def _run_with_trace( invocation_id=invocation_id, run_config=run_config, state_delta=state_delta, + metadata=metadata, ) if invocation_context.end_of_agents.get( invocation_context.agent.name @@ -476,6 +488,7 @@ async def _run_with_trace( new_message=new_message, # new_message is not None. run_config=run_config, state_delta=state_delta, + metadata=metadata, ) async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: @@ -502,7 +515,9 @@ async def execute(ctx: InvocationContext) -> AsyncGenerator[Event]: self.app, session, self.session_service ) - async with Aclosing(_run_with_trace(new_message, invocation_id)) as agen: + async with Aclosing( + _run_with_trace(new_message, invocation_id, metadata) + ) as agen: async for event in agen: yield event @@ -1186,6 +1201,7 @@ async def _setup_context_for_new_invocation( new_message: types.Content, run_config: RunConfig, state_delta: Optional[dict[str, Any]], + metadata: Optional[dict[str, Any]] = None, ) -> InvocationContext: """Sets up the context for a new invocation. @@ -1194,6 +1210,7 @@ async def _setup_context_for_new_invocation( new_message: The new message to process and append to the session. run_config: The run config of the agent. state_delta: Optional state changes to apply to the session. + metadata: Optional per-request metadata to pass to callbacks. Returns: The invocation context for the new invocation. @@ -1203,6 +1220,7 @@ async def _setup_context_for_new_invocation( session, new_message=new_message, run_config=run_config, + metadata=metadata, ) # Step 2: Handle new message, by running callbacks and appending to # session. @@ -1225,6 +1243,7 @@ async def _setup_context_for_resumed_invocation( invocation_id: Optional[str], run_config: RunConfig, state_delta: Optional[dict[str, Any]], + metadata: Optional[dict[str, Any]] = None, ) -> InvocationContext: """Sets up the context for a resumed invocation. @@ -1234,6 +1253,7 @@ async def _setup_context_for_resumed_invocation( invocation_id: The invocation id to resume. run_config: The run config of the agent. state_delta: Optional state changes to apply to the session. + metadata: Optional per-request metadata to pass to callbacks. Returns: The invocation context for the resumed invocation. @@ -1259,6 +1279,7 @@ async def _setup_context_for_resumed_invocation( new_message=user_message, run_config=run_config, invocation_id=invocation_id, + metadata=metadata, ) # Step 3: Maybe handle new message. if new_message: @@ -1303,6 +1324,7 @@ def _new_invocation_context( new_message: Optional[types.Content] = None, live_request_queue: Optional[LiveRequestQueue] = None, run_config: Optional[RunConfig] = None, + metadata: Optional[dict[str, Any]] = None, ) -> InvocationContext: """Creates a new invocation context. @@ -1312,6 +1334,7 @@ def _new_invocation_context( new_message: The new message for the context. live_request_queue: The live request queue for the context. run_config: The run config for the context. + metadata: Optional per-request metadata for the context. Returns: The new invocation context. @@ -1343,6 +1366,7 @@ def _new_invocation_context( live_request_queue=live_request_queue, run_config=run_config, resumability_config=self.resumability_config, + metadata=metadata, ) def _new_invocation_context_for_live( diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index d692f7e380..d8857758af 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -28,6 +28,8 @@ from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService from google.adk.cli.utils.agent_loader import AgentLoader from google.adk.events.event import Event +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse from google.adk.plugins.base_plugin import BasePlugin from google.adk.runners import Runner from google.adk.sessions.in_memory_session_service import InMemorySessionService @@ -1038,5 +1040,250 @@ def test_infer_agent_origin_detects_mismatch_for_user_agent( assert "actual_name" in runner._app_name_alignment_hint +class TestRunnerMetadata: + """Tests for Runner metadata parameter functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.session_service = InMemorySessionService() + self.artifact_service = InMemoryArtifactService() + self.root_agent = MockLlmAgent("root_agent") + self.runner = Runner( + app_name="test_app", + agent=self.root_agent, + session_service=self.session_service, + artifact_service=self.artifact_service, + ) + + def test_new_invocation_context_with_metadata(self): + """Test that _new_invocation_context correctly passes metadata.""" + mock_session = Session( + id=TEST_SESSION_ID, + app_name=TEST_APP_ID, + user_id=TEST_USER_ID, + events=[], + ) + + test_metadata = {"user_id": "test123", "trace_id": "trace456"} + invocation_context = self.runner._new_invocation_context( + mock_session, metadata=test_metadata + ) + + assert invocation_context.metadata == test_metadata + assert invocation_context.metadata["user_id"] == "test123" + assert invocation_context.metadata["trace_id"] == "trace456" + + def test_new_invocation_context_without_metadata(self): + """Test that _new_invocation_context works without metadata.""" + mock_session = Session( + id=TEST_SESSION_ID, + app_name=TEST_APP_ID, + user_id=TEST_USER_ID, + events=[], + ) + + invocation_context = self.runner._new_invocation_context(mock_session) + + assert invocation_context.metadata is None + + @pytest.mark.asyncio + async def test_run_async_passes_metadata_to_invocation_context(self): + """Test that run_async correctly passes metadata to before_model_callback.""" + # Capture metadata received in callback + captured_metadata = None + + def before_model_callback(callback_context, llm_request): + nonlocal captured_metadata + captured_metadata = llm_request.metadata + # Return a response to skip actual LLM call + return LlmResponse( + content=types.Content( + role="model", parts=[types.Part(text="Test response")] + ) + ) + + # Create agent with before_model_callback + agent_with_callback = LlmAgent( + name="callback_agent", + model="gemini-2.0-flash", + before_model_callback=before_model_callback, + ) + + runner_with_callback = Runner( + app_name="test_app", + agent=agent_with_callback, + session_service=self.session_service, + artifact_service=self.artifact_service, + ) + + session = await self.session_service.create_session( + app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID + ) + + test_metadata = {"experiment_id": "exp-001", "variant": "B"} + + async for event in runner_with_callback.run_async( + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + new_message=types.Content( + role="user", parts=[types.Part(text="Hello")] + ), + metadata=test_metadata, + ): + pass + + # Verify metadata was passed to before_model_callback + assert captured_metadata is not None + assert captured_metadata == test_metadata + assert captured_metadata["experiment_id"] == "exp-001" + assert captured_metadata["variant"] == "B" + + def test_metadata_field_in_invocation_context(self): + """Test that InvocationContext model accepts metadata field.""" + mock_session = Session( + id=TEST_SESSION_ID, + app_name=TEST_APP_ID, + user_id=TEST_USER_ID, + events=[], + ) + + test_metadata = {"key1": "value1", "key2": 123} + + # This should not raise a validation error + invocation_context = InvocationContext( + session_service=self.session_service, + invocation_id="test_inv_id", + agent=self.root_agent, + session=mock_session, + metadata=test_metadata, + ) + + assert invocation_context.metadata == test_metadata + + def test_metadata_field_in_llm_request(self): + """Test that LlmRequest model accepts metadata field.""" + test_metadata = {"context_key": "ctx123", "user_info": {"name": "test"}} + + llm_request = LlmRequest(metadata=test_metadata) + + assert llm_request.metadata == test_metadata + assert llm_request.metadata["context_key"] == "ctx123" + assert llm_request.metadata["user_info"]["name"] == "test" + + def test_llm_request_without_metadata(self): + """Test that LlmRequest works without metadata.""" + llm_request = LlmRequest() + + assert llm_request.metadata is None + + @pytest.mark.asyncio + async def test_empty_metadata_dict_not_converted_to_none(self): + """Test that empty dict {} is preserved and not converted to None.""" + captured_metadata = None + + def before_model_callback(callback_context, llm_request): + nonlocal captured_metadata + captured_metadata = llm_request.metadata + return LlmResponse( + content=types.Content( + role="model", parts=[types.Part(text="Test response")] + ) + ) + + agent_with_callback = LlmAgent( + name="callback_agent", + model="gemini-2.0-flash", + before_model_callback=before_model_callback, + ) + + runner_with_callback = Runner( + app_name="test_app", + agent=agent_with_callback, + session_service=self.session_service, + artifact_service=self.artifact_service, + ) + + await self.session_service.create_session( + app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID + ) + + # Pass empty dict - should NOT become None + async for event in runner_with_callback.run_async( + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + new_message=types.Content( + role="user", parts=[types.Part(text="Hello")] + ), + metadata={}, + ): + pass + + # Empty dict should be preserved, not converted to None + assert captured_metadata is not None + assert captured_metadata == {} + assert isinstance(captured_metadata, dict) + + @pytest.mark.asyncio + async def test_metadata_shallow_copy_isolation(self): + """Test that shallow copy isolates top-level changes but shares nested objects.""" + # Track modifications made in callback + callback_received_metadata = None + + def before_model_callback(callback_context, llm_request): + nonlocal callback_received_metadata + callback_received_metadata = llm_request.metadata + # Modify top-level key (should NOT affect original due to shallow copy) + llm_request.metadata["top_level_key"] = "modified_in_callback" + # Modify nested object (WILL affect original due to shallow copy) + llm_request.metadata["nested"]["inner_key"] = "modified_nested" + return LlmResponse( + content=types.Content( + role="model", parts=[types.Part(text="Test response")] + ) + ) + + agent_with_callback = LlmAgent( + name="callback_agent", + model="gemini-2.0-flash", + before_model_callback=before_model_callback, + ) + + runner_with_callback = Runner( + app_name="test_app", + agent=agent_with_callback, + session_service=self.session_service, + artifact_service=self.artifact_service, + ) + + await self.session_service.create_session( + app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID + ) + + # Original metadata with nested mutable object + original_metadata = { + "top_level_key": "original_value", + "nested": {"inner_key": "original_nested"}, + } + + async for event in runner_with_callback.run_async( + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + new_message=types.Content( + role="user", parts=[types.Part(text="Hello")] + ), + metadata=original_metadata, + ): + pass + + # Verify callback received metadata + assert callback_received_metadata is not None + + # Top-level changes in callback should NOT affect original (shallow copy) + assert original_metadata["top_level_key"] == "original_value" + + # Nested object changes in callback WILL affect original (shallow copy behavior) + assert original_metadata["nested"]["inner_key"] == "modified_nested" + + if __name__ == "__main__": pytest.main([__file__])