Skip to content

Commit d1f48c4

Browse files
committed
Fix the
- declare an empty outputs namespace for task - skip the link if the task is not a Task but a AiiDa data node
1 parent 295c39e commit d1f48c4

File tree

1 file changed

+23
-23
lines changed

1 file changed

+23
-23
lines changed

src/python_workflow_definition/aiida.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
from aiida import orm
55
from aiida_pythonjob.data.serializer import general_serializer
6-
from aiida_workgraph import WorkGraph, task
6+
from aiida_workgraph import WorkGraph, task, Task, namespace
77
from aiida_workgraph.socket import TaskSocketNamespace
8-
8+
from dataclasses import replace
9+
from node_graph.node_spec import SchemaSource
910
from python_workflow_definition.models import PythonWorkflowDefinitionWorkflow
1011
from python_workflow_definition.shared import (
1112
convert_nodes_list_to_dict,
@@ -38,10 +39,12 @@ def load_workflow_json(file_name: str) -> WorkGraph:
3839
p, m = identifier.rsplit(".", 1)
3940
mod = import_module(p)
4041
func = getattr(mod, m)
41-
wg.add_task(func)
42-
# Remove the default result output, because we will add the outputs later from the data in the link
43-
del wg.tasks[-1].outputs["result"]
44-
task_name_mapping[id] = wg.tasks[-1]
42+
decorated_func = task(outputs=namespace())(func)
43+
new_task = wg.add_task(decorated_func)
44+
new_task.spec = replace(
45+
new_task.spec, schema_source=SchemaSource.EMBEDDED
46+
)
47+
task_name_mapping[id] = new_task
4548
else:
4649
# data task
4750
data_node = general_serializer(identifier)
@@ -53,10 +56,11 @@ def load_workflow_json(file_name: str) -> WorkGraph:
5356
to_task = task_name_mapping[str(link[TARGET_LABEL])]
5457
# if the input is not exit, it means we pass the data into to the kwargs
5558
# in this case, we add the input socket
56-
if link[TARGET_PORT_LABEL] not in to_task.inputs:
57-
to_socket = to_task.add_input_spec("workgraph.any", name=link[TARGET_PORT_LABEL])
58-
else:
59-
to_socket = to_task.inputs[link[TARGET_PORT_LABEL]]
59+
if isinstance(to_task, Task):
60+
if link[TARGET_PORT_LABEL] not in to_task.inputs:
61+
to_socket = to_task.add_input_spec("workgraph.any", name=link[TARGET_PORT_LABEL])
62+
else:
63+
to_socket = to_task.inputs[link[TARGET_PORT_LABEL]]
6064
from_task = task_name_mapping[str(link[SOURCE_LABEL])]
6165
if isinstance(from_task, orm.Data):
6266
to_socket.value = from_task
@@ -69,20 +73,16 @@ def load_workflow_json(file_name: str) -> WorkGraph:
6973
# link[SOURCE_PORT_LABEL] = "__result__"
7074
# because we are not define the outputs explicitly during the pythonjob creation
7175
# we add it here, and assume the output exit
72-
try:
73-
if link[SOURCE_PORT_LABEL] not in from_task.outputs:
74-
# if str(link["sourcePort"]) not in from_task.outputs:
75-
from_socket = from_task.add_output_spec(
76-
"workgraph.any",
77-
name=link[SOURCE_PORT_LABEL],
78-
)
79-
else:
80-
from_socket = from_task.outputs[link[SOURCE_PORT_LABEL]]
81-
76+
if link[SOURCE_PORT_LABEL] not in from_task.outputs:
77+
# if str(link["sourcePort"]) not in from_task.outputs:
78+
from_socket = from_task.add_output_spec(
79+
"workgraph.any",
80+
name=link[SOURCE_PORT_LABEL],
81+
)
82+
else:
83+
from_socket = from_task.outputs[link[SOURCE_PORT_LABEL]]
84+
if isinstance(to_task, Task):
8285
wg.add_link(from_socket, to_socket)
83-
except:
84-
breakpoint()
85-
pass
8686
except Exception as e:
8787
traceback.print_exc()
8888
print("Failed to link", link, "with error:", e)

0 commit comments

Comments
 (0)