Skip to content

Commit 0665821

Browse files
committed
Modify write and load AiiDA methods to work without data nodes
1 parent d3ae576 commit 0665821

File tree

2 files changed

+144
-108
lines changed
  • example_workflows/arithmetic
  • python_workflow_definition/src/python_workflow_definition

2 files changed

+144
-108
lines changed
Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
{
22
"nodes": [
3-
{"id": 0, "function": "workflow.get_prod_and_div"},
4-
{"id": 1, "function": "workflow.get_sum"},
5-
{"id": 2, "value": 1},
6-
{"id": 3, "value": 2}
3+
{ "id": 0, "function": "arithmetic_workflow.get_prod_and_div" },
4+
{ "id": 1, "function": "arithmetic_workflow.get_sum" }
75
],
86
"edges": [
9-
{"target": 0, "targetPort": "x", "source": 2, "sourcePort": null},
10-
{"target": 0, "targetPort": "y", "source": 3, "sourcePort": null},
11-
{"target": 1, "targetPort": "x", "source": 0, "sourcePort": "prod"},
12-
{"target": 1, "targetPort": "y", "source": 0, "sourcePort": "div"}
7+
{ "target": 0, "targetPort": "x", "source": null, "sourcePort": null },
8+
{ "target": 0, "targetPort": "y", "source": null, "sourcePort": null },
9+
{ "target": 1, "targetPort": "x", "source": 0, "sourcePort": "prod" },
10+
{ "target": 1, "targetPort": "y", "source": 0, "sourcePort": "div" },
11+
{ "target": null, "targetPort": null, "source": 1, "sourcePort": "result" }
1312
]
14-
}
13+
}
Lines changed: 136 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
11
from importlib import import_module
22
import json
3-
import traceback
43

5-
from aiida import orm
6-
from aiida_pythonjob.data.serializer import general_serializer
7-
from aiida_workgraph import WorkGraph, task
8-
from aiida_workgraph.socket import TaskSocketNamespace
94

105
from python_workflow_definition.shared import (
116
convert_nodes_list_to_dict,
@@ -18,121 +13,163 @@
1813
)
1914

2015

21-
def load_workflow_json(file_name: str) -> WorkGraph:
16+
# -> WorkGraph
17+
def load_workflow_json(file_name: str):
18+
"""
19+
Load a workflow JSON file and convert it to a WorkGraph.
20+
21+
Args:
22+
file_name: Path to the JSON workflow file
23+
24+
Returns:
25+
A populated WorkGraph object
26+
"""
2227
with open(file_name) as f:
2328
data = json.load(f)
2429

30+
from aiida_workgraph import WorkGraph
2531
wg = WorkGraph()
2632
task_name_mapping = {}
2733

