diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 8fc5be6ca..256c74415 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -250,6 +250,9 @@ def __init__( if self._session_manager: self.hooks.add_hook(self._session_manager) + # Allow conversation_managers to subscribe to hooks + self.hooks.add_hook(self.conversation_manager) + self.tool_executor = tool_executor or ConcurrentToolExecutor() if hooks: diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index 2c1ee7847..47b761abc 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -3,13 +3,14 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Optional +from ...hooks.registry import HookProvider, HookRegistry from ...types.content import Message if TYPE_CHECKING: from ...agent.agent import Agent -class ConversationManager(ABC): +class ConversationManager(ABC, HookProvider): """Abstract base class for managing conversation history. This class provides an interface for implementing conversation management strategies to control the size of message @@ -18,6 +19,18 @@ class ConversationManager(ABC): - Manage memory usage - Control context length - Maintain relevant conversation state + + ConversationManager implements the HookProvider protocol, allowing derived classes to register hooks for agent + lifecycle events. Derived classes that override register_hooks must call the base implementation to ensure proper + hook registration. + + Example: + ```python + class MyConversationManager(ConversationManager): + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + super().register_hooks(registry, **kwargs) + # Register additional hooks here + ``` """ def __init__(self) -> None: @@ -30,6 +43,25 @@ def __init__(self) -> None: """ self.removed_message_count = 0 + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register hooks for agent lifecycle events. + + Derived classes that override this method must call the base implementation to ensure proper hook + registration chain. + + Args: + registry: The hook registry to register callbacks with. + **kwargs: Additional keyword arguments for future extensibility. + + Example: + ```python + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + super().register_hooks(registry, **kwargs) + registry.add_callback(SomeEvent, self.on_some_event) + ``` + """ + pass + def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]: """Restore the Conversation Manager's state from a session. diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index e082abe8e..5ed3a6f79 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -6,6 +6,7 @@ if TYPE_CHECKING: from ...agent.agent import Agent +from ...hooks import BeforeModelCallEvent, HookRegistry from ...types.content import Messages from ...types.exceptions import ContextWindowOverflowException from .conversation_manager import ConversationManager @@ -18,19 +19,107 @@ class SlidingWindowConversationManager(ConversationManager): This class handles the logic of maintaining a conversation window that preserves tool usage pairs and avoids invalid window states. + + Supports proactive management during agent loop execution via the per_turn parameter. """ - def __init__(self, window_size: int = 40, should_truncate_results: bool = True): + def __init__(self, window_size: int = 40, should_truncate_results: bool = True, *, per_turn: bool | int = False): """Initialize the sliding window conversation manager. Args: window_size: Maximum number of messages to keep in the agent's history. Defaults to 40 messages. should_truncate_results: Truncate tool results when a message is too large for the model's context window + per_turn: Controls when to apply message management during agent execution. + - False (default): Only apply management at the end (default behavior) + - True: Apply management before every model call + - int (e.g., 3): Apply management before every N model calls + + When to use per_turn: If your agent performs many tool operations in loops + (e.g., web browsing with frequent screenshots), enable per_turn to proactively + manage message history and prevent the agent loop from slowing down. Start with + per_turn=True and adjust to a specific frequency (e.g., per_turn=5) if needed + for performance tuning. + + Raises: + ValueError: If per_turn is 0 or a negative integer. """ super().__init__() + + # Validate per_turn parameter + # Note: Must check bool before int since bool is a subclass of int in Python + if not isinstance(per_turn, bool) and isinstance(per_turn, int) and per_turn <= 0: + raise ValueError(f"per_turn must be True, False, or a positive integer, got {per_turn}") + self.window_size = window_size self.should_truncate_results = should_truncate_results + self.per_turn = per_turn + self.model_call_count = 0 + + def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None: + """Register hook callbacks for per-turn conversation management. + + Args: + registry: The hook registry to register callbacks with. + **kwargs: Additional keyword arguments for future extensibility. + """ + super().register_hooks(registry, **kwargs) + + # Always register the callback - per_turn check happens in the callback + registry.add_callback(BeforeModelCallEvent, self._on_before_model_call) + + def _on_before_model_call(self, event: BeforeModelCallEvent) -> None: + """Handle before model call event for per-turn management. + + This callback is invoked before each model call. It tracks the model call count and applies message management + based on the per_turn configuration. + + Args: + event: The before model call event containing the agent and model execution details. + """ + # Check if per_turn is enabled + if self.per_turn is False: + return + + self.model_call_count += 1 + + # Determine if we should apply management + should_apply = False + if self.per_turn is True: + should_apply = True + elif isinstance(self.per_turn, int) and self.per_turn > 0: + should_apply = self.model_call_count % self.per_turn == 0 + + if should_apply: + logger.debug( + "model_call_count=<%d>, per_turn=<%s> | applying per-turn conversation management", + self.model_call_count, + self.per_turn, + ) + self.apply_management(event.agent) + + def get_state(self) -> dict[str, Any]: + """Get the current state of the conversation manager. + + Returns: + Dictionary containing the manager's state, including model call count for per-turn tracking. + """ + state = super().get_state() + state["model_call_count"] = self.model_call_count + return state + + def restore_from_session(self, state: dict[str, Any]) -> Optional[list]: + """Restore the conversation manager's state from a session. + + Args: + state: Previous state of the conversation manager + + Returns: + Optional list of messages to prepend to the agent's messages. + """ + result = super().restore_from_session(state) + self.model_call_count = state.get("model_call_count", 0) + return result def apply_management(self, agent: "Agent", **kwargs: Any) -> None: """Apply the sliding window to the agent's messages array to maintain a manageable history size. diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 1efc0bf5b..9edf7ffa7 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -10,7 +10,7 @@ import inspect import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar +from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar, runtime_checkable from ..interrupt import Interrupt, InterruptException @@ -84,6 +84,7 @@ class HookEvent(BaseHookEvent): """Generic for invoking events - non-contravariant to enable returning events.""" +@runtime_checkable class HookProvider(Protocol): """Protocol for objects that provide hook callbacks to an agent. diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 77d7dcce8..b2f8f5432 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -1,9 +1,15 @@ +from unittest.mock import MagicMock, patch + import pytest +from strands import tool from strands.agent.agent import Agent from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager +from strands.hooks.events import BeforeModelCallEvent +from strands.hooks.registry import HookProvider, HookRegistry from strands.types.exceptions import ContextWindowOverflowException +from tests.fixtures.mocked_model_provider import MockedModelProvider @pytest.fixture @@ -246,3 +252,177 @@ def test_null_conversation_does_not_restore_with_incorrect_state(): with pytest.raises(ValueError): manager.restore_from_session({}) + + +# ============================================================================== +# Per-Turn Management Tests +# ============================================================================== + + +def test_per_turn_parameter_validation(): + """Test per_turn parameter validation.""" + # Valid values + assert SlidingWindowConversationManager(per_turn=False).per_turn is False + assert SlidingWindowConversationManager(per_turn=True).per_turn is True + assert SlidingWindowConversationManager(per_turn=3).per_turn == 3 + + # Invalid values + with pytest.raises(ValueError): + SlidingWindowConversationManager(per_turn=0) + with pytest.raises(ValueError): + SlidingWindowConversationManager(per_turn=-1) + + +def test_conversation_manager_is_hook_provider(): + """Test that ConversationManager implements HookProvider protocol.""" + manager = NullConversationManager() + assert isinstance(manager, HookProvider) + + +def test_derived_class_does_not_need_to_implement_register_hooks(): + """Test that derived classes don't need to override register_hooks for backwards compatibility.""" + from strands.agent.conversation_manager.conversation_manager import ConversationManager + + class MinimalConversationManager(ConversationManager): + """A minimal implementation that only implements abstract methods.""" + + def apply_management(self, agent, **kwargs): + pass + + def reduce_context(self, agent, e=None, **kwargs): + pass + + # Should be able to instantiate without implementing register_hooks + manager = MinimalConversationManager() + registry = HookRegistry() + + # Should work without error + manager.register_hooks(registry) + assert not registry.has_callbacks() + + +def test_per_turn_hooks_registration(): + """Test that hooks are registered when conversation_manager implements HookProvider.""" + manager = SlidingWindowConversationManager(per_turn=True) + assert isinstance(manager, HookProvider) + + registry = HookRegistry() + manager.register_hooks(registry) + assert registry.has_callbacks() + + +def test_per_turn_false_no_management_during_loop(): + """Test that per_turn=False only manages in finally block.""" + manager = SlidingWindowConversationManager(per_turn=False, window_size=100) + responses = [{"role": "assistant", "content": [{"text": "Response"}]}] * 3 + model = MockedModelProvider(responses) + agent = Agent(model=model, conversation_manager=manager) + + with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock: + agent("Test") + # Should only be called once in finally block (per_turn disabled) + assert mock.call_count == 1 + + +def test_per_turn_true_manages_each_model_call(): + """Test that per_turn=True applies management before each model call.""" + manager = SlidingWindowConversationManager(per_turn=True, window_size=100) + responses = [{"role": "assistant", "content": [{"text": "Response"}]}] * 3 + model = MockedModelProvider(responses) + agent = Agent(model=model, conversation_manager=manager) + + with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock: + agent("Test") + # Should be called for each model call + finally block + # With simple text responses, agent makes 1 model call then stops + assert mock.call_count >= 1 + + +def test_per_turn_integer_manages_every_n_calls(): + """Test that per_turn=N applies management every N model calls.""" + manager = SlidingWindowConversationManager(per_turn=2, window_size=100) + # Create responses that trigger multiple model calls + responses = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": f"{i}", "name": "test", "input": {}}}]} + for i in range(5) + ] + [{"role": "assistant", "content": [{"text": "Done"}]}] + model = MockedModelProvider(responses) + + @tool(name="test") + def test_tool(query: str = "") -> str: + return "result" + + agent = Agent(model=model, conversation_manager=manager, tools=[test_tool]) + + with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock: + agent("Test") + # With 6 model calls and per_turn=2: called on 2nd, 4th, 6th + finally + assert mock.call_count == 4 + + +def test_per_turn_dynamic_change(): + """Test that per_turn can be changed dynamically.""" + manager = SlidingWindowConversationManager(per_turn=False) + registry = HookRegistry() + manager.register_hooks(registry) + + mock_agent = MagicMock() + mock_agent.messages = [] + event = BeforeModelCallEvent(agent=mock_agent) + + # Initially disabled + with patch.object(manager, "apply_management") as mock_apply: + registry.invoke_callbacks(event) + assert mock_apply.call_count == 0 + + # Enable dynamically + manager.per_turn = True + with patch.object(manager, "apply_management") as mock_apply: + registry.invoke_callbacks(event) + assert mock_apply.call_count == 1 + + +def test_per_turn_reduces_message_count(): + """Test that per_turn actually reduces message count during execution.""" + manager = SlidingWindowConversationManager(per_turn=1, window_size=4) + responses = [{"role": "assistant", "content": [{"text": f"Response {i}"}]} for i in range(10)] + model = MockedModelProvider(responses) + agent = Agent(model=model, conversation_manager=manager) + + message_counts = [] + original_apply = manager.apply_management + + def track_apply(agent_instance): + message_counts.append(len(agent_instance.messages)) + return original_apply(agent_instance) + + with patch.object(manager, "apply_management", side_effect=track_apply): + agent("Test") + + # Verify message count stayed around window_size + assert any(count <= manager.window_size for count in message_counts) + + +def test_per_turn_state_persistence(): + """Test that model_call_count is persisted in state.""" + manager = SlidingWindowConversationManager(per_turn=3) + manager.model_call_count = 7 + + state = manager.get_state() + assert state["model_call_count"] == 7 + + new_manager = SlidingWindowConversationManager(per_turn=3) + new_manager.restore_from_session(state) + assert new_manager.model_call_count == 7 + + +def test_per_turn_backward_compatibility(): + """Test that existing code without per_turn still works.""" + manager = SlidingWindowConversationManager(window_size=40) + assert manager.per_turn is False + + responses = [{"role": "assistant", "content": [{"text": "Hello"}]}] + model = MockedModelProvider(responses) + agent = Agent(model=model, conversation_manager=manager) + result = agent("Hello") + assert result is not None