Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions example_workflows/arithmetic/workflow.json
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
{
"nodes": [
{"id": 0, "function": "workflow.get_prod_and_div"},
{"id": 1, "function": "workflow.get_sum"},
{"id": 2, "value": 1},
{"id": 3, "value": 2}
{ "id": 0, "function": "workflow.get_prod_and_div" },
{ "id": 1, "function": "workflow.get_sum" }
],
"edges": [
{"target": 0, "targetPort": "x", "source": 2, "sourcePort": null},
{"target": 0, "targetPort": "y", "source": 3, "sourcePort": null},
{"target": 1, "targetPort": "x", "source": 0, "sourcePort": "prod"},
{"target": 1, "targetPort": "y", "source": 0, "sourcePort": "div"}
{ "target": 0, "targetPort": "x", "source": null, "sourcePort": null },
{ "target": 0, "targetPort": "y", "source": null, "sourcePort": null },
{ "target": 1, "targetPort": "x", "source": 0, "sourcePort": "prod" },
{ "target": 1, "targetPort": "y", "source": 0, "sourcePort": "div" },
{ "target": null, "targetPort": null, "source": 1, "sourcePort": "result" }
]
}
}
235 changes: 136 additions & 99 deletions python_workflow_definition/src/python_workflow_definition/aiida.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
from importlib import import_module
import json
import traceback

from aiida import orm
from aiida_pythonjob.data.serializer import general_serializer
from aiida_workgraph import WorkGraph, task
from aiida_workgraph.socket import TaskSocketNamespace

from python_workflow_definition.shared import (
convert_nodes_list_to_dict,
Expand All @@ -18,121 +13,163 @@
)


def load_workflow_json(file_name: str) -> WorkGraph:
# -> WorkGraph
def load_workflow_json(file_name: str):
"""
Load a workflow JSON file and convert it to a WorkGraph.

Args:
file_name: Path to the JSON workflow file

Returns:
A populated WorkGraph object
"""
with open(file_name) as f:
data = json.load(f)

from aiida_workgraph import WorkGraph
wg = WorkGraph()
task_name_mapping = {}

for id, identifier in convert_nodes_list_to_dict(
nodes_list=data[NODES_LABEL]
).items():
# Create all tasks first
for task_id, identifier in convert_nodes_list_to_dict(data[NODES_LABEL]).items():
if isinstance(identifier, str) and "." in identifier:
p, m = identifier.rsplit(".", 1)
mod = import_module(p)
func = getattr(mod, m)
# Import the function dynamically
package_path, function_name = identifier.rsplit(".", 1)
module = import_module(package_path)
func = getattr(module, function_name)

# Add task and prepare for linking
wg.add_task(func)
# Remove the default result output, because we will add the outputs later from the data in the link
del wg.tasks[-1].outputs["result"]
task_name_mapping[id] = wg.tasks[-1]
else:
# data task
data_node = general_serializer(identifier)
task_name_mapping[id] = data_node

# add links
current_task = wg.tasks[-1]

# Remove default output as we'll add custom outputs later
del current_task.outputs["result"]
task_name_mapping[task_id] = current_task

# Add all connections between tasks
for link in data[EDGES_LABEL]:
to_task = task_name_mapping[str(link[TARGET_LABEL])]
# if the input is not exit, it means we pass the data into to the kwargs
# in this case, we add the input socket
if link[TARGET_PORT_LABEL] not in to_task.inputs:
to_socket = to_task.add_input("workgraph.any", name=link[TARGET_PORT_LABEL])
else:
to_socket = to_task.inputs[link[TARGET_PORT_LABEL]]
from_task = task_name_mapping[str(link[SOURCE_LABEL])]
if isinstance(from_task, orm.Data):
to_socket.value = from_task
else:
try:
if link[SOURCE_PORT_LABEL] is None:
link[SOURCE_PORT_LABEL] = "result"
# because we are not define the outputs explicitly during the pythonjob creation
# we add it here, and assume the output exit
if link[SOURCE_PORT_LABEL] not in from_task.outputs:
# if str(link["sourcePort"]) not in from_task.outputs:
from_socket = from_task.add_output(
"workgraph.any",
name=link[SOURCE_PORT_LABEL],
# name=str(link["sourcePort"]),
metadata={"is_function_output": True},
)
else:
from_socket = from_task.outputs[link[SOURCE_PORT_LABEL]]

wg.add_link(from_socket, to_socket)
except Exception as e:
traceback.print_exc()
print("Failed to link", link, "with error:", e)
source_id = link[SOURCE_LABEL]
target_id = link[TARGET_LABEL]
source_port = link[SOURCE_PORT_LABEL]
target_port = link[TARGET_PORT_LABEL]

# Handle task-to-task connections
if source_id is not None and target_id is not None:
from_task = task_name_mapping[str(source_id)]
to_task = task_name_mapping[str(target_id)]

# Create output socket on source task
from_socket = from_task.add_output("workgraph.any", name=source_port, metadata={"is_function_output": True})

