From 3e6697899f307ed8c7f4a4eb9b34de877b8da151 Mon Sep 17 00:00:00 2001 From: Sakshar Thakkar Date: Mon, 29 Dec 2025 01:15:11 -0800 Subject: [PATCH 1/3] feat: add ToolRuntime injection for interruptible tools MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Escalation tools require `runtime.tool_call_id` to build proper ToolMessage responses in Command objects. Without this, LangGraph cannot match the response to the original tool call. Changes: - Detect tools with `runtime: ToolRuntime` param via signature inspection - Inject ToolRuntime with tool_call_id, state, config when needed - Regular tools continue to work without runtime injection 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/uipath_langchain/agent/tools/tool_node.py | 60 ++++- tests/agent/tools/test_tool_node.py | 233 +++++++++++++++++- 2 files changed, 279 insertions(+), 14 deletions(-) diff --git a/src/uipath_langchain/agent/tools/tool_node.py b/src/uipath_langchain/agent/tools/tool_node.py index 6391b270..25870e34 100644 --- a/src/uipath_langchain/agent/tools/tool_node.py +++ b/src/uipath_langchain/agent/tools/tool_node.py @@ -2,10 +2,12 @@ from collections.abc import Sequence from inspect import signature -from typing import Any, Awaitable, Callable, Literal +from typing import Any, Awaitable, Callable, Literal, get_type_hints +from langchain.tools import ToolRuntime from langchain_core.messages.ai import AIMessage from langchain_core.messages.tool import ToolCall, ToolMessage +from langchain_core.runnables.config import RunnableConfig from langchain_core.tools import BaseTool from langgraph._internal._runnable import RunnableCallable from langgraph.types import Command @@ -47,8 +49,42 @@ def __init__( self.tool = tool self.wrapper = wrapper self.awrapper = awrapper - - def _func(self, state: Any) -> OutputType: + self._needs_runtime = self._detect_runtime_injection(tool) + + def _detect_runtime_injection(self, tool: BaseTool) -> bool: + """Check if tool's func/coroutine expects ToolRuntime injection. + + Inspects func signature for a 'runtime' param typed as ToolRuntime. + Returns True if found, False otherwise. + """ + func = getattr(tool, "func", None) or getattr(tool, "coroutine", None) + if not func: + return False + try: + hints = get_type_hints(func) + return "runtime" in hints and hints["runtime"] is ToolRuntime + except Exception: + return False + + def _build_tool_runtime( + self, + call: ToolCall, + state: Any, + config: RunnableConfig | None, + ) -> ToolRuntime: + """Construct ToolRuntime for injection into tool functions.""" + return ToolRuntime( + state=state, + tool_call_id=call["id"], + config=config or {}, + context=None, + stream_writer=None, + store=None, + ) + + def _func( + self, state: Any, config: RunnableConfig | None = None + ) -> OutputType: call = self._extract_tool_call(state) if call is None: return None @@ -56,11 +92,18 @@ def _func(self, state: Any) -> OutputType: filtered_state = self._filter_state(state, self.wrapper) result = self.wrapper(self.tool, call, filtered_state) else: - result = self.tool.invoke(call["args"]) + if self._needs_runtime: + tool_runtime = self._build_tool_runtime(call, state, config) + args = {**call["args"], "runtime": tool_runtime} + result = self.tool.invoke(args) + else: + result = self.tool.invoke(call["args"]) return self._process_result(call, result) - async def _afunc(self, state: Any) -> OutputType: + async def _afunc( + self, state: Any, config: RunnableConfig | None = None + ) -> OutputType: call = self._extract_tool_call(state) if call is None: return None @@ -68,7 +111,12 @@ async def _afunc(self, state: Any) -> OutputType: filtered_state = self._filter_state(state, self.awrapper) result = await self.awrapper(self.tool, call, filtered_state) else: - result = await self.tool.ainvoke(call["args"]) + if self._needs_runtime: + tool_runtime = self._build_tool_runtime(call, state, config) + args = {**call["args"], "runtime": tool_runtime} + result = await self.tool.ainvoke(args) + else: + result = await self.tool.ainvoke(call["args"]) return self._process_result(call, result) diff --git a/tests/agent/tools/test_tool_node.py b/tests/agent/tools/test_tool_node.py index 3d5633ed..ec504ce7 100644 --- a/tests/agent/tools/test_tool_node.py +++ b/tests/agent/tools/test_tool_node.py @@ -1,19 +1,46 @@ """Tests for tool_node.py module.""" +import importlib.util +import sys from typing import Any, Dict import pytest +from langchain.tools import ToolRuntime from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages.tool import ToolCall, ToolMessage -from langchain_core.tools import BaseTool +from langchain_core.tools import BaseTool, StructuredTool from langgraph.types import Command -from pydantic import BaseModel - -from uipath_langchain.agent.tools.tool_node import ( - ToolWrapperMixin, - UiPathToolNode, - create_tool_node, -) +from pydantic import BaseModel, Field + + +# Import directly from module file to avoid circular import through __init__.py +def _import_tool_node(): + """Import tool_node module directly to bypass circular import.""" + import os + + module_path = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "..", + "src", + "uipath_langchain", + "agent", + "tools", + "tool_node.py", + ) + module_path = os.path.abspath(module_path) + spec = importlib.util.spec_from_file_location("tool_node", module_path) + module = importlib.util.module_from_spec(spec) + sys.modules["tool_node"] = module + spec.loader.exec_module(module) + return module + + +_tool_node_module = _import_tool_node() +ToolWrapperMixin = _tool_node_module.ToolWrapperMixin +UiPathToolNode = _tool_node_module.UiPathToolNode +create_tool_node = _tool_node_module.create_tool_node class MockTool(BaseTool): @@ -317,3 +344,193 @@ def test_create_tool_node_empty_tools(self): result = create_tool_node([]) assert result == {} + + +class RuntimeToolInput(BaseModel): + """Input schema for runtime tools.""" + + input_text: str = Field(default="", description="Input text") + + +class RegularToolInput(BaseModel): + """Input schema for regular tools.""" + + input_text: str = Field(default="", description="Input text") + + +class TestRuntimeInjection: + """Test cases for ToolRuntime injection feature.""" + + @pytest.fixture + def mock_state(self): + """Fixture for mock state with tool call.""" + tool_call = { + "name": "runtime_tool", + "args": {"input_text": "test input"}, + "id": "test_call_id_123", + } + ai_message = AIMessage(content="Using tool", tool_calls=[tool_call]) + return MockState(messages=[ai_message]) + + @pytest.fixture + def mock_state_regular(self): + """Fixture for mock state with regular tool call.""" + tool_call = { + "name": "regular_tool", + "args": {"input_text": "test input"}, + "id": "regular_call_id", + } + ai_message = AIMessage(content="Using tool", tool_calls=[tool_call]) + return MockState(messages=[ai_message]) + + def test_detect_runtime_injection_true(self): + """Test _detect_runtime_injection returns True for tools with runtime param.""" + + async def tool_with_runtime(runtime: ToolRuntime, **kwargs: Any) -> str: + return f"Got runtime with tool_call_id: {runtime.tool_call_id}" + + tool = StructuredTool( + name="runtime_tool", + description="Tool that requires runtime", + args_schema=RuntimeToolInput, + coroutine=tool_with_runtime, + ) + node = UiPathToolNode(tool) + assert node._needs_runtime is True + + def test_detect_runtime_injection_false(self): + """Test _detect_runtime_injection returns False for tools without runtime.""" + + async def tool_without_runtime(input_text: str = "") -> str: + return f"Result: {input_text}" + + tool = StructuredTool( + name="regular_tool", + description="Tool without runtime", + args_schema=RegularToolInput, + coroutine=tool_without_runtime, + ) + node = UiPathToolNode(tool) + assert node._needs_runtime is False + + def test_detect_runtime_injection_base_tool(self): + """Test _detect_runtime_injection returns False for BaseTool subclass.""" + tool = MockTool() + node = UiPathToolNode(tool) + assert node._needs_runtime is False + + async def test_async_tool_execution_with_runtime_injection(self, mock_state): + """Test async tool execution with runtime injection.""" + captured_runtime = {} + + async def tool_with_runtime(runtime: ToolRuntime, **kwargs: Any) -> str: + captured_runtime["tool_call_id"] = runtime.tool_call_id + captured_runtime["state"] = runtime.state + return f"Success with id: {runtime.tool_call_id}" + + tool = StructuredTool( + name="runtime_tool", + description="Tool that requires runtime", + args_schema=RuntimeToolInput, + coroutine=tool_with_runtime, + ) + node = UiPathToolNode(tool) + + result = await node._afunc(mock_state) + + assert result is not None + assert "messages" in result + tool_message = result["messages"][0] + assert "Success with id: test_call_id_123" in tool_message.content + assert captured_runtime["tool_call_id"] == "test_call_id_123" + assert captured_runtime["state"] == mock_state + + def test_sync_tool_execution_with_runtime_injection(self, mock_state): + """Test sync tool execution with runtime injection.""" + captured_runtime = {} + + def tool_with_runtime(runtime: ToolRuntime, **kwargs: Any) -> str: + captured_runtime["tool_call_id"] = runtime.tool_call_id + return f"Sync success with id: {runtime.tool_call_id}" + + tool = StructuredTool( + name="runtime_tool", + description="Tool that requires runtime", + args_schema=RuntimeToolInput, + func=tool_with_runtime, + ) + node = UiPathToolNode(tool) + + result = node._func(mock_state) + + assert result is not None + assert "messages" in result + tool_message = result["messages"][0] + assert "Sync success with id: test_call_id_123" in tool_message.content + assert captured_runtime["tool_call_id"] == "test_call_id_123" + + async def test_regular_tool_no_runtime_injection(self, mock_state_regular): + """Test regular tool execution without runtime injection.""" + + async def regular_tool(input_text: str = "") -> str: + return f"Regular result: {input_text}" + + tool = StructuredTool( + name="regular_tool", + description="Regular tool", + args_schema=RegularToolInput, + coroutine=regular_tool, + ) + node = UiPathToolNode(tool) + + result = await node._afunc(mock_state_regular) + + assert result is not None + assert "messages" in result + tool_message = result["messages"][0] + assert "Regular result: test input" in tool_message.content + + async def test_tool_returning_command_with_runtime(self, mock_state): + """Test tool with runtime returning a Command.""" + + async def tool_with_command(runtime: ToolRuntime, **kwargs: Any) -> Command: + return Command( + update={ + "messages": [ + ToolMessage( + content="Completed", + tool_call_id=runtime.tool_call_id, + ) + ] + }, + goto="next_node", + ) + + tool = StructuredTool( + name="runtime_tool", + description="Tool with runtime returning command", + args_schema=RuntimeToolInput, + coroutine=tool_with_command, + ) + node = UiPathToolNode(tool) + + result = await node._afunc(mock_state) + + assert isinstance(result, Command) + assert result.goto == "next_node" + assert result.update["messages"][0].tool_call_id == "test_call_id_123" + + def test_build_tool_runtime(self): + """Test _build_tool_runtime constructs ToolRuntime correctly.""" + tool = MockTool() + node = UiPathToolNode(tool) + + call = {"name": "test", "args": {}, "id": "call_123"} + state = MockState(messages=[]) + config = {"configurable": {"thread_id": "test_thread"}} + + runtime = node._build_tool_runtime(call, state, config) + + assert runtime.tool_call_id == "call_123" + assert runtime.state == state + assert runtime.config == config From 9a2cdb256dfa43402f515b8bb30ebb45d6cf72ef Mon Sep 17 00:00:00 2001 From: Sakshar Thakkar Date: Mon, 29 Dec 2025 01:22:17 -0800 Subject: [PATCH 2/3] fix: resolve mypy type errors in tool_node and tests --- src/uipath_langchain/agent/tools/tool_node.py | 2 +- tests/agent/tools/test_tool_node.py | 18 +++++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/uipath_langchain/agent/tools/tool_node.py b/src/uipath_langchain/agent/tools/tool_node.py index 25870e34..3f835cbd 100644 --- a/src/uipath_langchain/agent/tools/tool_node.py +++ b/src/uipath_langchain/agent/tools/tool_node.py @@ -78,7 +78,7 @@ def _build_tool_runtime( tool_call_id=call["id"], config=config or {}, context=None, - stream_writer=None, + stream_writer=lambda _: None, # no-op writer store=None, ) diff --git a/tests/agent/tools/test_tool_node.py b/tests/agent/tools/test_tool_node.py index ec504ce7..919202d2 100644 --- a/tests/agent/tools/test_tool_node.py +++ b/tests/agent/tools/test_tool_node.py @@ -14,7 +14,7 @@ # Import directly from module file to avoid circular import through __init__.py -def _import_tool_node(): +def _import_tool_node() -> Any: """Import tool_node module directly to bypass circular import.""" import os @@ -31,6 +31,7 @@ def _import_tool_node(): ) module_path = os.path.abspath(module_path) spec = importlib.util.spec_from_file_location("tool_node", module_path) + assert spec is not None and spec.loader is not None module = importlib.util.module_from_spec(spec) sys.modules["tool_node"] = module spec.loader.exec_module(module) @@ -38,9 +39,9 @@ def _import_tool_node(): _tool_node_module = _import_tool_node() -ToolWrapperMixin = _tool_node_module.ToolWrapperMixin -UiPathToolNode = _tool_node_module.UiPathToolNode -create_tool_node = _tool_node_module.create_tool_node +ToolWrapperMixin: Any = _tool_node_module.ToolWrapperMixin +UiPathToolNode: Any = _tool_node_module.UiPathToolNode +create_tool_node: Any = _tool_node_module.create_tool_node class MockTool(BaseTool): @@ -421,7 +422,7 @@ def test_detect_runtime_injection_base_tool(self): async def test_async_tool_execution_with_runtime_injection(self, mock_state): """Test async tool execution with runtime injection.""" - captured_runtime = {} + captured_runtime: Dict[str, Any] = {} async def tool_with_runtime(runtime: ToolRuntime, **kwargs: Any) -> str: captured_runtime["tool_call_id"] = runtime.tool_call_id @@ -447,7 +448,7 @@ async def tool_with_runtime(runtime: ToolRuntime, **kwargs: Any) -> str: def test_sync_tool_execution_with_runtime_injection(self, mock_state): """Test sync tool execution with runtime injection.""" - captured_runtime = {} + captured_runtime: Dict[str, Any] = {} def tool_with_runtime(runtime: ToolRuntime, **kwargs: Any) -> str: captured_runtime["tool_call_id"] = runtime.tool_call_id @@ -493,7 +494,9 @@ async def regular_tool(input_text: str = "") -> str: async def test_tool_returning_command_with_runtime(self, mock_state): """Test tool with runtime returning a Command.""" - async def tool_with_command(runtime: ToolRuntime, **kwargs: Any) -> Command: + async def tool_with_command( + runtime: ToolRuntime, **kwargs: Any + ) -> Command[Any]: return Command( update={ "messages": [ @@ -518,6 +521,7 @@ async def tool_with_command(runtime: ToolRuntime, **kwargs: Any) -> Command: assert isinstance(result, Command) assert result.goto == "next_node" + assert result.update is not None assert result.update["messages"][0].tool_call_id == "test_call_id_123" def test_build_tool_runtime(self): From 450d67df386420ca4826a20d1273f7c7a7101a04 Mon Sep 17 00:00:00 2001 From: Sakshar Thakkar Date: Mon, 29 Dec 2025 01:29:45 -0800 Subject: [PATCH 3/3] refactor: simplify ToolRuntime injection logic --- src/uipath_langchain/agent/tools/tool_node.py | 55 +++++++------------ tests/agent/tools/test_tool_node.py | 26 ++++++--- 2 files changed, 38 insertions(+), 43 deletions(-) diff --git a/src/uipath_langchain/agent/tools/tool_node.py b/src/uipath_langchain/agent/tools/tool_node.py index 3f835cbd..14124cb0 100644 --- a/src/uipath_langchain/agent/tools/tool_node.py +++ b/src/uipath_langchain/agent/tools/tool_node.py @@ -52,35 +52,32 @@ def __init__( self._needs_runtime = self._detect_runtime_injection(tool) def _detect_runtime_injection(self, tool: BaseTool) -> bool: - """Check if tool's func/coroutine expects ToolRuntime injection. - - Inspects func signature for a 'runtime' param typed as ToolRuntime. - Returns True if found, False otherwise. - """ + """Check if tool's func/coroutine expects ToolRuntime injection.""" func = getattr(tool, "func", None) or getattr(tool, "coroutine", None) if not func: return False try: hints = get_type_hints(func) - return "runtime" in hints and hints["runtime"] is ToolRuntime + return hints.get("runtime") is ToolRuntime except Exception: return False - def _build_tool_runtime( - self, - call: ToolCall, - state: Any, - config: RunnableConfig | None, - ) -> ToolRuntime: - """Construct ToolRuntime for injection into tool functions.""" - return ToolRuntime( - state=state, - tool_call_id=call["id"], - config=config or {}, - context=None, - stream_writer=lambda _: None, # no-op writer - store=None, - ) + def _get_tool_args( + self, call: ToolCall, state: Any, config: RunnableConfig | None + ) -> dict[str, Any]: + """Get tool args, injecting ToolRuntime if needed.""" + args = call["args"] + if self._needs_runtime: + runtime = ToolRuntime( + state=state, + tool_call_id=call["id"], + config=config or {}, + context=None, + stream_writer=lambda _: None, + store=None, + ) + args = {**args, "runtime": runtime} + return args def _func( self, state: Any, config: RunnableConfig | None = None @@ -92,13 +89,7 @@ def _func( filtered_state = self._filter_state(state, self.wrapper) result = self.wrapper(self.tool, call, filtered_state) else: - if self._needs_runtime: - tool_runtime = self._build_tool_runtime(call, state, config) - args = {**call["args"], "runtime": tool_runtime} - result = self.tool.invoke(args) - else: - result = self.tool.invoke(call["args"]) - + result = self.tool.invoke(self._get_tool_args(call, state, config)) return self._process_result(call, result) async def _afunc( @@ -111,13 +102,7 @@ async def _afunc( filtered_state = self._filter_state(state, self.awrapper) result = await self.awrapper(self.tool, call, filtered_state) else: - if self._needs_runtime: - tool_runtime = self._build_tool_runtime(call, state, config) - args = {**call["args"], "runtime": tool_runtime} - result = await self.tool.ainvoke(args) - else: - result = await self.tool.ainvoke(call["args"]) - + result = await self.tool.ainvoke(self._get_tool_args(call, state, config)) return self._process_result(call, result) def _extract_tool_call(self, state: Any) -> ToolCall | None: diff --git a/tests/agent/tools/test_tool_node.py b/tests/agent/tools/test_tool_node.py index 919202d2..71680857 100644 --- a/tests/agent/tools/test_tool_node.py +++ b/tests/agent/tools/test_tool_node.py @@ -524,17 +524,27 @@ async def tool_with_command( assert result.update is not None assert result.update["messages"][0].tool_call_id == "test_call_id_123" - def test_build_tool_runtime(self): - """Test _build_tool_runtime constructs ToolRuntime correctly.""" - tool = MockTool() + def test_get_tool_args_with_runtime(self): + """Test _get_tool_args injects ToolRuntime when needed.""" + + async def tool_with_runtime(runtime: ToolRuntime, **kwargs: Any) -> str: + return "ok" + + tool = StructuredTool( + name="runtime_tool", + description="Tool with runtime", + args_schema=RuntimeToolInput, + coroutine=tool_with_runtime, + ) node = UiPathToolNode(tool) - call = {"name": "test", "args": {}, "id": "call_123"} + call = {"name": "test", "args": {"input_text": "hi"}, "id": "call_123"} state = MockState(messages=[]) config = {"configurable": {"thread_id": "test_thread"}} - runtime = node._build_tool_runtime(call, state, config) + args = node._get_tool_args(call, state, config) - assert runtime.tool_call_id == "call_123" - assert runtime.state == state - assert runtime.config == config + assert "runtime" in args + assert args["runtime"].tool_call_id == "call_123" + assert args["runtime"].state == state + assert args["input_text"] == "hi"