Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 40 additions & 7 deletions src/uipath_langchain/agent/tools/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from collections.abc import Sequence
from inspect import signature
from typing import Any, Awaitable, Callable, Literal
from typing import Any, Awaitable, Callable, Literal, get_type_hints

from langchain.tools import ToolRuntime
from langchain_core.messages.ai import AIMessage
from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.runnables.config import RunnableConfig
from langchain_core.tools import BaseTool
from langgraph._internal._runnable import RunnableCallable
from langgraph.types import Command
Expand Down Expand Up @@ -47,29 +49,60 @@ def __init__(
self.tool = tool
self.wrapper = wrapper
self.awrapper = awrapper
self._needs_runtime = self._detect_runtime_injection(tool)

def _detect_runtime_injection(self, tool: BaseTool) -> bool:
"""Check if tool's func/coroutine expects ToolRuntime injection."""
func = getattr(tool, "func", None) or getattr(tool, "coroutine", None)
if not func:
return False
try:
hints = get_type_hints(func)
return hints.get("runtime") is ToolRuntime
except Exception:
return False

def _get_tool_args(
self, call: ToolCall, state: Any, config: RunnableConfig | None
) -> dict[str, Any]:
"""Get tool args, injecting ToolRuntime if needed."""
args = call["args"]
if self._needs_runtime:
runtime = ToolRuntime(
state=state,
tool_call_id=call["id"],
config=config or {},
context=None,
stream_writer=lambda _: None,
store=None,
)
args = {**args, "runtime": runtime}
Comment on lines +65 to +79
Copy link
Contributor

@andreitava-uip andreitava-uip Dec 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I would take an entirely different approach here.
I'd say tools themselves should never return messages or commands. A tool should only ever return its output schema.
Keep them completely graph-agnostic.
In the custom tool node design, the idea is for the wrappers to return commands if needed.

If anything, escalation tool should be refactored to no longer be coupled to the graph, which is something that should be in the works already.

return args

def _func(self, state: Any) -> OutputType:
def _func(
self, state: Any, config: RunnableConfig | None = None
) -> OutputType:
call = self._extract_tool_call(state)
if call is None:
return None
if self.wrapper:
filtered_state = self._filter_state(state, self.wrapper)
result = self.wrapper(self.tool, call, filtered_state)
else:
result = self.tool.invoke(call["args"])

result = self.tool.invoke(self._get_tool_args(call, state, config))
return self._process_result(call, result)

async def _afunc(self, state: Any) -> OutputType:
async def _afunc(
self, state: Any, config: RunnableConfig | None = None
) -> OutputType:
call = self._extract_tool_call(state)
if call is None:
return None
if self.awrapper:
filtered_state = self._filter_state(state, self.awrapper)
result = await self.awrapper(self.tool, call, filtered_state)
else:
result = await self.tool.ainvoke(call["args"])

result = await self.tool.ainvoke(self._get_tool_args(call, state, config))
return self._process_result(call, result)

def _extract_tool_call(self, state: Any) -> ToolCall | None:
Expand Down
247 changes: 239 additions & 8 deletions tests/agent/tools/test_tool_node.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,47 @@
"""Tests for tool_node.py module."""

import importlib.util
import sys
from typing import Any, Dict

import pytest
from langchain.tools import ToolRuntime
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages.tool import ToolCall, ToolMessage
from langchain_core.tools import BaseTool
from langchain_core.tools import BaseTool, StructuredTool
from langgraph.types import Command
from pydantic import BaseModel

from uipath_langchain.agent.tools.tool_node import (
ToolWrapperMixin,
UiPathToolNode,
create_tool_node,
)
from pydantic import BaseModel, Field


# Import directly from module file to avoid circular import through __init__.py
def _import_tool_node() -> Any:
"""Import tool_node module directly to bypass circular import."""
import os

module_path = os.path.join(
os.path.dirname(__file__),
"..",
"..",
"..",
"src",
"uipath_langchain",
"agent",
"tools",
"tool_node.py",
)
module_path = os.path.abspath(module_path)
spec = importlib.util.spec_from_file_location("tool_node", module_path)
assert spec is not None and spec.loader is not None
module = importlib.util.module_from_spec(spec)
sys.modules["tool_node"] = module
spec.loader.exec_module(module)
return module


_tool_node_module = _import_tool_node()
Comment on lines +17 to +41
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really not a fan of introducing such hacks to "solve" circular imports.
If we have circular imports we should just restructure our modules in order to avoid them.

ToolWrapperMixin: Any = _tool_node_module.ToolWrapperMixin
UiPathToolNode: Any = _tool_node_module.UiPathToolNode
create_tool_node: Any = _tool_node_module.create_tool_node


class MockTool(BaseTool):
Expand Down Expand Up @@ -317,3 +345,206 @@ def test_create_tool_node_empty_tools(self):
result = create_tool_node([])

assert result == {}


class RuntimeToolInput(BaseModel):
"""Input schema for runtime tools."""

input_text: str = Field(default="", description="Input text")


class RegularToolInput(BaseModel):
"""Input schema for regular tools."""

input_text: str = Field(default="", description="Input text")


class TestRuntimeInjection:
"""Test cases for ToolRuntime injection feature."""

@pytest.fixture
def mock_state(self):
"""Fixture for mock state with tool call."""
tool_call = {
"name": "runtime_tool",
"args": {"input_text": "test input"},
"id": "test_call_id_123",
}
ai_message = AIMessage(content="Using tool", tool_calls=[tool_call])
return MockState(messages=[ai_message])

@pytest.fixture
def mock_state_regular(self):
"""Fixture for mock state with regular tool call."""
tool_call = {
"name": "regular_tool",
"args": {"input_text": "test input"},
"id": "regular_call_id",
}
ai_message = AIMessage(content="Using tool", tool_calls=[tool_call])
return MockState(messages=[ai_message])

def test_detect_runtime_injection_true(self):
"""Test _detect_runtime_injection returns True for tools with runtime param."""

async def tool_with_runtime(runtime: ToolRuntime, **kwargs: Any) -> str:
return f"Got runtime with tool_call_id: {runtime.tool_call_id}"

tool = StructuredTool(
name="runtime_tool",
description="Tool that requires runtime",
args_schema=RuntimeToolInput,
coroutine=tool_with_runtime,
)
node = UiPathToolNode(tool)
assert node._needs_runtime is True

def test_detect_runtime_injection_false(self):
"""Test _detect_runtime_injection returns False for tools without runtime."""

async def tool_without_runtime(input_text: str = "") -> str:
return f"Result: {input_text}"

tool = StructuredTool(
name="regular_tool",
description="Tool without runtime",
args_schema=RegularToolInput,
coroutine=tool_without_runtime,
)
node = UiPathToolNode(tool)
assert node._needs_runtime is False

def test_detect_runtime_injection_base_tool(self):
"""Test _detect_runtime_injection returns False for BaseTool subclass."""
tool = MockTool()
node = UiPathToolNode(tool)
assert node._needs_runtime is False

async def test_async_tool_execution_with_runtime_injection(self, mock_state):
"""Test async tool execution with runtime injection."""
captured_runtime: Dict[str, Any] = {}

async def tool_with_runtime(runtime: ToolRuntime, **kwargs: Any) -> str:
captured_runtime["tool_call_id"] = runtime.tool_call_id
captured_runtime["state"] = runtime.state
return f"Success with id: {runtime.tool_call_id}"

tool = StructuredTool(
name="runtime_tool",
description="Tool that requires runtime",
args_schema=RuntimeToolInput,
coroutine=tool_with_runtime,
)
node = UiPathToolNode(tool)

result = await node._afunc(mock_state)

assert result is not None
assert "messages" in result
tool_message = result["messages"][0]
assert "Success with id: test_call_id_123" in tool_message.content
assert captured_runtime["tool_call_id"] == "test_call_id_123"
assert captured_runtime["state"] == mock_state

def test_sync_tool_execution_with_runtime_injection(self, mock_state):
"""Test sync tool execution with runtime injection."""
captured_runtime: Dict[str, Any] = {}

def tool_with_runtime(runtime: ToolRuntime, **kwargs: Any) -> str:
captured_runtime["tool_call_id"] = runtime.tool_call_id
return f"Sync success with id: {runtime.tool_call_id}"

tool = StructuredTool(
name="runtime_tool",
description="Tool that requires runtime",
args_schema=RuntimeToolInput,
func=tool_with_runtime,
)
node = UiPathToolNode(tool)

result = node._func(mock_state)

assert result is not None
assert "messages" in result
tool_message = result["messages"][0]
assert "Sync success with id: test_call_id_123" in tool_message.content
assert captured_runtime["tool_call_id"] == "test_call_id_123"

async def test_regular_tool_no_runtime_injection(self, mock_state_regular):
"""Test regular tool execution without runtime injection."""

async def regular_tool(input_text: str = "") -> str:
return f"Regular result: {input_text}"

tool = StructuredTool(
name="regular_tool",
description="Regular tool",
args_schema=RegularToolInput,
coroutine=regular_tool,
)
node = UiPathToolNode(tool)

result = await node._afunc(mock_state_regular)

assert result is not None
assert "messages" in result
tool_message = result["messages"][0]
assert "Regular result: test input" in tool_message.content

async def test_tool_returning_command_with_runtime(self, mock_state):
"""Test tool with runtime returning a Command."""

async def tool_with_command(
runtime: ToolRuntime, **kwargs: Any
) -> Command[Any]:
return Command(
update={
"messages": [
ToolMessage(
content="Completed",
tool_call_id=runtime.tool_call_id,
)
]
},
goto="next_node",
)

tool = StructuredTool(
name="runtime_tool",
description="Tool with runtime returning command",
args_schema=RuntimeToolInput,
coroutine=tool_with_command,
)
node = UiPathToolNode(tool)

result = await node._afunc(mock_state)

assert isinstance(result, Command)
assert result.goto == "next_node"
assert result.update is not None
assert result.update["messages"][0].tool_call_id == "test_call_id_123"

def test_get_tool_args_with_runtime(self):
"""Test _get_tool_args injects ToolRuntime when needed."""

async def tool_with_runtime(runtime: ToolRuntime, **kwargs: Any) -> str:
return "ok"

tool = StructuredTool(
name="runtime_tool",
description="Tool with runtime",
args_schema=RuntimeToolInput,
coroutine=tool_with_runtime,
)
node = UiPathToolNode(tool)

call = {"name": "test", "args": {"input_text": "hi"}, "id": "call_123"}
state = MockState(messages=[])
config = {"configurable": {"thread_id": "test_thread"}}

args = node._get_tool_args(call, state, config)

assert "runtime" in args
assert args["runtime"].tool_call_id == "call_123"
assert args["runtime"].state == state
assert args["input_text"] == "hi"
Loading