# Create or get input socket on target task
if target_port not in to_task.inputs:
to_socket = to_task.add_input("workgraph.any", name=target_port)
else:
to_socket = to_task.inputs[target_port]

# Connect the tasks
wg.add_link(from_socket, to_socket)

# Handle dangling outputs (no target)
elif source_id is not None and target_id is None:
from_task = task_name_mapping[str(source_id)]
from_task.add_output("workgraph.any", name=source_port, metadata={"is_function_output": True})

return wg


def write_workflow_json(wg: WorkGraph, file_name: str) -> dict:
def write_workflow_json(wg: "WorkGraph", file_name: str) -> dict:
"""
Write a WorkGraph to a JSON file.

Args:
wg: WorkGraph object to serialize
file_name: Path where the JSON file will be written

Returns:
Dictionary representation of the serialized workflow
"""
data = {NODES_LABEL: [], EDGES_LABEL: []}
node_name_mapping = {}
data_node_name_mapping = {}
i = 0
for node in wg.tasks:
executor = node.get_executor()

# Add all process nodes first
for i, node in enumerate(wg.tasks):
# Store node index for later reference
node_name_mapping[node.name] = i

callable_name = executor["callable_name"]
callable_name = f"{executor['module_path']}.{callable_name}"
# Get executor info and build full callable name
executor = node.get_executor()
callable_name = f"{executor['module_path']}.{executor['callable_name']}"

# Add node to data structure
data[NODES_LABEL].append({"id": i, "function": callable_name})
i += 1

# Create edges from WorkGraph links
for link in wg.links:
link_data = link.to_dict()
# if the from socket is the default result, we set it to None
if link_data["from_socket"] == "result":
link_data["from_socket"] = None
link_data[TARGET_LABEL] = node_name_mapping[link_data.pop("to_node")]
link_data[TARGET_PORT_LABEL] = link_data.pop("to_socket")
link_data[SOURCE_LABEL] = node_name_mapping[link_data.pop("from_node")]
link_data[SOURCE_PORT_LABEL] = link_data.pop("from_socket")
data[EDGES_LABEL].append(link_data)

for node in wg.tasks:
for input in node.inputs:
# assume namespace is not used as input
if isinstance(input, TaskSocketNamespace):
continue
if isinstance(input.value, orm.Data):
if input.value.uuid not in data_node_name_mapping:
if isinstance(input.value, orm.List):
raw_value = input.value.get_list()
elif isinstance(input.value, orm.Dict):
raw_value = input.value.get_dict()
# unknow reason, there is a key "node_type" in the dict
raw_value.pop("node_type", None)
else:
raw_value = input.value.value
data[NODES_LABEL].append({"id": i, "value": raw_value})
input_node_name = i
data_node_name_mapping[input.value.uuid] = input_node_name
i += 1
else:
input_node_name = data_node_name_mapping[input.value.uuid]
data[EDGES_LABEL].append(
{
TARGET_LABEL: node_name_mapping[node.name],
TARGET_PORT_LABEL: input._name,
SOURCE_LABEL: input_node_name,
SOURCE_PORT_LABEL: None,
}
)

# Handle default result case
from_socket = link_data.pop("from_socket")
if from_socket == "result":
from_socket = None

# Convert to expected format
edge = {
SOURCE_LABEL: node_name_mapping[link_data.pop("from_node")],
SOURCE_PORT_LABEL: from_socket,
TARGET_LABEL: node_name_mapping[link_data.pop("to_node")],
TARGET_PORT_LABEL: link_data.pop("to_socket"),
}

data[EDGES_LABEL].append(edge)

# Define sockets to ignore when adding workflow I/O connections
wg_input_sockets = {
"_wait",
"metadata",
"function_data",
"process_label",
"function_inputs",
"deserializers",
"serializers",
}
wg_output_sockets = {"_wait", "_outputs", "exit_code"}

# Handle first node's inputs (external inputs to workflow)
first_node = wg.tasks[0]
first_node_id = node_name_mapping[first_node.name]

for input_name in first_node.inputs._sockets:
if input_name not in wg_input_sockets:
data[EDGES_LABEL].append(
{
SOURCE_LABEL: None,
SOURCE_PORT_LABEL: None,
TARGET_LABEL: first_node_id,
TARGET_PORT_LABEL: input_name,
}
)

# Handle last node's outputs (workflow outputs)
last_node = wg.tasks[-1]
last_node_id = node_name_mapping[last_node.name]

for output_name in last_node.outputs._sockets:
if output_name not in wg_output_sockets:
data[EDGES_LABEL].append(
{
SOURCE_LABEL: last_node_id,
SOURCE_PORT_LABEL: output_name,
TARGET_LABEL: None,
TARGET_PORT_LABEL: None,
}
)

# Write the data to file
import ipdb; ipdb.set_trace()

with open(file_name, "w") as f:
# json.dump({"nodes": data[], "edges": edges_new_lst}, f)
json.dump(data, f, indent=2)

return data
Loading