diff --git a/src/google/adk/agents/__init__.py b/src/google/adk/agents/__init__.py index b5f8e88cde..55036a092b 100644 --- a/src/google/adk/agents/__init__.py +++ b/src/google/adk/agents/__init__.py @@ -19,6 +19,7 @@ from .llm_agent import Agent from .llm_agent import LlmAgent from .loop_agent import LoopAgent +from .map_agent import MapAgent from .mcp_instruction_provider import McpInstructionProvider from .parallel_agent import ParallelAgent from .run_config import RunConfig @@ -29,6 +30,7 @@ 'BaseAgent', 'LlmAgent', 'LoopAgent', + 'MapAgent', 'McpInstructionProvider', 'ParallelAgent', 'SequentialAgent', diff --git a/src/google/adk/agents/map_agent.py b/src/google/adk/agents/map_agent.py new file mode 100644 index 0000000000..a98b6f5e12 --- /dev/null +++ b/src/google/adk/agents/map_agent.py @@ -0,0 +1,173 @@ +import sys +from typing import Annotated +from typing import AsyncGenerator + +from annotated_types import Len +from google.adk.agents import BaseAgent +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.parallel_agent import _create_branch_ctx_for_sub_agent +from google.adk.agents.parallel_agent import _merge_agent_run +from google.adk.agents.parallel_agent import _merge_agent_run_pre_3_11 +from google.adk.events import Event +from google.adk.flows.llm_flows.contents import _should_include_event_in_context +from google.genai import types +from pydantic import Field +from pydantic import RootModel +from typing_extensions import override + +from ..utils.context_utils import Aclosing + + +class MapAgent(BaseAgent): + sub_agents: Annotated[list[BaseAgent], Len(1, 1)] = Field( + min_length=1, + max_length=1, + default_factory=list, + description=( + "A single base agent that will be copied and invoked for each prompt" + ), + ) + + @override + async def _run_async_impl( + self, invocation_context: InvocationContext + ) -> AsyncGenerator[Event, None]: + """Core logic of this workflow agent. + + Args: + invocation_context: InvocationContext, provides access to the input prompts. + + Yields: + Event: the events generated by the sub-agent for each input prompt. + """ + + # Create a branch string if it doesn't exist, to ensure parallel invocations don't interfere with each other + prompts, invoker = self._extract_input_prompts(invocation_context) + + # for agent naming - e.g. if there are 100-999 prompts, sub-agent copies are named 001, 002, 003 and so on + number_field_width = len(str(len(prompts))) + + # Clone the sub-agent's tree for each prompt + sub_agents = [ + self._branch_agent_tree(self.sub_agents[0], i, number_field_width) + for i, _ in enumerate(prompts) + ] + + # Set the map agent as the parent of the clones + self.clone(update={"sub_agents": sub_agents}) + + # Create a separate invocation context for each prompt, each with a numbered copy of the sub-agent. + contexts = [ + self._branch_context( + invocation_context, + agent=agent, + invoker=invoker, + prompt=prompt, + ) + for prompt, agent in zip(prompts, sub_agents) + ] + + agent_runs = [ctx.agent.run_async(ctx) for ctx in contexts] + + merge_func = ( + _merge_agent_run + if sys.version_info >= (3, 11) + else _merge_agent_run_pre_3_11 + ) + async with Aclosing(merge_func(agent_runs)) as agen: + async for event in agen: + yield event + + def _extract_input_prompts( + self, ctx: InvocationContext + ) -> tuple[list[str], str]: + """ + The input to the map agent is a list of strings. + We extract the text content from the latest event, and assume it is a list of strings serialized as a json string. + """ + for i in range(len(ctx.session.events) - 1, -1, -1): + event = ctx.session.events[i] + if _should_include_event_in_context(ctx.branch, event): + break + else: + return [], "user" + + invoker: str = event.author + input_message: str = ( + (event.content or types.Content()).parts or [types.Part()] + )[0].text or "" + + # Remove the event which has the prompt list, so that a sub agent does not + # see the prompts of its siblings, which may confuse it. + # The event is removed only for this invocation. + ctx.session.events.pop(i) + + agent_input = RootModel[list[str]].model_validate_json(input_message).root + + return agent_input, invoker + + @staticmethod + def _get_unique_name(name: str, idx: int, width: int) -> str: + """e.g. my_sub_agent_046""" + return f"{name}_{idx:0{width}d}" + + def _branch_context( + self, + ctx: InvocationContext, + *, + agent: BaseAgent, + invoker: str, + prompt: str, + ) -> InvocationContext: + """Creates a an invocation context for invoking a sub-agent clone with a single prompt. + + Args: + ctx: The current invocation context of the map agent. To be copied and edited for the sub-agent copy. + agent: the sub-agent clone to be invoked in the returned context. + invoker: the invoker of the map agent in this invocation. + prompt: the prompt on which the sub-agent copy should be invoked + + Returns: + InvocationContext: A new invocation context ready to run with the unique sub-agent copy and the prompt + """ + prompt_part = [types.Part(text=prompt)] + + # Add the prompt to the user_content of this branch to easily access agent input in callbacks + user_content = types.Content( + role="user", + parts=((ctx.user_content or types.Content()).parts or []) + prompt_part, + ) + + new_ctx = _create_branch_ctx_for_sub_agent(self, agent, ctx).model_copy( + update=dict(agent=agent, user_content=user_content) + ) + + # Add the prompt as a temporary event of this branch in place of the prompt list as the natural input of the sub-agent. + prompt_content = types.Content( + role="user" if invoker == "user" else "model", parts=prompt_part + ) + new_ctx.session.events.append( + Event(author=invoker, branch=new_ctx.branch, content=prompt_content) + ) + + return new_ctx + + def _branch_agent_tree( + self, agent: BaseAgent, idx: int, width: int + ) -> BaseAgent: + """ + Clone and rename an agent and its sub-tree to create a thread-safe branch. + Args: + agent: the root of the current sub-agent tree - in the first call it is the main sub-agent of the map agent + idx: index of the prompt in the input prompts, serves as a unique postfix to the agent name + width: number of digits in the total number of prompts, to ensure naming is consistent in field width + (e.g. 001, 002, ... 010, 011, ... 100, 101; and not 1, 2, ... 10, 11, ... 100, 101) + """ + new_name = self._get_unique_name(agent.name, idx=idx, width=width) + new_sub_agents = [ + self._branch_agent_tree(a, idx, width) for a in agent.sub_agents + ] + new_agent = agent.clone( + update={"name": new_name, "sub_agents": new_sub_agents} + ) + return new_agent diff --git a/tests/unittests/agents/test_map_agent.py b/tests/unittests/agents/test_map_agent.py new file mode 100644 index 0000000000..2179731ab4 --- /dev/null +++ b/tests/unittests/agents/test_map_agent.py @@ -0,0 +1,287 @@ +import json +import re +from typing import AsyncGenerator + +from google.adk.agents import LlmAgent +from google.adk.agents import LoopAgent +from google.adk.agents import MapAgent +from google.adk.agents import ParallelAgent +from google.adk.agents import SequentialAgent +from google.adk.agents.callback_context import CallbackContext +from google.adk.events import Event +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.genai import types +from pydantic import ValidationError +import pytest + +from .. import testing_utils + + +class OneTwoThreeModel(testing_utils.MockModel): + """Maps an input of 'i' to output of "['i', 'i+1', 'i+2']", e.g. '5' -> "['5', '6', '7']" """ + + responses: list[LlmResponse] = [] + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + agent_input: str | None = ( + (llm_request.contents[-1] or types.Content()).parts or [types.Part()] + )[-1].text + assert agent_input is not None + agent_input = re.sub(r"\[\w+\] said: ", "", agent_input) + assert agent_input.isnumeric() + res = json.dumps([str(int(agent_input) + i) for i in range(3)]) + yield LlmResponse( + content=testing_utils.ModelContent([types.Part(text=res)]) + ) + + +def extract_event_text(events: list[Event], agent_prefix: str) -> list[str]: + filtered_events = [e for e in events if e.author.startswith(agent_prefix)] + sorted_events = sorted( + filtered_events, + key=lambda e: ( + e.author, + ((e.content or types.Content()).parts or [types.Part()])[0].text + or "", + ), + ) + contents = [e.content or types.Content() for e in sorted_events] + return [(c.parts or [types.Part()])[0].text or "" for c in contents] + + +@pytest.mark.asyncio +async def test_map_agent_empty_input(): + def delete_events(callback_context: CallbackContext) -> None: + callback_context._invocation_context.session.events.clear() + + map = MapAgent( + name="map_agent", + sub_agents=[ + LlmAgent( + name="test", + model=testing_utils.MockModel.create([], error=RuntimeError()), + ) + ], + before_agent_callback=delete_events, + ) + + runner = testing_utils.TestInMemoryRunner(map) + await runner.run_async_with_new_session("") + + +@pytest.mark.asyncio +async def test_map_agent_text_input(): + map = MapAgent( + name="map_agent", + sub_agents=[LlmAgent(name="mock_agent", model=OneTwoThreeModel())], + ) + + runner = testing_utils.TestInMemoryRunner(map) + + n_runs = 100 + + input_data = json.dumps([str(i) for i in range(n_runs)]) + expected_output = [ + json.dumps([str(j) for j in range(i, i + 3)]) for i in range(n_runs) + ] + + events = await runner.run_async_with_new_session(input_data) + res = extract_event_text(events, "mock_agent") + + assert res == expected_output + + +@pytest.mark.asyncio +async def test_map_agent_with_loop_agent_parent(): + map_agent = MapAgent( + name="map_agent", + sub_agents=[LlmAgent(name="mock_agent", model=OneTwoThreeModel())], + ) + + loop_agent = LoopAgent( + name="test_loop", + sub_agents=[map_agent], + max_iterations=2, + ) + + runner = testing_utils.TestInMemoryRunner(loop_agent) + + input_data = json.dumps(["0"]) + expected_output = [json.dumps(["0", "1", "2"])] + [ + json.dumps([str(j) for j in range(i, i + 3)]) for i in range(3) + ] + + events = await runner.run_async_with_new_session(input_data) + res = extract_event_text(events, "mock_agent") + assert res == expected_output + + +@pytest.mark.parametrize("SubagentClass", [ParallelAgent, SequentialAgent]) +@pytest.mark.asyncio +async def test_map_agent_with_sequential_or_parallel_agent(SubagentClass): + """test map agent with a parallel / sequential sub-agent whose sub-agents don't communicate""" + + # A lone parallel agent wrapper hides mock_1's output from its 'cousin' mock_2 + mock1 = ParallelAgent( + name="seq_1", + sub_agents=[LlmAgent(name="mock_1", model=OneTwoThreeModel())], + ) + mock2 = LlmAgent(name="mock_2", model=OneTwoThreeModel()) + + subagent = SubagentClass( + name="subagent", + sub_agents=[mock1, mock2], + ) + + map = MapAgent( + name="map_agent", + sub_agents=[subagent], + ) + + runner = testing_utils.TestInMemoryRunner(map) + + input_data = json.dumps(["0", "1"]) + expected_output = [ + json.dumps([str(j) for j in range(i, i + 3)]) for i in [0, 1, 0, 1] + ] + + events = await runner.run_async_with_new_session(input_data) + res = extract_event_text(events, "mock_") + assert res == expected_output + + +@pytest.mark.asyncio +async def test_map_agent_with_map_agent(): + mock_leaf = LlmAgent(name="nested_mock", model=OneTwoThreeModel()) + + inner_map = MapAgent( + name="inner_map", + sub_agents=[mock_leaf], + ) + + outer_map = MapAgent( + name="outer_map", + sub_agents=[inner_map], + ) + + runner = testing_utils.TestInMemoryRunner(outer_map) + + input_data = json.dumps( + [json.dumps([str(i), str(i + 1)]) for i in [10, 20, 30]] + ) + expected_output = [ + json.dumps([str(j) for j in range(i, i + 3)]) + for i in [10, 11, 20, 21, 30, 31] + ] + + events = await runner.run_async_with_new_session(input_data) + + res = extract_event_text(events, "nested_mock") + assert res == expected_output + + +@pytest.mark.asyncio +async def test_map_agent_tree(): + inner_map = MapAgent( + name="map_inner", + sub_agents=[LlmAgent(name="mock_1", model=OneTwoThreeModel())], + ) + + main_loop = LoopAgent( + name="main_sequential", + sub_agents=[ + LlmAgent(name="mock_0", model=OneTwoThreeModel()), + inner_map, + ], + max_iterations=1, + ) + + outer_map = MapAgent( + name="map_outer", + sub_agents=[main_loop], + ) + + runner = testing_utils.TestInMemoryRunner(outer_map) + + input_data = json.dumps(["0", "1"]) + expected_output = [ + json.dumps([str(j) for j in range(i, i + 3)]) + for i in [0, 1, 0, 1, 2, 1, 2, 3] + ] + + events = await runner.run_async_with_new_session(input_data) + res = extract_event_text(events, "mock_") + + assert res == expected_output + + +@pytest.mark.asyncio +async def test_map_agent_callback_event_branch(): + def callback(callback_context: CallbackContext) -> types.Content: + return testing_utils.ModelContent([types.Part(text="OK")]) + + agent = MapAgent( + name="map_agent", + sub_agents=[LlmAgent(name="mock", model=OneTwoThreeModel())], + after_agent_callback=callback, + ) + + runner = testing_utils.TestInMemoryRunner(agent) + events = await runner.run_async_with_new_session(json.dumps(["0"])) + + event = events[-1] + assert event.branch is None + assert event.author == agent.name + assert event.content == testing_utils.ModelContent([types.Part(text="OK")]) + + +@pytest.mark.asyncio +async def test_map_agent_sub_agent_cloned_with_correct_parent(): + def callback(callback_context: CallbackContext) -> None: + current = callback_context._invocation_context.agent + parent = current.parent_agent + assert parent is not None + assert parent.sub_agents[0] == current + assert current.name == "sub_agent_0" + assert current.sub_agents[0].name == "mock_0" + + sub_agent = SequentialAgent( + name="sub_agent", + sub_agents=[LlmAgent(name="mock", model=OneTwoThreeModel())], + after_agent_callback=callback, + ) + + map_agent = MapAgent( + name="map_agent", + sub_agents=[sub_agent], + ) + + runner = testing_utils.TestInMemoryRunner(map_agent) + await runner.run_async_with_new_session(json.dumps(["0"])) + + +@pytest.mark.asyncio +async def test_map_agent_errors(): + """Map agent can only operate on a list input. Should crash on Dev Error""" + with pytest.raises(ValidationError): + MapAgent( + name="map_agent", + sub_agents=[], + ) + with pytest.raises(ValidationError): + MapAgent( + name="map_agent", + sub_agents=[LlmAgent(name="mock_0"), LlmAgent(name="mock_1")], + ) + + map_agent = MapAgent( + name="map_agent", + sub_agents=[LlmAgent(name="mock", model=OneTwoThreeModel())], + ) + runner = testing_utils.TestInMemoryRunner(map_agent) + + with pytest.raises(ValidationError): + await runner.run_async_with_new_session("Not a list")