Skip to content

Commit 9739b47

Browse files
committed
merge Liam's changes
1 parent 038bc85 commit 9739b47

File tree

1 file changed

+2
-122
lines changed

1 file changed

+2
-122
lines changed

python_workflow_definition/src/python_workflow_definition/pyiron_workflow.py

Lines changed: 2 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from inspect import isfunction
22
from importlib import import_module
3-
from typing import Any, Optional
3+
from typing import Any
44

55
import numpy as np
66
from pyiron_workflow import function_node, Workflow
77
from pyiron_workflow.api import Function
88

99
from python_workflow_definition.models import PythonWorkflowDefinitionWorkflow
1010
from python_workflow_definition.shared import (
11-
convert_nodes_list_to_dict,
1211
update_node_names,
1312
set_result_node,
1413
remove_result,
@@ -80,86 +79,6 @@ def get_edges(graph_dict, node_mapping_dict, nodes_links_dict):
8079
return edges_lst
8180

8281

83-
def create_input_nodes(nodes_dict, edges_lst):
84-
node_conversion_dict = {
85-
ed[SOURCE_LABEL]: ed[TARGET_PORT_LABEL]
86-
for ed in edges_lst
87-
if ed[SOURCE_PORT_LABEL] is None
88-
}
89-
nodes_to_create_dict = {v: nodes_dict[k] for k, v in node_conversion_dict.items()}
90-
return nodes_to_create_dict, node_conversion_dict
91-
92-
93-
def set_input_nodes(workflow, nodes_to_create_dict):
94-
for k, v in nodes_to_create_dict.items():
95-
workflow.__setattr__(k, v)
96-
return workflow
97-
98-
99-
def get_source_handles(edges_lst):
100-
source_handle_dict = {}
101-
for ed in edges_lst:
102-
if ed[SOURCE_LABEL] not in source_handle_dict.keys():
103-
source_handle_dict[ed[SOURCE_LABEL]] = [ed[SOURCE_PORT_LABEL]]
104-
else:
105-
source_handle_dict[ed[SOURCE_LABEL]].append(ed[SOURCE_PORT_LABEL])
106-
return source_handle_dict
107-
108-
109-
def get_function_nodes(nodes_dict, source_handle_dict):
110-
function_dict = {}
111-
for k, v in nodes_dict.items():
112-
if isfunction(v):
113-
function_dict[k] = {"node_function": v}
114-
return function_dict
115-
116-
117-
def get_kwargs(lst):
118-
return {
119-
t[TARGET_PORT_LABEL]: {
120-
SOURCE_LABEL: t[SOURCE_LABEL],
121-
SOURCE_PORT_LABEL: t[SOURCE_PORT_LABEL],
122-
}
123-
for t in lst
124-
}
125-
126-
127-
def group_edges(edges_lst):
128-
edges_sorted_lst = sorted(edges_lst, key=lambda x: x[TARGET_LABEL], reverse=True)
129-
total_dict = {}
130-
tmp_lst = []
131-
target_id = edges_sorted_lst[0][TARGET_LABEL]
132-
for ed in edges_sorted_lst:
133-
if target_id == ed[TARGET_LABEL]:
134-
tmp_lst.append(ed)
135-
else:
136-
total_dict[target_id] = get_kwargs(lst=tmp_lst)
137-
target_id = ed[TARGET_LABEL]
138-
tmp_lst = [ed]
139-
total_dict[target_id] = get_kwargs(lst=tmp_lst)
140-
return total_dict
141-
142-
143-
def build_workflow(workflow, function_dict, total_dict, node_conversion_dict):
144-
for k, v in function_dict.items():
145-
kwargs_link_dict = total_dict[k]
146-
kwargs_dict = {}
147-
for kw, vw in kwargs_link_dict.items():
148-
if vw[SOURCE_LABEL] in node_conversion_dict.keys():
149-
kwargs_dict[kw] = workflow.__getattribute__(
150-
node_conversion_dict[vw[SOURCE_LABEL]]
151-
)
152-
else:
153-
kwargs_dict[kw] = workflow.__getattr__(
154-
"tmp_" + str(vw[SOURCE_LABEL])
155-
).__getitem__(vw[SOURCE_PORT_LABEL])
156-
v.update(kwargs_dict)
157-
workflow.__setattr__(
158-
"tmp_" + str(k), function_node(**v, validate_output_labels=False)
159-
)
160-
return workflow, "tmp_" + str(k)
161-
162-
16382
def write_workflow_json(graph_as_dict: dict, file_name: str = "workflow.json"):
16483
nodes_dict, node_mapping_dict, input_dict = get_linked_nodes(
16584
graph_dict=graph_as_dict
@@ -276,45 +195,6 @@ def write_workflow_json(graph_as_dict: dict, file_name: str = "workflow.json"):
276195
).dump_json_file(file_name=file_name, indent=2)
277196

278197

279-
def load_workflow_json(file_name: str, workflow: Optional[Workflow] = None):
280-
content = remove_result(
281-
workflow_dict=PythonWorkflowDefinitionWorkflow.load_json_file(
282-
file_name=file_name
283-
)
284-
)
285-
edges_lst = content[EDGES_LABEL]
286-
287-
nodes_new_dict = {}
288-
for k, v in convert_nodes_list_to_dict(nodes_list=content[NODES_LABEL]).items():
289-
if isinstance(v, str) and "." in v:
290-
p, m = v.rsplit(".", 1)
291-
mod = import_module(p)
292-
nodes_new_dict[int(k)] = getattr(mod, m)
293-
else:
294-
nodes_new_dict[int(k)] = v
295-
296-
if workflow is None:
297-
workflow = Workflow(file_name.split(".")[0])
298-
299-
nodes_to_create_dict, node_conversion_dict = create_input_nodes(
300-
nodes_dict=nodes_new_dict, edges_lst=edges_lst
301-
)
302-
wf = set_input_nodes(workflow=workflow, nodes_to_create_dict=nodes_to_create_dict)
303-
304-
source_handle_dict = get_source_handles(edges_lst=edges_lst)
305-
function_dict = get_function_nodes(
306-
nodes_dict=nodes_new_dict, source_handle_dict=source_handle_dict
307-
)
308-
total_dict = group_edges(edges_lst=edges_lst)
309-
310-
return build_workflow(
311-
workflow=wf,
312-
function_dict=function_dict,
313-
total_dict=total_dict,
314-
node_conversion_dict=node_conversion_dict,
315-
)
316-
317-
318198
def import_from_string(library_path: str) -> Any:
319199
# Copied from bagofholding
320200
split_path = library_path.split(".", 1)
@@ -328,7 +208,7 @@ def import_from_string(library_path: str) -> Any:
328208
return obj
329209

330210

331-
def build_function_dag_workflow(file_name: str) -> Workflow:
211+
def load_workflow_json(file_name: str) -> Workflow:
332212
content = remove_result(
333213
PythonWorkflowDefinitionWorkflow.load_json_file(file_name=file_name)
334214
)

0 commit comments

Comments
 (0)