Skip to content

Commit 424a068

Browse files
committed
add code to generate nested aiida-workflow
1 parent 97b36ec commit 424a068

File tree

2 files changed

+57
-1
lines changed

2 files changed

+57
-1
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from aiida_workgraph import task
2+
from aiida import load_profile
3+
4+
load_profile()
5+
6+
7+
@task(outputs=["prod", "div"])
8+
def get_prod_and_div(x, y):
9+
return {"prod": x * y, "div": x / y}
10+
11+
12+
@task
13+
def get_sum(x, y):
14+
return x + y
15+
16+
17+
@task
18+
def get_square(x):
19+
return x**2
20+
21+
22+
@task.graph
23+
def nested_workflow(x, y):
24+
"""Inner workflow from prod_div.json:
25+
- get_prod_and_div(x, y) → prod, div
26+
- get_sum(prod, div) → result
27+
- get_square(result) → result
28+
"""
29+
prod_and_div = get_prod_and_div(x=x, y=y)
30+
sum_result = get_sum(x=prod_and_div.prod, y=prod_and_div.div)
31+
square_result = get_square(x=sum_result.result)
32+
return square_result.result
33+
34+
35+
@task.graph
36+
def main_workflow(a, b, c):
37+
"""Outer workflow from main.pwd.json:
38+
- Pre-processing: get_prod_and_div(a, c) → prod, div
39+
- Nested workflow: nested_workflow(prod, div) → result
40+
- Post-processing: get_sum(result, b) → final_result
41+
"""
42+
# Pre-processing step
43+
preprocessing = get_prod_and_div(x=a, y=c)
44+
45+
# Nested workflow
46+
nested_result = nested_workflow(x=preprocessing.prod, y=preprocessing.div)
47+
48+
# Post-processing step
49+
final_result = get_sum(x=nested_result.result, y=b)
50+
51+
return final_result.result
52+
53+
54+
wg = main_workflow.build(a=3, b=2, c=4)
55+
wg.run()

src/python_workflow_definition/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def check_value_format(cls, v: str):
6363
raise ValueError(msg)
6464
return v
6565

66+
6667
class PythonWorkflowDefinitionWorklowNode(PythonWorkflowDefinitionBaseNode):
6768
"""
6869
Model for function execution nodes.
@@ -90,7 +91,7 @@ def check_value_format(cls, v: str):
9091
PythonWorkflowDefinitionInputNode,
9192
PythonWorkflowDefinitionOutputNode,
9293
PythonWorkflowDefinitionFunctionNode,
93-
PythonWorkflowDefinitionWorklowNode
94+
PythonWorkflowDefinitionWorklowNode,
9495
],
9596
Field(discriminator="type"),
9697
]

0 commit comments

Comments
 (0)