28-
for id, identifier in convert_nodes_list_to_dict(
29-
nodes_list=data[NODES_LABEL]
30-
).items():
34+
# Create all tasks first
35+
for task_id, identifier in convert_nodes_list_to_dict(data[NODES_LABEL]).items():
3136
if isinstance(identifier, str) and "." in identifier:
32-
p, m = identifier.rsplit(".", 1)
33-
mod = import_module(p)
34-
func = getattr(mod, m)
37+
# Import the function dynamically
38+
package_path, function_name = identifier.rsplit(".", 1)
39+
module = import_module(package_path)
40+
func = getattr(module, function_name)
41+
42+
# Add task and prepare for linking
3543
wg.add_task(func)
36-
# Remove the default result output, because we will add the outputs later from the data in the link
37-
del wg.tasks[-1].outputs["result"]
38-
task_name_mapping[id] = wg.tasks[-1]
39-
else:
40-
# data task
41-
data_node = general_serializer(identifier)
42-
task_name_mapping[id] = data_node
43-
44-
# add links
44+
current_task = wg.tasks[-1]
45+
46+
# Remove default output as we'll add custom outputs later
47+
del current_task.outputs["result"]
48+
task_name_mapping[task_id] = current_task
49+
50+
# Add all connections between tasks
4551
for link in data[EDGES_LABEL]:
46-
to_task = task_name_mapping[str(link[TARGET_LABEL])]
47-
# if the input is not exit, it means we pass the data into to the kwargs
48-
# in this case, we add the input socket
49-
if link[TARGET_PORT_LABEL] not in to_task.inputs:
50-
to_socket = to_task.add_input("workgraph.any", name=link[TARGET_PORT_LABEL])
51-
else:
52-
to_socket = to_task.inputs[link[TARGET_PORT_LABEL]]
53-
from_task = task_name_mapping[str(link[SOURCE_LABEL])]
54-
if isinstance(from_task, orm.Data):
55-
to_socket.value = from_task
56-
else:
57-
try:
58-
if link[SOURCE_PORT_LABEL] is None:
59-
link[SOURCE_PORT_LABEL] = "result"
60-
# because we are not define the outputs explicitly during the pythonjob creation
61-
# we add it here, and assume the output exit
62-
if link[SOURCE_PORT_LABEL] not in from_task.outputs:
63-
# if str(link["sourcePort"]) not in from_task.outputs:
64-
from_socket = from_task.add_output(
65-
"workgraph.any",
66-
name=link[SOURCE_PORT_LABEL],
67-
# name=str(link["sourcePort"]),
68-
metadata={"is_function_output": True},
69-
)
70-
else:
71-
from_socket = from_task.outputs[link[SOURCE_PORT_LABEL]]
72-
73-
wg.add_link(from_socket, to_socket)
74-
except Exception as e:
75-
traceback.print_exc()
76-
print("Failed to link", link, "with error:", e)
52+
source_id = link[SOURCE_LABEL]
53+
target_id = link[TARGET_LABEL]
54+
source_port = link[SOURCE_PORT_LABEL]
55+
target_port = link[TARGET_PORT_LABEL]
56+
57+
# Handle task-to-task connections
58+
if source_id is not None and target_id is not None:
59+
from_task = task_name_mapping[str(source_id)]
60+
to_task = task_name_mapping[str(target_id)]
61+
62+
# Create output socket on source task
63+
from_socket = from_task.add_output("workgraph.any", name=source_port, metadata={"is_function_output": True})
64+
65+
# Create or get input socket on target task
66+
if target_port not in to_task.inputs:
67+
to_socket = to_task.add_input("workgraph.any", name=target_port)
68+
else:
69+
to_socket = to_task.inputs[target_port]
70+
71+
# Connect the tasks
72+
wg.add_link(from_socket, to_socket)
73+
74+
# Handle dangling outputs (no target)
75+
elif source_id is not None and target_id is None:
76+
from_task = task_name_mapping[str(source_id)]
77+
from_task.add_output("workgraph.any", name=source_port, metadata={"is_function_output": True})
78+
7779
return wg
7880

7981

80-
def write_workflow_json(wg: WorkGraph, file_name: str) -> dict:
82+
def write_workflow_json(wg: "WorkGraph", file_name: str) -> dict:
83+
"""
84+
Write a WorkGraph to a JSON file.
85+
86+
Args:
87+
wg: WorkGraph object to serialize
88+
file_name: Path where the JSON file will be written
89+
90+
Returns:
91+
Dictionary representation of the serialized workflow
92+
"""
8193
data = {NODES_LABEL: [], EDGES_LABEL: []}
8294
node_name_mapping = {}
83-
data_node_name_mapping = {}
84-
i = 0
85-
for node in wg.tasks:
86-
executor = node.get_executor()
95+
96+
# Add all process nodes first
97+
for i, node in enumerate(wg.tasks):
98+
# Store node index for later reference
8799
node_name_mapping[node.name] = i
88100

89-
callable_name = executor["callable_name"]
90-
callable_name = f"{executor['module_path']}.{callable_name}"
101+
# Get executor info and build full callable name
102+
executor = node.get_executor()
103+
callable_name = f"{executor['module_path']}.{executor['callable_name']}"
104+
105+
# Add node to data structure
91106
data[NODES_LABEL].append({"id": i, "function": callable_name})
92-
i += 1
93107

