Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
[project]
name = "uipath-langchain"
version = "0.1.44"
version = "0.2.0"
description = "Python SDK that enables developers to build and deploy LangGraph agents to the UiPath Cloud Platform"
readme = { file = "README.md", content-type = "text/markdown" }
requires-python = ">=3.11"
dependencies = [
"uipath>=2.2.44, <2.3.0",
"uipath>=2.3.0, <2.4.0",
"uipath-runtime>=0.3.2, <0.4.0",
"langgraph>=1.0.0, <2.0.0",
"langchain-core>=1.0.0, <2.0.0",
"langchain-core>=1.2.5, <2.0.0",
"aiosqlite==0.21.0",
"langgraph-checkpoint-sqlite>=3.0.0, <4.0.0",
"langchain-openai>=1.0.0, <2.0.0",
Expand Down
1 change: 1 addition & 0 deletions src/uipath_langchain/runtime/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ async def _create_runtime_instance(
delegate=base_runtime,
storage=storage,
trigger_manager=trigger_manager,
runtime_id=runtime_id,
)

async def new_runtime(
Expand Down
50 changes: 21 additions & 29 deletions src/uipath_langchain/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,29 +293,7 @@ def _extract_graph_result(self, final_chunk: Any) -> Any:

def _is_interrupted(self, state: StateSnapshot) -> bool:
"""Check if execution was interrupted (static or dynamic)."""
# Check for static interrupts (interrupt_before/after)
if hasattr(state, "next") and state.next:
return True

# Check for dynamic interrupts (interrupt() inside node)
if hasattr(state, "tasks"):
for task in state.tasks:
if hasattr(task, "interrupts") and task.interrupts:
return True

return False

def _get_dynamic_interrupt(self, state: StateSnapshot) -> Interrupt | None:
"""Get the first dynamic interrupt if any."""
if not hasattr(state, "tasks"):
return None

for task in state.tasks:
if hasattr(task, "interrupts") and task.interrupts:
for interrupt in task.interrupts:
if isinstance(interrupt, Interrupt):
return interrupt
return None
return bool(state.next)

async def _create_runtime_result(
self,
Expand Down Expand Up @@ -344,13 +322,27 @@ async def _create_suspended_result(
graph_state: StateSnapshot,
) -> UiPathRuntimeResult:
"""Create result for suspended execution."""
# Check if it's a dynamic interrupt
dynamic_interrupt = self._get_dynamic_interrupt(graph_state)

if dynamic_interrupt:
# Dynamic interrupt - should create and save resume trigger
interrupt_map: dict[str, Any] = {}

# Get nodes that are still scheduled to run
next_nodes = set(graph_state.next) if graph_state.next else set()

if graph_state.interrupts:
for interrupt in graph_state.interrupts:
if isinstance(interrupt, Interrupt):
# Find which task this interrupt belongs to
for task in graph_state.tasks:
if task.interrupts and interrupt in task.interrupts:
# Only include if this task's node is still in next
if task.name in next_nodes:
interrupt_map[interrupt.id] = interrupt.value
break

# If we have dynamic interrupts, return suspended with interrupt map
# The output is used to create the resume triggers
if interrupt_map:
return UiPathRuntimeResult(
output=dynamic_interrupt.value,
output=interrupt_map,
status=UiPathRuntimeStatus.SUSPENDED,
)
else:
Expand Down
249 changes: 178 additions & 71 deletions src/uipath_langchain/runtime/storage.py
Original file line number Diff line number Diff line change
@@ -1,115 +1,222 @@
"""SQLite implementation of UiPathResumableStorageProtocol."""

import json
from typing import cast
from typing import Any, cast

from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from pydantic import BaseModel
from uipath.runtime import (
UiPathApiTrigger,
UiPathResumeTrigger,
UiPathResumeTriggerName,
UiPathResumeTriggerType,
)
from uipath.runtime import UiPathResumeTrigger


class SqliteResumableStorage:
"""SQLite storage for resume triggers."""
"""SQLite storage for resume triggers and arbitrary kv pairs."""

def __init__(
self, memory: AsyncSqliteSaver, table_name: str = "__uipath_resume_triggers"
self,
memory: AsyncSqliteSaver,
):
self.memory = memory
self.table_name = table_name
self.rs_table_name = "__uipath_resume_triggers"
self.kv_table_name = "__uipath_runtime_kv"
self._initialized = False

async def _ensure_table(self) -> None:
"""Create table if needed."""
"""Create tables if needed."""
if self._initialized:
return

await self.memory.setup()
async with self.memory.lock, self.memory.conn.cursor() as cur:
await cur.execute(f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
# Enable WAL mode for high concurrency
await cur.execute("PRAGMA journal_mode=WAL")

await cur.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.rs_table_name} (
id INTEGER PRIMARY KEY AUTOINCREMENT,
type TEXT NOT NULL,
name TEXT NOT NULL,
key TEXT,
folder_key TEXT,
folder_path TEXT,
payload TEXT,
runtime_id TEXT NOT NULL,
interrupt_id TEXT NOT NULL,
data TEXT NOT NULL,
timestamp DATETIME DEFAULT (strftime('%Y-%m-%d %H:%M:%S', 'now', 'utc'))
)
""")
await self.memory.conn.commit()
self._initialized = True
"""
)

async def save_trigger(self, trigger: UiPathResumeTrigger) -> None:
"""Save resume trigger to database."""
await self._ensure_table()
await cur.execute(
f"""
CREATE INDEX IF NOT EXISTS idx_{self.rs_table_name}_runtime_id
ON {self.rs_table_name}(runtime_id)
"""
)

trigger_key = (
trigger.api_resume.inbox_id if trigger.api_resume else trigger.item_key
)
payload = trigger.payload
if payload:
payload = (
(
payload.model_dump()
if isinstance(payload, BaseModel)
else json.dumps(payload)
await cur.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.kv_table_name} (
runtime_id TEXT NOT NULL,
namespace TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT,
timestamp DATETIME DEFAULT (strftime('%Y-%m-%d %H:%M:%S', 'now', 'utc')),
PRIMARY KEY (runtime_id, namespace, key)
)
if isinstance(payload, dict)
else str(payload)
"""
)

await self.memory.conn.commit()

self._initialized = True

async def save_triggers(
self, runtime_id: str, triggers: list[UiPathResumeTrigger]
) -> None:
"""Save resume triggers to database, replacing all existing triggers for this runtime_id."""
await self._ensure_table()

async with self.memory.lock, self.memory.conn.cursor() as cur:
# Delete all existing triggers for this runtime_id
await cur.execute(
f"INSERT INTO {self.table_name} (type, key, name, payload, folder_path, folder_key) VALUES (?, ?, ?, ?, ?, ?)",
(
trigger.trigger_type.value,
trigger_key,
trigger.trigger_name.value,
payload,
trigger.folder_path,
trigger.folder_key,
),
f"""
DELETE FROM {self.rs_table_name}
WHERE runtime_id = ?
""",
(runtime_id,),
)

# Insert new triggers
for trigger in triggers:
trigger_data = trigger.model_dump()
trigger_data["payload"] = trigger.payload
trigger_data["trigger_name"] = trigger.trigger_name

await cur.execute(
f"""
INSERT INTO {self.rs_table_name}
(runtime_id, interrupt_id, data)
VALUES (?, ?, ?)
""",
(
runtime_id,
trigger.interrupt_id,
json.dumps(trigger_data),
),
)
await self.memory.conn.commit()

async def get_latest_trigger(self) -> UiPathResumeTrigger | None:
"""Get most recent trigger from database."""
async def get_triggers(self, runtime_id: str) -> list[UiPathResumeTrigger] | None:
"""Get all triggers for runtime_id from database."""
await self._ensure_table()

async with self.memory.lock, self.memory.conn.cursor() as cur:
await cur.execute(f"""
SELECT type, key, name, folder_path, folder_key, payload
FROM {self.table_name}
ORDER BY timestamp DESC
LIMIT 1
""")
result = await cur.fetchone()
await cur.execute(
f"""
SELECT data
FROM {self.rs_table_name}
WHERE runtime_id = ?
ORDER BY timestamp ASC
""",
(runtime_id,),
)
results = await cur.fetchall()

if not results:
return None

if not result:
return None
triggers = []
for result in results:
data_text = cast(str, result[0])
trigger = UiPathResumeTrigger.model_validate_json(data_text)
triggers.append(trigger)

trigger_type, key, name, folder_path, folder_key, payload = cast(
tuple[str, str, str, str, str, str], tuple(result)
return triggers

async def delete_trigger(
self, runtime_id: str, trigger: UiPathResumeTrigger
) -> None:
"""Delete resume trigger from storage."""
await self._ensure_table()

async with self.memory.lock, self.memory.conn.cursor() as cur:
await cur.execute(
f"""
DELETE FROM {self.rs_table_name}
WHERE runtime_id = ? AND interrupt_id = ?
""",
(
runtime_id,
trigger.interrupt_id,
),
)
await self.memory.conn.commit()

async def set_value(
self,
runtime_id: str,
namespace: str,
key: str,
value: Any,
) -> None:
"""Save arbitrary key-value pair to database."""
if not (
isinstance(value, str)
or isinstance(value, dict)
or isinstance(value, BaseModel)
or value is None
):
raise TypeError("Value must be str, dict, BaseModel or None.")

await self._ensure_table()

resume_trigger = UiPathResumeTrigger(
trigger_type=UiPathResumeTriggerType(trigger_type),
trigger_name=UiPathResumeTriggerName(name),
item_key=key,
folder_path=folder_path,
folder_key=folder_key,
payload=payload,
value_text = self._dump_value(value)

async with self.memory.lock, self.memory.conn.cursor() as cur:
await cur.execute(
f"""
INSERT INTO {self.kv_table_name} (runtime_id, namespace, key, value)
VALUES (?, ?, ?, ?)
ON CONFLICT(runtime_id, namespace, key)
DO UPDATE SET
value = excluded.value,
timestamp = (strftime('%Y-%m-%d %H:%M:%S', 'now', 'utc'))
""",
(runtime_id, namespace, key, value_text),
)
await self.memory.conn.commit()

if resume_trigger.trigger_type == UiPathResumeTriggerType.API:
resume_trigger.api_resume = UiPathApiTrigger(
inbox_id=resume_trigger.item_key, request=resume_trigger.payload
)
async def get_value(self, runtime_id: str, namespace: str, key: str) -> Any:
"""Get arbitrary key-value pair from database (scoped by runtime_id + namespace)."""
await self._ensure_table()

return resume_trigger
async with self.memory.lock, self.memory.conn.cursor() as cur:
await cur.execute(
f"""
SELECT value
FROM {self.kv_table_name}
WHERE runtime_id = ? AND namespace = ? AND key = ?
LIMIT 1
""",
(runtime_id, namespace, key),
)
row = await cur.fetchone()

if not row:
return None

return self._load_value(cast(str | None, row[0]))

def _dump_value(self, value: str | dict[str, Any] | BaseModel | None) -> str | None:
if value is None:
return None
if isinstance(value, BaseModel):
return "j:" + json.dumps(value.model_dump())
if isinstance(value, dict):
return "j:" + json.dumps(value)
return "s:" + value

def _load_value(self, raw: str | None) -> Any:
if raw is None:
return None
if raw.startswith("s:"):
return raw[2:]
if raw.startswith("j:"):
return json.loads(raw[2:])
return raw
Loading