Skip to content

Commit 2bc578b

Browse files
committed
add function to load workflow from JSOn - not working
1 parent 1df9b19 commit 2bc578b

File tree

1 file changed

+110
-0
lines changed

1 file changed

+110
-0
lines changed

python_workflow_definition/src/python_workflow_definition/pyiron_workflow.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from inspect import isfunction
2+
from importlib import import_module
3+
from typing import Optional
24

35
import numpy as np
6+
from pyiron_workflow import function_node, Workflow
47

58
from python_workflow_definition.models import PythonWorkflowDefinitionWorkflow
69
from python_workflow_definition.shared import (
10+
convert_nodes_list_to_dict,
711
update_node_names,
812
set_result_node,
13+
remove_result,
914
NODES_LABEL,
1015
EDGES_LABEL,
1116
SOURCE_LABEL,
@@ -74,6 +79,76 @@ def get_edges(graph_dict, node_mapping_dict, nodes_links_dict):
7479
return edges_lst
7580

7681

82+
def create_input_nodes(nodes_dict, edges_lst):
83+
node_conversion_dict = {
84+
ed[SOURCE_LABEL]: ed[TARGET_PORT_LABEL]
85+
for ed in edges_lst if ed[SOURCE_PORT_LABEL] is None
86+
}
87+
nodes_to_create_dict = {
88+
v: nodes_dict[k]
89+
for k, v in node_conversion_dict.items()
90+
}
91+
return nodes_to_create_dict, node_conversion_dict
92+
93+
94+
def set_input_nodes(workflow, nodes_to_create_dict):
95+
for k, v in nodes_to_create_dict.items():
96+
workflow.__setattr__(k, v)
97+
return workflow
98+
99+
100+
def get_source_handles(edges_lst):
101+
source_handle_dict = {}
102+
for ed in edges_lst:
103+
if ed[SOURCE_LABEL] not in source_handle_dict.keys():
104+
source_handle_dict[ed[SOURCE_LABEL]] = [ed[SOURCE_PORT_LABEL]]
105+
else:
106+
source_handle_dict[ed[SOURCE_LABEL]].append(ed[SOURCE_PORT_LABEL])
107+
return source_handle_dict
108+
109+
110+
def get_function_nodes(nodes_dict, source_handle_dict):
111+
function_dict = {}
112+
for k, v in nodes_dict.items():
113+
if isfunction(v):
114+
function_dict[k] = {"node_function": v}
115+
return function_dict
116+
117+
118+
def get_kwargs(lst):
119+
return {t[TARGET_PORT_LABEL]: {SOURCE_LABEL: t[SOURCE_LABEL], SOURCE_PORT_LABEL: t[SOURCE_PORT_LABEL]} for t in lst}
120+
121+
122+
def group_edges(edges_lst):
123+
edges_sorted_lst = sorted(edges_lst, key=lambda x: x[TARGET_LABEL], reverse=True)
124+
total_dict = {}
125+
tmp_lst = []
126+
target_id = edges_sorted_lst[0][TARGET_LABEL]
127+
for ed in edges_sorted_lst:
128+
if target_id == ed[TARGET_LABEL]:
129+
tmp_lst.append(ed)
130+
else:
131+
total_dict[target_id] = get_kwargs(lst=tmp_lst)
132+
target_id = ed[TARGET_LABEL]
133+
tmp_lst = [ed]
134+
total_dict[target_id] = get_kwargs(lst=tmp_lst)
135+
return total_dict
136+
137+
138+
def build_workflow(workflow, function_dict, total_dict, node_conversion_dict):
139+
for k, v in function_dict.items():
140+
kwargs_link_dict = total_dict[k]
141+
kwargs_dict = {}
142+
for kw, vw in kwargs_link_dict.items():
143+
if vw[SOURCE_LABEL] in node_conversion_dict.keys():
144+
kwargs_dict[kw] = workflow.__getattribute__(node_conversion_dict[vw[SOURCE_LABEL]])
145+
else:
146+
kwargs_dict[kw] = workflow.__getattr__("tmp_" + str(vw[SOURCE_LABEL])).__getitem__(vw[SOURCE_PORT_LABEL])
147+
v.update(kwargs_dict)
148+
workflow.__setattr__("tmp_" + str(k), function_node(**v, validate_output_labels=False))
149+
return workflow, "tmp_" + str(k)
150+
151+
77152
def write_workflow_json(graph_as_dict: dict, file_name: str = "workflow.json"):
78153
nodes_dict, node_mapping_dict, input_dict = get_linked_nodes(
79154
graph_dict=graph_as_dict
@@ -188,3 +263,38 @@ def write_workflow_json(graph_as_dict: dict, file_name: str = "workflow.json"):
188263
)
189264
)
190265
).dump_json_file(file_name=file_name, indent=2)
266+
267+
268+
def load_workflow_json(file_name: str, workflow: Optional[Workflow] = None):
269+
content = remove_result(
270+
workflow_dict=PythonWorkflowDefinitionWorkflow.load_json_file(
271+
file_name=file_name
272+
)
273+
)
274+
edges_lst = content[EDGES_LABEL]
275+
276+
nodes_new_dict = {}
277+
for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
278+
if isinstance(v, str) and "." in v:
279+
p, m = v.rsplit('.', 1)
280+
mod = import_module(p)
281+
nodes_new_dict[int(k)] = getattr(mod, m)
282+
else:
283+
nodes_new_dict[int(k)] = v
284+
285+
if workflow is None:
286+
workflow = Workflow(file_name.split(".")[0])
287+
288+
nodes_to_create_dict, node_conversion_dict = create_input_nodes(nodes_dict=nodes_new_dict, edges_lst=edges_lst)
289+
wf = set_input_nodes(workflow=workflow, nodes_to_create_dict=nodes_to_create_dict)
290+
291+
source_handle_dict = get_source_handles(edges_lst=edges_lst)
292+
function_dict = get_function_nodes(nodes_dict=nodes_new_dict, source_handle_dict=source_handle_dict)
293+
total_dict = group_edges(edges_lst=edges_lst)
294+
295+
return build_workflow(
296+
workflow=wf,
297+
function_dict=function_dict,
298+
total_dict=total_dict,
299+
node_conversion_dict=node_conversion_dict,
300+
)

0 commit comments

Comments
 (0)