33
44from aiida import orm
55from aiida_pythonjob .data .serializer import general_serializer
6- from aiida_workgraph import WorkGraph , task
6+ from aiida_workgraph import WorkGraph , task , Task , namespace
77from aiida_workgraph .socket import TaskSocketNamespace
8-
8+ from dataclasses import replace
9+ from node_graph .node_spec import SchemaSource
910from python_workflow_definition .models import PythonWorkflowDefinitionWorkflow
1011from python_workflow_definition .shared import (
1112 convert_nodes_list_to_dict ,
@@ -38,10 +39,12 @@ def load_workflow_json(file_name: str) -> WorkGraph:
3839 p , m = identifier .rsplit ("." , 1 )
3940 mod = import_module (p )
4041 func = getattr (mod , m )
41- wg .add_task (func )
42- # Remove the default result output, because we will add the outputs later from the data in the link
43- del wg .tasks [- 1 ].outputs ["result" ]
44- task_name_mapping [id ] = wg .tasks [- 1 ]
42+ decorated_func = task (outputs = namespace ())(func )
43+ new_task = wg .add_task (decorated_func )
44+ new_task .spec = replace (
45+ new_task .spec , schema_source = SchemaSource .EMBEDDED
46+ )
47+ task_name_mapping [id ] = new_task
4548 else :
4649 # data task
4750 data_node = general_serializer (identifier )
@@ -53,10 +56,11 @@ def load_workflow_json(file_name: str) -> WorkGraph:
5356 to_task = task_name_mapping [str (link [TARGET_LABEL ])]
5457 # if the input is not exit, it means we pass the data into to the kwargs
5558 # in this case, we add the input socket
56- if link [TARGET_PORT_LABEL ] not in to_task .inputs :
57- to_socket = to_task .add_input_spec ("workgraph.any" , name = link [TARGET_PORT_LABEL ])
58- else :
59- to_socket = to_task .inputs [link [TARGET_PORT_LABEL ]]
59+ if isinstance (to_task , Task ):
60+ if link [TARGET_PORT_LABEL ] not in to_task .inputs :
61+ to_socket = to_task .add_input_spec ("workgraph.any" , name = link [TARGET_PORT_LABEL ])
62+ else :
63+ to_socket = to_task .inputs [link [TARGET_PORT_LABEL ]]
6064 from_task = task_name_mapping [str (link [SOURCE_LABEL ])]
6165 if isinstance (from_task , orm .Data ):
6266 to_socket .value = from_task
@@ -69,20 +73,16 @@ def load_workflow_json(file_name: str) -> WorkGraph:
6973 # link[SOURCE_PORT_LABEL] = "__result__"
7074 # because we are not define the outputs explicitly during the pythonjob creation
7175 # we add it here, and assume the output exit
72- try :
73- if link [SOURCE_PORT_LABEL ] not in from_task .outputs :
74- # if str(link["sourcePort"]) not in from_task.outputs:
75- from_socket = from_task .add_output_spec (
76- "workgraph.any" ,
77- name = link [SOURCE_PORT_LABEL ],
78- )
79- else :
80- from_socket = from_task .outputs [link [SOURCE_PORT_LABEL ]]
81-
76+ if link [SOURCE_PORT_LABEL ] not in from_task .outputs :
77+ # if str(link["sourcePort"]) not in from_task.outputs:
78+ from_socket = from_task .add_output_spec (
79+ "workgraph.any" ,
80+ name = link [SOURCE_PORT_LABEL ],
81+ )
82+ else :
83+ from_socket = from_task .outputs [link [SOURCE_PORT_LABEL ]]
84+ if isinstance (to_task , Task ):
8285 wg .add_link (from_socket , to_socket )
83- except :
84- breakpoint ()
85- pass
8686 except Exception as e :
8787 traceback .print_exc ()
8888 print ("Failed to link" , link , "with error:" , e )
0 commit comments