1+ import json
2+ from importlib import import_module
3+ from inspect import isfunction
4+ import numpy as np
5+
6+ from aiida .engine import calcfunction
7+ from aiida_workgraph import WorkGraph
8+ from aiida_workgraph .decorator import build_task_from_callable
9+
10+
11+ def get_edges_label_lst (work_graph_dict ):
12+ edges_label_lst = []
13+ for link_dict in work_graph_dict ["links" ]:
14+ if link_dict ['from_socket' ] == "result" :
15+ edges_label_lst .append (
16+ {
17+ 'target' : link_dict ['to_node' ],
18+ 'targetHandle' : link_dict ['to_socket' ],
19+ 'source' : link_dict ['from_node' ],
20+ 'sourceHandle' : None ,
21+ }
22+ )
23+ else :
24+ edges_label_lst .append (
25+ {
26+ 'target' : link_dict ['to_node' ],
27+ 'targetHandle' : link_dict ['to_socket' ],
28+ 'source' : link_dict ['from_node' ],
29+ 'sourceHandle' : link_dict ['from_socket' ],
30+ }
31+ )
32+
33+ return edges_label_lst
34+
35+
36+ def get_function_dict (work_graph_dict ):
37+ kwargs_dict , function_dict = {}, {}
38+ for task_name , task_dict in work_graph_dict ["tasks" ].items ():
39+ input_variables = [
40+ input_parameter
41+ for input_parameter in task_dict ['inputs' ].keys ()
42+ if not input_parameter .startswith ("metadata" ) and not input_parameter .startswith ("_wait" )
43+ ]
44+ input_kwargs = {
45+ input_parameter : task_dict ['inputs' ][input_parameter ]['property' ]["value" ].value if isinstance (
46+ task_dict ['inputs' ][input_parameter ]['property' ]["value" ], dict ) else
47+ task_dict ['inputs' ][input_parameter ]['property' ]["value" ]
48+ for input_parameter in input_variables
49+ }
50+ kwargs_dict [task_name ] = input_kwargs
51+ function_dict [task_name ] = task_dict ['executor' ]['callable' ].process_class ._func
52+ return kwargs_dict , function_dict
53+
54+
55+ def get_mapping (function_dict ):
56+ nodes_dict , mapping_dict = {}, {}
57+ for i , [k , v ] in enumerate (function_dict .items ()):
58+ nodes_dict [i ] = v
59+ mapping_dict [k ] = i
60+
61+ return nodes_dict , mapping_dict
62+
63+
64+ def get_value_dict (kwargs_dict ):
65+ value_dict = {}
66+ for func_name , val_dict in kwargs_dict .items ():
67+ for k , v in val_dict .items ():
68+ if v is not None :
69+ if func_name not in value_dict .keys ():
70+ value_dict [func_name ] = {}
71+ value_dict [func_name ][k ] = v
72+
73+ return value_dict
74+
75+
76+ def extend_mapping (nodes_dict , value_dict , mapping_dict ):
77+ i = len (nodes_dict )
78+ for val_dict in value_dict .values ():
79+ for k , v in val_dict .items ():
80+ nodes_dict [i ] = v
81+ mapping_dict [v ] = i
82+ i += 1
83+
84+ return nodes_dict , mapping_dict
85+
86+
87+ def extend_edges_label_lst (kwargs_dict , edges_label_lst ):
88+ for func_name , val_dict in kwargs_dict .items ():
89+ for k , v in val_dict .items ():
90+ if v is not None :
91+ edges_label_lst .append ({'target' : func_name , 'targetHandle' : k , 'source' : v , 'sourceHandle' : None })
92+ return edges_label_lst
93+
94+
95+ def get_edges_lst (edges_label_lst , mapping_dict ):
96+ edges_lst = []
97+ for edge in edges_label_lst :
98+ edges_lst .append ({'target' : mapping_dict [edge ['target' ]], 'targetHandle' : edge ['targetHandle' ],
99+ 'source' : mapping_dict [edge ['source' ]], 'sourceHandle' : edge ['sourceHandle' ]})
100+
101+ return edges_lst
102+
103+
104+ def get_kwargs (lst ):
105+ return {t ['targetHandle' ]: {'source' : t ['source' ], 'sourceHandle' : t ['sourceHandle' ]} for t in lst }
106+
107+
108+ def wrap_function (func , ** kwargs ):
109+ # First, apply the calcfunction decorator
110+ func_decorated = calcfunction (func )
111+ # Then, apply task decorator
112+ task_decorated = build_task_from_callable (
113+ func_decorated ,
114+ inputs = kwargs .get ("inputs" , []),
115+ outputs = kwargs .get ("outputs" , []),
116+ )
117+ identifier = kwargs .get ("identifier" , None )
118+ func_decorated .identifier = identifier if identifier else func .__name__
119+ func_decorated .task = func_decorated .node = task_decorated
120+ return func_decorated
121+
122+
123+ def group_edges (edges_lst ):
124+ # edges_sorted_lst = sorted(edges_lst, key=lambda x: x['target'], reverse=True)
125+ total_dict = {}
126+ tmp_lst = []
127+ target_id = edges_lst [0 ]['target' ]
128+ for ed in edges_lst :
129+ if target_id == ed ["target" ]:
130+ tmp_lst .append (ed )
131+ else :
132+ total_dict [target_id ] = get_kwargs (lst = tmp_lst )
133+ target_id = ed ["target" ]
134+ tmp_lst = [ed ]
135+ total_dict [target_id ] = get_kwargs (lst = tmp_lst )
136+ return total_dict
137+
138+
139+ def get_output_labels (edges_lst ):
140+ output_label_dict = {}
141+ for ed in edges_lst :
142+ if ed ['sourceHandle' ] is not None :
143+ if ed ["source" ] not in output_label_dict .keys ():
144+ output_label_dict [ed ["source" ]] = []
145+ output_label_dict [ed ["source" ]].append (ed ['sourceHandle' ])
146+ return output_label_dict
147+
148+
149+ def get_wrap_function_dict (nodes_dict , output_label_dict ):
150+ function_dict = {}
151+ for k , v in nodes_dict .items ():
152+ if isfunction (v ):
153+ if k in output_label_dict .keys ():
154+ kwargs = {"outputs" : [{"name" : label } for label in output_label_dict [k ]]}
155+ function_dict [k ] = wrap_function (func = v , ** kwargs )
156+ else :
157+ function_dict [k ] = wrap_function (func = v )
158+
159+ return function_dict
160+
161+
162+ def build_workgraph (function_dict , total_dict , nodes_dict , label = "my_workflow" ):
163+ wg = WorkGraph (label )
164+ mapping_dict = {}
165+ for k , v in function_dict .items ():
166+ name = v .__name__
167+ mapping_dict [k ] = name
168+ total_item_dict = total_dict [k ]
169+ kwargs = {}
170+ for tk , tv in total_item_dict .items ():
171+ if tv ['source' ] in mapping_dict .keys ():
172+ kwargs [tk ] = wg .tasks [mapping_dict [tv ['source' ]]].outputs [tv ['sourceHandle' ]]
173+ elif tv ['sourceHandle' ] is None :
174+ kwargs [tk ] = nodes_dict [tv ['source' ]]
175+ else :
176+ raise ValueError ()
177+ wg .add_task (v , name = name , ** kwargs )
178+ return wg
179+
180+
181+ def write_workflow_json (work_graph_dict , file_name = "workflow.json" ):
182+ edges_label_lst = get_edges_label_lst (work_graph_dict = work_graph_dict )
183+ kwargs_dict , function_dict = get_function_dict (work_graph_dict = work_graph_dict )
184+ nodes_dict , mapping_dict = get_mapping (function_dict = function_dict )
185+ value_dict = get_value_dict (kwargs_dict = kwargs_dict )
186+ nodes_dict , mapping_dict = extend_mapping (nodes_dict = nodes_dict , value_dict = value_dict , mapping_dict = mapping_dict )
187+ edges_label_lst = extend_edges_label_lst (kwargs_dict = kwargs_dict , edges_label_lst = edges_label_lst )
188+ edges_lst = get_edges_lst (edges_label_lst = edges_label_lst , mapping_dict = mapping_dict )
189+
190+ nodes_store_dict = {}
191+ for k , v in nodes_dict .items ():
192+ if isfunction (v ):
193+ nodes_store_dict [k ] = v .__module__ + "." + v .__name__
194+ elif isinstance (v , np .ndarray ):
195+ nodes_store_dict [k ] = v .tolist ()
196+ else :
197+ nodes_store_dict [k ] = v
198+
199+ with open (file_name , "w" ) as f :
200+ json .dump ({"nodes" : nodes_store_dict , "edges" : edges_lst }, f )
201+
202+
203+ def load_workflow_json (file_name , label = "my_workflow" ):
204+ with open (file_name , "r" ) as f :
205+ content = json .load (f )
206+
207+ nodes_new_dict = {}
208+ for k , v in content ["nodes" ].items ():
209+ if isinstance (v , str ) and "." in v :
210+ p , m = v .rsplit ('.' , 1 )
211+ if p == "python_workflow_definition.pyiron_base" :
212+ p = "python_workflow_definition.jobflow"
213+ mod = import_module (p )
214+ nodes_new_dict [int (k )] = getattr (mod , m )
215+ else :
216+ nodes_new_dict [int (k )] = v
217+
218+ output_label_dict = get_output_labels (edges_lst = content ["edges" ])
219+ total_dict = group_edges (edges_lst = content ["edges" ])
220+ function_dict = get_wrap_function_dict (nodes_dict = nodes_new_dict , output_label_dict = output_label_dict )
221+ return build_workgraph (
222+ function_dict = function_dict ,
223+ total_dict = total_dict ,
224+ nodes_dict = nodes_new_dict ,
225+ label = label ,
226+ )
0 commit comments