Skip to content

Commit eebb50e

Browse files
committed
properly pass through default values to inner workgraph
1 parent 630d69a commit eebb50e

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

src/python_workflow_definition/aiida.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def load_workflow_json_nested(file_name: str) -> WorkGraph:
5555

5656
wg = WorkGraph(
5757
inputs=namespace(**inputs_ns) if inputs_ns else None,
58-
outputs=namespace(**outputs_ns) if outputs_ns else None
58+
outputs=namespace(**outputs_ns) if outputs_ns else None,
5959
)
6060
else:
6161
wg = WorkGraph()
@@ -95,10 +95,14 @@ def load_workflow_json_nested(file_name: str) -> WorkGraph:
9595
elif node_type == "input":
9696
# Store input node info for later connection to wg.inputs
9797
input_node_mapping[node_id] = node["name"]
98-
# Also create a data node for direct value setting if needed
98+
# Set default value on the workflow's exposed input if provided
9999
if "value" in node and node["value"] is not None:
100100
value = node["value"]
101101
data_node = general_serializer(value)
102+
# Set the default on the workflow's exposed input
103+
if hasattr(wg.inputs, node["name"]):
104+
setattr(wg.inputs, node["name"], data_node)
105+
# Also store in mapping for direct connections in non-nested contexts
102106
task_name_mapping[node_id] = data_node
103107

104108
elif node_type == "output":
@@ -113,15 +117,19 @@ def load_workflow_json_nested(file_name: str) -> WorkGraph:
113117
target_port = link[TARGET_PORT_LABEL]
114118

115119
# Handle output node connections
116-
target_node = next((n for n in data[NODES_LABEL] if str(n["id"]) == target_id), None)
120+
target_node = next(
121+
(n for n in data[NODES_LABEL] if str(n["id"]) == target_id), None
122+
)
117123
if target_node and target_node["type"] == "output":
118124
# This connects a task output to a workflow output
119125
from_task = task_name_mapping.get(source_id)
120126
if from_task and isinstance(from_task, Task):
121127
if source_port is None:
122128
source_port = "result"
123129
if source_port not in from_task.outputs:
124-
from_socket = from_task.add_output_spec("workgraph.any", name=source_port)
130+
from_socket = from_task.add_output_spec(
131+
"workgraph.any", name=source_port
132+
)
125133
else:
126134
from_socket = from_task.outputs[source_port]
127135

@@ -132,13 +140,17 @@ def load_workflow_json_nested(file_name: str) -> WorkGraph:
132140
continue
133141

134142
# Handle input node connections
135-
source_node = next((n for n in data[NODES_LABEL] if str(n["id"]) == source_id), None)
143+
source_node = next(
144+
(n for n in data[NODES_LABEL] if str(n["id"]) == source_id), None
145+
)
136146
if source_node and source_node["type"] == "input":
137147
to_task = task_name_mapping.get(target_id)
138148
if to_task and isinstance(to_task, Task):
139149
# Add target socket if it doesn't exist
140150
if target_port not in to_task.inputs:
141-
to_socket = to_task.add_input_spec("workgraph.any", name=target_port)
151+
to_socket = to_task.add_input_spec(
152+
"workgraph.any", name=target_port
153+
)
142154
else:
143155
to_socket = to_task.inputs[target_port]
144156

@@ -177,7 +189,9 @@ def load_workflow_json_nested(file_name: str) -> WorkGraph:
177189

178190
# Add source socket if needed
179191
if source_port not in from_task.outputs:
180-
from_socket = from_task.add_output_spec("workgraph.any", name=source_port)
192+
from_socket = from_task.add_output_spec(
193+
"workgraph.any", name=source_port
194+
)
181195
else:
182196
from_socket = from_task.outputs[source_port]
183197

@@ -188,6 +202,7 @@ def load_workflow_json_nested(file_name: str) -> WorkGraph:
188202

189203
return wg
190204

205+
191206
def load_workflow_json(file_name: str) -> WorkGraph:
192207

193208
data = PythonWorkflowDefinitionWorkflow.load_json_file(file_name=file_name)

0 commit comments

Comments
 (0)