diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index 2f5d03a772..f68b349d9c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index cfd850b3a3..1bc4ee58c8 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index e15f9af981..f5a495c398 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -294,6 +294,10 @@ async def run_async( async for event in agen: yield event + # Propagate branch changes back to parent context. + if ctx.branch != parent_context.branch: + parent_context.branch = ctx.branch + if ctx.end_invocation: return diff --git a/src/google/adk/agents/branch.py b/src/google/adk/agents/branch.py new file mode 100644 index 0000000000..99146fa5f4 --- /dev/null +++ b/src/google/adk/agents/branch.py @@ -0,0 +1,163 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Branch context for provenance-based event filtering in parallel agents.""" + +from __future__ import annotations + +import threading +from typing import Optional + +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field +from pydantic import model_serializer +from pydantic import PrivateAttr + + +class BranchTokenFactory: + """Thread-safe global counter for branch tokens. + + Each fork operation in a parallel agent execution creates new unique tokens + that are used to track provenance and determine event visibility across + branches WITHIN a single invocation. + + The counter resets at the start of each invocation, ensuring tokens are + only used for parallel execution isolation within that invocation. Events + from previous invocations are always visible (branch filtering only applies + within current invocation). + """ + + _lock = threading.Lock() + _next = 0 + + @classmethod + def new_token(cls) -> int: + """Generate a new unique token. + + Returns: + A unique integer token. + """ + with cls._lock: + cls._next += 1 + return cls._next + + @classmethod + def reset(cls) -> None: + """Reset the counter to zero. + + This should be called at the start of each invocation to ensure tokens + are fresh for that invocation's parallel execution tracking. + """ + with cls._lock: + cls._next = 0 + + +class Branch(BaseModel): + """Branch tracking using token sets for parallel agent execution. + + Tracks event provenance across parallel and sequential agent execution. + Event visibility is determined by subset relationships: an event is visible + to a context if all the event's tokens are present in the context's token set. + + Example: + Root context: {} + After fork(): child_0 has {1}, child_1 has {2} + After join: parent has {1, 2} + + Events from child_0 (tokens={1}) are visible to parent (tokens={1,2}) + because {1} ⊆ {1,2}. + """ + + model_config = ConfigDict( + frozen=True, # Make instances immutable for hashing + arbitrary_types_allowed=True, + ) + """The pydantic model config.""" + + tokens: frozenset[int] = Field(default_factory=frozenset) + """Set of integer tokens representing branch provenance. + + If empty, represents the root context. Use frozenset for immutability + and to enable hashing for use in sets/dicts. + """ + + @model_serializer + def serialize_model(self): + """Custom serializer to convert frozenset to list for JSON serialization.""" + return {'tokens': list(self.tokens)} + + def fork(self) -> Branch: + """Create a child context for parallel execution. + + The child gets a unique new token added to the parent's token set. + This ensures: + 1. Child can see parent's events (parent tokens ⊆ child tokens) + 2. Siblings cannot see each other's events (sibling tokens are disjoint) + + Returns: + A new Branch with parent.tokens ∪ {new_token}. + """ + new_token = BranchTokenFactory.new_token() + return Branch(tokens=self.tokens | {new_token}) + + def join(self, others: list[Branch]) -> Branch: + """Merge token sets from parallel branches. + + This is called when parallel execution completes and we need to merge + the provenance from all branches. The result contains the union of all + token sets, ensuring subsequent agents can see events from all branches. + + Args: + others: List of other Branches to join with self. + + Returns: + New Branch with union of all token sets. + """ + combined = set(self.tokens) + for ctx in others: + combined |= ctx.tokens + return Branch(tokens=frozenset(combined)) + + def can_see(self, event_ctx: Branch) -> bool: + """Check if an event is visible from this context. + + An event is visible if all of its tokens are present in the current + context's token set (subset relationship). + + Args: + event_ctx: The Branch of the event to check. + + Returns: + True if the event is visible, False otherwise. + """ + return event_ctx.tokens.issubset(self.tokens) + + def __str__(self) -> str: + """Human-readable string representation. + + Returns: + String showing token set or "root" if empty. + """ + if not self.tokens: + return 'Branch(root)' + return f'Branch({sorted(self.tokens)})' + + def __repr__(self) -> str: + """Developer representation. + + Returns: + String representation for debugging. + """ + return str(self) diff --git a/src/google/adk/agents/invocation_context.py b/src/google/adk/agents/invocation_context.py index 24fdce9d59..8aff4f2132 100644 --- a/src/google/adk/agents/invocation_context.py +++ b/src/google/adk/agents/invocation_context.py @@ -36,6 +36,7 @@ from .active_streaming_tool import ActiveStreamingTool from .base_agent import BaseAgent from .base_agent import BaseAgentState +from .branch import Branch from .context_cache_config import ContextCacheConfig from .live_request_queue import LiveRequestQueue from .run_config import RunConfig @@ -149,15 +150,8 @@ class InvocationContext(BaseModel): invocation_id: str """The id of this invocation context. Readonly.""" - branch: Optional[str] = None - """The branch of the invocation context. - - The format is like agent_1.agent_2.agent_3, where agent_1 is the parent of - agent_2, and agent_2 is the parent of agent_3. - - Branch is used when multiple sub-agents shouldn't see their peer agents' - conversation history. - """ + branch: Branch = Field(default_factory=Branch) + """The branch context tracking event provenance for visibility filtering.""" agent: BaseAgent """The current agent of this invocation context. Readonly.""" user_content: Optional[types.Content] = None @@ -349,7 +343,11 @@ def _get_events( if event.invocation_id == self.invocation_id ] if current_branch: - results = [event for event in results if event.branch == self.branch] + results = [ + event + for event in results + if event.branch is None or event.branch == self.branch + ] return results def should_pause_invocation(self, event: Event) -> bool: diff --git a/src/google/adk/agents/parallel_agent.py b/src/google/adk/agents/parallel_agent.py index 09e65a67a4..3805e7e171 100644 --- a/src/google/adk/agents/parallel_agent.py +++ b/src/google/adk/agents/parallel_agent.py @@ -28,6 +28,7 @@ from .base_agent import BaseAgent from .base_agent import BaseAgentState from .base_agent_config import BaseAgentConfig +from .branch import Branch from .invocation_context import InvocationContext from .parallel_agent_config import ParallelAgentConfig @@ -39,12 +40,8 @@ def _create_branch_ctx_for_sub_agent( ) -> InvocationContext: """Create isolated branch for every sub-agent.""" invocation_context = invocation_context.model_copy() - branch_suffix = f'{agent.name}.{sub_agent.name}' - invocation_context.branch = ( - f'{invocation_context.branch}.{branch_suffix}' - if invocation_context.branch - else branch_suffix - ) + parent_branch = invocation_context.branch or Branch() + invocation_context.branch = parent_branch.fork() return invocation_context @@ -173,9 +170,11 @@ async def _run_async_impl( yield self._create_agent_state_event(ctx) agent_runs = [] + sub_agent_contexts = [] # Prepare and collect async generators for each sub-agent. for sub_agent in self.sub_agents: sub_agent_ctx = _create_branch_ctx_for_sub_agent(self, sub_agent, ctx) + sub_agent_contexts.append(sub_agent_ctx) # Only include sub-agents that haven't finished in a previous run. if not sub_agent_ctx.end_of_agents.get(sub_agent.name): @@ -197,6 +196,11 @@ async def _run_async_impl( if pause_invocation: return + # Join all child branches back together after parallel execution completes + parent_branch = ctx.branch or Branch() + joined_branch = parent_branch.join([c.branch for c in sub_agent_contexts]) + ctx.branch = joined_branch + # Once all sub-agents are done, mark the ParallelAgent as final. if ctx.is_resumable and all( ctx.end_of_agents.get(sub_agent.name) for sub_agent in self.sub_agents diff --git a/src/google/adk/events/event.py b/src/google/adk/events/event.py index cca086430b..d7b0e9bce0 100644 --- a/src/google/adk/events/event.py +++ b/src/google/adk/events/event.py @@ -23,6 +23,7 @@ from pydantic import ConfigDict from pydantic import Field +from ..agents.branch import Branch from ..models.llm_response import LlmResponse from .event_actions import EventActions @@ -56,15 +57,8 @@ class Event(LlmResponse): Agent client will know from this field about which function call is long running. only valid for function call event """ - branch: Optional[str] = None - """The branch of the event. - - The format is like agent_1.agent_2.agent_3, where agent_1 is the parent of - agent_2, and agent_2 is the parent of agent_3. - - Branch is used when multiple sub-agent shouldn't see their peer agents' - conversation history. - """ + branch: Optional[Branch] = None + """The branch context of the event. Used for provenance-based event filtering in parallel agents.""" # The following are computed fields. # Do not assign the ID. It will be assigned by the session. diff --git a/src/google/adk/flows/llm_flows/audio_cache_manager.py b/src/google/adk/flows/llm_flows/audio_cache_manager.py index a6308b3fe6..c9a08de8e0 100644 --- a/src/google/adk/flows/llm_flows/audio_cache_manager.py +++ b/src/google/adk/flows/llm_flows/audio_cache_manager.py @@ -185,6 +185,7 @@ async def _flush_cache_to_services( id=Event.new_id(), invocation_id=invocation_context.invocation_id, author=audio_cache[0].role, + branch=invocation_context.branch, content=types.Content( role=audio_cache[0].role, parts=[ diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 824cd26be1..c2af56cd86 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -324,6 +324,7 @@ def get_author_for_event(llm_response): id=Event.new_id(), invocation_id=invocation_context.invocation_id, author=get_author_for_event(llm_response), + branch=invocation_context.branch, ) async with Aclosing( diff --git a/src/google/adk/flows/llm_flows/contents.py b/src/google/adk/flows/llm_flows/contents.py index fefa014c45..fe2f68401c 100644 --- a/src/google/adk/flows/llm_flows/contents.py +++ b/src/google/adk/flows/llm_flows/contents.py @@ -22,6 +22,7 @@ from google.genai import types from typing_extensions import override +from ...agents.branch import Branch from ...agents.invocation_context import InvocationContext from ...events.event import Event from ...models.llm_request import LlmRequest @@ -54,6 +55,7 @@ async def run_async( invocation_context.branch, invocation_context.session.events, agent.name, + invocation_context.invocation_id, ) else: # Include current turn context only (no conversation history) @@ -61,6 +63,7 @@ async def run_async( invocation_context.branch, invocation_context.session.events, agent.name, + invocation_context.invocation_id, ) # Add instruction-related contents to proper position in conversation @@ -252,7 +255,9 @@ def _contains_empty_content(event: Event) -> bool: def _should_include_event_in_context( - current_branch: Optional[str], event: Event + current_branch: Optional[Branch], + event: Event, + current_invocation_id: str = '', ) -> bool: """Determines if an event should be included in the LLM context. @@ -263,13 +268,16 @@ def _should_include_event_in_context( Args: current_branch: The current branch of the agent. event: The event to filter. + current_invocation_id: The current invocation ID for branch filtering. Returns: True if the event should be included in the context, False otherwise. """ return not ( _contains_empty_content(event) - or not _is_event_belongs_to_branch(current_branch, event) + or not _is_event_belongs_to_branch( + current_branch, event, current_invocation_id + ) or _is_auth_event(event) or _is_request_confirmation_event(event) ) @@ -334,7 +342,10 @@ def _process_compaction_events(events: list[Event]) -> list[Event]: def _get_contents( - current_branch: Optional[str], events: list[Event], agent_name: str = '' + current_branch: Optional[Branch], + events: list[Event], + agent_name: str = '', + current_invocation_id: str = '', ) -> list[types.Content]: """Get the contents for the LLM request. @@ -344,6 +355,7 @@ def _get_contents( current_branch: The current branch of the agent. events: Events to process. agent_name: The name of the agent. + current_invocation_id: The current invocation ID for branch filtering. Returns: A list of processed contents. @@ -375,7 +387,9 @@ def _get_contents( raw_filtered_events = [ e for e in rewind_filtered_events - if _should_include_event_in_context(current_branch, e) + if _should_include_event_in_context( + current_branch, e, current_invocation_id + ) ] has_compaction_events = any( @@ -449,7 +463,10 @@ def _get_contents( def _get_current_turn_contents( - current_branch: Optional[str], events: list[Event], agent_name: str = '' + current_branch: Optional[Branch], + events: list[Event], + agent_name: str = '', + current_invocation_id: str = '', ) -> list[types.Content]: """Get contents for the current turn only (no conversation history). @@ -465,6 +482,7 @@ def _get_current_turn_contents( current_branch: The current branch of the agent. events: A list of all session events. agent_name: The name of the agent. + current_invocation_id: The current invocation ID for branch filtering. Returns: A list of contents for the current turn only, preserving context needed @@ -473,10 +491,12 @@ def _get_current_turn_contents( # Find the latest event that starts the current turn and process from there for i in range(len(events) - 1, -1, -1): event = events[i] - if _should_include_event_in_context(current_branch, event) and ( - event.author == 'user' or _is_other_agent_reply(agent_name, event) - ): - return _get_contents(current_branch, events[i:], agent_name) + if _should_include_event_in_context( + current_branch, event, current_invocation_id + ) and (event.author == 'user' or _is_other_agent_reply(agent_name, event)): + return _get_contents( + current_branch, events[i:], agent_name, current_invocation_id + ) return [] @@ -617,21 +637,40 @@ def _merge_function_response_events( def _is_event_belongs_to_branch( - invocation_branch: Optional[str], event: Event + invocation_branch: Optional[Branch], + event: Event, + current_invocation_id: str = '', ) -> bool: """Check if an event belongs to the current branch. - This is for event context segregation between agents. E.g. agent A shouldn't - see output of agent B. + This is for event context segregation between agents within the same + invocation. E.g. parallel agent A shouldn't see output of parallel agent B. + + Within the current invocation, uses Branch's token-set visibility: + An Event is visible if its branch tokens are a subset of the current branch's tokens + (event.tokens ⊆ current.tokens). + + Args: + invocation_branch: The current branch context. + event: The event to check visibility for. + current_invocation_id: The current invocation ID. + + Returns: + True if the event should be visible, False otherwise. """ - if not invocation_branch or not event.branch: + if not invocation_branch: return True - # We use dot to delimit branch nodes. To avoid simple prefix match - # (e.g. agent_0 unexpectedly matching agent_00), require either perfect branch - # match, or match prefix with an additional explicit '.' - return invocation_branch == event.branch or invocation_branch.startswith( - f'{event.branch}.' - ) + + # Events from different invocations are ALWAYS visible (multi-turn history) + if event.invocation_id != current_invocation_id: + return True + + # Events without Branch are from old code - considered visible + if not event.branch: + return True + + # Check token-set visibility: event.tokens ⊆ invocation_branch.tokens + return invocation_branch.can_see(event.branch) def _is_function_call_event(event: Event, function_name: str) -> bool: diff --git a/src/google/adk/flows/llm_flows/transcription_manager.py b/src/google/adk/flows/llm_flows/transcription_manager.py index e44e2ad493..3f7e79011f 100644 --- a/src/google/adk/flows/llm_flows/transcription_manager.py +++ b/src/google/adk/flows/llm_flows/transcription_manager.py @@ -87,6 +87,7 @@ async def _create_and_save_transcription_event( id=Event.new_id(), invocation_id=invocation_context.invocation_id, author=author, + branch=invocation_context.branch, input_transcription=transcription if is_input else None, output_transcription=transcription if not is_input else None, timestamp=time.time(), diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 33187ff0c3..5491af2bf2 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -1289,6 +1289,10 @@ def _new_invocation_context( run_config = run_config or RunConfig() invocation_id = invocation_id or new_invocation_context_id() + from .agents.branch import BranchTokenFactory + + BranchTokenFactory.reset() + if run_config.support_cfc and isinstance(self.agent, LlmAgent): model_name = self.agent.canonical_model.model if not model_name.startswith('gemini-2'): diff --git a/src/google/adk/sessions/vertex_ai_session_service.py b/src/google/adk/sessions/vertex_ai_session_service.py index cce7e99b32..3701ed791b 100644 --- a/src/google/adk/sessions/vertex_ai_session_service.py +++ b/src/google/adk/sessions/vertex_ai_session_service.py @@ -30,6 +30,7 @@ import vertexai from . import _session_util +from ..agents.branch import Branch from ..events.event import Event from ..events.event_actions import EventActions from ..utils.vertex_ai_utils import get_express_mode_api_key @@ -359,7 +360,16 @@ def _from_api_event(api_event_obj: vertexai.types.SessionEvent) -> Event: partial = getattr(event_metadata, 'partial', None) turn_complete = getattr(event_metadata, 'turn_complete', None) interrupted = getattr(event_metadata, 'interrupted', None) - branch = getattr(event_metadata, 'branch', None) + + branch_raw = getattr(event_metadata, 'branch', None) + branch: Optional[Branch] = None + if isinstance(branch_raw, dict): + branch = Branch.model_validate(branch_raw) + elif isinstance(branch_raw, Branch): + branch = branch_raw + elif branch_raw is not None: + branch = None + custom_metadata = getattr(event_metadata, 'custom_metadata', None) grounding_metadata = _session_util.decode_model( getattr(event_metadata, 'grounding_metadata', None), diff --git a/tests/unittests/a2a/converters/test_event_converter.py b/tests/unittests/a2a/converters/test_event_converter.py index 09fd65c3bf..73e5263016 100644 --- a/tests/unittests/a2a/converters/test_event_converter.py +++ b/tests/unittests/a2a/converters/test_event_converter.py @@ -22,6 +22,7 @@ from a2a.types import TaskState from a2a.types import TaskStatusUpdateEvent from google.adk.a2a.converters.event_converter import _create_artifact_id +from google.adk.agents.branch import Branch from google.adk.a2a.converters.event_converter import _create_error_status_event from google.adk.a2a.converters.event_converter import _create_status_update_event from google.adk.a2a.converters.event_converter import _get_adk_metadata_key @@ -137,7 +138,7 @@ def test_get_context_metadata_success(self): def test_get_context_metadata_with_optional_fields(self): """Test context metadata creation with optional fields.""" - self.mock_event.branch = "test-branch" + self.mock_event.branch = Branch() self.mock_event.error_code = "ERROR_001" mock_metadata = Mock() @@ -154,7 +155,6 @@ def test_get_context_metadata_with_optional_fields(self): assert f"{ADK_METADATA_KEY_PREFIX}branch" in result assert f"{ADK_METADATA_KEY_PREFIX}grounding_metadata" in result assert f"{ADK_METADATA_KEY_PREFIX}actions" in result - assert result[f"{ADK_METADATA_KEY_PREFIX}branch"] == "test-branch" assert result[f"{ADK_METADATA_KEY_PREFIX}actions"] == { "test_actions": "value" } @@ -603,7 +603,7 @@ def setup_method(self): """Set up test fixtures.""" self.mock_invocation_context = Mock(spec=InvocationContext) self.mock_invocation_context.invocation_id = "test-invocation-id" - self.mock_invocation_context.branch = "test-branch" + self.mock_invocation_context.branch = Branch() def test_convert_a2a_task_to_event_with_artifacts_priority(self): """Test convert_a2a_task_to_event prioritizes artifacts over status/history.""" @@ -737,7 +737,7 @@ def test_convert_a2a_task_to_event_no_message(self): # Verify minimal event was created with correct invocation_id assert result.author == "test-author" - assert result.branch == "test-branch" + assert isinstance(result.branch, Branch) assert result.invocation_id == "test-invocation-id" @patch("google.adk.a2a.converters.event_converter.uuid.uuid4") @@ -758,7 +758,7 @@ def test_convert_a2a_task_to_event_default_author(self, mock_uuid): # Verify default author was used and UUID was generated for invocation_id assert result.author == "a2a agent" - assert result.branch is None + assert result.branch is None # No invocation context means no branch assert result.invocation_id == "generated-uuid" def test_convert_a2a_task_to_event_none_task(self): @@ -813,7 +813,7 @@ def test_convert_a2a_message_to_event_success(self): # Verify conversion was successful assert result.author == "test-author" - assert result.branch == "test-branch" + assert isinstance(result.branch, Branch) assert result.invocation_id == "test-invocation-id" assert result.content.role == "model" assert len(result.content.parts) == 1 @@ -1004,5 +1004,5 @@ def test_convert_a2a_message_to_event_default_author(self, mock_uuid): # Verify default author was used and UUID was generated for invocation_id assert result.author == "a2a agent" - assert result.branch is None + assert result.branch is None # No invocation context means no branch assert result.invocation_id == "generated-uuid" diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py index 259bdd51c2..a4c7fa8979 100644 --- a/tests/unittests/agents/test_base_agent.py +++ b/tests/unittests/agents/test_base_agent.py @@ -25,6 +25,7 @@ from google.adk.agents.base_agent import BaseAgent from google.adk.agents.base_agent import BaseAgentState +from google.adk.agents.branch import Branch from google.adk.agents.callback_context import CallbackContext from google.adk.agents.invocation_context import InvocationContext from google.adk.apps.app import ResumabilityConfig @@ -149,21 +150,23 @@ async def _run_live_impl( async def _create_parent_invocation_context( test_name: str, agent: BaseAgent, - branch: Optional[str] = None, + branch: Optional[Branch] = None, plugins: list[BasePlugin] = [], ) -> InvocationContext: session_service = InMemorySessionService() session = await session_service.create_session( app_name='test_app', user_id='test_user' ) - return InvocationContext( - invocation_id=f'{test_name}_invocation_id', - branch=branch, - agent=agent, - session=session, - session_service=session_service, - plugin_manager=PluginManager(plugins=plugins), - ) + context_kwargs = { + 'invocation_id': f'{test_name}_invocation_id', + 'agent': agent, + 'session': session, + 'session_service': session_service, + 'plugin_manager': PluginManager(plugins=plugins), + } + if branch is not None: + context_kwargs['branch'] = branch + return InvocationContext(**context_kwargs) def test_invalid_agent_name(): @@ -189,7 +192,7 @@ async def test_run_async(request: pytest.FixtureRequest): async def test_run_async_with_branch(request: pytest.FixtureRequest): agent = _TestingAgent(name=f'{request.function.__name__}_test_agent') parent_ctx = await _create_parent_invocation_context( - request.function.__name__, agent, branch='parent_branch' + request.function.__name__, agent, branch=Branch() ) events = [e async for e in agent.run_async(parent_ctx)] @@ -197,7 +200,7 @@ async def test_run_async_with_branch(request: pytest.FixtureRequest): assert len(events) == 1 assert events[0].author == agent.name assert events[0].content.parts[0].text == 'Hello, world!' - assert events[0].branch == 'parent_branch' + assert events[0].branch == parent_ctx.branch @pytest.mark.asyncio @@ -713,7 +716,7 @@ async def test_run_live(request: pytest.FixtureRequest): async def test_run_live_with_branch(request: pytest.FixtureRequest): agent = _TestingAgent(name=f'{request.function.__name__}_test_agent') parent_ctx = await _create_parent_invocation_context( - request.function.__name__, agent, branch='parent_branch' + request.function.__name__, agent, branch=Branch() ) events = [e async for e in agent.run_live(parent_ctx)] @@ -721,7 +724,7 @@ async def test_run_live_with_branch(request: pytest.FixtureRequest): assert len(events) == 1 assert events[0].author == agent.name assert events[0].content.parts[0].text == 'Hello, live!' - assert events[0].branch == 'parent_branch' + assert events[0].branch == parent_ctx.branch @pytest.mark.asyncio @@ -1034,7 +1037,7 @@ async def test_create_agent_state_event(): session_service=session_service, ) - ctx.branch = 'test_branch' + ctx.branch = Branch() # Test case 1: set agent state in context state = _TestAgentState(test_field='checkpoint') @@ -1043,7 +1046,7 @@ async def test_create_agent_state_event(): assert event is not None assert event.invocation_id == ctx.invocation_id assert event.author == agent.name - assert event.branch == 'test_branch' + assert event.branch == ctx.branch assert event.actions is not None assert event.actions.agent_state is not None assert event.actions.agent_state == state.model_dump(mode='json') @@ -1055,7 +1058,7 @@ async def test_create_agent_state_event(): assert event is not None assert event.invocation_id == ctx.invocation_id assert event.author == agent.name - assert event.branch == 'test_branch' + assert event.branch == ctx.branch assert event.actions is not None assert event.actions.end_of_agent assert event.actions.agent_state is None diff --git a/tests/unittests/agents/test_branch_context.py b/tests/unittests/agents/test_branch_context.py new file mode 100644 index 0000000000..6267fb5913 --- /dev/null +++ b/tests/unittests/agents/test_branch_context.py @@ -0,0 +1,484 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Branch token-set based branch tracking.""" + +from __future__ import annotations + +from google.adk.agents.branch import Branch +from google.adk.agents.branch import BranchTokenFactory +import pytest + + +class TestTokenFactory: + """Tests for the TokenFactory class.""" + + def test_new_token_increments(self): + """Test that new_token generates unique incrementing tokens.""" + # Reset the factory + BranchTokenFactory.reset() + + token1 = BranchTokenFactory.new_token() + token2 = BranchTokenFactory.new_token() + token3 = BranchTokenFactory.new_token() + + assert token1 < token2 < token3 + assert token2 == token1 + 1 + assert token3 == token2 + 1 + + def test_new_token_thread_safe(self): + """Test that token generation is thread-safe.""" + import threading + + # Reset the factory + BranchTokenFactory.reset() + tokens = [] + + def generate_tokens(): + for _ in range(100): + tokens.append(BranchTokenFactory.new_token()) + + threads = [threading.Thread(target=generate_tokens) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + # All tokens should be unique + assert len(tokens) == len(set(tokens)) + # Should have 1000 total tokens + assert len(tokens) == 1000 + + +class TestBranchContext: + """Tests for the Branch class.""" + + def test_initialization_default(self): + """Test that default initialization creates root context.""" + ctx = Branch() + assert ctx.tokens == frozenset() + + def test_initialization_with_tokens(self): + """Test initialization with specific tokens.""" + ctx = Branch(tokens=frozenset({1, 2, 3})) + assert ctx.tokens == frozenset({1, 2, 3}) + + def test_fork_creates_children(self): + """Test that fork creates child contexts.""" + BranchTokenFactory.reset() + parent = Branch() + child1 = parent.fork() + child2 = parent.fork() + child3 = parent.fork() + + assert isinstance(child1, Branch) + assert isinstance(child2, Branch) + assert isinstance(child3, Branch) + + def test_fork_children_have_unique_tokens(self): + """Test that each forked child has a unique token.""" + BranchTokenFactory.reset() + parent = Branch(tokens=frozenset({0})) + child1 = parent.fork() + child2 = parent.fork() + child3 = parent.fork() + + # Each child should have parent tokens plus one new unique token + assert len(child1.tokens) == 2 + assert len(child2.tokens) == 2 + assert len(child3.tokens) == 2 + + # Extract the new tokens (the ones not in parent) + new_token1 = list(child1.tokens - parent.tokens)[0] + new_token2 = list(child2.tokens - parent.tokens)[0] + new_token3 = list(child3.tokens - parent.tokens)[0] + + # All new tokens should be unique + assert len({new_token1, new_token2, new_token3}) == 3 + + def test_fork_children_inherit_parent_tokens(self): + """Test that forked children inherit all parent tokens.""" + BranchTokenFactory.reset() + parent = Branch(tokens=frozenset({10, 20, 30})) + child1 = parent.fork() + child2 = parent.fork() + + assert parent.tokens.issubset(child1.tokens) + assert parent.tokens.issubset(child2.tokens) + + def test_join_unions_all_tokens(self): + """Test that join creates union of all token sets.""" + BranchTokenFactory.reset() + parent = Branch(tokens=frozenset({0})) + child1 = Branch(tokens=frozenset({0, 1})) + child2 = Branch(tokens=frozenset({0, 2})) + child3 = Branch(tokens=frozenset({0, 3})) + + joined = parent.join([child1, child2, child3]) + + assert joined.tokens == frozenset({0, 1, 2, 3}) + + def test_can_see_subset_relationship(self): + """Test that can_see implements correct subset logic.""" + parent = Branch(tokens=frozenset({1, 2, 3, 4})) + event1 = Branch(tokens=frozenset({1, 2})) + event2 = Branch(tokens=frozenset({1, 2, 3})) + event3 = Branch(tokens=frozenset({1, 2, 3, 4, 5})) + + # Parent can see events whose tokens are subsets + assert parent.can_see(event1) # {1,2} ⊆ {1,2,3,4} + assert parent.can_see(event2) # {1,2,3} ⊆ {1,2,3,4} + + # Parent cannot see events with tokens it doesn't have + assert not parent.can_see(event3) # {1,2,3,4,5} ⊄ {1,2,3,4} + + def test_can_see_empty_context(self): + """Test visibility with empty (root) contexts.""" + root = Branch() + child = Branch(tokens=frozenset({1})) + + # Root can see itself + assert root.can_see(root) + + # Child can see root (empty set is subset of any set) + assert child.can_see(root) + + # Root cannot see child + assert not root.can_see(child) + + def test_equality(self): + """Test equality based on token sets.""" + ctx1 = Branch(tokens=frozenset({1, 2, 3})) + ctx2 = Branch(tokens=frozenset({1, 2, 3})) + ctx3 = Branch(tokens=frozenset({1, 2})) + + assert ctx1 == ctx2 + assert ctx1 != ctx3 + assert ctx2 != ctx3 + + def test_hashable(self): + """Test that Branch can be used in sets and dicts.""" + ctx1 = Branch(tokens=frozenset({1, 2})) + ctx2 = Branch(tokens=frozenset({1, 2})) + ctx3 = Branch(tokens=frozenset({3, 4})) + + # Should be able to add to set + context_set = {ctx1, ctx2, ctx3} + # ctx1 and ctx2 are equal, so set should have 2 elements + assert len(context_set) == 2 + + # Should be able to use as dict key + context_dict = {ctx1: "first", ctx3: "second"} + assert context_dict[ctx2] == "first" # ctx2 == ctx1 + + def test_str_representation(self): + """Test string representation.""" + root = Branch() + assert str(root) == "Branch(root)" + + ctx = Branch(tokens=frozenset({3, 1, 2})) + # Should show sorted tokens + assert str(ctx) == "Branch([1, 2, 3])" + + def test_parallel_to_sequential_scenario(self): + """Test the actual bug scenario: parallel → sequential → parallel.""" + BranchTokenFactory.reset() + + # Root context + root = Branch() + + # First parallel agent forks to 2 children + agent1_ctx = root.fork() # tokens={1} + agent2_ctx = root.fork() # tokens={2} + + # After parallel execution, join the branches + after_parallel1 = root.join([agent1_ctx, agent2_ctx]) # tokens={1,2} + + # Sequential agent passes context through (second parallel agent) + agent3_ctx = after_parallel1.fork() # tokens={1,2,3} + agent4_ctx = after_parallel1.fork() # tokens={1,2,4} + + # THE BUG FIX: agent3 should be able to see agent1's events + assert agent3_ctx.can_see(agent1_ctx) # {1} ⊆ {1,2,3} ✓ + + # agent3 should also see agent2's events + assert agent3_ctx.can_see(agent2_ctx) # {2} ⊆ {1,2,3} ✓ + + # agent4 should see both agent1 and agent2 + assert agent4_ctx.can_see(agent1_ctx) # {1} ⊆ {1,2,4} ✓ + assert agent4_ctx.can_see(agent2_ctx) # {2} ⊆ {1,2,4} ✓ + + # But siblings shouldn't see each other during parallel execution + assert not agent1_ctx.can_see(agent2_ctx) # {2} ⊄ {1} ✗ + assert not agent2_ctx.can_see(agent1_ctx) # {1} ⊄ {2} ✗ + assert not agent3_ctx.can_see(agent4_ctx) # {1,2,4} ⊄ {1,2,3} ✗ + assert not agent4_ctx.can_see(agent3_ctx) # {1,2,3} ⊄ {1,2,4} ✗ + + def test_pydantic_serialization(self): + """Test that Branch can be serialized by Pydantic.""" + ctx = Branch(tokens=frozenset({1, 2, 3})) + + # Test model_dump (Pydantic serialization) + dumped = ctx.model_dump() + assert "tokens" in dumped + # Frozenset gets converted to some iterable + assert set(dumped["tokens"]) == {1, 2, 3} + + # Test round-trip + restored = Branch(**dumped) + assert restored.tokens == ctx.tokens + + def test_immutability(self): + """Test that Branch is immutable (frozen).""" + ctx = Branch(tokens=frozenset({1, 2, 3})) + + # Should not be able to modify tokens + with pytest.raises( + Exception + ): # Pydantic raises ValidationError or AttributeError + ctx.tokens = frozenset({4, 5, 6}) + + +class TestGitHubIssue3470Scenarios: + """Tests for the exact scenarios described in GitHub issue #3470. + + Issue: https://github.com/google/adk-python/issues/3470 + Two problematic architectures: + 1. Reducer architecture: Sequential[Parallel[A,B,C], Reducer] + 2. Sequence of parallels: Sequential[Parallel1[A,B,C], Parallel2[D,E,F]] + """ + + def test_reducer_architecture_single(self): + """Test reducer architecture: Sequential[Parallel[A,B,C], Reducer]. + + The reducer R1 should be able to see outputs from A, B, and C. + This is the basic reducer pattern that should work. + """ + BranchTokenFactory.reset() + + # Root context + root = Branch() + + # Sequential agent S1 has sub-agents: [Parallel1, Reducer1] + # Parallel1 forks into A, B, C + agent_a_ctx = root.fork() # tokens={1} + agent_b_ctx = root.fork() # tokens={2} + agent_c_ctx = root.fork() # tokens={3} + + # After parallel execution, join the branches for sequential continuation + after_parallel1 = root.join( + [agent_a_ctx, agent_b_ctx, agent_c_ctx] + ) # tokens={1,2,3} + + # Reducer1 runs in sequential after parallel, uses joined context + reducer1_ctx = after_parallel1 + + # CRITICAL: Reducer1 should see all outputs from A, B, C + assert reducer1_ctx.can_see(agent_a_ctx) # {1} ⊆ {1,2,3} ✓ + assert reducer1_ctx.can_see(agent_b_ctx) # {2} ⊆ {1,2,3} ✓ + assert reducer1_ctx.can_see(agent_c_ctx) # {3} ⊆ {1,2,3} ✓ + + def test_nested_reducer_architecture(self): + """Test nested reducer architecture from issue #3470. + + Architecture: + Sequential[ + Parallel[ + Sequential[Parallel[A,B,C], R1], + Sequential[Parallel[D,E,F], R2] + ], + R3 + ] + + This is the failing case where: + - R1 should see A, B, C + - R2 should see D, E, F + - R3 should see R1, R2 (and transitively A-F) + """ + BranchTokenFactory.reset() + + root = Branch() + + # Top-level parallel splits into two sequential branches + seq1_ctx = root.fork() # Group1: tokens={1} + seq2_ctx = root.fork() # Group2: tokens={2} + + # === GROUP 1: Sequential[Parallel[A,B,C], R1] === + # Parallel1 (ABC) forks from seq1_ctx + agent_a_ctx = seq1_ctx.fork() # tokens={1,3} + agent_b_ctx = seq1_ctx.fork() # tokens={1,4} + agent_c_ctx = seq1_ctx.fork() # tokens={1,5} + + # After parallel1, join for R1 + after_parallel1 = seq1_ctx.join( + [agent_a_ctx, agent_b_ctx, agent_c_ctx] + ) # tokens={1,3,4,5} + reducer1_ctx = after_parallel1 + + # R1 should see A, B, C + assert reducer1_ctx.can_see(agent_a_ctx) # {1,3} ⊆ {1,3,4,5} ✓ + assert reducer1_ctx.can_see(agent_b_ctx) # {1,4} ⊆ {1,3,4,5} ✓ + assert reducer1_ctx.can_see(agent_c_ctx) # {1,5} ⊆ {1,3,4,5} ✓ + + # === GROUP 2: Sequential[Parallel[D,E,F], R2] === + # Parallel2 (DEF) forks from seq2_ctx + agent_d_ctx = seq2_ctx.fork() # tokens={2,6} + agent_e_ctx = seq2_ctx.fork() # tokens={2,7} + agent_f_ctx = seq2_ctx.fork() # tokens={2,8} + + # After parallel2, join for R2 + after_parallel2 = seq2_ctx.join( + [agent_d_ctx, agent_e_ctx, agent_f_ctx] + ) # tokens={2,6,7,8} + reducer2_ctx = after_parallel2 + + # R2 should see D, E, F + assert reducer2_ctx.can_see(agent_d_ctx) # {2,6} ⊆ {2,6,7,8} ✓ + assert reducer2_ctx.can_see(agent_e_ctx) # {2,7} ⊆ {2,6,7,8} ✓ + assert reducer2_ctx.can_see(agent_f_ctx) # {2,8} ⊆ {2,6,7,8} ✓ + + # === FINAL: Join both groups and run R3 === + # After top-level parallel completes, join for final reducer + final_joined = root.join( + [after_parallel1, after_parallel2] + ) # tokens={1,2,3,4,5,6,7,8} + reducer3_ctx = final_joined + + # R3 should see R1 and R2's contexts + assert reducer3_ctx.can_see(reducer1_ctx) # {1,3,4,5} ⊆ {1,2,3,4,5,6,7,8} ✓ + assert reducer3_ctx.can_see(reducer2_ctx) # {2,6,7,8} ⊆ {1,2,3,4,5,6,7,8} ✓ + + # R3 should also see all original agents transitively + assert reducer3_ctx.can_see(agent_a_ctx) # {1,3} ⊆ {1,2,3,4,5,6,7,8} ✓ + assert reducer3_ctx.can_see(agent_b_ctx) # {1,4} ⊆ {1,2,3,4,5,6,7,8} ✓ + assert reducer3_ctx.can_see(agent_c_ctx) # {1,5} ⊆ {1,2,3,4,5,6,7,8} ✓ + assert reducer3_ctx.can_see(agent_d_ctx) # {2,6} ⊆ {1,2,3,4,5,6,7,8} ✓ + assert reducer3_ctx.can_see(agent_e_ctx) # {2,7} ⊆ {1,2,3,4,5,6,7,8} ✓ + assert reducer3_ctx.can_see(agent_f_ctx) # {2,8} ⊆ {1,2,3,4,5,6,7,8} ✓ + + # But groups shouldn't see each other during parallel execution + assert not agent_a_ctx.can_see(agent_d_ctx) # {2,6} ⊄ {1,3} ✗ + assert not reducer1_ctx.can_see(reducer2_ctx) # {2,6,7,8} ⊄ {1,3,4,5} ✗ + + def test_sequence_of_parallels(self): + """Test sequence of parallels from issue #3470. + + Architecture: + Sequential[ + Parallel1[A, B, C], + Parallel2[D, E, F], + Parallel3[G, H, I] + ] + + The bug: With string-based branches: + - A, B, C have branches: parallel1.A, parallel1.B, parallel1.C + - D, E, F have branches: parallel2.D, parallel2.E, parallel2.F + - G, H, I have branches: parallel3.G, parallel3.H, parallel3.I + + These are NOT prefixes of each other, so D/E/F can't see A/B/C, + and G/H/I can't see anyone before them. + + With token-sets: Each subsequent parallel group inherits tokens from + previous groups via join, enabling proper visibility. + """ + BranchTokenFactory.reset() + + root = Branch() + + # === PARALLEL GROUP 1: A, B, C === + agent_a_ctx = root.fork() # tokens={1} + agent_b_ctx = root.fork() # tokens={2} + agent_c_ctx = root.fork() # tokens={3} + + # After parallel1, join for sequential continuation + after_parallel1 = root.join( + [agent_a_ctx, agent_b_ctx, agent_c_ctx] + ) # tokens={1,2,3} + + # === PARALLEL GROUP 2: D, E, F === + # Fork from joined context, so inherits all previous tokens + agent_d_ctx = after_parallel1.fork() # tokens={1,2,3,4} + agent_e_ctx = after_parallel1.fork() # tokens={1,2,3,5} + agent_f_ctx = after_parallel1.fork() # tokens={1,2,3,6} + + # CRITICAL: D, E, F should see A, B, C's outputs + assert agent_d_ctx.can_see(agent_a_ctx) # {1} ⊆ {1,2,3,4} ✓ + assert agent_d_ctx.can_see(agent_b_ctx) # {2} ⊆ {1,2,3,4} ✓ + assert agent_d_ctx.can_see(agent_c_ctx) # {3} ⊆ {1,2,3,4} ✓ + + assert agent_e_ctx.can_see(agent_a_ctx) # {1} ⊆ {1,2,3,5} ✓ + assert agent_f_ctx.can_see(agent_a_ctx) # {1} ⊆ {1,2,3,6} ✓ + + # But parallel2 siblings can't see each other + assert not agent_d_ctx.can_see(agent_e_ctx) # {1,2,3,5} ⊄ {1,2,3,4} ✗ + assert not agent_d_ctx.can_see(agent_f_ctx) # {1,2,3,6} ⊄ {1,2,3,4} ✗ + + # After parallel2, join for sequential continuation + after_parallel2 = after_parallel1.join( + [agent_d_ctx, agent_e_ctx, agent_f_ctx] + ) # tokens={1,2,3,4,5,6} + + # === PARALLEL GROUP 3: G, H, I === + agent_g_ctx = after_parallel2.fork() # tokens={1,2,3,4,5,6,7} + agent_h_ctx = after_parallel2.fork() # tokens={1,2,3,4,5,6,8} + agent_i_ctx = after_parallel2.fork() # tokens={1,2,3,4,5,6,9} + + # CRITICAL: G, H, I should see ALL previous agents' outputs + # Can see group 1 + assert agent_g_ctx.can_see(agent_a_ctx) # {1} ⊆ {1,2,3,4,5,6,7} ✓ + assert agent_g_ctx.can_see(agent_b_ctx) # {2} ⊆ {1,2,3,4,5,6,7} ✓ + assert agent_g_ctx.can_see(agent_c_ctx) # {3} ⊆ {1,2,3,4,5,6,7} ✓ + + # Can see group 2 + assert agent_g_ctx.can_see(agent_d_ctx) # {1,2,3,4} ⊆ {1,2,3,4,5,6,7} ✓ + assert agent_g_ctx.can_see(agent_e_ctx) # {1,2,3,5} ⊆ {1,2,3,4,5,6,7} ✓ + assert agent_g_ctx.can_see(agent_f_ctx) # {1,2,3,6} ⊆ {1,2,3,4,5,6,7} ✓ + + # Same for H and I + assert agent_h_ctx.can_see(agent_a_ctx) + assert agent_h_ctx.can_see(agent_d_ctx) + assert agent_i_ctx.can_see(agent_a_ctx) + assert agent_i_ctx.can_see(agent_d_ctx) + + # But parallel3 siblings can't see each other + assert not agent_g_ctx.can_see( + agent_h_ctx + ) # {1,2,3,4,5,6,8} ⊄ {1,2,3,4,5,6,7} ✗ + assert not agent_g_ctx.can_see( + agent_i_ctx + ) # {1,2,3,4,5,6,9} ⊄ {1,2,3,4,5,6,7} ✗ + + def test_string_based_approach_fails(self): + """Demonstrate why string-based prefix matching fails for sequence of parallels. + + This test documents the OLD broken behavior to show why token-sets are necessary. + """ + # With string-based branches (OLD APPROACH - BROKEN): + # Parallel1: "parallel1.A", "parallel1.B", "parallel1.C" + # Parallel2: "parallel2.D", "parallel2.E", "parallel2.F" + + # Check if "parallel2.D" starts with "parallel1.A" + assert not "parallel2.D".startswith("parallel1.A") # FALSE - Can't see! + + # Check if "parallel1.A" starts with "parallel2.D" + assert not "parallel1.A".startswith("parallel2.D") # FALSE - Can't see! + + # Neither direction works with prefix matching for sibling parallel groups! + # This is why the bug exists in the original implementation. + + # With token-sets: + # After parallel1, context has tokens {1,2,3} + # Parallel2 forks from {1,2,3}, so D gets {1,2,3,4} + # Agent A has tokens {1} + # Check: {1} ⊆ {1,2,3,4} = TRUE ✓ diff --git a/tests/unittests/agents/test_invocation_context.py b/tests/unittests/agents/test_invocation_context.py index 620453e817..8b3ebde222 100644 --- a/tests/unittests/agents/test_invocation_context.py +++ b/tests/unittests/agents/test_invocation_context.py @@ -16,6 +16,7 @@ from google.adk.agents.base_agent import BaseAgent from google.adk.agents.base_agent import BaseAgentState +from google.adk.agents.branch import Branch from google.adk.agents.invocation_context import InvocationContext from google.adk.apps import ResumabilityConfig from google.adk.events.event import Event @@ -36,32 +37,38 @@ class TestInvocationContext: @pytest.fixture def mock_events(self): """Create mock events for testing.""" + # Create a parent branch and fork it to create two children + parent_branch = Branch() + agent1_branch = parent_branch.fork() + agent2_branch = parent_branch.fork() + event1 = Mock(spec=Event) event1.invocation_id = 'inv_1' - event1.branch = 'agent_1' + event1.branch = agent1_branch event2 = Mock(spec=Event) event2.invocation_id = 'inv_1' - event2.branch = 'agent_2' + event2.branch = agent2_branch event3 = Mock(spec=Event) event3.invocation_id = 'inv_2' - event3.branch = 'agent_1' + event3.branch = agent1_branch # Same as event1 event4 = Mock(spec=Event) event4.invocation_id = 'inv_2' - event4.branch = 'agent_2' + event4.branch = agent2_branch # Same as event2 return [event1, event2, event3, event4] @pytest.fixture def mock_invocation_context(self, mock_events): """Create a mock invocation context for testing.""" + # Use agent1_branch so it can see event1 and event3 but not event2 and event4 ctx = InvocationContext( session_service=Mock(spec=BaseSessionService), agent=Mock(spec=BaseAgent), invocation_id='inv_1', - branch='agent_1', + branch=mock_events[0].branch, # Use agent1_branch session=Mock(spec=Session, events=mock_events), ) return ctx @@ -109,7 +116,7 @@ def test_get_events_with_no_events_in_session(self, mock_invocation_context): def test_get_events_with_no_matching_events(self, mock_invocation_context): """Tests get_events when no events match the filters.""" mock_invocation_context.invocation_id = 'inv_3' - mock_invocation_context.branch = 'branch_C' + mock_invocation_context.branch = Branch() # Different branch from events # Filter by invocation events = mock_invocation_context._get_events(current_invocation=True) diff --git a/tests/unittests/agents/test_langgraph_agent.py b/tests/unittests/agents/test_langgraph_agent.py index 026f3130c0..f990d98266 100644 --- a/tests/unittests/agents/test_langgraph_agent.py +++ b/tests/unittests/agents/test_langgraph_agent.py @@ -19,6 +19,7 @@ # Skip all tests in this module if LangGraph dependencies are not available LANGGRAPH_AVAILABLE = True try: + from google.adk.agents.branch import Branch from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.langgraph_agent import LangGraphAgent from google.adk.events.event import Event @@ -76,6 +77,7 @@ def __getattr__(self, name): def __call__(self, *args, **kwargs): return DummyTypes() + Branch = DummyTypes() InvocationContext = DummyTypes() LangGraphAgent = DummyTypes() Event = DummyTypes() @@ -232,7 +234,7 @@ async def test_langgraph_agent( mock_parent_context = MagicMock(spec=InvocationContext) mock_session = MagicMock() mock_parent_context.session = mock_session - mock_parent_context.branch = "parent_agent" + mock_parent_context.branch = Branch() mock_parent_context.end_invocation = False mock_session.events = events_list mock_parent_context.invocation_id = "test_invocation_id" diff --git a/tests/unittests/agents/test_nested_agent_branch_visibility.py b/tests/unittests/agents/test_nested_agent_branch_visibility.py new file mode 100644 index 0000000000..b11bc5462e --- /dev/null +++ b/tests/unittests/agents/test_nested_agent_branch_visibility.py @@ -0,0 +1,658 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for branch visibility in nested agent architectures. + +Tests that agents in complex multi-agent orchestrations can correctly see +events from previous agents using token-based branch tracking. + +Two key architectures tested: + +1. Nested Parallel + Reduce: + Sequential[Parallel[A,B,C], Reducer1] in parallel with + Sequential[Parallel[D,E,F], Reducer2], followed by Reducer3 + + Tests that reducers can see outputs from their parallel groups, and + that a final reducer can see all nested outputs. + +2. Sequence of Parallels: + Sequential[Parallel1[A,B,C], Parallel2[D,E,F], Parallel3[G,H,I]] + + Tests that each subsequent parallel group can see outputs from all + previous parallel groups. + +Note: These tests validate the fix for GitHub issue #3470, where string-based +branch prefixes failed to provide proper visibility across parallel groups. +""" + +from __future__ import annotations + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.parallel_agent import ParallelAgent +from google.adk.agents.sequential_agent import SequentialAgent +import pytest + +from tests.unittests import testing_utils + + +def test_nested_parallel_reduce_architecture(): + """Test the nested parallel + reduce architecture from GitHub issue #3470. + + Architecture: + Sequential1 = Parallel[A, B, C] -> Reducer1 + Sequential2 = Parallel[D, E, F] -> Reducer2 + Final = Parallel[Sequential1, Sequential2] -> Reducer3 + + The bug was that: + - Reducer1 couldn't see outputs from A, B, C + - Reducer2 couldn't see outputs from D, E, F + - Reducer3 couldn't see outputs from Reducer1 and Reducer2 + + With Branch fix: + - A, B, C get tokens {1}, {2}, {3} + - Parallel1 joins to {1,2,3} + - Reducer1 gets {1,2,3} and can see all events from {1}, {2}, {3} + - Same for D, E, F in Sequential2 + - Final reducer can see all previous events + """ + print("\n" + "=" * 70) + print("INTEGRATION TEST: Nested Parallel + Reduce (GitHub Issue #3470)") + print("=" * 70) + print("\nArchitecture:") + print(" Sequential[") + print(" Parallel[") + print(" Sequential[Parallel[Alice,Bob,Charlie], Reducer1], ← Group 1") + print(" Sequential[Parallel[David,Eve,Frank], Reducer2] ← Group 2") + print(" ],") + print(" Final_Reducer ← Sees all outputs") + print(" ]") + print() + + # Group 1 agents + agent_a = LlmAgent( + name="Alice", + description="Agent A", + instruction="Say: I am Alice", + model=testing_utils.MockModel.create(responses=["I am Alice"]), + ) + agent_b = LlmAgent( + name="Bob", + description="Agent B", + instruction="Say: I am Bob", + model=testing_utils.MockModel.create(responses=["I am Bob"]), + ) + agent_c = LlmAgent( + name="Charlie", + description="Agent C", + instruction="Say: I am Charlie", + model=testing_utils.MockModel.create(responses=["I am Charlie"]), + ) + + # Group 2 agents + agent_d = LlmAgent( + name="David", + description="Agent D", + instruction="Say: I am David", + model=testing_utils.MockModel.create(responses=["I am David"]), + ) + agent_e = LlmAgent( + name="Eve", + description="Agent E", + instruction="Say: I am Eve", + model=testing_utils.MockModel.create(responses=["I am Eve"]), + ) + agent_f = LlmAgent( + name="Frank", + description="Agent F", + instruction="Say: I am Frank", + model=testing_utils.MockModel.create(responses=["I am Frank"]), + ) + + # Parallel groups + parallel_abc = ParallelAgent( + name="ABC_Parallel", + description="Parallel group ABC", + sub_agents=[agent_a, agent_b, agent_c], + ) + + parallel_def = ParallelAgent( + name="DEF_Parallel", + description="Parallel group DEF", + sub_agents=[agent_d, agent_e, agent_f], + ) + + # Reducers with models that track requests + reducer1_model = testing_utils.MockModel.create(responses=["Summary of ABC"]) + reducer1 = LlmAgent( + name="Reducer1", + description="Reducer for ABC", + instruction="Summarize responses from A, B, and C", + model=reducer1_model, + ) + + reducer2_model = testing_utils.MockModel.create(responses=["Summary of DEF"]) + reducer2 = LlmAgent( + name="Reducer2", + description="Reducer for DEF", + instruction="Summarize responses from D, E, and F", + model=reducer2_model, + ) + + # Sequential groups (Parallel -> Reducer) + sequential1 = SequentialAgent( + name="Group1_Sequential", + description="Sequential ABC -> Reducer1", + sub_agents=[parallel_abc, reducer1], + ) + + sequential2 = SequentialAgent( + name="Group2_Sequential", + description="Sequential DEF -> Reducer2", + sub_agents=[parallel_def, reducer2], + ) + + # Run both sequential groups in parallel + final_parallel = ParallelAgent( + name="Final_Parallel", + description="Run both groups in parallel", + sub_agents=[sequential1, sequential2], + ) + + # Final reducer with model that tracks requests + final_reducer_model = testing_utils.MockModel.create( + responses=["Final summary"] + ) + final_reducer = LlmAgent( + name="Final_Reducer", + description="Final reducer", + instruction="Summarize all outputs", + model=final_reducer_model, + ) + + # Top-level sequential + root_agent = SequentialAgent( + name="Root_Sequential", + description="Root sequential agent", + sub_agents=[final_parallel, final_reducer], + ) + + # Run the agent + runner = testing_utils.InMemoryRunner(root_agent=root_agent) + runner.run("Start") + session = runner.session + + # Debug: print all events and their branches + print("\n=== Token Distribution (Nested Parallel) ===") + print(f" {'Agent':<15} {'Tokens':<30}") + print(f" {'-'*15} {'-'*30}") + for event in session.events: + if event.author and event.branch: + tokens_sorted = sorted(event.branch.tokens) + print(f" {event.author:15} | tokens={tokens_sorted}") + print("=" * 70 + "\n") + + # Verify all agents ran + agent_names = {event.author for event in session.events if event.author} + expected_agents = { + "Alice", + "Bob", + "Charlie", + "David", + "Eve", + "Frank", + "Reducer1", + "Reducer2", + "Final_Reducer", + } + assert expected_agents.issubset( + agent_names + ), f"Missing agents: {expected_agents - agent_names}" + + # Verify event visibility using branch tokens + # Get reducer events + reducer1_events = [e for e in session.events if e.author == "Reducer1"] + reducer2_events = [e for e in session.events if e.author == "Reducer2"] + final_reducer_events = [ + e for e in session.events if e.author == "Final_Reducer" + ] + + assert len(reducer1_events) > 0, "Reducer1 should have events" + assert len(reducer2_events) > 0, "Reducer2 should have events" + assert len(final_reducer_events) > 0, "Final_Reducer should have events" + + # Check that reducers can see their parallel group outputs + # Reducer1 should see A, B, C + abc_events = [ + e + for e in session.events + if e.author in ["Alice", "Bob", "Charlie"] and e.branch + ] + for abc_event in abc_events: + for reducer1_event in reducer1_events: + if reducer1_event.branch: + # Reducer1's tokens should be a superset of ABC tokens + assert reducer1_event.branch.can_see(abc_event.branch), ( + f"Reducer1 (tokens={reducer1_event.branch.tokens}) should see" + f" {abc_event.author} (tokens={abc_event.branch.tokens})" + ) + + # Reducer2 should see D, E, F + def_events = [ + e + for e in session.events + if e.author in ["David", "Eve", "Frank"] and e.branch + ] + for def_event in def_events: + for reducer2_event in reducer2_events: + if reducer2_event.branch: + # Reducer2's tokens should be a superset of DEF tokens + assert reducer2_event.branch.can_see(def_event.branch), ( + f"Reducer2 (tokens={reducer2_event.branch.tokens}) should see" + f" {def_event.author} (tokens={def_event.branch.tokens})" + ) + + # Final reducer should see all reducers + all_reducer_events = reducer1_events + reducer2_events + for reducer_event in all_reducer_events: + if reducer_event.branch: + for final_event in final_reducer_events: + if final_event.branch: + assert final_event.branch.can_see(reducer_event.branch), ( + f"Final_Reducer (tokens={final_event.branch.tokens}) should see" + f" {reducer_event.author} (tokens={reducer_event.branch.tokens})" + ) + + # Verify LLM request contents - the actual text sent to the model + # This is the critical test: does the reducer actually receive the parallel agents' outputs? + + # Helper to extract text from simplified contents + def extract_text(contents): + """Extract all text from simplified contents.""" + texts = [] + for role, content in contents: + if isinstance(content, str): + texts.append(content) + elif isinstance(content, list): + for part in content: + if hasattr(part, "text") and part.text: + texts.append(part.text) + elif hasattr(content, "text") and content.text: + texts.append(content.text) + return " ".join(texts) + + # Reducer1 should receive outputs from A, B, C in its LLM request + assert ( + len(reducer1_model.requests) > 0 + ), "Reducer1 should have made LLM requests" + reducer1_contents = testing_utils.simplify_contents( + reducer1_model.requests[0].contents + ) + reducer1_text = extract_text(reducer1_contents) + + # Check that A, B, C outputs are in the context + assert "Alice" in reducer1_text or "I am Alice" in reducer1_text, ( + "Reducer1 should see Alice's output in LLM request. Got:" + f" {reducer1_text[:200]}" + ) + assert "Bob" in reducer1_text or "I am Bob" in reducer1_text, ( + "Reducer1 should see Bob's output in LLM request. Got:" + f" {reducer1_text[:200]}" + ) + assert "Charlie" in reducer1_text or "I am Charlie" in reducer1_text, ( + "Reducer1 should see Charlie's output in LLM request. Got:" + f" {reducer1_text[:200]}" + ) + + # Reducer2 should receive outputs from D, E, F in its LLM request + assert ( + len(reducer2_model.requests) > 0 + ), "Reducer2 should have made LLM requests" + reducer2_contents = testing_utils.simplify_contents( + reducer2_model.requests[0].contents + ) + reducer2_text = extract_text(reducer2_contents) + + assert "David" in reducer2_text or "I am David" in reducer2_text, ( + "Reducer2 should see David's output in LLM request. Got:" + f" {reducer2_text[:200]}" + ) + assert "Eve" in reducer2_text or "I am Eve" in reducer2_text, ( + "Reducer2 should see Eve's output in LLM request. Got:" + f" {reducer2_text[:200]}" + ) + assert "Frank" in reducer2_text or "I am Frank" in reducer2_text, ( + "Reducer2 should see Frank's output in LLM request. Got:" + f" {reducer2_text[:200]}" + ) + + # Final reducer should receive outputs from both reducers AND nested agents + assert ( + len(final_reducer_model.requests) > 0 + ), "Final_Reducer should have made LLM requests" + final_contents = testing_utils.simplify_contents( + final_reducer_model.requests[0].contents + ) + final_text = extract_text(final_contents) + + # Should see the reducer summaries + assert "Summary of ABC" in final_text, ( + "Final_Reducer should see Reducer1's summary in LLM request. Got:" + f" {final_text[:200]}" + ) + assert "Summary of DEF" in final_text, ( + "Final_Reducer should see Reducer2's summary in LLM request. Got:" + f" {final_text[:200]}" + ) + + # Should also see the original agent outputs (nested visibility) + assert "Alice" in final_text or "I am Alice" in final_text, ( + "Final_Reducer should see Alice's output in LLM request. Got:" + f" {final_text[:200]}" + ) + assert "David" in final_text or "I am David" in final_text, ( + "Final_Reducer should see David's output in LLM request. Got:" + f" {final_text[:200]}" + ) + + +def test_sequence_of_parallel_agents(): + """Test sequence of parallel agents from GitHub issue #3470. + + Architecture: + Sequential[Parallel1[A,B,C], Parallel2[D,E,F], Parallel3[G,H,I]] + + The bug was that agents in Parallel2 and Parallel3 couldn't see outputs + from previous parallel groups. + + With Branch fix: + - Parallel1: A={1}, B={2}, C={3}, joins to {1,2,3} + - Parallel2 forks from {1,2,3}: D={1,2,3,4}, E={1,2,3,5}, F={1,2,3,6} + - D, E, F can all see A, B, C because {1}⊆{1,2,3,4} + - Parallel3 forks from joined tokens and can see all previous events + """ + print("\n" + "=" * 70) + print("INTEGRATION TEST: Sequence of Parallels (GitHub Issue #3470)") + print("=" * 70) + print("\nArchitecture:") + print(" Sequential[") + print(" Parallel1[Alice, Bob, Charlie], ← Group 1") + print(" Parallel2[David, Eve, Frank], ← Group 2 (sees Group 1)") + print(" Parallel3[Grace, Henry, Iris] ← Group 3 (sees Groups 1 & 2)") + print(" ]") + print() + + # Group 1 + agent_a_model = testing_utils.MockModel.create(responses=["I am Alice"]) + agent_a = LlmAgent( + name="Alice", + description="Agent A", + instruction="Say: I am Alice", + model=agent_a_model, + ) + agent_b = LlmAgent( + name="Bob", + description="Agent B", + instruction="Say: I am Bob", + model=testing_utils.MockModel.create(responses=["I am Bob"]), + ) + agent_c = LlmAgent( + name="Charlie", + description="Agent C", + instruction="Say: I am Charlie", + model=testing_utils.MockModel.create(responses=["I am Charlie"]), + ) + + # Group 2 - track David's model to check it sees Group 1 + agent_d_model = testing_utils.MockModel.create(responses=["I am David"]) + agent_d = LlmAgent( + name="David", + description="Agent D", + instruction="Say: I am David", + model=agent_d_model, + ) + agent_e = LlmAgent( + name="Eve", + description="Agent E", + instruction="Say: I am Eve", + model=testing_utils.MockModel.create(responses=["I am Eve"]), + ) + agent_f = LlmAgent( + name="Frank", + description="Agent F", + instruction="Say: I am Frank", + model=testing_utils.MockModel.create(responses=["I am Frank"]), + ) + + # Group 3 - track Grace's model to check it sees Groups 1 and 2 + agent_g_model = testing_utils.MockModel.create(responses=["I am Grace"]) + agent_g = LlmAgent( + name="Grace", + description="Agent G", + instruction="Say: I am Grace", + model=agent_g_model, + ) + agent_h = LlmAgent( + name="Henry", + description="Agent H", + instruction="Say: I am Henry", + model=testing_utils.MockModel.create(responses=["I am Henry"]), + ) + agent_i = LlmAgent( + name="Iris", + description="Agent I", + instruction="Say: I am Iris", + model=testing_utils.MockModel.create(responses=["I am Iris"]), + ) + + # Create parallel groups + parallel1 = ParallelAgent( + name="Parallel1", + description="First parallel group", + sub_agents=[agent_a, agent_b, agent_c], + ) + + parallel2 = ParallelAgent( + name="Parallel2", + description="Second parallel group", + sub_agents=[agent_d, agent_e, agent_f], + ) + + parallel3 = ParallelAgent( + name="Parallel3", + description="Third parallel group", + sub_agents=[agent_g, agent_h, agent_i], + ) + + # Create sequential agent + root_agent = SequentialAgent( + name="Root_Sequential", + description="Sequential of parallels", + sub_agents=[parallel1, parallel2, parallel3], + ) + + # Run the agent + runner = testing_utils.InMemoryRunner(root_agent=root_agent) + runner.run("Start") + session = runner.session + + # Verify all agents ran + agent_names = {event.author for event in session.events if event.author} + expected_agents = { + "Alice", + "Bob", + "Charlie", + "David", + "Eve", + "Frank", + "Grace", + "Henry", + "Iris", + } + assert expected_agents.issubset( + agent_names + ), f"Missing agents: {expected_agents - agent_names}" + + # Get events by agent group + parallel1_events = [ + e + for e in session.events + if e.author in ["Alice", "Bob", "Charlie"] and e.branch + ] + parallel2_events = [ + e + for e in session.events + if e.author in ["David", "Eve", "Frank"] and e.branch + ] + parallel3_events = [ + e + for e in session.events + if e.author in ["Grace", "Henry", "Iris"] and e.branch + ] + + assert len(parallel1_events) > 0, "Parallel1 should have events" + assert len(parallel2_events) > 0, "Parallel2 should have events" + assert len(parallel3_events) > 0, "Parallel3 should have events" + + # Verify visibility: Parallel2 should see Parallel1 + for p1_event in parallel1_events: + for p2_event in parallel2_events: + # Parallel2 tokens should be superset of Parallel1 tokens + assert p2_event.branch.can_see(p1_event.branch), ( + f"{p2_event.author} (tokens={p2_event.branch.tokens}) should see" + f" {p1_event.author} (tokens={p1_event.branch.tokens})" + ) + + # Verify visibility: Parallel3 should see Parallel1 and Parallel2 + for p1_event in parallel1_events: + for p3_event in parallel3_events: + assert p3_event.branch.can_see(p1_event.branch), ( + f"{p3_event.author} (tokens={p3_event.branch.tokens}) should see" + f" {p1_event.author} (tokens={p1_event.branch.tokens})" + ) + + for p2_event in parallel2_events: + for p3_event in parallel3_events: + assert p3_event.branch.can_see(p2_event.branch), ( + f"{p3_event.author} (tokens={p3_event.branch.tokens}) should see" + f" {p2_event.author} (tokens={p2_event.branch.tokens})" + ) + + # Print token sets for verification + print("\n=== Token Distribution ===") + print(f" {'Agent':<15} {'Tokens':<30} {'Can See'}") + print(f" {'-'*15} {'-'*30} {'-'*40}") + + # Organize events by group for clearer display + group1_agents = ["Alice", "Bob", "Charlie"] + group2_agents = ["David", "Eve", "Frank"] + group3_agents = ["Grace", "Henry", "Iris"] + + print(f" {'--- Group 1 ---':<15}") + for event in session.events: + if event.author in group1_agents and event.branch: + tokens_sorted = str(sorted(event.branch.tokens)) + print(f" {event.author:15} | tokens={tokens_sorted:<28} {'Root'}") + + print(f" {'--- Group 2 ---':<15}") + for event in session.events: + if event.author in group2_agents and event.branch: + tokens_sorted = str(sorted(event.branch.tokens)) + print( + f" {event.author:15} |" + f" tokens={tokens_sorted:<28} {'Root, Group 1 (A,B,C)'}" + ) + + print(f" {'--- Group 3 ---':<15}") + for event in session.events: + if event.author in group3_agents and event.branch: + tokens_sorted = str(sorted(event.branch.tokens)) + print( + f" {event.author:15} |" + f" tokens={tokens_sorted:<28} {'Root, Groups 1 & 2 (A-F)'}" + ) + + print("\nKey Observations:") + print(" ✓ Group 2 agents have tokens {1,2,3,...} - inherit from Group 1") + print( + " ✓ Group 3 agents have tokens {1,2,3,4,5,6,...} - inherit from Groups" + " 1 & 2" + ) + print(" ✓ Each agent can see all events with token subsets") + print("=" * 70 + "\n") + + # Verify LLM request contents - the actual text sent to the models + # This is the critical test from the GitHub issue: does each parallel group + # actually receive the previous groups' outputs in their LLM context? + + # Helper to extract text from simplified contents + def extract_text(contents): + """Extract all text from simplified contents.""" + texts = [] + for role, content in contents: + if isinstance(content, str): + texts.append(content) + elif isinstance(content, list): + for part in content: + if hasattr(part, "text") and part.text: + texts.append(part.text) + elif hasattr(content, "text") and content.text: + texts.append(content.text) + return " ".join(texts) + + # David (in Parallel2) should see Alice, Bob, Charlie from Parallel1 + assert len(agent_d_model.requests) > 0, "David should have made LLM requests" + david_contents = testing_utils.simplify_contents( + agent_d_model.requests[0].contents + ) + david_text = extract_text(david_contents) + + assert "Alice" in david_text or "I am Alice" in david_text, ( + "David should see Alice's output in LLM request (Parallel2 seeing" + f" Parallel1). Got: {david_text[:200]}" + ) + assert "Bob" in david_text or "I am Bob" in david_text, ( + "David should see Bob's output in LLM request (Parallel2 seeing" + f" Parallel1). Got: {david_text[:200]}" + ) + assert "Charlie" in david_text or "I am Charlie" in david_text, ( + "David should see Charlie's output in LLM request (Parallel2 seeing" + f" Parallel1). Got: {david_text[:200]}" + ) + + # Grace (in Parallel3) should see all previous agents + assert len(agent_g_model.requests) > 0, "Grace should have made LLM requests" + grace_contents = testing_utils.simplify_contents( + agent_g_model.requests[0].contents + ) + grace_text = extract_text(grace_contents) + + # Should see Parallel1 agents + assert "Alice" in grace_text or "I am Alice" in grace_text, ( + "Grace should see Alice's output in LLM request (Parallel3 seeing" + f" Parallel1). Got: {grace_text[:200]}" + ) + assert "Bob" in grace_text or "I am Bob" in grace_text, ( + "Grace should see Bob's output in LLM request (Parallel3 seeing" + f" Parallel1). Got: {grace_text[:200]}" + ) + + # Should see Parallel2 agents + assert "David" in grace_text or "I am David" in grace_text, ( + "Grace should see David's output in LLM request (Parallel3 seeing" + f" Parallel2). Got: {grace_text[:200]}" + ) + assert "Eve" in grace_text or "I am Eve" in grace_text, ( + "Grace should see Eve's output in LLM request (Parallel3 seeing" + f" Parallel2). Got: {grace_text[:200]}" + ) diff --git a/tests/unittests/agents/test_parallel_agent.py b/tests/unittests/agents/test_parallel_agent.py index 5b6c046f54..5d61835fa8 100644 --- a/tests/unittests/agents/test_parallel_agent.py +++ b/tests/unittests/agents/test_parallel_agent.py @@ -102,8 +102,11 @@ async def test_run_async(request: pytest.FixtureRequest, is_resumable: bool): # and agent1 has a delay. assert events[1].author == agent2.name assert events[2].author == agent1.name - assert events[1].branch == f'{parallel_agent.name}.{agent2.name}' - assert events[2].branch == f'{parallel_agent.name}.{agent1.name}' + # Branches are now Branch objects with unique tokens + assert events[1].branch is not None + assert events[2].branch is not None + # Parallel siblings should have different branches (different tokens) + assert events[1].branch != events[2].branch assert events[1].content.parts[0].text == f'Hello, async {agent2.name}!' assert events[2].content.parts[0].text == f'Hello, async {agent1.name}!' @@ -114,8 +117,11 @@ async def test_run_async(request: pytest.FixtureRequest, is_resumable: bool): assert events[0].author == agent2.name assert events[1].author == agent1.name - assert events[0].branch == f'{parallel_agent.name}.{agent2.name}' - assert events[1].branch == f'{parallel_agent.name}.{agent1.name}' + # Branches are now Branch objects with unique tokens + assert events[0].branch is not None + assert events[1].branch is not None + # Parallel siblings should have different branches + assert events[0].branch != events[1].branch assert events[0].content.parts[0].text == f'Hello, async {agent2.name}!' assert events[1].content.parts[0].text == f'Hello, async {agent1.name}!' @@ -158,26 +164,27 @@ async def test_run_async_branches( assert events[1].author == sequential_agent.name assert not events[1].actions.end_of_agent assert events[1].actions.agent_state['current_sub_agent'] == agent2.name - assert events[1].branch == f'{parallel_agent.name}.{sequential_agent.name}' + assert events[1].branch is not None + sequential_branch = events[1].branch # 3. agent 2 event assert events[2].author == agent2.name - assert events[2].branch == f'{parallel_agent.name}.{sequential_agent.name}' + assert events[2].branch is not None # 4. sequential agent checkpoint assert events[3].author == sequential_agent.name assert not events[3].actions.end_of_agent assert events[3].actions.agent_state['current_sub_agent'] == agent3.name - assert events[3].branch == f'{parallel_agent.name}.{sequential_agent.name}' + assert events[3].branch is not None # 5. agent 3 event assert events[4].author == agent3.name - assert events[4].branch == f'{parallel_agent.name}.{sequential_agent.name}' + assert events[4].branch is not None # 6. sequential agent checkpoint (end) assert events[5].author == sequential_agent.name assert events[5].actions.end_of_agent - assert events[5].branch == f'{parallel_agent.name}.{sequential_agent.name}' + assert events[5].branch is not None # Descendants of the same sub-agent should have the same branch. assert events[1].branch == events[2].branch @@ -187,10 +194,11 @@ async def test_run_async_branches( # 7. agent 1 event assert events[6].author == agent1.name - assert events[6].branch == f'{parallel_agent.name}.{agent1.name}' + assert events[6].branch is not None + agent1_branch = events[6].branch # Sub-agents should have different branches. - assert events[6].branch != events[1].branch + assert agent1_branch != sequential_branch # 8. parallel agent checkpoint (end) assert events[7].author == parallel_agent.name @@ -200,15 +208,20 @@ async def test_run_async_branches( # 1. agent 2 event assert events[0].author == agent2.name - assert events[0].branch == f'{parallel_agent.name}.{sequential_agent.name}' + assert events[0].branch is not None + sequential_branch = events[0].branch # 2. agent 3 event assert events[1].author == agent3.name - assert events[1].branch == f'{parallel_agent.name}.{sequential_agent.name}' + assert events[1].branch is not None + # Sequential sub-agents share the same branch + assert events[1].branch == sequential_branch # 3. agent 1 event assert events[2].author == agent1.name - assert events[2].branch == f'{parallel_agent.name}.{agent1.name}' + assert events[2].branch is not None + # Parallel siblings have different branches + assert events[2].branch != sequential_branch @pytest.mark.asyncio @@ -246,17 +259,22 @@ async def test_resume_async_branches(request: pytest.FixtureRequest): # The sequential agent resumes from agent3. # 1. Agent 3 event assert events[0].author == agent3.name - assert events[0].branch == f'{parallel_agent.name}.{sequential_agent.name}' + assert events[0].branch is not None + sequential_branch = events[0].branch # 2. Sequential agent checkpoint (end) assert events[1].author == sequential_agent.name assert events[1].actions.end_of_agent - assert events[1].branch == f'{parallel_agent.name}.{sequential_agent.name}' + assert events[1].branch is not None + # Same branch as agent3 (sequential) + assert events[1].branch == sequential_branch # Agent 1 runs in parallel but has a delay. # 3. Agent 1 event assert events[2].author == agent1.name - assert events[2].branch == f'{parallel_agent.name}.{agent1.name}' + assert events[2].branch is not None + # Different branch from sequential (parallel sibling) + assert events[2].branch != sequential_branch # 4. Parallel agent checkpoint (end) assert events[3].author == parallel_agent.name diff --git a/tests/unittests/agents/test_parallel_event_visibility_integration.py b/tests/unittests/agents/test_parallel_event_visibility_integration.py new file mode 100644 index 0000000000..01ba5fb14e --- /dev/null +++ b/tests/unittests/agents/test_parallel_event_visibility_integration.py @@ -0,0 +1,79 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for parallel agent event visibility (GitHub issue #3470).""" + +from __future__ import annotations + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.parallel_agent import ParallelAgent +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.runners import InMemoryRunner +from google.genai import types +import pytest + +from tests.unittests import testing_utils + + +@pytest.mark.asyncio +async def test_sequence_of_parallels(): + """Test: Sequential[Parallel1[A,B,C], Parallel2[D,E,F]]. + + KEY test from GitHub issue #3470. D,E,F should see A,B,C outputs. + """ + agent_a = LlmAgent( + name="AgentA", model=testing_utils.MockModel.create(responses=["A"]) + ) + agent_d = LlmAgent( + name="AgentD", model=testing_utils.MockModel.create(responses=["D"]) + ) + + parallel1 = ParallelAgent(name="P1", sub_agents=[agent_a]) + parallel2 = ParallelAgent(name="P2", sub_agents=[agent_d]) + root = SequentialAgent(name="Root", sub_agents=[parallel1, parallel2]) + + runner = InMemoryRunner(agent=root, app_name="test") + session = await runner.session_service.create_session( + app_name="test", user_id="user" + ) + + async for event in runner.run_async( + user_id="user", + session_id=session.id, + new_message=types.Content(role="user", parts=[types.Part(text="go")]), + ): + pass + + final_session = await runner.session_service.get_session( + app_name="test", user_id="user", session_id=session.id + ) + + # Debug: print all events and their branches + print("\n=== All Events in Session ===") + for event in final_session.events: + branch_tokens = event.branch.tokens if event.branch else frozenset() + print(f"{event.author:15} | tokens={branch_tokens}") + + agent_a_branch = next( + e.branch for e in final_session.events if e.author == "AgentA" + ) + agent_d_branch = next( + e.branch for e in final_session.events if e.author == "AgentD" + ) + + # KEY: D's tokens should be superset of A's tokens + assert agent_a_branch.tokens.issubset(agent_d_branch.tokens), ( + f"AgentD should see AgentA. A={agent_a_branch.tokens}," + f" D={agent_d_branch.tokens}" + ) diff --git a/tests/unittests/agents/test_remote_a2a_agent.py b/tests/unittests/agents/test_remote_a2a_agent.py index b3894d73d0..91f2fb1fe2 100644 --- a/tests/unittests/agents/test_remote_a2a_agent.py +++ b/tests/unittests/agents/test_remote_a2a_agent.py @@ -37,6 +37,7 @@ from a2a.types import TaskStatus from a2a.types import TaskStatusUpdateEvent from a2a.types import TextPart +from google.adk.agents.branch import Branch from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.remote_a2a_agent import A2A_METADATA_PREFIX from google.adk.agents.remote_a2a_agent import AgentCardResolutionError @@ -542,7 +543,7 @@ def setup_method(self): self.mock_context = Mock(spec=InvocationContext) self.mock_context.session = self.mock_session self.mock_context.invocation_id = "invocation-123" - self.mock_context.branch = "main" + self.mock_context.branch = Branch() def test_create_a2a_request_for_user_function_response_no_function_call(self): """Test function response request creation when no function call exists.""" @@ -1036,7 +1037,7 @@ def setup_method(self): self.mock_context = Mock(spec=InvocationContext) self.mock_context.session = self.mock_session self.mock_context.invocation_id = "invocation-123" - self.mock_context.branch = "main" + self.mock_context.branch = Branch() def test_create_a2a_request_for_user_function_response_no_function_call(self): """Test function response request creation when no function call exists.""" @@ -1435,7 +1436,7 @@ def setup_method(self): self.mock_context = Mock(spec=InvocationContext) self.mock_context.session = self.mock_session self.mock_context.invocation_id = "invocation-123" - self.mock_context.branch = "main" + self.mock_context.branch = Branch() @pytest.mark.asyncio async def test_run_async_impl_initialization_failure(self): @@ -1711,7 +1712,7 @@ def setup_method(self): self.mock_context = Mock(spec=InvocationContext) self.mock_context.session = self.mock_session self.mock_context.invocation_id = "invocation-123" - self.mock_context.branch = "main" + self.mock_context.branch = Branch() @pytest.mark.asyncio async def test_run_async_impl_initialization_failure(self): @@ -1998,7 +1999,7 @@ async def test_full_workflow_with_direct_agent_card(self): mock_context = Mock(spec=InvocationContext) mock_context.session = mock_session mock_context.invocation_id = "invocation-123" - mock_context.branch = "main" + mock_context.branch = Branch() # Mock dependencies with patch( @@ -2094,7 +2095,7 @@ async def test_full_workflow_with_direct_agent_card_and_factory(self): mock_context = Mock(spec=InvocationContext) mock_context.session = mock_session mock_context.invocation_id = "invocation-123" - mock_context.branch = "main" + mock_context.branch = Branch() # Mock dependencies with patch( diff --git a/tests/unittests/agents/test_sequence_of_parallel_agents.py b/tests/unittests/agents/test_sequence_of_parallel_agents.py new file mode 100644 index 0000000000..857a7eccb4 --- /dev/null +++ b/tests/unittests/agents/test_sequence_of_parallel_agents.py @@ -0,0 +1,121 @@ +"""Test sequential parallel agents to verify common prefix visibility.""" + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.parallel_agent import ParallelAgent +from google.adk.agents.sequential_agent import SequentialAgent + +from tests.unittests import testing_utils + + +def test_sequential_parallels(): + """Test Sequential[Parallel1[A,B], Parallel2[D,E]]. + + D and E should be able to see A and B's outputs because: + - Parallel1 creates: "Parallel1.A", "Parallel1.B" + - Parallel1 joins: ctx.branch = "Parallel1" + - Parallel2 creates: "Parallel1.Parallel2.D", "Parallel1.Parallel2.E" + - Common prefix check: "Parallel1.Parallel2.D" and "Parallel1.A" share "Parallel1" + """ + # Parallel1 agents + alice_model = testing_utils.MockModel.create(responses=["I am Alice"]) + alice = LlmAgent( + name="Alice", + description="Agent A", + instruction="Say: I am Alice", + model=alice_model, + ) + + bob_model = testing_utils.MockModel.create(responses=["I am Bob"]) + bob = LlmAgent( + name="Bob", + description="Agent B", + instruction="Say: I am Bob", + model=bob_model, + ) + + # Parallel2 agents - David should see Alice and Bob + david_model = testing_utils.MockModel.create(responses=["I am David"]) + david = LlmAgent( + name="David", + description="Agent D", + instruction="Respond based on context", + model=david_model, + ) + + eve_model = testing_utils.MockModel.create(responses=["I am Eve"]) + eve = LlmAgent( + name="Eve", + description="Agent E", + instruction="Respond based on context", + model=eve_model, + ) + + # Create parallel groups + parallel1 = ParallelAgent( + name="Parallel1", + description="First parallel group", + sub_agents=[alice, bob], + ) + + parallel2 = ParallelAgent( + name="Parallel2", + description="Second parallel group", + sub_agents=[david, eve], + ) + + # Create sequential agent + root = SequentialAgent( + name="Root", + description="Sequential of parallels", + sub_agents=[parallel1, parallel2], + ) + + # Run the agent + runner = testing_utils.InMemoryRunner(root_agent=root) + runner.run("Start") + session = runner.session + + # Print branch contexts for debugging + print("\n=== Branch Hierarchy ===") + for event in session.events: + if event.author and event.branch: + print(f"{event.author:15} | branch={event.branch}") + + # Helper to extract text from simplified contents + def extract_text(contents): + texts = [] + for role, content in contents: + if isinstance(content, str): + texts.append(content) + elif isinstance(content, list): + for part in content: + if hasattr(part, "text") and part.text: + texts.append(part.text) + elif hasattr(content, "text") and content.text: + texts.append(content.text) + return " ".join(texts) + + # David (in Parallel2) should see Alice and Bob from Parallel1 + assert len(david_model.requests) > 0, "David should have made LLM requests" + david_contents = testing_utils.simplify_contents( + david_model.requests[0].contents + ) + david_text = extract_text(david_contents) + + print(f"\nDavid's LLM request text (first 300 chars):\n{david_text[:300]}") + + assert ( + "Alice" in david_text or "I am Alice" in david_text + ), f"David should see Alice's output. Got: {david_text[:200]}" + assert ( + "Bob" in david_text or "I am Bob" in david_text + ), f"David should see Bob's output. Got: {david_text[:200]}" + + print( + "\n✅ SUCCESS! David can see Alice and Bob (common prefix filtering" + " works!)" + ) + + +if __name__ == "__main__": + test_sequential_parallels() diff --git a/tests/unittests/flows/llm_flows/test_contents.py b/tests/unittests/flows/llm_flows/test_contents.py index b2aa91dbee..988fdd3d8c 100644 --- a/tests/unittests/flows/llm_flows/test_contents.py +++ b/tests/unittests/flows/llm_flows/test_contents.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from google.adk.agents.branch import Branch from google.adk.agents.llm_agent import Agent from google.adk.events.event import Event from google.adk.events.event_actions import EventActions @@ -207,29 +208,34 @@ async def test_include_contents_none_multi_branch_current_turn(): invocation_context = await testing_utils.create_invocation_context( agent=agent ) - invocation_context.branch = "root.parent_agent" + # Set current branch - agent is in branch with token 1 + invocation_context.branch = Branch(tokens=frozenset({1})) # Create multi-branch conversation where current turn starts from user # This can arise from having a Parallel Agent with two or more Sequential # Agents as sub agents, each with two Llm Agents as sub agents + # Use same invocation_id as context for branch filtering to work + inv_id = invocation_context.invocation_id events = [ Event( - invocation_id="inv1", - branch="root", + invocation_id=inv_id, author="user", content=types.UserContent("First user message"), + branch=Branch(), # Root branch - visible to all ), Event( - invocation_id="inv1", - branch="root.parent_agent", + invocation_id=inv_id, author="sibling_agent", content=types.ModelContent("Sibling agent response"), + branch=Branch(tokens=frozenset({1})), # Same branch - visible ), Event( - invocation_id="inv1", - branch="root.uncle_agent", + invocation_id=inv_id, author="cousin_agent", content=types.ModelContent("Cousin agent response"), + branch=Branch( + tokens=frozenset({2}) + ), # Different branch - not visible ), ] invocation_context.session.events = events diff --git a/tests/unittests/flows/llm_flows/test_contents_branch.py b/tests/unittests/flows/llm_flows/test_contents_branch.py index 2347354127..a82d00b853 100644 --- a/tests/unittests/flows/llm_flows/test_contents_branch.py +++ b/tests/unittests/flows/llm_flows/test_contents_branch.py @@ -18,6 +18,7 @@ Child agents can see parent agents' events, but not sibling agents' events. """ +from google.adk.agents.branch import Branch from google.adk.agents.llm_agent import Agent from google.adk.events.event import Event from google.adk.flows.llm_flows.contents import request_processor @@ -36,39 +37,50 @@ async def test_branch_filtering_child_sees_parent(): invocation_context = await testing_utils.create_invocation_context( agent=agent ) - # Set current branch as child of "parent_agent" - invocation_context.branch = "parent_agent.child_agent" + # Set current branch as child - child has tokens {1, 2} (inherited 1 from parent, got 2 from fork) + invocation_context.branch = Branch(tokens=frozenset({1, 2})) # Add events from parent and child levels + # Using same invocation_id for all events to test branch filtering within invocation + inv_id = invocation_context.invocation_id events = [ Event( - invocation_id="inv1", + invocation_id=inv_id, author="user", content=types.UserContent("User message"), + branch=Branch(), # Root branch - visible to all ), Event( - invocation_id="inv2", + invocation_id=inv_id, author="parent_agent", content=types.ModelContent("Parent agent response"), - branch="parent_agent", # Parent branch - should be included + branch=Branch( + tokens=frozenset({1}) + ), # Parent branch - should be included ({1} ⊆ {1,2}) ), Event( - invocation_id="inv3", + invocation_id=inv_id, author="child_agent", content=types.ModelContent("Child agent response"), - branch="parent_agent.child_agent", # Current branch - should be included + branch=Branch( + tokens=frozenset({1, 2}) + ), # Current branch - should be included ), Event( - invocation_id="inv4", + invocation_id=inv_id, author="child_agent", content=types.ModelContent("Excluded response 1"), - branch="parent_agent.child_agent000", # Prefix match BUT not itself/ancestor - should be excluded + branch=Branch( + tokens=frozenset({1, 3}) + ), # Sibling branch - should be excluded ({1,3} ⊄ {1,2}) ), Event( - invocation_id="inv5", + invocation_id=inv_id, author="child_agent", content=types.ModelContent("Excluded response 2"), - branch="parent_agent.child", # Prefix match BUT not itself/ancestor - should be excluded + branch=Branch( + tokens=frozenset({3}) + ), # Different branch - should be excluded ({3} ⊄ {1,2}) ), ] invocation_context.session.events = events @@ -96,33 +108,41 @@ async def test_branch_filtering_excludes_sibling_agents(): invocation_context = await testing_utils.create_invocation_context( agent=agent ) - # Set current branch as first child - invocation_context.branch = "parent_agent.child_agent1" + # Set current branch as first child - has tokens {1, 2} (inherited 1 from parent, got 2 from fork) + invocation_context.branch = Branch(tokens=frozenset({1, 2})) # Add events from parent, current child, and sibling child + inv_id = invocation_context.invocation_id events = [ Event( - invocation_id="inv1", + invocation_id=inv_id, author="user", content=types.UserContent("User message"), + branch=Branch(), # Root - visible to all ), Event( - invocation_id="inv2", + invocation_id=inv_id, author="parent_agent", content=types.ModelContent("Parent response"), - branch="parent_agent", # Parent - should be included + branch=Branch( + tokens=frozenset({1}) + ), # Parent - should be included ({1} ⊆ {1,2}) ), Event( - invocation_id="inv3", + invocation_id=inv_id, author="child_agent1", content=types.ModelContent("Child1 response"), - branch="parent_agent.child_agent1", # Current - should be included + branch=Branch( + tokens=frozenset({1, 2}) + ), # Current - should be included ), Event( - invocation_id="inv4", + invocation_id=inv_id, author="child_agent2", content=types.ModelContent("Sibling response"), - branch="parent_agent.child_agent2", # Sibling - should be excluded + branch=Branch( + tokens=frozenset({1, 3}) + ), # Sibling - should be excluded ({1,3} ⊄ {1,2}) ), ] invocation_context.session.events = events @@ -150,28 +170,29 @@ async def test_branch_filtering_no_branch_allows_all(): invocation_context = await testing_utils.create_invocation_context( agent=agent ) - # No current branch set (None) - invocation_context.branch = None + # Root branch (empty tokens) - can see all events + invocation_context.branch = Branch() - # Add events with and without branches + # Add events with various branches + inv_id = invocation_context.invocation_id events = [ Event( - invocation_id="inv1", + invocation_id=inv_id, author="user", content=types.UserContent("No branch message"), - branch=None, + branch=Branch(), # Root - visible ), Event( - invocation_id="inv2", + invocation_id=inv_id, author="agent1", content=types.ModelContent("Agent with branch"), - branch="agent1", + branch=Branch(tokens=frozenset({1})), # Not visible ({1} ⊄ {}) ), Event( - invocation_id="inv3", + invocation_id=inv_id, author="user", content=types.UserContent("Another no branch"), - branch=None, + branch=Branch(), # Root - visible ), ] invocation_context.session.events = events @@ -180,15 +201,10 @@ async def test_branch_filtering_no_branch_allows_all(): async for _ in request_processor.run_async(invocation_context, llm_request): pass - # Verify all events are included when no current branch - assert len(llm_request.contents) == 3 + # Verify only root events are visible (root can't see events with tokens) + assert len(llm_request.contents) == 2 assert llm_request.contents[0] == types.UserContent("No branch message") - assert llm_request.contents[1].role == "user" - assert llm_request.contents[1].parts == [ - types.Part(text="For context:"), - types.Part(text="[agent1] said: Agent with branch"), - ] - assert llm_request.contents[2] == types.UserContent("Another no branch") + assert llm_request.contents[1] == types.UserContent("Another no branch") @pytest.mark.asyncio @@ -199,34 +215,44 @@ async def test_branch_filtering_grandchild_sees_grandparent(): invocation_context = await testing_utils.create_invocation_context( agent=agent ) - # Set deeply nested branch: grandparent.parent.grandchild - invocation_context.branch = "grandparent_agent.parent_agent.grandchild_agent" + # Set deeply nested branch: grandchild has tokens {1, 2, 3} + # (inherited 1 from grandparent, 2 from parent, got 3 from its own fork) + invocation_context.branch = Branch(tokens=frozenset({1, 2, 3})) # Add events from all levels of hierarchy + inv_id = invocation_context.invocation_id events = [ Event( - invocation_id="inv1", + invocation_id=inv_id, author="grandparent_agent", content=types.ModelContent("Grandparent response"), - branch="grandparent_agent", + branch=Branch( + tokens=frozenset({1}) + ), # Should be visible ({1} ⊆ {1,2,3}) ), Event( - invocation_id="inv2", + invocation_id=inv_id, author="parent_agent", content=types.ModelContent("Parent response"), - branch="grandparent_agent.parent_agent", + branch=Branch( + tokens=frozenset({1, 2}) + ), # Should be visible ({1,2} ⊆ {1,2,3}) ), Event( - invocation_id="inv3", + invocation_id=inv_id, author="grandchild_agent", content=types.ModelContent("Grandchild response"), - branch="grandparent_agent.parent_agent.grandchild_agent", + branch=Branch( + tokens=frozenset({1, 2, 3}) + ), # Should be visible (same) ), Event( - invocation_id="inv4", + invocation_id=inv_id, author="sibling_agent", content=types.ModelContent("Sibling response"), - branch="grandparent_agent.parent_agent.sibling_agent", + branch=Branch( + tokens=frozenset({1, 2, 4}) + ), # Should be excluded ({1,2,4} ⊄ {1,2,3}) ), ] invocation_context.session.events = events @@ -258,33 +284,39 @@ async def test_branch_filtering_parent_cannot_see_child(): invocation_context = await testing_utils.create_invocation_context( agent=agent ) - # Set current branch as parent - invocation_context.branch = "parent_agent" + # Set current branch as parent with token {1} + invocation_context.branch = Branch(tokens=frozenset({1})) # Add events from parent and its children + inv_id = invocation_context.invocation_id events = [ Event( - invocation_id="inv1", + invocation_id=inv_id, author="user", content=types.UserContent("User message"), + branch=Branch(), # Root - visible to all ), Event( - invocation_id="inv2", + invocation_id=inv_id, author="parent_agent", content=types.ModelContent("Parent response"), - branch="parent_agent", + branch=Branch(tokens=frozenset({1})), # Should be visible (same) ), Event( - invocation_id="inv3", + invocation_id=inv_id, author="child_agent", content=types.ModelContent("Child response"), - branch="parent_agent.child_agent", + branch=Branch( + tokens=frozenset({1, 2}) + ), # Should be excluded ({1,2} ⊄ {1}) ), Event( - invocation_id="inv4", + invocation_id=inv_id, author="grandchild_agent", content=types.ModelContent("Grandchild response"), - branch="parent_agent.child_agent.grandchild_agent", + branch=Branch( + tokens=frozenset({1, 2, 3}) + ), # Should be excluded ({1,2,3} ⊄ {1}) ), ] invocation_context.session.events = events diff --git a/tests/unittests/flows/llm_flows/test_functions_simple.py b/tests/unittests/flows/llm_flows/test_functions_simple.py index 9fa1151387..a7617fdb70 100644 --- a/tests/unittests/flows/llm_flows/test_functions_simple.py +++ b/tests/unittests/flows/llm_flows/test_functions_simple.py @@ -16,6 +16,7 @@ from typing import Any from typing import Callable +from google.adk.agents.branch import Branch from google.adk.agents.llm_agent import Agent from google.adk.events.event import Event from google.adk.flows.llm_flows.functions import find_matching_function_call @@ -1009,7 +1010,7 @@ def test_merge_parallel_function_response_events_preserves_other_attributes(): """Test that merge_parallel_function_response_events preserves other attributes from base event.""" invocation_id = 'base_invocation_123' base_author = 'base_agent' - base_branch = 'main_branch' + base_branch = Branch(tokens=frozenset({1})) function_response1 = types.FunctionResponse( id='func_123', name='test_function1', response={'result': 'success1'} @@ -1031,7 +1032,6 @@ def test_merge_parallel_function_response_events_preserves_other_attributes(): event2 = Event( invocation_id='different_invocation_456', author='different_agent', # Different author - branch='different_branch', # Different branch content=types.Content( role='user', parts=[types.Part(function_response=function_response2)] ), diff --git a/tests/unittests/flows/llm_flows/test_instructions.py b/tests/unittests/flows/llm_flows/test_instructions.py index dc6fe17638..6fca19f5e5 100644 --- a/tests/unittests/flows/llm_flows/test_instructions.py +++ b/tests/unittests/flows/llm_flows/test_instructions.py @@ -15,6 +15,7 @@ from typing import Any from typing import Optional +from google.adk.agents.branch import Branch from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import Agent from google.adk.agents.llm_agent import LlmAgent @@ -47,7 +48,7 @@ async def _create_invocation_context( session=session, session_service=session_service, run_config=RunConfig(), - branch="main", + branch=Branch(), ) diff --git a/tests/unittests/runners/test_run_tool_confirmation.py b/tests/unittests/runners/test_run_tool_confirmation.py index d6acb66959..8b841e630a 100644 --- a/tests/unittests/runners/test_run_tool_confirmation.py +++ b/tests/unittests/runners/test_run_tool_confirmation.py @@ -19,6 +19,7 @@ from google.adk.agents.base_agent import BaseAgent from google.adk.agents.base_agent import BaseAgentState +from google.adk.agents.branch import Branch from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.parallel_agent import ParallelAgent from google.adk.agents.sequential_agent import SequentialAgent @@ -753,16 +754,18 @@ async def test_pause_and_resume_on_request_confirmation( # Verify that each branch is paused after the long running tool call. # So that no intermediate llm response is generated. - root_agent_events = [event for event in events if event.branch is None] + # Root events have empty token set (root branch) + root_agent_events = [event for event in events if event.branch == Branch()] + # Sub-agent events have specific branch tokens sub_agent1_branch_events = [ event for event in events - if event.branch == f"{agent.name}.{sub_agent1.name}" + if event.branch != Branch() and event.author == sub_agent1.name ] sub_agent2_branch_events = [ event for event in events - if event.branch == f"{agent.name}.{sub_agent2.name}" + if event.branch != Branch() and event.author == sub_agent2.name ] assert testing_utils.simplify_resumable_app_events( copy.deepcopy(root_agent_events) @@ -883,16 +886,16 @@ async def test_pause_and_resume_on_request_confirmation( for event in events: assert event.invocation_id == invocation_id - root_agent_events = [event for event in events if event.branch is None] + root_agent_events = [event for event in events if event.branch == Branch()] sub_agent1_branch_events = [ event for event in events - if event.branch == f"{agent.name}.{sub_agent1.name}" + if event.branch != Branch() and event.author == sub_agent1.name ] sub_agent2_branch_events = [ event for event in events - if event.branch == f"{agent.name}.{sub_agent2.name}" + if event.branch != Branch() and event.author == sub_agent2.name ] # Verify that sub_agent1 is resumed and final; sub_agent2 is still paused; diff --git a/tests/unittests/sessions/test_vertex_ai_session_service.py b/tests/unittests/sessions/test_vertex_ai_session_service.py index 14d2b15b6e..d7b33f4faf 100644 --- a/tests/unittests/sessions/test_vertex_ai_session_service.py +++ b/tests/unittests/sessions/test_vertex_ai_session_service.py @@ -23,6 +23,7 @@ from dateutil.parser import isoparse from fastapi.openapi import models as openapi_models +from google.adk.agents.branch import Branch from google.adk.auth import auth_schemes from google.adk.auth.auth_tool import AuthConfig from google.adk.events.event import Event @@ -182,7 +183,6 @@ def _generate_mock_events_for_session_5(num_events): partial=False, turn_complete=True, interrupted=False, - branch='', long_running_tool_ids={'tool1'}, ), ], @@ -689,7 +689,7 @@ async def test_append_event(): ), error_code='1', error_message='test_error', - branch='test_branch', + branch=Branch(), custom_metadata={'custom': 'data'}, long_running_tool_ids={'tool2'}, )