108+
# Create edges from WorkGraph links
94109
for link in wg.links:
95110
link_data = link.to_dict()
96-
# if the from socket is the default result, we set it to None
97-
if link_data["from_socket"] == "result":
98-
link_data["from_socket"] = None
99-
link_data[TARGET_LABEL] = node_name_mapping[link_data.pop("to_node")]
100-
link_data[TARGET_PORT_LABEL] = link_data.pop("to_socket")
101-
link_data[SOURCE_LABEL] = node_name_mapping[link_data.pop("from_node")]
102-
link_data[SOURCE_PORT_LABEL] = link_data.pop("from_socket")
103-
data[EDGES_LABEL].append(link_data)
104-
105-
for node in wg.tasks:
106-
for input in node.inputs:
107-
# assume namespace is not used as input
108-
if isinstance(input, TaskSocketNamespace):
109-
continue
110-
if isinstance(input.value, orm.Data):
111-
if input.value.uuid not in data_node_name_mapping:
112-
if isinstance(input.value, orm.List):
113-
raw_value = input.value.get_list()
114-
elif isinstance(input.value, orm.Dict):
115-
raw_value = input.value.get_dict()
116-
# unknow reason, there is a key "node_type" in the dict
117-
raw_value.pop("node_type", None)
118-
else:
119-
raw_value = input.value.value
120-
data[NODES_LABEL].append({"id": i, "value": raw_value})
121-
input_node_name = i
122-
data_node_name_mapping[input.value.uuid] = input_node_name
123-
i += 1
124-
else:
125-
input_node_name = data_node_name_mapping[input.value.uuid]
126-
data[EDGES_LABEL].append(
127-
{
128-
TARGET_LABEL: node_name_mapping[node.name],
129-
TARGET_PORT_LABEL: input._name,
130-
SOURCE_LABEL: input_node_name,
131-
SOURCE_PORT_LABEL: None,
132-
}
133-
)
111+
112+
# Handle default result case
113+
from_socket = link_data.pop("from_socket")
114+
if from_socket == "result":
115+
from_socket = None
116+
117+
# Convert to expected format
118+
edge = {
119+
SOURCE_LABEL: node_name_mapping[link_data.pop("from_node")],
120+
SOURCE_PORT_LABEL: from_socket,
121+
TARGET_LABEL: node_name_mapping[link_data.pop("to_node")],
122+
TARGET_PORT_LABEL: link_data.pop("to_socket"),
123+
}
124+
125+
data[EDGES_LABEL].append(edge)
126+
127+
# Define sockets to ignore when adding workflow I/O connections
128+
wg_input_sockets = {
129+
"_wait",
130+
"metadata",
131+
"function_data",
132+
"process_label",
133+
"function_inputs",
134+
"deserializers",
135+
"serializers",
136+
}
137+
wg_output_sockets = {"_wait", "_outputs", "exit_code"}
138+
139+
# Handle first node's inputs (external inputs to workflow)
140+
first_node = wg.tasks[0]
141+
first_node_id = node_name_mapping[first_node.name]
142+
143+
for input_name in first_node.inputs._sockets:
144+
if input_name not in wg_input_sockets:
145+
data[EDGES_LABEL].append(
146+
{
147+
SOURCE_LABEL: None,
148+
SOURCE_PORT_LABEL: None,
149+
TARGET_LABEL: first_node_id,
150+
TARGET_PORT_LABEL: input_name,
151+
}
152+
)
153+
154+
# Handle last node's outputs (workflow outputs)
155+
last_node = wg.tasks[-1]
156+
last_node_id = node_name_mapping[last_node.name]
157+
158+
for output_name in last_node.outputs._sockets:
159+
if output_name not in wg_output_sockets:
160+
data[EDGES_LABEL].append(
161+
{
162+
SOURCE_LABEL: last_node_id,
163+
SOURCE_PORT_LABEL: output_name,
164+
TARGET_LABEL: None,
165+
TARGET_PORT_LABEL: None,
166+
}
167+
)
168+
169+
# Write the data to file
170+
import ipdb; ipdb.set_trace()
171+
134172
with open(file_name, "w") as f:
135-
# json.dump({"nodes": data[], "edges": edges_new_lst}, f)
136173
json.dump(data, f, indent=2)
137174

138175
return data

0 commit comments

Comments
 (0)