From b94a0c23349441661a811b5213cf7244bcf6ed7b Mon Sep 17 00:00:00 2001 From: Andrei Date: Mon, 29 Dec 2025 14:35:21 +0200 Subject: [PATCH] feat: support sequential parallel tool calls --- pyproject.toml | 2 +- src/uipath_langchain/agent/react/agent.py | 13 +- src/uipath_langchain/agent/react/router.py | 117 ++++----- src/uipath_langchain/agent/react/types.py | 5 + src/uipath_langchain/agent/react/utils.py | 17 +- .../agent/tools/orchestrator_node.py | 111 +++++++++ src/uipath_langchain/agent/tools/tool_node.py | 29 ++- tests/agent/tools/test_orchestrator_node.py | 230 ++++++++++++++++++ tests/agent/tools/test_tool_node.py | 35 +-- uv.lock | 2 +- 10 files changed, 461 insertions(+), 100 deletions(-) create mode 100644 src/uipath_langchain/agent/tools/orchestrator_node.py create mode 100644 tests/agent/tools/test_orchestrator_node.py diff --git a/pyproject.toml b/pyproject.toml index 0af0c1f7..768154cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "uipath-langchain" -version = "0.2.0" +version = "0.2.1" description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform" readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.11" diff --git a/src/uipath_langchain/agent/react/agent.py b/src/uipath_langchain/agent/react/agent.py index 5f63a033..973904ab 100644 --- a/src/uipath_langchain/agent/react/agent.py +++ b/src/uipath_langchain/agent/react/agent.py @@ -11,6 +11,7 @@ from ..guardrails.actions import GuardrailAction from ..tools import create_tool_node +from ..tools.orchestrator_node import create_orchestrator_node from .guardrails.guardrails_subgraph import ( create_agent_init_guardrails_subgraph, create_agent_terminate_guardrails_subgraph, @@ -105,6 +106,9 @@ def create_agent( ) builder.add_node(AgentGraphNode.TERMINATE, terminate_with_guardrails_subgraph) + orchestrator_node = create_orchestrator_node(config.thinking_messages_limit) + builder.add_node(AgentGraphNode.ORCHESTRATOR, orchestrator_node) + builder.add_edge(START, AgentGraphNode.INIT) llm_node = create_llm_node(model, llm_tools, config.thinking_messages_limit) @@ -114,16 +118,19 @@ def create_agent( builder.add_node(AgentGraphNode.AGENT, llm_with_guardrails_subgraph) builder.add_edge(AgentGraphNode.INIT, AgentGraphNode.AGENT) + builder.add_edge(AgentGraphNode.AGENT, AgentGraphNode.ORCHESTRATOR) + tool_node_names = list(tool_nodes_with_guardrails.keys()) - route_agent = create_route_agent(config.thinking_messages_limit) + route_agent = create_route_agent() + builder.add_conditional_edges( - AgentGraphNode.AGENT, + AgentGraphNode.ORCHESTRATOR, route_agent, [AgentGraphNode.AGENT, *tool_node_names, AgentGraphNode.TERMINATE], ) for tool_name in tool_node_names: - builder.add_edge(tool_name, AgentGraphNode.AGENT) + builder.add_edge(tool_name, AgentGraphNode.ORCHESTRATOR) builder.add_edge(AgentGraphNode.TERMINATE, END) diff --git a/src/uipath_langchain/agent/react/router.py b/src/uipath_langchain/agent/react/router.py index dfcbf57a..71cd526f 100644 --- a/src/uipath_langchain/agent/react/router.py +++ b/src/uipath_langchain/agent/react/router.py @@ -2,56 +2,17 @@ from typing import Literal -from langchain_core.messages import AIMessage, AnyMessage, ToolCall -from uipath.agent.react import END_EXECUTION_TOOL, RAISE_ERROR_TOOL - from ..exceptions import AgentNodeRoutingException -from .types import AgentGraphNode, AgentGraphState -from .utils import count_consecutive_thinking_messages - -FLOW_CONTROL_TOOLS = [END_EXECUTION_TOOL.name, RAISE_ERROR_TOOL.name] - - -def __filter_control_flow_tool_calls( - tool_calls: list[ToolCall], -) -> list[ToolCall]: - """Remove control flow tools when multiple tool calls exist.""" - if len(tool_calls) <= 1: - return tool_calls - - return [tc for tc in tool_calls if tc.get("name") not in FLOW_CONTROL_TOOLS] - - -def __has_control_flow_tool(tool_calls: list[ToolCall]) -> bool: - """Check if any tool call is of a control flow tool.""" - return any(tc.get("name") in FLOW_CONTROL_TOOLS for tc in tool_calls) - - -def __validate_last_message_is_AI(messages: list[AnyMessage]) -> AIMessage: - """Validate and return last message from state. - - Raises: - AgentNodeRoutingException: If messages are empty or last message is not AIMessage - """ - if not messages: - raise AgentNodeRoutingException( - "No messages in state - cannot route after agent" - ) - - last_message = messages[-1] - if not isinstance(last_message, AIMessage): - raise AgentNodeRoutingException( - f"Last message is not AIMessage (type: {type(last_message).__name__}) - cannot route after agent" - ) - - return last_message +from .types import ( + FLOW_CONTROL_TOOLS, + AgentGraphNode, + AgentGraphState, +) +from .utils import find_latest_ai_message -def create_route_agent(thinking_messages_limit: int = 0): - """Create a routing function configured with thinking_messages_limit. - - Args: - thinking_messages_limit: Max consecutive thinking messages before error +def create_route_agent(): + """Create a routing function for LangGraph conditional edges. Returns: Routing function for LangGraph conditional edges @@ -59,50 +20,58 @@ def create_route_agent(thinking_messages_limit: int = 0): def route_agent( state: AgentGraphState, - ) -> list[str] | Literal[AgentGraphNode.AGENT, AgentGraphNode.TERMINATE]: - """Route after agent: handles all routing logic including control flow detection. + ) -> str | Literal[AgentGraphNode.AGENT, AgentGraphNode.TERMINATE]: + """Route after agent: looks at current tool call index and routes to corresponding tool node. Routing logic: - 1. If multiple tool calls exist, filter out control flow tools (EndExecution, RaiseError) - 2. If control flow tool(s) remain, route to TERMINATE - 3. If regular tool calls remain, route to specific tool nodes (return list of tool names) - 4. If no tool calls, handle consecutive completions + 1. If current_tool_call_index is None, route back to LLM + 2. If current_tool_call_index is set, route to the corresponding tool node + 3. Handle control flow tools for termination Returns: - - list[str]: Tool node names for parallel execution - - AgentGraphNode.AGENT: For consecutive completions + - str: Tool node name for single tool execution + - AgentGraphNode.AGENT: When no current tool call index - AgentGraphNode.TERMINATE: For control flow termination Raises: - AgentNodeRoutingException: When encountering unexpected state (empty messages, non-AIMessage, or excessive completions) + AgentNodeRoutingException: When encountering unexpected state """ + current_index = state.current_tool_call_index + + # no tool call in progress, route back to LLM + if current_index is None: + return AgentGraphNode.AGENT + messages = state.messages - last_message = __validate_last_message_is_AI(messages) - tool_calls = list(last_message.tool_calls) if last_message.tool_calls else [] - tool_calls = __filter_control_flow_tool_calls(tool_calls) + if not messages: + raise AgentNodeRoutingException( + "No messages in state - cannot route after agent" + ) - if tool_calls and __has_control_flow_tool(tool_calls): - return AgentGraphNode.TERMINATE + latest_ai_message = find_latest_ai_message(messages) - if tool_calls: - return [tc["name"] for tc in tool_calls] + if latest_ai_message is None: + raise AgentNodeRoutingException( + "No AIMessage found in messages - cannot route after agent" + ) - consecutive_thinking_messages = count_consecutive_thinking_messages(messages) + tool_calls = ( + list(latest_ai_message.tool_calls) if latest_ai_message.tool_calls else [] + ) - if consecutive_thinking_messages > thinking_messages_limit: + if current_index >= len(tool_calls): raise AgentNodeRoutingException( - f"Agent exceeded consecutive completions limit without producing tool calls " - f"(completions: {consecutive_thinking_messages}, max: {thinking_messages_limit}). " - f"This should not happen as tool_choice='required' is enforced at the limit." + f"Current tool call index {current_index} exceeds available tool calls ({len(tool_calls)})" ) - if last_message.content: - return AgentGraphNode.AGENT + current_tool_call = tool_calls[current_index] + tool_name = current_tool_call["name"] - raise AgentNodeRoutingException( - f"Agent produced empty response without tool calls " - f"(completions: {consecutive_thinking_messages}, has_content: False)" - ) + # handle control flow tools for termination + if tool_name in FLOW_CONTROL_TOOLS: + return AgentGraphNode.TERMINATE + + return tool_name return route_agent diff --git a/src/uipath_langchain/agent/react/types.py b/src/uipath_langchain/agent/react/types.py index e0b68b38..b1be2b11 100644 --- a/src/uipath_langchain/agent/react/types.py +++ b/src/uipath_langchain/agent/react/types.py @@ -4,10 +4,13 @@ from langchain_core.messages import AnyMessage from langgraph.graph.message import add_messages from pydantic import BaseModel, Field +from uipath.agent.react import END_EXECUTION_TOOL, RAISE_ERROR_TOOL from uipath.platform.attachments import Attachment from uipath_langchain.agent.react.utils import add_job_attachments +FLOW_CONTROL_TOOLS = [END_EXECUTION_TOOL.name, RAISE_ERROR_TOOL.name] + class AgentTerminationSource(StrEnum): ESCALATION = "escalation" @@ -27,6 +30,7 @@ class AgentGraphState(BaseModel): messages: Annotated[list[AnyMessage], add_messages] = [] job_attachments: Annotated[dict[str, Attachment], add_job_attachments] = {} termination: AgentTermination | None = None + current_tool_call_index: int | None = None class AgentGuardrailsGraphState(AgentGraphState): @@ -41,6 +45,7 @@ class AgentGraphNode(StrEnum): GUARDED_INIT = "guarded-init" AGENT = "agent" LLM = "llm" + ORCHESTRATOR = "orchestrator" TOOLS = "tools" TERMINATE = "terminate" GUARDED_TERMINATE = "guarded-terminate" diff --git a/src/uipath_langchain/agent/react/utils.py b/src/uipath_langchain/agent/react/utils.py index 94244c87..d2083a9c 100644 --- a/src/uipath_langchain/agent/react/utils.py +++ b/src/uipath_langchain/agent/react/utils.py @@ -2,7 +2,7 @@ from typing import Any, Sequence -from langchain_core.messages import AIMessage, BaseMessage +from langchain_core.messages import AIMessage, AnyMessage, BaseMessage from pydantic import BaseModel from uipath.agent.react import END_EXECUTION_TOOL from uipath.platform.attachments import Attachment @@ -73,3 +73,18 @@ def add_job_attachments( return right return {**left, **right} + + +def find_latest_ai_message(messages: list[AnyMessage]) -> AIMessage | None: + """Find and return the latest AIMessage from a list of messages. + + Args: + messages: List of messages to search through + + Returns: + The latest AIMessage found, or None if no AIMessage exists + """ + for message in reversed(messages): + if isinstance(message, AIMessage): + return message + return None diff --git a/src/uipath_langchain/agent/tools/orchestrator_node.py b/src/uipath_langchain/agent/tools/orchestrator_node.py new file mode 100644 index 00000000..5ecf5b09 --- /dev/null +++ b/src/uipath_langchain/agent/tools/orchestrator_node.py @@ -0,0 +1,111 @@ +from typing import Any + +from langchain_core.messages import ToolCall + +from uipath_langchain.agent.exceptions import AgentNodeRoutingException +from uipath_langchain.agent.react.types import FLOW_CONTROL_TOOLS, AgentGraphState +from uipath_langchain.agent.react.utils import ( + count_consecutive_thinking_messages, + find_latest_ai_message, +) + + +def __filter_control_flow_tool_calls(tool_calls: list[ToolCall]) -> list[ToolCall]: + """Remove control flow tools when multiple tool calls exist.""" + if len(tool_calls) <= 1: + return tool_calls + + return [tc for tc in tool_calls if tc.get("name") not in FLOW_CONTROL_TOOLS] + + +def create_orchestrator_node(thinking_messages_limit: int = 0): + """Create an orchestrator node responsible for sequencing tool calls. + + Args: + thinking_messages_limit: Max consecutive thinking messages before error + """ + + def orchestrator_node(state: AgentGraphState) -> dict[str, Any]: + current_index = state.current_tool_call_index + + if current_index is None: + # new batch of tool calls + if not state.messages: + raise AgentNodeRoutingException( + "No messages in state - cannot process tool calls" + ) + + # check consecutive thinking messages limit + if thinking_messages_limit >= 0: + consecutive_thinking = count_consecutive_thinking_messages( + state.messages + ) + if consecutive_thinking > thinking_messages_limit: + raise AgentNodeRoutingException( + f"Too many consecutive thinking messages ({consecutive_thinking}). " + f"Limit is {thinking_messages_limit}. Agent must use tools." + ) + + latest_ai_message = find_latest_ai_message(state.messages) + + if latest_ai_message is None or not latest_ai_message.tool_calls: + return {"current_tool_call_index": None} + + # apply flow control tool filtering + original_tool_calls = list(latest_ai_message.tool_calls) + filtered_tool_calls = __filter_control_flow_tool_calls(original_tool_calls) + + if len(filtered_tool_calls) != len(original_tool_calls): + modified_message = latest_ai_message.model_copy() + modified_message.tool_calls = filtered_tool_calls + + # we need to filter out the content within the message as well, otherwise the LLM will raise an error + filtered_tool_call_ids = {tc["id"] for tc in filtered_tool_calls} + if isinstance(modified_message.content, list): + modified_message.content = [ + block + for block in modified_message.content + if ( + isinstance(block, dict) + and ( + block.get("call_id") in filtered_tool_call_ids + or block.get("call_id") is None # keep non-tool blocks + ) + ) + or not isinstance(block, dict) + ] + + return { + "current_tool_call_index": 0, + "messages": [modified_message], + } + + return {"current_tool_call_index": 0} + + # in the middle of processing a batch + if not state.messages: + raise AgentNodeRoutingException( + "No messages in state during batch processing" + ) + + latest_ai_message = find_latest_ai_message(state.messages) + + if latest_ai_message is None: + raise AgentNodeRoutingException( + "No AI message found during batch processing" + ) + + if not latest_ai_message.tool_calls: + raise AgentNodeRoutingException( + "No tool calls found in AI message during batch processing" + ) + + total_tool_calls = len(latest_ai_message.tool_calls) + next_index = current_index + 1 + + if next_index >= total_tool_calls: + return {"current_tool_call_index": None} + else: + return {"current_tool_call_index": next_index} + + return orchestrator_node diff --git a/src/uipath_langchain/agent/tools/tool_node.py b/src/uipath_langchain/agent/tools/tool_node.py index 6391b270..dd79d3b5 100644 --- a/src/uipath_langchain/agent/tools/tool_node.py +++ b/src/uipath_langchain/agent/tools/tool_node.py @@ -4,13 +4,14 @@ from inspect import signature from typing import Any, Awaitable, Callable, Literal -from langchain_core.messages.ai import AIMessage from langchain_core.messages.tool import ToolCall, ToolMessage from langchain_core.tools import BaseTool from langgraph._internal._runnable import RunnableCallable from langgraph.types import Command from pydantic import BaseModel +from uipath_langchain.agent.react.utils import find_latest_ai_message + # the type safety can be improved with generics ToolWrapperType = Callable[ [BaseTool, ToolCall, Any], dict[str, Any] | Command[Any] | None @@ -78,12 +79,30 @@ def _extract_tool_call(self, state: Any) -> ToolCall | None: if not hasattr(state, "messages"): raise ValueError("State does not have messages key") - last_message = state.messages[-1] - if not isinstance(last_message, AIMessage): - raise ValueError("Last message in message stack is not an AIMessage.") + latest_ai_message = find_latest_ai_message(state.messages) + if latest_ai_message is None or not latest_ai_message.tool_calls: + return None + + latest_ai_index = next( + i + for i, msg in enumerate(reversed(state.messages)) + if msg is latest_ai_message + ) + messages_after_ai = state.messages[len(state.messages) - latest_ai_index :] - for tool_call in last_message.tool_calls: + for tool_call in latest_ai_message.tool_calls: if tool_call["name"] == self.tool.name: + existing_tool_response_ids = { + msg.tool_call_id + for msg in messages_after_ai + if isinstance(msg, ToolMessage) + and msg.tool_call_id == tool_call["id"] + } + if tool_call["id"] in existing_tool_response_ids: + raise ValueError( + f"Tool response already exists for tool call {tool_call['id']}" + ) + return tool_call return None diff --git a/tests/agent/tools/test_orchestrator_node.py b/tests/agent/tools/test_orchestrator_node.py new file mode 100644 index 00000000..d1896156 --- /dev/null +++ b/tests/agent/tools/test_orchestrator_node.py @@ -0,0 +1,230 @@ +"""Tests for orchestrator_node.py module.""" + +from typing import Any + +import pytest +from langchain_core.messages import AIMessage, HumanMessage + +from uipath_langchain.agent.exceptions import AgentNodeRoutingException +from uipath_langchain.agent.react.types import AgentGraphState +from uipath_langchain.agent.tools.orchestrator_node import create_orchestrator_node + + +class TestOrchestratorNode: + """Test cases for orchestrator node.""" + + def test_no_messages_throws_exception(self): + """Test that empty messages throw exception.""" + orchestrator = create_orchestrator_node() + state = AgentGraphState(messages=[]) + + with pytest.raises(AgentNodeRoutingException): + orchestrator(state) + + def test_no_ai_message_returns_none(self): + """Test that no AI message returns current_tool_call_index None to route back to LLM.""" + orchestrator = create_orchestrator_node() + human_message = HumanMessage(content="Hello") + state = AgentGraphState(messages=[human_message], current_tool_call_index=None) + + result = orchestrator(state) + + assert result == {"current_tool_call_index": None} + + def test_new_tool_call_batch_sets_index_to_zero(self): + """Test that new tool call batch sets current_tool_call_index to 0.""" + orchestrator = create_orchestrator_node() + tool_call = { + "name": "test_tool", + "args": {"input": "test"}, + "id": "call_1", + } + ai_message = AIMessage(content="Using tool", tool_calls=[tool_call]) + state = AgentGraphState(messages=[ai_message], current_tool_call_index=None) + + result = orchestrator(state) + + assert result == {"current_tool_call_index": 0} + + def test_thinking_messages_limit_enforcement(self): + """Test that thinking messages limit is enforced.""" + orchestrator = create_orchestrator_node(thinking_messages_limit=2) + + # Create multiple AI messages without tool calls (thinking messages) + messages = [ + AIMessage(content="Thinking 1"), + AIMessage(content="Thinking 2"), + AIMessage(content="Still thinking 3"), + ] + state = AgentGraphState(messages=messages, current_tool_call_index=None) + + with pytest.raises(AgentNodeRoutingException): + orchestrator(state) + + def test_thinking_messages_limit_zero_throws_immediately(self): + """Test that thinking_messages_limit=0 throws on any AI message without tool calls.""" + orchestrator = create_orchestrator_node(thinking_messages_limit=0) + + # Single AI message without tool calls should throw + messages = [AIMessage(content="Just thinking")] + state = AgentGraphState(messages=messages, current_tool_call_index=None) + + with pytest.raises(AgentNodeRoutingException): + orchestrator(state) + + def test_flow_control_tool_filtering_single_tool(self): + """Test that single flow control tool is not filtered.""" + orchestrator = create_orchestrator_node() + tool_call = { + "name": "end_execution", + "args": {"result": "done"}, + "id": "call_1", + } + ai_message = AIMessage(content="Ending", tool_calls=[tool_call]) + state = AgentGraphState(messages=[ai_message], current_tool_call_index=None) + + result = orchestrator(state) + + assert result == {"current_tool_call_index": 0} + + def test_flow_control_tool_filtering_multiple_tools(self): + """Test that flow control tools are filtered when multiple tools exist.""" + orchestrator = create_orchestrator_node() + tool_calls = [ + { + "name": "regular_tool", + "args": {"input": "test"}, + "id": "call_1", + }, + { + "name": "end_execution", + "args": {"result": "done"}, + "id": "call_2", + }, + ] + ai_message = AIMessage(content="Using tools", tool_calls=tool_calls) + state = AgentGraphState(messages=[ai_message], current_tool_call_index=None) + + result = orchestrator(state) + + assert "messages" in result + assert result["current_tool_call_index"] == 0 + + modified_message = result["messages"][0] + assert len(modified_message.tool_calls) == 1 + assert modified_message.tool_calls[0]["name"] == "regular_tool" + + def test_content_filtering_with_tool_calls(self): + """Test that content blocks are filtered along with tool calls.""" + orchestrator = create_orchestrator_node() + tool_calls = [ + { + "name": "regular_tool", + "args": {"input": "test"}, + "id": "call_1", + }, + { + "name": "end_execution", + "args": {"result": "done"}, + "id": "call_2", + }, + ] + content_blocks: list[str | dict[Any, Any]] = [ + {"type": "text", "text": "Using tools"}, + {"type": "tool_use", "call_id": "call_1", "name": "regular_tool"}, + {"type": "tool_use", "call_id": "call_2", "name": "end_execution"}, + ] + ai_message = AIMessage(content=content_blocks, tool_calls=tool_calls) + state = AgentGraphState(messages=[ai_message], current_tool_call_index=None) + + print(ai_message.content) + result = orchestrator(state) + + assert "messages" in result + modified_message = result["messages"][0] + assert len(modified_message.tool_calls) == 1 + assert modified_message.tool_calls[0]["name"] == "regular_tool" + + # Check content is filtered + filtered_content = modified_message.content + print(filtered_content) + assert len(filtered_content) == 2 # text block + regular_tool block + tool_use_blocks = [ + block + for block in filtered_content + if isinstance(block, dict) and block.get("call_id") is not None + ] + assert len(tool_use_blocks) == 1 + assert tool_use_blocks[0]["call_id"] == "call_1" + + def test_processing_batch_advancement(self): + """Test advancement through tool call batch.""" + orchestrator = create_orchestrator_node() + tool_calls = [ + { + "name": "tool_1", + "args": {"input": "test1"}, + "id": "call_1", + }, + { + "name": "tool_2", + "args": {"input": "test2"}, + "id": "call_2", + }, + ] + ai_message = AIMessage(content="Using tools", tool_calls=tool_calls) + state = AgentGraphState(messages=[ai_message], current_tool_call_index=0) + + result = orchestrator(state) + + assert result == {"current_tool_call_index": 1} + + def test_processing_batch_completion(self): + """Test completion of tool call batch.""" + orchestrator = create_orchestrator_node() + tool_calls = [ + { + "name": "tool_1", + "args": {"input": "test1"}, + "id": "call_1", + }, + { + "name": "tool_2", + "args": {"input": "test2"}, + "id": "call_2", + }, + ] + ai_message = AIMessage(content="Using tools", tool_calls=tool_calls) + state = AgentGraphState(messages=[ai_message], current_tool_call_index=1) + + result = orchestrator(state) + + assert result == {"current_tool_call_index": None} + + def test_no_latest_ai_message_in_batch_throws_exception(self): + """Test that no latest AI message during batch processing throws exception.""" + orchestrator = create_orchestrator_node() + human_message = HumanMessage(content="Hello") + state = AgentGraphState(messages=[human_message], current_tool_call_index=0) + + with pytest.raises(AgentNodeRoutingException): + orchestrator(state) + + def test_latest_ai_message_with_tool_responses_mixed(self): + """Test finding latest AI message when mixed with tool responses.""" + orchestrator = create_orchestrator_node() + tool_call = { + "name": "test_tool", + "args": {"input": "test"}, + "id": "call_1", + } + messages = [ + AIMessage(content="Using tool", tool_calls=[tool_call]), + HumanMessage(content="Some response"), + AIMessage(content="Another tool call", tool_calls=[tool_call]), + ] + state = AgentGraphState(messages=messages, current_tool_call_index=None) + + result = orchestrator(state) + + assert result == {"current_tool_call_index": 0} diff --git a/tests/agent/tools/test_tool_node.py b/tests/agent/tools/test_tool_node.py index 3d5633ed..edabd90e 100644 --- a/tests/agent/tools/test_tool_node.py +++ b/tests/agent/tools/test_tool_node.py @@ -3,7 +3,7 @@ from typing import Any, Dict import pytest -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import AIMessage from langchain_core.messages.tool import ToolCall, ToolMessage from langchain_core.tools import BaseTool from langgraph.types import Command @@ -110,10 +110,18 @@ def empty_state(self): return MockState(messages=[ai_message]) @pytest.fixture - def non_ai_state(self): - """Fixture for state with non-AI message.""" - human_message = HumanMessage(content="Hello") - return MockState(messages=[human_message]) + def state_with_existing_tool_response(self): + """Fixture for state with existing tool response.""" + tool_call = { + "name": "mock_tool", + "args": {"input_text": "test input"}, + "id": "test_call_id", + } + ai_message = AIMessage(content="Using tool", tool_calls=[tool_call]) + tool_message = ToolMessage( + content="Previous result", name="mock_tool", tool_call_id="test_call_id" + ) + return MockState(messages=[ai_message, tool_message]) def test_basic_tool_execution(self, mock_tool, mock_state): """Test basic tool execution without wrappers.""" @@ -200,14 +208,14 @@ def test_no_tool_calls_returns_none(self, mock_tool, empty_state): assert result is None - def test_non_ai_message_raises_error(self, mock_tool, non_ai_state): - """Test that non-AI messages raise ValueError.""" + def test_existing_tool_response_raises_error( + self, mock_tool, state_with_existing_tool_response + ): + """Test that existing tool responses raise ValueError.""" node = UiPathToolNode(mock_tool) - with pytest.raises( - ValueError, match="Last message in message stack is not an AIMessage" - ): - node._func(non_ai_state) + with pytest.raises(ValueError): + node._func(state_with_existing_tool_response) def test_mismatched_tool_name_returns_none(self, mock_tool, mock_state): """Test that mismatched tool names return None.""" @@ -240,10 +248,7 @@ def invalid_wrapper( node = UiPathToolNode(mock_tool, wrapper=invalid_wrapper) - with pytest.raises( - ValueError, - match="Wrapper state parameter must be a pydantic BaseModel subclass", - ): + with pytest.raises(ValueError): node._func(mock_state) diff --git a/uv.lock b/uv.lock index 0c5cf0bb..f56b203e 100644 --- a/uv.lock +++ b/uv.lock @@ -3260,7 +3260,7 @@ wheels = [ [[package]] name = "uipath-langchain" -version = "0.2.0" +version = "0.2.1" source = { editable = "." } dependencies = [ { name = "aiosqlite" },