From f2cc98f4de6e1c6530f2f082628b96d299858d9c Mon Sep 17 00:00:00 2001 From: "anan.yablonko" Date: Mon, 22 Dec 2025 17:36:37 +0200 Subject: [PATCH 1/9] Add map agent --- src/google/adk/agents/map_agent.py | 161 ++++++++++++++++ tests/unittests/agents/test_map_agent.py | 229 +++++++++++++++++++++++ 2 files changed, 390 insertions(+) create mode 100644 src/google/adk/agents/map_agent.py create mode 100644 tests/unittests/agents/test_map_agent.py diff --git a/src/google/adk/agents/map_agent.py b/src/google/adk/agents/map_agent.py new file mode 100644 index 0000000000..675d561f53 --- /dev/null +++ b/src/google/adk/agents/map_agent.py @@ -0,0 +1,161 @@ +from typing import Annotated +from typing import AsyncGenerator + +from annotated_types import Len +from google.adk.agents import BaseAgent +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.parallel_agent import _merge_agent_run +from google.adk.events import Event +from google.genai import types +from pydantic import Field +from pydantic import RootModel +from typing_extensions import override + + +class MapAgent(BaseAgent): + sub_agents: Annotated[list[BaseAgent], Len(1, 1)] = Field( + min_length=1, + max_length=1, + default_factory=list, + description=( + "A single base agent that will be copied and invoked for each prompt" + ), + ) + + @override + async def _run_async_impl( + self, invocation_context: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Core logic of this workflow agent. + + Args: + invocation_context: InvocationContext, provides access to the input prompts. + + Yields: + Event: the events generated by the sub-agent for each input prompt. + """ + + # Create a branch string if it doesn't exist, to ensure parallel invocations don't interfere with each other + invocation_context.branch = invocation_context.branch or self.name + + prompts, invoker = self._extract_input_prompts(invocation_context) + + # for agent naming - e.g. if there are 100-999 prompts, sub-agent copies are named 001, 002, 003 and so on + number_field_width = len(str(len(prompts))) + + # Create a separate invocation context for each prompt, each with a numbered copy of the sub-agent. + contexts = [ + self._branch_context( + invocation_context, + idx=i, + prompt=prompt, + invoker=invoker, + width=number_field_width, + ) + for i, prompt in enumerate(prompts) + ] + + async for event in _merge_agent_run( + [ctx.agent.run_async(ctx) for ctx in contexts] + ): + yield event + + def _extract_input_prompts( + self, ctx: InvocationContext + ) -> tuple[list[str], str]: + """ + The input to the gather agent is a list of strings. + We extract the text content from the latest event, and assume it is a list of strings serialized as a json string. + """ + invoker = "user" + + for i in range(len(ctx.session.events) - 1, -1, -1): + event = ctx.session.events[i] + if event.branch is None or ( + ctx.branch is not None and event.branch.startswith(ctx.branch) + ): + break + else: + return [], "user" + + invoker: str = event.author + input_message: str = ( + (event.content or types.Content()).parts or [types.Part()] + )[0].text or "" + + """ + Remove the event which has the prompt list, so that a sub agent does not see the prompts of its siblings, which may confuse it. + The event is removed only for this invocation. + """ + ctx.session.events.pop(i) + + agent_input = RootModel[list[str]].model_validate_json(input_message).root + + return agent_input, invoker + + @staticmethod + def _get_unique_name(name: str, idx: int, width: int) -> str: + """e.g. my_sub_agent_046""" + return f"{name}_{idx:0{width}d}" + + def _branch_context( + self, + ctx: InvocationContext, + *, + prompt: str, + invoker: str, + idx: int, + width: int, + ) -> InvocationContext: + """Creates a numbered copy of the sub-agent that sees a single prompt, and can run separately from its siblings. + + Args: + ctx: The current invocation context of the gather agent. To be copied and edited for the sub-agent copy. + prompt: the prompt on which the sub-agent copy should be invoked + invoker: the invoker of the gather agent in this invocation. + idx: index of the prompt in the input prompts, serves as a unique postfix to the agent name + width: number of digits in the total number of prompts, to ensure naming is consistent in field width + (e.g. 001, 002, ... 010, 011, ... 100, 101; and not 1, 2, ... 10, 11, ... 100, 101) + + Returns: + InvocationContext: A new invocation context ready to run with the unique sub-agent copy and the prompt + """ + + agent = self._branch_agent_tree(self.sub_agents[0], idx, width) + + branch = f"{ctx.branch}.{agent.name}" + prompt_part = [types.Part(text=prompt)] + + # Add the prompt to the user_content of this branch to easily access agent input in callbacks + user_content = types.Content( + role="user", + parts=((ctx.user_content or types.Content()).parts or []) + prompt_part, + ) + new_ctx = ctx.model_copy( + update=dict(branch=branch, agent=agent, user_content=user_content) + ) + + # Add the prompt as a temporary event of this branch in place of the prompt list as the natural input of the sub-agent. + prompt_content = types.Content( + role="user" if invoker == "user" else "model", parts=prompt_part + ) + new_ctx.session.events.append( + Event(author=invoker, branch=branch, content=prompt_content) + ) + + return new_ctx + + def _branch_agent_tree( + self, agent: BaseAgent, idx: int, width: int + ) -> BaseAgent: + """ + Clone and rename an agent and its sub-tree to create a thread-safe branch. + """ + new_agent = agent.model_copy( + update={"name": self._get_unique_name(agent.name, idx=idx, width=width)} + ) + + new_agent.sub_agents = [ + self._branch_agent_tree(a, idx, width) for a in agent.sub_agents + ] + return new_agent diff --git a/tests/unittests/agents/test_map_agent.py b/tests/unittests/agents/test_map_agent.py new file mode 100644 index 0000000000..80ca7c9619 --- /dev/null +++ b/tests/unittests/agents/test_map_agent.py @@ -0,0 +1,229 @@ +import json +import re +from typing import AsyncGenerator + +from google.adk.agents import LlmAgent +from google.adk.agents import LoopAgent +from google.adk.agents import MapAgent +from google.adk.agents import ParallelAgent +from google.adk.agents import SequentialAgent +from google.adk.agents.callback_context import CallbackContext +from google.adk.events import Event +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.genai import types +import pytest + +from ..testing_utils import MockModel +from ..testing_utils import ModelContent +from ..testing_utils import TestInMemoryRunner + + +class OneTwoThreeModel(MockModel): + """Maps an input of 'i' to output of "['i', 'i+1', 'i+2']", e.g. '5' -> "['5', '6', '7']" """ + + responses: list[LlmResponse] = [] + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + agent_input: str | None = ( + (llm_request.contents[-1] or types.Content()).parts or [types.Part()] + )[-1].text + assert agent_input is not None + agent_input = re.sub(r"\[\w+\] said: ", "", agent_input) + assert agent_input.isnumeric() + res = json.dumps([str(int(agent_input) + i) for i in range(3)]) + yield LlmResponse(content=ModelContent([types.Part(text=res)])) + + +def extract_event_text(events: list[Event], agent_prefix: str) -> list[str]: + filtered_events = [e for e in events if e.author.startswith(agent_prefix)] + sorted_events = sorted( + filtered_events, + key=lambda e: ( + e.author, + ((e.content or types.Content()).parts or [types.Part()])[0].text + or "", + ), + ) + contents = [e.content or types.Content() for e in sorted_events] + return [(c.parts or [types.Part()])[0].text or "" for c in contents] + + +@pytest.mark.asyncio +async def test_gather_agent_empty_input(): + def delete_events(callback_context: CallbackContext) -> None: + callback_context._invocation_context.session.events.clear() + + gather = MapAgent( + name="gather_agent", + sub_agents=[ + LlmAgent( + name="test", model=MockModel.create([], error=RuntimeError()) + ) + ], + before_agent_callback=delete_events, + ) + + runner = TestInMemoryRunner(gather) + await runner.run_async_with_new_session("") + + +@pytest.mark.asyncio +async def test_gather_agent_text_input(): + gather = MapAgent( + name="gather_agent", + sub_agents=[LlmAgent(name="mock_agent", model=OneTwoThreeModel())], + ) + + runner = TestInMemoryRunner(gather) + + n_runs = 100 + + input_data = json.dumps([str(i) for i in range(n_runs)]) + expected_output = [ + json.dumps([str(j) for j in range(i, i + 3)]) for i in range(n_runs) + ] + + events = await runner.run_async_with_new_session(input_data) + res = extract_event_text(events, "mock_agent") + + assert res == expected_output + + +@pytest.mark.asyncio +async def test_gather_agent_with_loop_agent_parent(): + gather_agent = MapAgent( + name="gather_agent", + sub_agents=[LlmAgent(name="mock_agent", model=OneTwoThreeModel())], + ) + + loop_agent = LoopAgent( + name="test_loop", + sub_agents=[gather_agent], + max_iterations=2, + ) + + runner = TestInMemoryRunner(loop_agent) + + input_data = json.dumps(["0"]) + expected_output = [json.dumps(["0", "1", "2"])] + [ + json.dumps([str(j) for j in range(i, i + 3)]) for i in range(3) + ] + + events = await runner.run_async_with_new_session(input_data) + res = extract_event_text(events, "mock_agent") + assert res == expected_output + + +@pytest.mark.parametrize("SubagentClass", [ParallelAgent, SequentialAgent]) +@pytest.mark.asyncio +async def test_gather_agent_with_sequential_or_parallel_agent(SubagentClass): + """test gather agent with a parallel / sequential sub-agent whose sub-agents don't communicate""" + + # A lone parallel agent wrapper hides mock_1's output from its 'cousin' mock_2 + mock1 = ParallelAgent( + name="seq_1", + sub_agents=[LlmAgent(name="mock_1", model=OneTwoThreeModel())], + ) + mock2 = LlmAgent(name="mock_2", model=OneTwoThreeModel()) + + subagent = SubagentClass( + name="subagent", + sub_agents=[mock1, mock2], + ) + + gather = MapAgent( + name="gather_agent", + sub_agents=[subagent], + ) + + runner = TestInMemoryRunner(gather) + + input_data = json.dumps(["0", "1"]) + expected_output = [ + json.dumps([str(j) for j in range(i, i + 3)]) for i in [0, 1, 0, 1] + ] + + events = await runner.run_async_with_new_session(input_data) + res = extract_event_text(events, "mock_") + assert res == expected_output + + +@pytest.mark.asyncio +async def test_gather_agent_with_gather_agent(): + mock_leaf = LlmAgent(name="nested_mock", model=OneTwoThreeModel()) + + inner_gather = MapAgent( + name="inner_gather", + sub_agents=[mock_leaf], + ) + + outer_gather = MapAgent( + name="outer_gather", + sub_agents=[inner_gather], + ) + + runner = TestInMemoryRunner(outer_gather) + + input_data = json.dumps( + [json.dumps([str(i), str(i + 1)]) for i in [10, 20, 30]] + ) + expected_output = [ + json.dumps([str(j) for j in range(i, i + 3)]) + for i in [10, 11, 20, 21, 30, 31] + ] + + events = await runner.run_async_with_new_session(input_data) + + res = [e for e in events if e.author.startswith("nested_mock")] + res = sorted( + res, + key=lambda e: ( + e.author, + ((e.content or types.Content()).parts or [types.Part()])[0].text + or "", + ), + ) + res = [ + ((e.content or types.Content()).parts or [types.Part()])[0].text or "" + for e in res + ] + assert len(res) == 6 + assert res == expected_output + + +@pytest.mark.asyncio +async def test_gather_agent_tree(): + inner_gather = MapAgent( + name="gather_inner", + sub_agents=[LlmAgent(name="mock_1", model=OneTwoThreeModel())], + ) + + main_loop = LoopAgent( + name="main_sequential", + sub_agents=[ + LlmAgent(name="mock_0", model=OneTwoThreeModel()), + inner_gather, + ], + max_iterations=1, + ) + + outer_gather = MapAgent( + name="gather_outer", + sub_agents=[main_loop], + ) + + runner = TestInMemoryRunner(outer_gather) + + input_data = json.dumps(["0", "1"]) + expected_output = [ + json.dumps([str(j) for j in range(i, i + 3)]) + for i in [0, 1, 0, 1, 2, 1, 2, 3] + ] + + events = await runner.run_async_with_new_session(input_data) + res = extract_event_text(events, "mock_") + + assert res == expected_output From d886a403f9c38f74b90de5b3f8fcadb1c35a0f2c Mon Sep 17 00:00:00 2001 From: "anan.yablonko" Date: Mon, 22 Dec 2025 17:45:53 +0200 Subject: [PATCH 2/9] rename gather to map agent in comments etc --- src/google/adk/agents/map_agent.py | 6 +-- tests/unittests/agents/test_map_agent.py | 62 ++++++++++++------------ 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/src/google/adk/agents/map_agent.py b/src/google/adk/agents/map_agent.py index 675d561f53..cd4385584c 100644 --- a/src/google/adk/agents/map_agent.py +++ b/src/google/adk/agents/map_agent.py @@ -64,7 +64,7 @@ def _extract_input_prompts( self, ctx: InvocationContext ) -> tuple[list[str], str]: """ - The input to the gather agent is a list of strings. + The input to the map agent is a list of strings. We extract the text content from the latest event, and assume it is a list of strings serialized as a json string. """ invoker = "user" @@ -110,9 +110,9 @@ def _branch_context( """Creates a numbered copy of the sub-agent that sees a single prompt, and can run separately from its siblings. Args: - ctx: The current invocation context of the gather agent. To be copied and edited for the sub-agent copy. + ctx: The current invocation context of the map agent. To be copied and edited for the sub-agent copy. prompt: the prompt on which the sub-agent copy should be invoked - invoker: the invoker of the gather agent in this invocation. + invoker: the invoker of the map agent in this invocation. idx: index of the prompt in the input prompts, serves as a unique postfix to the agent name width: number of digits in the total number of prompts, to ensure naming is consistent in field width (e.g. 001, 002, ... 010, 011, ... 100, 101; and not 1, 2, ... 10, 11, ... 100, 101) diff --git a/tests/unittests/agents/test_map_agent.py b/tests/unittests/agents/test_map_agent.py index 80ca7c9619..701d89f584 100644 --- a/tests/unittests/agents/test_map_agent.py +++ b/tests/unittests/agents/test_map_agent.py @@ -52,12 +52,12 @@ def extract_event_text(events: list[Event], agent_prefix: str) -> list[str]: @pytest.mark.asyncio -async def test_gather_agent_empty_input(): +async def test_map_agent_empty_input(): def delete_events(callback_context: CallbackContext) -> None: callback_context._invocation_context.session.events.clear() - gather = MapAgent( - name="gather_agent", + map = MapAgent( + name="map_agent", sub_agents=[ LlmAgent( name="test", model=MockModel.create([], error=RuntimeError()) @@ -66,18 +66,18 @@ def delete_events(callback_context: CallbackContext) -> None: before_agent_callback=delete_events, ) - runner = TestInMemoryRunner(gather) + runner = TestInMemoryRunner(map) await runner.run_async_with_new_session("") @pytest.mark.asyncio -async def test_gather_agent_text_input(): - gather = MapAgent( - name="gather_agent", +async def test_map_agent_text_input(): + map = MapAgent( + name="map_agent", sub_agents=[LlmAgent(name="mock_agent", model=OneTwoThreeModel())], ) - runner = TestInMemoryRunner(gather) + runner = TestInMemoryRunner(map) n_runs = 100 @@ -93,15 +93,15 @@ async def test_gather_agent_text_input(): @pytest.mark.asyncio -async def test_gather_agent_with_loop_agent_parent(): - gather_agent = MapAgent( - name="gather_agent", +async def test_map_agent_with_loop_agent_parent(): + map_agent = MapAgent( + name="map_agent", sub_agents=[LlmAgent(name="mock_agent", model=OneTwoThreeModel())], ) loop_agent = LoopAgent( name="test_loop", - sub_agents=[gather_agent], + sub_agents=[map_agent], max_iterations=2, ) @@ -119,8 +119,8 @@ async def test_gather_agent_with_loop_agent_parent(): @pytest.mark.parametrize("SubagentClass", [ParallelAgent, SequentialAgent]) @pytest.mark.asyncio -async def test_gather_agent_with_sequential_or_parallel_agent(SubagentClass): - """test gather agent with a parallel / sequential sub-agent whose sub-agents don't communicate""" +async def test_map_agent_with_sequential_or_parallel_agent(SubagentClass): + """test map agent with a parallel / sequential sub-agent whose sub-agents don't communicate""" # A lone parallel agent wrapper hides mock_1's output from its 'cousin' mock_2 mock1 = ParallelAgent( @@ -134,12 +134,12 @@ async def test_gather_agent_with_sequential_or_parallel_agent(SubagentClass): sub_agents=[mock1, mock2], ) - gather = MapAgent( - name="gather_agent", + map = MapAgent( + name="map_agent", sub_agents=[subagent], ) - runner = TestInMemoryRunner(gather) + runner = TestInMemoryRunner(map) input_data = json.dumps(["0", "1"]) expected_output = [ @@ -152,20 +152,20 @@ async def test_gather_agent_with_sequential_or_parallel_agent(SubagentClass): @pytest.mark.asyncio -async def test_gather_agent_with_gather_agent(): +async def test_map_agent_with_map_agent(): mock_leaf = LlmAgent(name="nested_mock", model=OneTwoThreeModel()) - inner_gather = MapAgent( - name="inner_gather", + inner_map = MapAgent( + name="inner_map", sub_agents=[mock_leaf], ) - outer_gather = MapAgent( - name="outer_gather", - sub_agents=[inner_gather], + outer_map = MapAgent( + name="outer_map", + sub_agents=[inner_map], ) - runner = TestInMemoryRunner(outer_gather) + runner = TestInMemoryRunner(outer_map) input_data = json.dumps( [json.dumps([str(i), str(i + 1)]) for i in [10, 20, 30]] @@ -195,9 +195,9 @@ async def test_gather_agent_with_gather_agent(): @pytest.mark.asyncio -async def test_gather_agent_tree(): - inner_gather = MapAgent( - name="gather_inner", +async def test_map_agent_tree(): + inner_map = MapAgent( + name="map_inner", sub_agents=[LlmAgent(name="mock_1", model=OneTwoThreeModel())], ) @@ -205,17 +205,17 @@ async def test_gather_agent_tree(): name="main_sequential", sub_agents=[ LlmAgent(name="mock_0", model=OneTwoThreeModel()), - inner_gather, + inner_map, ], max_iterations=1, ) - outer_gather = MapAgent( - name="gather_outer", + outer_map = MapAgent( + name="map_outer", sub_agents=[main_loop], ) - runner = TestInMemoryRunner(outer_gather) + runner = TestInMemoryRunner(outer_map) input_data = json.dumps(["0", "1"]) expected_output = [ From 0beaaacb9281581f618443f649fae142a970a62c Mon Sep 17 00:00:00 2001 From: "anan.yablonko" Date: Mon, 22 Dec 2025 18:50:41 +0200 Subject: [PATCH 3/9] Add init with import path Fix comments and test (Gemini code assist) --- src/google/adk/agents/__init__.py | 2 ++ src/google/adk/agents/map_agent.py | 7 +++---- tests/unittests/agents/test_map_agent.py | 15 +-------------- 3 files changed, 6 insertions(+), 18 deletions(-) diff --git a/src/google/adk/agents/__init__.py b/src/google/adk/agents/__init__.py index b5f8e88cde..55036a092b 100644 --- a/src/google/adk/agents/__init__.py +++ b/src/google/adk/agents/__init__.py @@ -19,6 +19,7 @@ from .llm_agent import Agent from .llm_agent import LlmAgent from .loop_agent import LoopAgent +from .map_agent import MapAgent from .mcp_instruction_provider import McpInstructionProvider from .parallel_agent import ParallelAgent from .run_config import RunConfig @@ -29,6 +30,7 @@ 'BaseAgent', 'LlmAgent', 'LoopAgent', + 'MapAgent', 'McpInstructionProvider', 'ParallelAgent', 'SequentialAgent', diff --git a/src/google/adk/agents/map_agent.py b/src/google/adk/agents/map_agent.py index cd4385584c..413dbceae8 100644 --- a/src/google/adk/agents/map_agent.py +++ b/src/google/adk/agents/map_agent.py @@ -83,10 +83,9 @@ def _extract_input_prompts( (event.content or types.Content()).parts or [types.Part()] )[0].text or "" - """ - Remove the event which has the prompt list, so that a sub agent does not see the prompts of its siblings, which may confuse it. - The event is removed only for this invocation. - """ + # Remove the event which has the prompt list, so that a sub agent does not + # see the prompts of its siblings, which may confuse it. + # The event is removed only for this invocation. ctx.session.events.pop(i) agent_input = RootModel[list[str]].model_validate_json(input_message).root diff --git a/tests/unittests/agents/test_map_agent.py b/tests/unittests/agents/test_map_agent.py index 701d89f584..014ff505fd 100644 --- a/tests/unittests/agents/test_map_agent.py +++ b/tests/unittests/agents/test_map_agent.py @@ -177,20 +177,7 @@ async def test_map_agent_with_map_agent(): events = await runner.run_async_with_new_session(input_data) - res = [e for e in events if e.author.startswith("nested_mock")] - res = sorted( - res, - key=lambda e: ( - e.author, - ((e.content or types.Content()).parts or [types.Part()])[0].text - or "", - ), - ) - res = [ - ((e.content or types.Content()).parts or [types.Part()])[0].text or "" - for e in res - ] - assert len(res) == 6 + res = extract_event_text(events, "nested_mock") assert res == expected_output From b6ee76f64107906619d863315867db2f2f33f33e Mon Sep 17 00:00:00 2001 From: "anan.yablonko" Date: Tue, 23 Dec 2025 11:35:06 +0200 Subject: [PATCH 4/9] use existing branch-related utils to fix branching issues --- src/google/adk/agents/map_agent.py | 16 +++++++--------- tests/unittests/agents/test_map_agent.py | 21 +++++++++++++++++++++ 2 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/google/adk/agents/map_agent.py b/src/google/adk/agents/map_agent.py index 413dbceae8..2fc27b1463 100644 --- a/src/google/adk/agents/map_agent.py +++ b/src/google/adk/agents/map_agent.py @@ -4,8 +4,10 @@ from annotated_types import Len from google.adk.agents import BaseAgent from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.parallel_agent import _create_branch_ctx_for_sub_agent from google.adk.agents.parallel_agent import _merge_agent_run from google.adk.events import Event +from google.adk.flows.llm_flows.contents import _should_include_event_in_context from google.genai import types from pydantic import Field from pydantic import RootModel @@ -36,8 +38,6 @@ async def _run_async_impl( """ # Create a branch string if it doesn't exist, to ensure parallel invocations don't interfere with each other - invocation_context.branch = invocation_context.branch or self.name - prompts, invoker = self._extract_input_prompts(invocation_context) # for agent naming - e.g. if there are 100-999 prompts, sub-agent copies are named 001, 002, 003 and so on @@ -71,9 +71,7 @@ def _extract_input_prompts( for i in range(len(ctx.session.events) - 1, -1, -1): event = ctx.session.events[i] - if event.branch is None or ( - ctx.branch is not None and event.branch.startswith(ctx.branch) - ): + if _should_include_event_in_context(ctx.branch, event): break else: return [], "user" @@ -122,7 +120,6 @@ def _branch_context( agent = self._branch_agent_tree(self.sub_agents[0], idx, width) - branch = f"{ctx.branch}.{agent.name}" prompt_part = [types.Part(text=prompt)] # Add the prompt to the user_content of this branch to easily access agent input in callbacks @@ -130,8 +127,9 @@ def _branch_context( role="user", parts=((ctx.user_content or types.Content()).parts or []) + prompt_part, ) - new_ctx = ctx.model_copy( - update=dict(branch=branch, agent=agent, user_content=user_content) + + new_ctx = _create_branch_ctx_for_sub_agent(self, agent, ctx).model_copy( + update=dict(agent=agent, user_content=user_content) ) # Add the prompt as a temporary event of this branch in place of the prompt list as the natural input of the sub-agent. @@ -139,7 +137,7 @@ def _branch_context( role="user" if invoker == "user" else "model", parts=prompt_part ) new_ctx.session.events.append( - Event(author=invoker, branch=branch, content=prompt_content) + Event(author=invoker, branch=new_ctx.branch, content=prompt_content) ) return new_ctx diff --git a/tests/unittests/agents/test_map_agent.py b/tests/unittests/agents/test_map_agent.py index 014ff505fd..14ae9372b1 100644 --- a/tests/unittests/agents/test_map_agent.py +++ b/tests/unittests/agents/test_map_agent.py @@ -214,3 +214,24 @@ async def test_map_agent_tree(): res = extract_event_text(events, "mock_") assert res == expected_output + + +@pytest.mark.asyncio +async def test_map_agent_callback_event_branch(): + def callback(callback_context: CallbackContext) -> types.Content: + return ModelContent([types.Part(text="OK")]) + + agent = MapAgent( + name="map_agent", + sub_agents=[LlmAgent(name="mock", model=OneTwoThreeModel())], + after_agent_callback=callback + ) + + runner = TestInMemoryRunner(agent) + events = await runner.run_async_with_new_session(json.dumps(["0"])) + + event = events[-1] + assert event.branch is None + assert event.author == agent.name + assert event.content == ModelContent([types.Part(text="OK")]) + From 9a013ebbf274233f7da0c2b41ef7f7872173cc4f Mon Sep 17 00:00:00 2001 From: "anan.yablonko" Date: Tue, 23 Dec 2025 12:47:13 +0200 Subject: [PATCH 5/9] remove unused variable --- src/google/adk/agents/map_agent.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/google/adk/agents/map_agent.py b/src/google/adk/agents/map_agent.py index 2fc27b1463..05c71409d5 100644 --- a/src/google/adk/agents/map_agent.py +++ b/src/google/adk/agents/map_agent.py @@ -67,8 +67,6 @@ def _extract_input_prompts( The input to the map agent is a list of strings. We extract the text content from the latest event, and assume it is a list of strings serialized as a json string. """ - invoker = "user" - for i in range(len(ctx.session.events) - 1, -1, -1): event = ctx.session.events[i] if _should_include_event_in_context(ctx.branch, event): From 304a91b72e7c8bb76874395a8b1c0666fa420d9e Mon Sep 17 00:00:00 2001 From: "anan.yablonko" Date: Tue, 23 Dec 2025 12:48:39 +0200 Subject: [PATCH 6/9] fix cloning the sub agent tree. --- src/google/adk/agents/map_agent.py | 48 +++++++++++++---------- tests/unittests/agents/test_map_agent.py | 50 ++++++++++++++++++------ 2 files changed, 65 insertions(+), 33 deletions(-) diff --git a/src/google/adk/agents/map_agent.py b/src/google/adk/agents/map_agent.py index 05c71409d5..987ab592e0 100644 --- a/src/google/adk/agents/map_agent.py +++ b/src/google/adk/agents/map_agent.py @@ -1,5 +1,6 @@ from typing import Annotated from typing import AsyncGenerator +from typing import Optional from annotated_types import Len from google.adk.agents import BaseAgent @@ -43,16 +44,24 @@ async def _run_async_impl( # for agent naming - e.g. if there are 100-999 prompts, sub-agent copies are named 001, 002, 003 and so on number_field_width = len(str(len(prompts))) + # Clone the sub-agent's tree for each prompt + sub_agents = [ + self._branch_agent_tree(self.sub_agents[0], i, number_field_width) + for i, _ in enumerate(prompts) + ] + + # Set the map agent as the parent of the clones + self.clone(update={"sub_agents": sub_agents}) + # Create a separate invocation context for each prompt, each with a numbered copy of the sub-agent. contexts = [ self._branch_context( invocation_context, - idx=i, - prompt=prompt, + agent=agent, invoker=invoker, - width=number_field_width, + prompt=prompt, ) - for i, prompt in enumerate(prompts) + for prompt, agent in zip(prompts, sub_agents) ] async for event in _merge_agent_run( @@ -97,27 +106,21 @@ def _branch_context( self, ctx: InvocationContext, *, - prompt: str, + agent: BaseAgent, invoker: str, - idx: int, - width: int, + prompt: str, ) -> InvocationContext: - """Creates a numbered copy of the sub-agent that sees a single prompt, and can run separately from its siblings. + """Creates a an invocation context for invoking a sub-agent clone with a single prompt. Args: ctx: The current invocation context of the map agent. To be copied and edited for the sub-agent copy. - prompt: the prompt on which the sub-agent copy should be invoked + agent: the sub-agent clone to be invoked in the returned context. invoker: the invoker of the map agent in this invocation. - idx: index of the prompt in the input prompts, serves as a unique postfix to the agent name - width: number of digits in the total number of prompts, to ensure naming is consistent in field width - (e.g. 001, 002, ... 010, 011, ... 100, 101; and not 1, 2, ... 10, 11, ... 100, 101) + prompt: the prompt on which the sub-agent copy should be invoked Returns: InvocationContext: A new invocation context ready to run with the unique sub-agent copy and the prompt """ - - agent = self._branch_agent_tree(self.sub_agents[0], idx, width) - prompt_part = [types.Part(text=prompt)] # Add the prompt to the user_content of this branch to easily access agent input in callbacks @@ -145,12 +148,17 @@ def _branch_agent_tree( ) -> BaseAgent: """ Clone and rename an agent and its sub-tree to create a thread-safe branch. + Args: + agent: the root of the current sub-agent tree - in the first call it is the main sub-agent of the map agent + idx: index of the prompt in the input prompts, serves as a unique postfix to the agent name + width: number of digits in the total number of prompts, to ensure naming is consistent in field width + (e.g. 001, 002, ... 010, 011, ... 100, 101; and not 1, 2, ... 10, 11, ... 100, 101) """ - new_agent = agent.model_copy( - update={"name": self._get_unique_name(agent.name, idx=idx, width=width)} - ) - - new_agent.sub_agents = [ + new_name = self._get_unique_name(agent.name, idx=idx, width=width) + new_sub_agents = [ self._branch_agent_tree(a, idx, width) for a in agent.sub_agents ] + new_agent = agent.clone( + update={"name": new_name, "sub_agents": new_sub_agents} + ) return new_agent diff --git a/tests/unittests/agents/test_map_agent.py b/tests/unittests/agents/test_map_agent.py index 14ae9372b1..df611fe8cb 100644 --- a/tests/unittests/agents/test_map_agent.py +++ b/tests/unittests/agents/test_map_agent.py @@ -218,20 +218,44 @@ async def test_map_agent_tree(): @pytest.mark.asyncio async def test_map_agent_callback_event_branch(): - def callback(callback_context: CallbackContext) -> types.Content: - return ModelContent([types.Part(text="OK")]) + def callback(callback_context: CallbackContext) -> types.Content: + return ModelContent([types.Part(text="OK")]) + + agent = MapAgent( + name="map_agent", + sub_agents=[LlmAgent(name="mock", model=OneTwoThreeModel())], + after_agent_callback=callback, + ) + + runner = TestInMemoryRunner(agent) + events = await runner.run_async_with_new_session(json.dumps(["0"])) + + event = events[-1] + assert event.branch is None + assert event.author == agent.name + assert event.content == ModelContent([types.Part(text="OK")]) - agent = MapAgent( - name="map_agent", - sub_agents=[LlmAgent(name="mock", model=OneTwoThreeModel())], - after_agent_callback=callback - ) - runner = TestInMemoryRunner(agent) - events = await runner.run_async_with_new_session(json.dumps(["0"])) +@pytest.mark.asyncio +async def test_map_agent_sub_agent_cloned_with_correct_parent(): + def callback(callback_context: CallbackContext) -> None: + current = callback_context._invocation_context.agent + parent = current.parent_agent + assert parent is not None + assert parent.sub_agents[0] == current + assert current.name == "sub_agent_0" + assert current.sub_agents[0].name == "mock_0" + + sub_agent = SequentialAgent( + name="sub_agent", + sub_agents=[LlmAgent(name="mock", model=OneTwoThreeModel())], + after_agent_callback=callback, + ) - event = events[-1] - assert event.branch is None - assert event.author == agent.name - assert event.content == ModelContent([types.Part(text="OK")]) + map_agent = MapAgent( + name="map_agent", + sub_agents=[sub_agent], + ) + runner = TestInMemoryRunner(map_agent) + await runner.run_async_with_new_session(json.dumps(["0"])) From 5faf33179523240c70c66b8c233b01087a2c32f8 Mon Sep 17 00:00:00 2001 From: "anan.yablonko" Date: Tue, 23 Dec 2025 13:00:37 +0200 Subject: [PATCH 7/9] add test for errors. --- tests/unittests/agents/test_map_agent.py | 58 +++++++++++++++++------- 1 file changed, 42 insertions(+), 16 deletions(-) diff --git a/tests/unittests/agents/test_map_agent.py b/tests/unittests/agents/test_map_agent.py index df611fe8cb..2179731ab4 100644 --- a/tests/unittests/agents/test_map_agent.py +++ b/tests/unittests/agents/test_map_agent.py @@ -12,14 +12,13 @@ from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.genai import types +from pydantic import ValidationError import pytest -from ..testing_utils import MockModel -from ..testing_utils import ModelContent -from ..testing_utils import TestInMemoryRunner +from .. import testing_utils -class OneTwoThreeModel(MockModel): +class OneTwoThreeModel(testing_utils.MockModel): """Maps an input of 'i' to output of "['i', 'i+1', 'i+2']", e.g. '5' -> "['5', '6', '7']" """ responses: list[LlmResponse] = [] @@ -34,7 +33,9 @@ async def generate_content_async( agent_input = re.sub(r"\[\w+\] said: ", "", agent_input) assert agent_input.isnumeric() res = json.dumps([str(int(agent_input) + i) for i in range(3)]) - yield LlmResponse(content=ModelContent([types.Part(text=res)])) + yield LlmResponse( + content=testing_utils.ModelContent([types.Part(text=res)]) + ) def extract_event_text(events: list[Event], agent_prefix: str) -> list[str]: @@ -60,13 +61,14 @@ def delete_events(callback_context: CallbackContext) -> None: name="map_agent", sub_agents=[ LlmAgent( - name="test", model=MockModel.create([], error=RuntimeError()) + name="test", + model=testing_utils.MockModel.create([], error=RuntimeError()), ) ], before_agent_callback=delete_events, ) - runner = TestInMemoryRunner(map) + runner = testing_utils.TestInMemoryRunner(map) await runner.run_async_with_new_session("") @@ -77,7 +79,7 @@ async def test_map_agent_text_input(): sub_agents=[LlmAgent(name="mock_agent", model=OneTwoThreeModel())], ) - runner = TestInMemoryRunner(map) + runner = testing_utils.TestInMemoryRunner(map) n_runs = 100 @@ -105,7 +107,7 @@ async def test_map_agent_with_loop_agent_parent(): max_iterations=2, ) - runner = TestInMemoryRunner(loop_agent) + runner = testing_utils.TestInMemoryRunner(loop_agent) input_data = json.dumps(["0"]) expected_output = [json.dumps(["0", "1", "2"])] + [ @@ -139,7 +141,7 @@ async def test_map_agent_with_sequential_or_parallel_agent(SubagentClass): sub_agents=[subagent], ) - runner = TestInMemoryRunner(map) + runner = testing_utils.TestInMemoryRunner(map) input_data = json.dumps(["0", "1"]) expected_output = [ @@ -165,7 +167,7 @@ async def test_map_agent_with_map_agent(): sub_agents=[inner_map], ) - runner = TestInMemoryRunner(outer_map) + runner = testing_utils.TestInMemoryRunner(outer_map) input_data = json.dumps( [json.dumps([str(i), str(i + 1)]) for i in [10, 20, 30]] @@ -202,7 +204,7 @@ async def test_map_agent_tree(): sub_agents=[main_loop], ) - runner = TestInMemoryRunner(outer_map) + runner = testing_utils.TestInMemoryRunner(outer_map) input_data = json.dumps(["0", "1"]) expected_output = [ @@ -219,7 +221,7 @@ async def test_map_agent_tree(): @pytest.mark.asyncio async def test_map_agent_callback_event_branch(): def callback(callback_context: CallbackContext) -> types.Content: - return ModelContent([types.Part(text="OK")]) + return testing_utils.ModelContent([types.Part(text="OK")]) agent = MapAgent( name="map_agent", @@ -227,13 +229,13 @@ def callback(callback_context: CallbackContext) -> types.Content: after_agent_callback=callback, ) - runner = TestInMemoryRunner(agent) + runner = testing_utils.TestInMemoryRunner(agent) events = await runner.run_async_with_new_session(json.dumps(["0"])) event = events[-1] assert event.branch is None assert event.author == agent.name - assert event.content == ModelContent([types.Part(text="OK")]) + assert event.content == testing_utils.ModelContent([types.Part(text="OK")]) @pytest.mark.asyncio @@ -257,5 +259,29 @@ def callback(callback_context: CallbackContext) -> None: sub_agents=[sub_agent], ) - runner = TestInMemoryRunner(map_agent) + runner = testing_utils.TestInMemoryRunner(map_agent) await runner.run_async_with_new_session(json.dumps(["0"])) + + +@pytest.mark.asyncio +async def test_map_agent_errors(): + """Map agent can only operate on a list input. Should crash on Dev Error""" + with pytest.raises(ValidationError): + MapAgent( + name="map_agent", + sub_agents=[], + ) + with pytest.raises(ValidationError): + MapAgent( + name="map_agent", + sub_agents=[LlmAgent(name="mock_0"), LlmAgent(name="mock_1")], + ) + + map_agent = MapAgent( + name="map_agent", + sub_agents=[LlmAgent(name="mock", model=OneTwoThreeModel())], + ) + runner = testing_utils.TestInMemoryRunner(map_agent) + + with pytest.raises(ValidationError): + await runner.run_async_with_new_session("Not a list") From 3faa44e3435a014aa7137c140f179e6e2cda8155 Mon Sep 17 00:00:00 2001 From: "anan.yablonko" Date: Tue, 23 Dec 2025 13:20:47 +0200 Subject: [PATCH 8/9] remove unused import --- src/google/adk/agents/map_agent.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/google/adk/agents/map_agent.py b/src/google/adk/agents/map_agent.py index 987ab592e0..a2b3b6761b 100644 --- a/src/google/adk/agents/map_agent.py +++ b/src/google/adk/agents/map_agent.py @@ -1,6 +1,5 @@ from typing import Annotated from typing import AsyncGenerator -from typing import Optional from annotated_types import Len from google.adk.agents import BaseAgent From b14e3b53ebffa5fc76a1b2b7c353ba0bd1ffc411 Mon Sep 17 00:00:00 2001 From: "anan.yablonko" Date: Wed, 24 Dec 2025 10:03:45 +0200 Subject: [PATCH 9/9] Fix for python version <3.11 --- src/google/adk/agents/map_agent.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/google/adk/agents/map_agent.py b/src/google/adk/agents/map_agent.py index a2b3b6761b..a98b6f5e12 100644 --- a/src/google/adk/agents/map_agent.py +++ b/src/google/adk/agents/map_agent.py @@ -1,3 +1,4 @@ +import sys from typing import Annotated from typing import AsyncGenerator @@ -6,6 +7,7 @@ from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.parallel_agent import _create_branch_ctx_for_sub_agent from google.adk.agents.parallel_agent import _merge_agent_run +from google.adk.agents.parallel_agent import _merge_agent_run_pre_3_11 from google.adk.events import Event from google.adk.flows.llm_flows.contents import _should_include_event_in_context from google.genai import types @@ -13,6 +15,8 @@ from pydantic import RootModel from typing_extensions import override +from ..utils.context_utils import Aclosing + class MapAgent(BaseAgent): sub_agents: Annotated[list[BaseAgent], Len(1, 1)] = Field( @@ -63,10 +67,16 @@ async def _run_async_impl( for prompt, agent in zip(prompts, sub_agents) ] - async for event in _merge_agent_run( - [ctx.agent.run_async(ctx) for ctx in contexts] - ): - yield event + agent_runs = [ctx.agent.run_async(ctx) for ctx in contexts] + + merge_func = ( + _merge_agent_run + if sys.version_info >= (3, 11) + else _merge_agent_run_pre_3_11 + ) + async with Aclosing(merge_func(agent_runs)) as agen: + async for event in agen: + yield event def _extract_input_prompts( self, ctx: InvocationContext