diff --git a/src/uipath_langchain/agent/tools/tool_node.py b/src/uipath_langchain/agent/tools/tool_node.py index 6391b270..14124cb0 100644 --- a/src/uipath_langchain/agent/tools/tool_node.py +++ b/src/uipath_langchain/agent/tools/tool_node.py @@ -2,10 +2,12 @@ from collections.abc import Sequence from inspect import signature -from typing import Any, Awaitable, Callable, Literal +from typing import Any, Awaitable, Callable, Literal, get_type_hints +from langchain.tools import ToolRuntime from langchain_core.messages.ai import AIMessage from langchain_core.messages.tool import ToolCall, ToolMessage +from langchain_core.runnables.config import RunnableConfig from langchain_core.tools import BaseTool from langgraph._internal._runnable import RunnableCallable from langgraph.types import Command @@ -47,8 +49,39 @@ def __init__( self.tool = tool self.wrapper = wrapper self.awrapper = awrapper + self._needs_runtime = self._detect_runtime_injection(tool) + + def _detect_runtime_injection(self, tool: BaseTool) -> bool: + """Check if tool's func/coroutine expects ToolRuntime injection.""" + func = getattr(tool, "func", None) or getattr(tool, "coroutine", None) + if not func: + return False + try: + hints = get_type_hints(func) + return hints.get("runtime") is ToolRuntime + except Exception: + return False + + def _get_tool_args( + self, call: ToolCall, state: Any, config: RunnableConfig | None + ) -> dict[str, Any]: + """Get tool args, injecting ToolRuntime if needed.""" + args = call["args"] + if self._needs_runtime: + runtime = ToolRuntime( + state=state, + tool_call_id=call["id"], + config=config or {}, + context=None, + stream_writer=lambda _: None, + store=None, + ) + args = {**args, "runtime": runtime} + return args - def _func(self, state: Any) -> OutputType: + def _func( + self, state: Any, config: RunnableConfig | None = None + ) -> OutputType: call = self._extract_tool_call(state) if call is None: return None @@ -56,11 +89,12 @@ def _func(self, state: Any) -> OutputType: filtered_state = self._filter_state(state, self.wrapper) result = self.wrapper(self.tool, call, filtered_state) else: - result = self.tool.invoke(call["args"]) - + result = self.tool.invoke(self._get_tool_args(call, state, config)) return self._process_result(call, result) - async def _afunc(self, state: Any) -> OutputType: + async def _afunc( + self, state: Any, config: RunnableConfig | None = None + ) -> OutputType: call = self._extract_tool_call(state) if call is None: return None @@ -68,8 +102,7 @@ async def _afunc(self, state: Any) -> OutputType: filtered_state = self._filter_state(state, self.awrapper) result = await self.awrapper(self.tool, call, filtered_state) else: - result = await self.tool.ainvoke(call["args"]) - + result = await self.tool.ainvoke(self._get_tool_args(call, state, config)) return self._process_result(call, result) def _extract_tool_call(self, state: Any) -> ToolCall | None: diff --git a/tests/agent/tools/test_tool_node.py b/tests/agent/tools/test_tool_node.py index 3d5633ed..71680857 100644 --- a/tests/agent/tools/test_tool_node.py +++ b/tests/agent/tools/test_tool_node.py @@ -1,19 +1,47 @@ """Tests for tool_node.py module.""" +import importlib.util +import sys from typing import Any, Dict import pytest +from langchain.tools import ToolRuntime from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages.tool import ToolCall, ToolMessage -from langchain_core.tools import BaseTool +from langchain_core.tools import BaseTool, StructuredTool from langgraph.types import Command -from pydantic import BaseModel - -from uipath_langchain.agent.tools.tool_node import ( - ToolWrapperMixin, - UiPathToolNode, - create_tool_node, -) +from pydantic import BaseModel, Field + + +# Import directly from module file to avoid circular import through __init__.py +def _import_tool_node() -> Any: + """Import tool_node module directly to bypass circular import.""" + import os + + module_path = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "..", + "src", + "uipath_langchain", + "agent", + "tools", + "tool_node.py", + ) + module_path = os.path.abspath(module_path) + spec = importlib.util.spec_from_file_location("tool_node", module_path) + assert spec is not None and spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules["tool_node"] = module + spec.loader.exec_module(module) + return module + + +_tool_node_module = _import_tool_node() +ToolWrapperMixin: Any = _tool_node_module.ToolWrapperMixin +UiPathToolNode: Any = _tool_node_module.UiPathToolNode +create_tool_node: Any = _tool_node_module.create_tool_node class MockTool(BaseTool): @@ -317,3 +345,206 @@ def test_create_tool_node_empty_tools(self): result = create_tool_node([]) assert result == {} + + +class RuntimeToolInput(BaseModel): + """Input schema for runtime tools.""" + + input_text: str = Field(default="", description="Input text") + + +class RegularToolInput(BaseModel): + """Input schema for regular tools.""" + + input_text: str = Field(default="", description="Input text") + + +class TestRuntimeInjection: + """Test cases for ToolRuntime injection feature.""" + + @pytest.fixture + def mock_state(self): + """Fixture for mock state with tool call.""" + tool_call = { + "name": "runtime_tool", + "args": {"input_text": "test input"}, + "id": "test_call_id_123", + } + ai_message = AIMessage(content="Using tool", tool_calls=[tool_call]) + return MockState(messages=[ai_message]) + + @pytest.fixture + def mock_state_regular(self): + """Fixture for mock state with regular tool call.""" + tool_call = { + "name": "regular_tool", + "args": {"input_text": "test input"}, + "id": "regular_call_id", + } + ai_message = AIMessage(content="Using tool", tool_calls=[tool_call]) + return MockState(messages=[ai_message]) + + def test_detect_runtime_injection_true(self): + """Test _detect_runtime_injection returns True for tools with runtime param.""" + + async def tool_with_runtime(runtime: ToolRuntime, **kwargs: Any) -> str: + return f"Got runtime with tool_call_id: {runtime.tool_call_id}" + + tool = StructuredTool( + name="runtime_tool", + description="Tool that requires runtime", + args_schema=RuntimeToolInput, + coroutine=tool_with_runtime, + ) + node = UiPathToolNode(tool) + assert node._needs_runtime is True + + def test_detect_runtime_injection_false(self): + """Test _detect_runtime_injection returns False for tools without runtime.""" + + async def tool_without_runtime(input_text: str = "") -> str: + return f"Result: {input_text}" + + tool = StructuredTool( + name="regular_tool", + description="Tool without runtime", + args_schema=RegularToolInput, + coroutine=tool_without_runtime, + ) + node = UiPathToolNode(tool) + assert node._needs_runtime is False + + def test_detect_runtime_injection_base_tool(self): + """Test _detect_runtime_injection returns False for BaseTool subclass.""" + tool = MockTool() + node = UiPathToolNode(tool) + assert node._needs_runtime is False + + async def test_async_tool_execution_with_runtime_injection(self, mock_state): + """Test async tool execution with runtime injection.""" + captured_runtime: Dict[str, Any] = {} + + async def tool_with_runtime(runtime: ToolRuntime, **kwargs: Any) -> str: + captured_runtime["tool_call_id"] = runtime.tool_call_id + captured_runtime["state"] = runtime.state + return f"Success with id: {runtime.tool_call_id}" + + tool = StructuredTool( + name="runtime_tool", + description="Tool that requires runtime", + args_schema=RuntimeToolInput, + coroutine=tool_with_runtime, + ) + node = UiPathToolNode(tool) + + result = await node._afunc(mock_state) + + assert result is not None + assert "messages" in result + tool_message = result["messages"][0] + assert "Success with id: test_call_id_123" in tool_message.content + assert captured_runtime["tool_call_id"] == "test_call_id_123" + assert captured_runtime["state"] == mock_state + + def test_sync_tool_execution_with_runtime_injection(self, mock_state): + """Test sync tool execution with runtime injection.""" + captured_runtime: Dict[str, Any] = {} + + def tool_with_runtime(runtime: ToolRuntime, **kwargs: Any) -> str: + captured_runtime["tool_call_id"] = runtime.tool_call_id + return f"Sync success with id: {runtime.tool_call_id}" + + tool = StructuredTool( + name="runtime_tool", + description="Tool that requires runtime", + args_schema=RuntimeToolInput, + func=tool_with_runtime, + ) + node = UiPathToolNode(tool) + + result = node._func(mock_state) + + assert result is not None + assert "messages" in result + tool_message = result["messages"][0] + assert "Sync success with id: test_call_id_123" in tool_message.content + assert captured_runtime["tool_call_id"] == "test_call_id_123" + + async def test_regular_tool_no_runtime_injection(self, mock_state_regular): + """Test regular tool execution without runtime injection.""" + + async def regular_tool(input_text: str = "") -> str: + return f"Regular result: {input_text}" + + tool = StructuredTool( + name="regular_tool", + description="Regular tool", + args_schema=RegularToolInput, + coroutine=regular_tool, + ) + node = UiPathToolNode(tool) + + result = await node._afunc(mock_state_regular) + + assert result is not None + assert "messages" in result + tool_message = result["messages"][0] + assert "Regular result: test input" in tool_message.content + + async def test_tool_returning_command_with_runtime(self, mock_state): + """Test tool with runtime returning a Command.""" + + async def tool_with_command( + runtime: ToolRuntime, **kwargs: Any + ) -> Command[Any]: + return Command( + update={ + "messages": [ + ToolMessage( + content="Completed", + tool_call_id=runtime.tool_call_id, + ) + ] + }, + goto="next_node", + ) + + tool = StructuredTool( + name="runtime_tool", + description="Tool with runtime returning command", + args_schema=RuntimeToolInput, + coroutine=tool_with_command, + ) + node = UiPathToolNode(tool) + + result = await node._afunc(mock_state) + + assert isinstance(result, Command) + assert result.goto == "next_node" + assert result.update is not None + assert result.update["messages"][0].tool_call_id == "test_call_id_123" + + def test_get_tool_args_with_runtime(self): + """Test _get_tool_args injects ToolRuntime when needed.""" + + async def tool_with_runtime(runtime: ToolRuntime, **kwargs: Any) -> str: + return "ok" + + tool = StructuredTool( + name="runtime_tool", + description="Tool with runtime", + args_schema=RuntimeToolInput, + coroutine=tool_with_runtime, + ) + node = UiPathToolNode(tool) + + call = {"name": "test", "args": {"input_text": "hi"}, "id": "call_123"} + state = MockState(messages=[]) + config = {"configurable": {"thread_id": "test_thread"}} + + args = node._get_tool_args(call, state, config) + + assert "runtime" in args + assert args["runtime"].tool_call_id == "call_123" + assert args["runtime"].state == state + assert args["input_text"] == "hi"