11from inspect import isfunction
22from importlib import import_module
3- from typing import Any , Optional
3+ from typing import Any
44
55import numpy as np
66from pyiron_workflow import function_node , Workflow
77from pyiron_workflow .api import Function
88
99from python_workflow_definition .models import PythonWorkflowDefinitionWorkflow
1010from python_workflow_definition .shared import (
11- convert_nodes_list_to_dict ,
1211 update_node_names ,
1312 set_result_node ,
1413 remove_result ,
@@ -80,86 +79,6 @@ def get_edges(graph_dict, node_mapping_dict, nodes_links_dict):
8079 return edges_lst
8180
8281
83- def create_input_nodes (nodes_dict , edges_lst ):
84- node_conversion_dict = {
85- ed [SOURCE_LABEL ]: ed [TARGET_PORT_LABEL ]
86- for ed in edges_lst
87- if ed [SOURCE_PORT_LABEL ] is None
88- }
89- nodes_to_create_dict = {v : nodes_dict [k ] for k , v in node_conversion_dict .items ()}
90- return nodes_to_create_dict , node_conversion_dict
91-
92-
93- def set_input_nodes (workflow , nodes_to_create_dict ):
94- for k , v in nodes_to_create_dict .items ():
95- workflow .__setattr__ (k , v )
96- return workflow
97-
98-
99- def get_source_handles (edges_lst ):
100- source_handle_dict = {}
101- for ed in edges_lst :
102- if ed [SOURCE_LABEL ] not in source_handle_dict .keys ():
103- source_handle_dict [ed [SOURCE_LABEL ]] = [ed [SOURCE_PORT_LABEL ]]
104- else :
105- source_handle_dict [ed [SOURCE_LABEL ]].append (ed [SOURCE_PORT_LABEL ])
106- return source_handle_dict
107-
108-
109- def get_function_nodes (nodes_dict , source_handle_dict ):
110- function_dict = {}
111- for k , v in nodes_dict .items ():
112- if isfunction (v ):
113- function_dict [k ] = {"node_function" : v }
114- return function_dict
115-
116-
117- def get_kwargs (lst ):
118- return {
119- t [TARGET_PORT_LABEL ]: {
120- SOURCE_LABEL : t [SOURCE_LABEL ],
121- SOURCE_PORT_LABEL : t [SOURCE_PORT_LABEL ],
122- }
123- for t in lst
124- }
125-
126-
127- def group_edges (edges_lst ):
128- edges_sorted_lst = sorted (edges_lst , key = lambda x : x [TARGET_LABEL ], reverse = True )
129- total_dict = {}
130- tmp_lst = []
131- target_id = edges_sorted_lst [0 ][TARGET_LABEL ]
132- for ed in edges_sorted_lst :
133- if target_id == ed [TARGET_LABEL ]:
134- tmp_lst .append (ed )
135- else :
136- total_dict [target_id ] = get_kwargs (lst = tmp_lst )
137- target_id = ed [TARGET_LABEL ]
138- tmp_lst = [ed ]
139- total_dict [target_id ] = get_kwargs (lst = tmp_lst )
140- return total_dict
141-
142-
143- def build_workflow (workflow , function_dict , total_dict , node_conversion_dict ):
144- for k , v in function_dict .items ():
145- kwargs_link_dict = total_dict [k ]
146- kwargs_dict = {}
147- for kw , vw in kwargs_link_dict .items ():
148- if vw [SOURCE_LABEL ] in node_conversion_dict .keys ():
149- kwargs_dict [kw ] = workflow .__getattribute__ (
150- node_conversion_dict [vw [SOURCE_LABEL ]]
151- )
152- else :
153- kwargs_dict [kw ] = workflow .__getattr__ (
154- "tmp_" + str (vw [SOURCE_LABEL ])
155- ).__getitem__ (vw [SOURCE_PORT_LABEL ])
156- v .update (kwargs_dict )
157- workflow .__setattr__ (
158- "tmp_" + str (k ), function_node (** v , validate_output_labels = False )
159- )
160- return workflow , "tmp_" + str (k )
161-
162-
16382def write_workflow_json (graph_as_dict : dict , file_name : str = "workflow.json" ):
16483 nodes_dict , node_mapping_dict , input_dict = get_linked_nodes (
16584 graph_dict = graph_as_dict
@@ -276,45 +195,6 @@ def write_workflow_json(graph_as_dict: dict, file_name: str = "workflow.json"):
276195 ).dump_json_file (file_name = file_name , indent = 2 )
277196
278197
279- def load_workflow_json (file_name : str , workflow : Optional [Workflow ] = None ):
280- content = remove_result (
281- workflow_dict = PythonWorkflowDefinitionWorkflow .load_json_file (
282- file_name = file_name
283- )
284- )
285- edges_lst = content [EDGES_LABEL ]
286-
287- nodes_new_dict = {}
288- for k , v in convert_nodes_list_to_dict (nodes_list = content [NODES_LABEL ]).items ():
289- if isinstance (v , str ) and "." in v :
290- p , m = v .rsplit ("." , 1 )
291- mod = import_module (p )
292- nodes_new_dict [int (k )] = getattr (mod , m )
293- else :
294- nodes_new_dict [int (k )] = v
295-
296- if workflow is None :
297- workflow = Workflow (file_name .split ("." )[0 ])
298-
299- nodes_to_create_dict , node_conversion_dict = create_input_nodes (
300- nodes_dict = nodes_new_dict , edges_lst = edges_lst
301- )
302- wf = set_input_nodes (workflow = workflow , nodes_to_create_dict = nodes_to_create_dict )
303-
304- source_handle_dict = get_source_handles (edges_lst = edges_lst )
305- function_dict = get_function_nodes (
306- nodes_dict = nodes_new_dict , source_handle_dict = source_handle_dict
307- )
308- total_dict = group_edges (edges_lst = edges_lst )
309-
310- return build_workflow (
311- workflow = wf ,
312- function_dict = function_dict ,
313- total_dict = total_dict ,
314- node_conversion_dict = node_conversion_dict ,
315- )
316-
317-
318198def import_from_string (library_path : str ) -> Any :
319199 # Copied from bagofholding
320200 split_path = library_path .split ("." , 1 )
@@ -328,7 +208,7 @@ def import_from_string(library_path: str) -> Any:
328208 return obj
329209
330210
331- def build_function_dag_workflow (file_name : str ) -> Workflow :
211+ def load_workflow_json (file_name : str ) -> Workflow :
332212 content = remove_result (
333213 PythonWorkflowDefinitionWorkflow .load_json_file (file_name = file_name )
334214 )
0 commit comments