Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ def __init__(
if self._session_manager:
self.hooks.add_hook(self._session_manager)

# Allow conversation_managers to subscribe to hooks
self.hooks.add_hook(self.conversation_manager)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

session manager uses conversation manager, right? Should we put conversation_manager hooks before it in case there is something the conversation manager needs to do earlier?

Also what are you thinking for documentation? It is a bit unclear now that conversation manager is now both called manually and via hooks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what is the story for transitioning from "manual" conversation management, meaning conversation_manager.apply_management, to using only hooks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we put conversation_manager hooks before it in case there is something the conversation manager needs to do earlier?

I actually think we want to do it after, in case the conversation_manager wants to trim after the session-manager


Also what are you thinking for documentation?

I will have to do a docs PR for this; specifically I need to update it to account for the per_turn parameter AND conversation managers being able to register for hooks.

Also, what is the story for transitioning from "manual" conversation management, meaning conversation_manager.apply_management, to using only hooks.

I'll cut an issue - though IMHO it's low priority


self.tool_executor = tool_executor or ConcurrentToolExecutor()

if hooks:
Expand Down
34 changes: 33 additions & 1 deletion src/strands/agent/conversation_manager/conversation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional

from ...hooks.registry import HookProvider, HookRegistry
from ...types.content import Message

if TYPE_CHECKING:
from ...agent.agent import Agent


class ConversationManager(ABC):
class ConversationManager(ABC, HookProvider):
"""Abstract base class for managing conversation history.

This class provides an interface for implementing conversation management strategies to control the size of message
Expand All @@ -18,6 +19,18 @@ class ConversationManager(ABC):
- Manage memory usage
- Control context length
- Maintain relevant conversation state

ConversationManager implements the HookProvider protocol, allowing derived classes to register hooks for agent
lifecycle events. Derived classes that override register_hooks must call the base implementation to ensure proper
hook registration.

Example:
```python
class MyConversationManager(ConversationManager):
def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
super().register_hooks(registry, **kwargs)
# Register additional hooks here
```
"""

def __init__(self) -> None:
Expand All @@ -30,6 +43,25 @@ def __init__(self) -> None:
"""
self.removed_message_count = 0

def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
"""Register hooks for agent lifecycle events.

Derived classes that override this method must call the base implementation to ensure proper hook
registration chain.

Args:
registry: The hook registry to register callbacks with.
**kwargs: Additional keyword arguments for future extensibility.

Example:
```python
def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
super().register_hooks(registry, **kwargs)
registry.add_callback(SomeEvent, self.on_some_event)
```
"""
pass

def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]:
"""Restore the Conversation Manager's state from a session.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
if TYPE_CHECKING:
from ...agent.agent import Agent

from ...hooks import BeforeModelCallEvent, HookRegistry
from ...types.content import Messages
from ...types.exceptions import ContextWindowOverflowException
from .conversation_manager import ConversationManager
Expand All @@ -18,19 +19,107 @@ class SlidingWindowConversationManager(ConversationManager):

This class handles the logic of maintaining a conversation window that preserves tool usage pairs and avoids
invalid window states.

Supports proactive management during agent loop execution via the per_turn parameter.
"""

def __init__(self, window_size: int = 40, should_truncate_results: bool = True):
def __init__(self, window_size: int = 40, should_truncate_results: bool = True, *, per_turn: bool | int = False):
"""Initialize the sliding window conversation manager.

Args:
window_size: Maximum number of messages to keep in the agent's history.
Defaults to 40 messages.
should_truncate_results: Truncate tool results when a message is too large for the model's context window
per_turn: Controls when to apply message management during agent execution.
- False (default): Only apply management at the end (default behavior)
- True: Apply management before every model call
- int (e.g., 3): Apply management before every N model calls

When to use per_turn: If your agent performs many tool operations in loops
(e.g., web browsing with frequent screenshots), enable per_turn to proactively
manage message history and prevent the agent loop from slowing down. Start with
per_turn=True and adjust to a specific frequency (e.g., per_turn=5) if needed
for performance tuning.

Raises:
ValueError: If per_turn is 0 or a negative integer.
"""
super().__init__()

# Validate per_turn parameter
# Note: Must check bool before int since bool is a subclass of int in Python
if not isinstance(per_turn, bool) and isinstance(per_turn, int) and per_turn <= 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our SDK I don't think we validate input parameters

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we against it or just that this is contrary to other places?

Personally I don't see a problem with enforcing it for new things, but can remove it if we want

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While validation wouldn't hurt, we intentionally avoid it. @pgrayy mentioned it will make SDK cumbersome, and yes, if we do it here, we can do else where.

raise ValueError(f"per_turn must be True, False, or a positive integer, got {per_turn}")

self.window_size = window_size
self.should_truncate_results = should_truncate_results
self.per_turn = per_turn
self.model_call_count = 0

def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None:
"""Register hook callbacks for per-turn conversation management.

Args:
registry: The hook registry to register callbacks with.
**kwargs: Additional keyword arguments for future extensibility.
"""
super().register_hooks(registry, **kwargs)

# Always register the callback - per_turn check happens in the callback
registry.add_callback(BeforeModelCallEvent, self._on_before_model_call)

def _on_before_model_call(self, event: BeforeModelCallEvent) -> None:
"""Handle before model call event for per-turn management.

This callback is invoked before each model call. It tracks the model call count and applies message management
based on the per_turn configuration.

Args:
event: The before model call event containing the agent and model execution details.
"""
# Check if per_turn is enabled
if self.per_turn is False:
return

self.model_call_count += 1

# Determine if we should apply management
should_apply = False
if self.per_turn is True:
should_apply = True
elif isinstance(self.per_turn, int) and self.per_turn > 0:
should_apply = self.model_call_count % self.per_turn == 0

if should_apply:
logger.debug(
"model_call_count=<%d>, per_turn=<%s> | applying per-turn conversation management",
self.model_call_count,
self.per_turn,
)
self.apply_management(event.agent)

def get_state(self) -> dict[str, Any]:
"""Get the current state of the conversation manager.

Returns:
Dictionary containing the manager's state, including model call count for per-turn tracking.
"""
state = super().get_state()
state["model_call_count"] = self.model_call_count
return state

def restore_from_session(self, state: dict[str, Any]) -> Optional[list]:
"""Restore the conversation manager's state from a session.

Args:
state: Previous state of the conversation manager

Returns:
Optional list of messages to prepend to the agent's messages.
"""
result = super().restore_from_session(state)
self.model_call_count = state.get("model_call_count", 0)
return result

def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
"""Apply the sliding window to the agent's messages array to maintain a manageable history size.
Expand Down
3 changes: 2 additions & 1 deletion src/strands/hooks/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import inspect
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar
from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar, runtime_checkable

from ..interrupt import Interrupt, InterruptException

Expand Down Expand Up @@ -84,6 +84,7 @@ class HookEvent(BaseHookEvent):
"""Generic for invoking events - non-contravariant to enable returning events."""


@runtime_checkable
class HookProvider(Protocol):
"""Protocol for objects that provide hook callbacks to an agent.
Expand Down
180 changes: 180 additions & 0 deletions tests/strands/agent/test_conversation_manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from unittest.mock import MagicMock, patch

import pytest

from strands import tool
from strands.agent.agent import Agent
from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager
from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager
from strands.hooks.events import BeforeModelCallEvent
from strands.hooks.registry import HookProvider, HookRegistry
from strands.types.exceptions import ContextWindowOverflowException
from tests.fixtures.mocked_model_provider import MockedModelProvider


@pytest.fixture
Expand Down Expand Up @@ -246,3 +252,177 @@ def test_null_conversation_does_not_restore_with_incorrect_state():

with pytest.raises(ValueError):
manager.restore_from_session({})


# ==============================================================================
# Per-Turn Management Tests
# ==============================================================================


def test_per_turn_parameter_validation():
"""Test per_turn parameter validation."""
# Valid values
assert SlidingWindowConversationManager(per_turn=False).per_turn is False
assert SlidingWindowConversationManager(per_turn=True).per_turn is True
assert SlidingWindowConversationManager(per_turn=3).per_turn == 3

# Invalid values
with pytest.raises(ValueError):
SlidingWindowConversationManager(per_turn=0)
with pytest.raises(ValueError):
SlidingWindowConversationManager(per_turn=-1)


def test_conversation_manager_is_hook_provider():
"""Test that ConversationManager implements HookProvider protocol."""
manager = NullConversationManager()
assert isinstance(manager, HookProvider)


def test_derived_class_does_not_need_to_implement_register_hooks():
"""Test that derived classes don't need to override register_hooks for backwards compatibility."""
from strands.agent.conversation_manager.conversation_manager import ConversationManager

class MinimalConversationManager(ConversationManager):
"""A minimal implementation that only implements abstract methods."""

def apply_management(self, agent, **kwargs):
pass

def reduce_context(self, agent, e=None, **kwargs):
pass

# Should be able to instantiate without implementing register_hooks
manager = MinimalConversationManager()
registry = HookRegistry()

# Should work without error
manager.register_hooks(registry)
assert not registry.has_callbacks()


def test_per_turn_hooks_registration():
"""Test that hooks are registered when conversation_manager implements HookProvider."""
manager = SlidingWindowConversationManager(per_turn=True)
assert isinstance(manager, HookProvider)

registry = HookRegistry()
manager.register_hooks(registry)
assert registry.has_callbacks()


def test_per_turn_false_no_management_during_loop():
"""Test that per_turn=False only manages in finally block."""
manager = SlidingWindowConversationManager(per_turn=False, window_size=100)
responses = [{"role": "assistant", "content": [{"text": "Response"}]}] * 3
model = MockedModelProvider(responses)
agent = Agent(model=model, conversation_manager=manager)

with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock:
agent("Test")
# Should only be called once in finally block (per_turn disabled)
assert mock.call_count == 1


def test_per_turn_true_manages_each_model_call():
"""Test that per_turn=True applies management before each model call."""
manager = SlidingWindowConversationManager(per_turn=True, window_size=100)
responses = [{"role": "assistant", "content": [{"text": "Response"}]}] * 3
model = MockedModelProvider(responses)
agent = Agent(model=model, conversation_manager=manager)

with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock:
agent("Test")
# Should be called for each model call + finally block
# With simple text responses, agent makes 1 model call then stops
assert mock.call_count >= 1


def test_per_turn_integer_manages_every_n_calls():
"""Test that per_turn=N applies management every N model calls."""
manager = SlidingWindowConversationManager(per_turn=2, window_size=100)
# Create responses that trigger multiple model calls
responses = [
{"role": "assistant", "content": [{"toolUse": {"toolUseId": f"{i}", "name": "test", "input": {}}}]}
for i in range(5)
] + [{"role": "assistant", "content": [{"text": "Done"}]}]
model = MockedModelProvider(responses)

@tool(name="test")
def test_tool(query: str = "") -> str:
return "result"

agent = Agent(model=model, conversation_manager=manager, tools=[test_tool])

with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock:
agent("Test")
# With 6 model calls and per_turn=2: called on 2nd, 4th, 6th + finally
assert mock.call_count == 4


def test_per_turn_dynamic_change():
"""Test that per_turn can be changed dynamically."""
manager = SlidingWindowConversationManager(per_turn=False)
registry = HookRegistry()
manager.register_hooks(registry)

mock_agent = MagicMock()
mock_agent.messages = []
event = BeforeModelCallEvent(agent=mock_agent)

# Initially disabled
with patch.object(manager, "apply_management") as mock_apply:
registry.invoke_callbacks(event)
assert mock_apply.call_count == 0

# Enable dynamically
manager.per_turn = True
with patch.object(manager, "apply_management") as mock_apply:
registry.invoke_callbacks(event)
assert mock_apply.call_count == 1


def test_per_turn_reduces_message_count():
"""Test that per_turn actually reduces message count during execution."""
manager = SlidingWindowConversationManager(per_turn=1, window_size=4)
responses = [{"role": "assistant", "content": [{"text": f"Response {i}"}]} for i in range(10)]
model = MockedModelProvider(responses)
agent = Agent(model=model, conversation_manager=manager)

message_counts = []
original_apply = manager.apply_management

def track_apply(agent_instance):
message_counts.append(len(agent_instance.messages))
return original_apply(agent_instance)

with patch.object(manager, "apply_management", side_effect=track_apply):
agent("Test")

# Verify message count stayed around window_size
assert any(count <= manager.window_size for count in message_counts)


def test_per_turn_state_persistence():
"""Test that model_call_count is persisted in state."""
manager = SlidingWindowConversationManager(per_turn=3)
manager.model_call_count = 7

state = manager.get_state()
assert state["model_call_count"] == 7

new_manager = SlidingWindowConversationManager(per_turn=3)
new_manager.restore_from_session(state)
assert new_manager.model_call_count == 7


def test_per_turn_backward_compatibility():
"""Test that existing code without per_turn still works."""
manager = SlidingWindowConversationManager(window_size=40)
assert manager.per_turn is False

responses = [{"role": "assistant", "content": [{"text": "Hello"}]}]
model = MockedModelProvider(responses)
agent = Agent(model=model, conversation_manager=manager)
result = agent("Hello")
assert result is not None
Loading