Skip to content

Commit caafa1b

Browse files
committed
Add Aiida interface
1 parent b82b4a6 commit caafa1b

File tree

2 files changed

+228
-1
lines changed

2 files changed

+228
-1
lines changed

environment.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ dependencies:
99
- matplotlib=3.10.1
1010
- xmlschema=3.4.3
1111
- jobflow=0.1.19
12-
- pygraphviz=1.14
12+
- pygraphviz=1.14
13+
- aiida-workgraph 0.4.10
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import json
2+
from importlib import import_module
3+
from inspect import isfunction
4+
import numpy as np
5+
6+
from aiida.engine import calcfunction
7+
from aiida_workgraph import WorkGraph
8+
from aiida_workgraph.decorator import build_task_from_callable
9+
10+
11+
def get_edges_label_lst(work_graph_dict):
12+
edges_label_lst = []
13+
for link_dict in work_graph_dict["links"]:
14+
if link_dict['from_socket'] == "result":
15+
edges_label_lst.append(
16+
{
17+
'target': link_dict['to_node'],
18+
'targetHandle': link_dict['to_socket'],
19+
'source': link_dict['from_node'],
20+
'sourceHandle': None,
21+
}
22+
)
23+
else:
24+
edges_label_lst.append(
25+
{
26+
'target': link_dict['to_node'],
27+
'targetHandle': link_dict['to_socket'],
28+
'source': link_dict['from_node'],
29+
'sourceHandle': link_dict['from_socket'],
30+
}
31+
)
32+
33+
return edges_label_lst
34+
35+
36+
def get_function_dict(work_graph_dict):
37+
kwargs_dict, function_dict = {}, {}
38+
for task_name, task_dict in work_graph_dict["tasks"].items():
39+
input_variables = [
40+
input_parameter
41+
for input_parameter in task_dict['inputs'].keys()
42+
if not input_parameter.startswith("metadata") and not input_parameter.startswith("_wait")
43+
]
44+
input_kwargs = {
45+
input_parameter: task_dict['inputs'][input_parameter]['property']["value"].value if isinstance(
46+
task_dict['inputs'][input_parameter]['property']["value"], dict) else
47+
task_dict['inputs'][input_parameter]['property']["value"]
48+
for input_parameter in input_variables
49+
}
50+
kwargs_dict[task_name] = input_kwargs
51+
function_dict[task_name] = task_dict['executor']['callable'].process_class._func
52+
return kwargs_dict, function_dict
53+
54+
55+
def get_mapping(function_dict):
56+
nodes_dict, mapping_dict = {}, {}
57+
for i, [k, v] in enumerate(function_dict.items()):
58+
nodes_dict[i] = v
59+
mapping_dict[k] = i
60+
61+
return nodes_dict, mapping_dict
62+
63+
64+
def get_value_dict(kwargs_dict):
65+
value_dict = {}
66+
for func_name, val_dict in kwargs_dict.items():
67+
for k, v in val_dict.items():
68+
if v is not None:
69+
if func_name not in value_dict.keys():
70+
value_dict[func_name] = {}
71+
value_dict[func_name][k] = v
72+
73+
return value_dict
74+
75+
76+
def extend_mapping(nodes_dict, value_dict, mapping_dict):
77+
i = len(nodes_dict)
78+
for val_dict in value_dict.values():
79+
for k, v in val_dict.items():
80+
nodes_dict[i] = v
81+
mapping_dict[v] = i
82+
i += 1
83+
84+
return nodes_dict, mapping_dict
85+
86+
87+
def extend_edges_label_lst(kwargs_dict, edges_label_lst):
88+
for func_name, val_dict in kwargs_dict.items():
89+
for k, v in val_dict.items():
90+
if v is not None:
91+
edges_label_lst.append({'target': func_name, 'targetHandle': k, 'source': v, 'sourceHandle': None})
92+
return edges_label_lst
93+
94+
95+
def get_edges_lst(edges_label_lst, mapping_dict):
96+
edges_lst = []
97+
for edge in edges_label_lst:
98+
edges_lst.append({'target': mapping_dict[edge['target']], 'targetHandle': edge['targetHandle'],
99+
'source': mapping_dict[edge['source']], 'sourceHandle': edge['sourceHandle']})
100+
101+
return edges_lst
102+
103+
104+
def get_kwargs(lst):
105+
return {t['targetHandle']: {'source': t['source'], 'sourceHandle': t['sourceHandle']} for t in lst}
106+
107+
108+
def wrap_function(func, **kwargs):
109+
# First, apply the calcfunction decorator
110+
func_decorated = calcfunction(func)
111+
# Then, apply task decorator
112+
task_decorated = build_task_from_callable(
113+
func_decorated,
114+
inputs=kwargs.get("inputs", []),
115+
outputs=kwargs.get("outputs", []),
116+
)
117+
identifier = kwargs.get("identifier", None)
118+
func_decorated.identifier = identifier if identifier else func.__name__
119+
func_decorated.task = func_decorated.node = task_decorated
120+
return func_decorated
121+
122+
123+
def group_edges(edges_lst):
124+
# edges_sorted_lst = sorted(edges_lst, key=lambda x: x['target'], reverse=True)
125+
total_dict = {}
126+
tmp_lst = []
127+
target_id = edges_lst[0]['target']
128+
for ed in edges_lst:
129+
if target_id == ed["target"]:
130+
tmp_lst.append(ed)
131+
else:
132+
total_dict[target_id] = get_kwargs(lst=tmp_lst)
133+
target_id = ed["target"]
134+
tmp_lst = [ed]
135+
total_dict[target_id] = get_kwargs(lst=tmp_lst)
136+
return total_dict
137+
138+
139+
def get_output_labels(edges_lst):
140+
output_label_dict = {}
141+
for ed in edges_lst:
142+
if ed['sourceHandle'] is not None:
143+
if ed["source"] not in output_label_dict.keys():
144+
output_label_dict[ed["source"]] = []
145+
output_label_dict[ed["source"]].append(ed['sourceHandle'])
146+
return output_label_dict
147+
148+
149+
def get_wrap_function_dict(nodes_dict, output_label_dict):
150+
function_dict = {}
151+
for k, v in nodes_dict.items():
152+
if isfunction(v):
153+
if k in output_label_dict.keys():
154+
kwargs = {"outputs": [{"name": label} for label in output_label_dict[k]]}
155+
function_dict[k] = wrap_function(func=v, **kwargs)
156+
else:
157+
function_dict[k] = wrap_function(func=v)
158+
159+
return function_dict
160+
161+
162+
def build_workgraph(function_dict, total_dict, nodes_dict, label="my_workflow"):
163+
wg = WorkGraph(label)
164+
mapping_dict = {}
165+
for k, v in function_dict.items():
166+
name = v.__name__
167+
mapping_dict[k] = name
168+
total_item_dict = total_dict[k]
169+
kwargs = {}
170+
for tk, tv in total_item_dict.items():
171+
if tv['source'] in mapping_dict.keys():
172+
kwargs[tk] = wg.tasks[mapping_dict[tv['source']]].outputs[tv['sourceHandle']]
173+
elif tv['sourceHandle'] is None:
174+
kwargs[tk] = nodes_dict[tv['source']]
175+
else:
176+
raise ValueError()
177+
wg.add_task(v, name=name, **kwargs)
178+
return wg
179+
180+
181+
def write_workflow_json(work_graph_dict, file_name="workflow.json"):
182+
edges_label_lst = get_edges_label_lst(work_graph_dict=work_graph_dict)
183+
kwargs_dict, function_dict = get_function_dict(work_graph_dict=work_graph_dict)
184+
nodes_dict, mapping_dict = get_mapping(function_dict=function_dict)
185+
value_dict = get_value_dict(kwargs_dict=kwargs_dict)
186+
nodes_dict, mapping_dict = extend_mapping(nodes_dict=nodes_dict, value_dict=value_dict, mapping_dict=mapping_dict)
187+
edges_label_lst = extend_edges_label_lst(kwargs_dict=kwargs_dict, edges_label_lst=edges_label_lst)
188+
edges_lst = get_edges_lst(edges_label_lst=edges_label_lst, mapping_dict=mapping_dict)
189+
190+
nodes_store_dict = {}
191+
for k, v in nodes_dict.items():
192+
if isfunction(v):
193+
nodes_store_dict[k] = v.__module__ + "." + v.__name__
194+
elif isinstance(v, np.ndarray):
195+
nodes_store_dict[k] = v.tolist()
196+
else:
197+
nodes_store_dict[k] = v
198+
199+
with open(file_name, "w") as f:
200+
json.dump({"nodes": nodes_store_dict, "edges": edges_lst}, f)
201+
202+
203+
def load_workflow_json(file_name, label="my_workflow"):
204+
with open(file_name, "r") as f:
205+
content = json.load(f)
206+
207+
nodes_new_dict = {}
208+
for k, v in content["nodes"].items():
209+
if isinstance(v, str) and "." in v:
210+
p, m = v.rsplit('.', 1)
211+
if p == "python_workflow_definition.pyiron_base":
212+
p = "python_workflow_definition.jobflow"
213+
mod = import_module(p)
214+
nodes_new_dict[int(k)] = getattr(mod, m)
215+
else:
216+
nodes_new_dict[int(k)] = v
217+
218+
output_label_dict = get_output_labels(edges_lst=content["edges"])
219+
total_dict = group_edges(edges_lst=content["edges"])
220+
function_dict = get_wrap_function_dict(nodes_dict=nodes_new_dict, output_label_dict=output_label_dict)
221+
return build_workgraph(
222+
function_dict=function_dict,
223+
total_dict=total_dict,
224+
nodes_dict=nodes_new_dict,
225+
label=label,
226+
)

0 commit comments

Comments
 (0)