Skip to content

Commit 373df59

Browse files
committed
update write_workflow_json()
1 parent 68e145f commit 373df59

File tree

2 files changed

+125
-58
lines changed

2 files changed

+125
-58
lines changed

example_workflows/quantum_espresso/pyiron_workflow.ipynb

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

python_workflow_definition/src/python_workflow_definition/pyiron_workflow.py

Lines changed: 124 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ def get_linked_nodes(graph_dict):
2828
node_mapping_dict = {}
2929
input_dict = {}
3030
for i, [k, v] in enumerate(graph_dict["nodes"].items()):
31-
nodes_dict[i] = v.node_function
31+
if "inputs_to_dict_factory" in str(type(v)):
32+
nodes_dict[i] = get_dict
33+
else:
34+
nodes_dict[i] = v.node_function
3235
node_mapping_dict[k] = i
3336
input_dict[k] = {
3437
con.full_label: con.value
@@ -41,9 +44,10 @@ def get_linked_nodes(graph_dict):
4144
def extend_nodes_dict(nodes_dict, input_dict):
4245
i = len(nodes_dict)
4346
nodes_links_dict = {}
47+
nodes_values_str_lst = [str(s) for s in nodes_dict.values()]
4448
for val_dict in input_dict.values():
4549
for k, v in val_dict.items():
46-
if v not in nodes_dict.values():
50+
if str(v) not in nodes_values_str_lst:
4751
nodes_dict[i] = v
4852
nodes_links_dict[k] = i
4953
i += 1
@@ -96,16 +100,93 @@ def write_workflow_json(graph_as_dict: dict, file_name: str = "workflow.json"):
96100
if isfunction(v) and "pyiron_workflow" in v.__module__:
97101
pyiron_workflow_modules[k] = v
98102

99-
target_dict = {}
103+
cache_mapping_dict, remap_dict = {}, {}
104+
for k, v in nodes_dict.items():
105+
if not isfunction(v) and str(v) not in cache_mapping_dict:
106+
cache_mapping_dict[str(v)] = k
107+
elif not isfunction(v):
108+
remap_dict[k] = cache_mapping_dict[str(v)]
109+
110+
item_node_lst = [
111+
e[SOURCE_LABEL] for e in edges_lst
112+
if e[TARGET_LABEL] in pyiron_workflow_modules.keys() and e[TARGET_PORT_LABEL] == "item"
113+
]
114+
115+
values_from_dict_lst = [
116+
k for k, v in nodes_dict.items()
117+
if isfunction(v) and v.__name__ == "get_values_from_dict"
118+
]
119+
120+
remap_get_list_dict = {}
121+
for e in edges_lst:
122+
if e[TARGET_LABEL] in values_from_dict_lst:
123+
remap_get_list_dict[e[SOURCE_LABEL]] = e[TARGET_LABEL]
124+
125+
nodes_remaining_dict = {
126+
k: v for k, v in nodes_dict.items()
127+
if k not in pyiron_workflow_modules.keys() and k not in remap_dict.keys() and k not in item_node_lst and k not in remap_get_list_dict.values()
128+
}
129+
130+
nodes_store_lst = []
131+
nodes_final_order_dict = {}
132+
for k, [i, v] in enumerate(nodes_remaining_dict.items()):
133+
if i in remap_get_list_dict:
134+
nodes_store_lst.append(
135+
{"id": k, "type": "function", "value": "python_workflow_definition.shared.get_list"}
136+
)
137+
elif isfunction(v):
138+
mod = v.__module__
139+
if mod == "python_workflow_definition.pyiron_workflow":
140+
mod = "python_workflow_definition.shared"
141+
nodes_store_lst.append(
142+
{"id": k, "type": "function", "value": mod + "." + v.__name__}
143+
)
144+
elif isinstance(v, np.ndarray):
145+
nodes_store_lst.append({"id": k, "type": "input", "value": v.tolist()})
146+
else:
147+
nodes_store_lst.append({"id": k, "type": "input", "value": v})
148+
nodes_final_order_dict[i] = k
149+
150+
remap_get_list_remove_edges = [
151+
edge for edge in edges_lst
152+
if edge[TARGET_LABEL] in remap_get_list_dict.values()
153+
]
154+
155+
edge_get_list_updated_lst = []
100156
for edge in edges_lst:
157+
if edge[SOURCE_LABEL] in remap_get_list_dict.values():
158+
connected_edge = [
159+
edge_con for edge_con in remap_get_list_remove_edges
160+
if edge_con[TARGET_LABEL] == edge[SOURCE_LABEL]
161+
][-1]
162+
edge_updated = {
163+
TARGET_LABEL: edge[TARGET_LABEL],
164+
TARGET_PORT_LABEL: edge[TARGET_PORT_LABEL],
165+
SOURCE_LABEL: connected_edge[SOURCE_LABEL],
166+
SOURCE_PORT_LABEL: connected_edge[SOURCE_PORT_LABEL],
167+
}
168+
edge_get_list_updated_lst.append(edge_updated)
169+
elif edge[SOURCE_LABEL] in remap_dict.keys():
170+
edge_updated = {
171+
TARGET_LABEL: edge[TARGET_LABEL],
172+
TARGET_PORT_LABEL: edge[TARGET_PORT_LABEL],
173+
SOURCE_LABEL: remap_dict[edge[SOURCE_LABEL]],
174+
SOURCE_PORT_LABEL: edge[SOURCE_PORT_LABEL],
175+
}
176+
edge_get_list_updated_lst.append(edge_updated)
177+
elif edge[TARGET_LABEL] not in remap_get_list_dict.values():
178+
edge_get_list_updated_lst.append(edge)
179+
180+
target_dict = {}
181+
for edge in edge_get_list_updated_lst:
101182
for k in pyiron_workflow_modules.keys():
102183
if k == edge[TARGET_LABEL]:
103184
if k not in target_dict:
104185
target_dict[k] = []
105186
target_dict[k].append(edge)
106187

107188
source_dict = {}
108-
for edge in edges_lst:
189+
for edge in edge_get_list_updated_lst:
109190
for k in pyiron_workflow_modules.keys():
110191
if k == edge[SOURCE_LABEL]:
111192
if k not in source_dict:
@@ -121,67 +202,53 @@ def write_workflow_json(graph_as_dict: dict, file_name: str = "workflow.json"):
121202
nodes_to_delete.append(edge[SOURCE_LABEL])
122203
else:
123204
source = edge[SOURCE_LABEL]
124-
edge_new_lst.append(
125-
{
126-
SOURCE_LABEL: source,
127-
SOURCE_PORT_LABEL: sourcehandle,
128-
TARGET_LABEL: source_dict[k][-1][TARGET_LABEL],
129-
TARGET_PORT_LABEL: source_dict[k][-1][TARGET_PORT_LABEL],
130-
}
131-
)
132-
133-
nodes_to_skip = nodes_to_delete + list(pyiron_workflow_modules.keys())
134-
nodes_new_dict = {k: v for k, v in nodes_dict.items() if k not in nodes_to_skip}
135-
136-
nodes_store_lst = []
137-
for k, v in enumerate(nodes_new_dict.values()):
138-
if isfunction(v):
139-
mod = v.__module__
140-
if mod == "python_workflow_definition.pyiron_workflow":
141-
mod = "python_workflow_definition.shared"
142-
nodes_store_lst.append(
143-
{"id": k, "type": "function", "value": mod + "." + v.__name__}
205+
if "s_" == source_dict[k][-1][TARGET_PORT_LABEL][:2]:
206+
edge_new_lst.append(
207+
{
208+
SOURCE_LABEL: nodes_final_order_dict[source],
209+
SOURCE_PORT_LABEL: sourcehandle,
210+
TARGET_LABEL: nodes_final_order_dict[source_dict[k][-1][TARGET_LABEL]],
211+
TARGET_PORT_LABEL: source_dict[k][-1][TARGET_PORT_LABEL][2:],
212+
}
144213
)
145-
elif isinstance(v, np.ndarray):
146-
nodes_store_lst.append({"id": k, "type": "input", "value": v.tolist()})
147214
else:
148-
nodes_store_lst.append({"id": k, "type": "input", "value": v})
215+
edge_new_lst.append(
216+
{
217+
SOURCE_LABEL: nodes_final_order_dict[source],
218+
SOURCE_PORT_LABEL: sourcehandle,
219+
TARGET_LABEL: nodes_final_order_dict[source_dict[k][-1][TARGET_LABEL]],
220+
TARGET_PORT_LABEL: source_dict[k][-1][TARGET_PORT_LABEL],
221+
}
222+
)
149223

150-
for edge in edges_lst:
224+
nodes_to_skip = nodes_to_delete + list(pyiron_workflow_modules.keys())
225+
for edge in edge_get_list_updated_lst:
151226
if (
152-
edge[TARGET_LABEL] not in nodes_to_skip
153-
and edge[SOURCE_LABEL] not in nodes_to_skip
227+
edge[TARGET_LABEL] not in nodes_to_skip
228+
and edge[SOURCE_LABEL] not in nodes_to_skip
154229
):
155-
edge_new_lst.append(edge)
156-
157-
nodes_updated_dict, mapping_dict = {}, {}
158-
i = 0
159-
for k, v in nodes_new_dict.items():
160-
nodes_updated_dict[i] = v
161-
mapping_dict[k] = i
162-
i += 1
163-
164-
edge_updated_lst = []
165-
for edge in edge_new_lst:
166-
source_node = nodes_new_dict[edge[SOURCE_LABEL]]
167-
if isfunction(source_node) and source_node.__name__ == edge[SOURCE_PORT_LABEL]:
168-
edge_updated_lst.append(
169-
{
170-
SOURCE_LABEL: mapping_dict[edge[SOURCE_LABEL]],
230+
source_node = nodes_remaining_dict[edge[SOURCE_LABEL]]
231+
if isfunction(source_node) and source_node.__name__ == edge[SOURCE_PORT_LABEL]:
232+
edge_new_lst.append({
233+
TARGET_LABEL: nodes_final_order_dict[edge[TARGET_LABEL]],
234+
TARGET_PORT_LABEL: edge[TARGET_PORT_LABEL],
235+
SOURCE_LABEL: nodes_final_order_dict[edge[SOURCE_LABEL]],
171236
SOURCE_PORT_LABEL: None,
172-
TARGET_LABEL: mapping_dict[edge[TARGET_LABEL]],
237+
})
238+
elif isfunction(source_node) and source_node.__name__ == "get_dict" and edge[SOURCE_PORT_LABEL] == "dict":
239+
edge_new_lst.append({
240+
TARGET_LABEL: nodes_final_order_dict[edge[TARGET_LABEL]],
173241
TARGET_PORT_LABEL: edge[TARGET_PORT_LABEL],
174-
}
175-
)
176-
else:
177-
edge_updated_lst.append(
178-
{
179-
SOURCE_LABEL: mapping_dict[edge[SOURCE_LABEL]],
180-
SOURCE_PORT_LABEL: edge[SOURCE_PORT_LABEL],
181-
TARGET_LABEL: mapping_dict[edge[TARGET_LABEL]],
242+
SOURCE_LABEL: nodes_final_order_dict[edge[SOURCE_LABEL]],
243+
SOURCE_PORT_LABEL: None,
244+
})
245+
else:
246+
edge_new_lst.append({
247+
TARGET_LABEL: nodes_final_order_dict[edge[TARGET_LABEL]],
182248
TARGET_PORT_LABEL: edge[TARGET_PORT_LABEL],
183-
}
184-
)
249+
SOURCE_LABEL: nodes_final_order_dict[edge[SOURCE_LABEL]],
250+
SOURCE_PORT_LABEL: edge[SOURCE_PORT_LABEL],
251+
})
185252

186253
PythonWorkflowDefinitionWorkflow(
187254
**set_result_node(

0 commit comments

Comments
 (0)