Skip to content

Commit d474622

Browse files
committed
round-trip working
1 parent 424a068 commit d474622

File tree

4 files changed

+419
-53
lines changed

4 files changed

+419
-53
lines changed
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""
2+
Round-trip test for nested workflows.
3+
4+
This script demonstrates that:
5+
1. Loading a nested workflow JSON preserves all structure and values
6+
2. Exporting a loaded workflow produces identical JSON
7+
3. Multiple round-trips are stable (load -> export -> load -> export produces identical results)
8+
"""
9+
10+
import json
11+
from pathlib import Path
12+
from python_workflow_definition.aiida import load_workflow_json, write_workflow_json
13+
from aiida import load_profile
14+
15+
# Load AiiDA profile
16+
load_profile()
17+
18+
19+
def compare_json_files(file1: str, file2: str) -> bool:
20+
"""Compare two JSON files for structural equality."""
21+
with open(file1) as f1, open(file2) as f2:
22+
data1 = json.load(f1)
23+
data2 = json.load(f2)
24+
# Compare as sorted JSON strings to ignore ordering
25+
return json.dumps(data1, sort_keys=True) == json.dumps(data2, sort_keys=True)
26+
27+
28+
def print_workflow_info(wg, name: str):
29+
"""Print information about a loaded workflow."""
30+
print(f"\n{name}:")
31+
32+
# Count tasks (excluding internal graph tasks)
33+
task_count = len([t for t in wg.tasks if t.name not in ["graph_inputs", "graph_outputs", "graph_ctx"]])
34+
print(f" Tasks: {task_count}")
35+
36+
# Show inputs
37+
if hasattr(wg.inputs, '_sockets'):
38+
print(" Inputs:")
39+
for name, socket in wg.inputs._sockets.items():
40+
if not name.startswith('_') and name != 'metadata':
41+
if hasattr(socket, 'value') and socket.value is not None:
42+
value = socket.value.value if hasattr(socket.value, 'value') else socket.value
43+
print(f" {name} = {value}")
44+
45+
# Show outputs
46+
if hasattr(wg.outputs, '_sockets'):
47+
output_names = [name for name in wg.outputs._sockets.keys()
48+
if not name.startswith('_') and name != 'metadata']
49+
if output_names:
50+
print(f" Outputs: {', '.join(output_names)}")
51+
52+
# Check for nested workflows
53+
nested_count = 0
54+
for task in wg.tasks:
55+
if hasattr(task, 'tasks'):
56+
nested_tasks = [t for t in task.tasks if t.name not in ['graph_inputs', 'graph_outputs', 'graph_ctx']]
57+
if len(nested_tasks) > 0:
58+
nested_count += 1
59+
print(f" Nested workflow '{task.name}' with {len(nested_tasks)} tasks")
60+
# Show nested workflow defaults
61+
for subtask in task.tasks:
62+
if subtask.name == 'graph_inputs' and hasattr(subtask, 'outputs'):
63+
print(" Default inputs:")
64+
for out in subtask.outputs:
65+
if hasattr(out, '_name') and not out._name.startswith('_'):
66+
value = out.value.value if hasattr(out.value, 'value') else out.value
67+
print(f" {out._name} = {value}")
68+
69+
70+
def main():
71+
print("=" * 70)
72+
print("NESTED WORKFLOW ROUND-TRIP TEST")
73+
print("=" * 70)
74+
75+
# Define file paths
76+
original_file = "main.pwd.json"
77+
roundtrip1_file = "roundtrip1.pwd.json"
78+
roundtrip2_file = "roundtrip2.pwd.json"
79+
nested_original = "prod_div.json"
80+
nested_export = "nested_1.json"
81+
82+
# Test 1: Load original workflow
83+
print("\n[1] Loading original workflow...")
84+
wg_original = load_workflow_json(original_file)
85+
print_workflow_info(wg_original, "Original workflow")
86+
87+
# Test 2: Export to roundtrip1
88+
print("\n[2] Exporting to roundtrip1.pwd.json...")
89+
write_workflow_json(wg_original, roundtrip1_file)
90+
print(f" Exported main workflow to {roundtrip1_file}")
91+
if Path(nested_export).exists():
92+
print(f" Exported nested workflow to {nested_export}")
93+
94+
# Test 3: Load roundtrip1
95+
print("\n[3] Loading roundtrip1.pwd.json...")
96+
wg_roundtrip1 = load_workflow_json(roundtrip1_file)
97+
print_workflow_info(wg_roundtrip1, "Roundtrip 1 workflow")
98+
99+
# Test 4: Export to roundtrip2
100+
print("\n[4] Exporting to roundtrip2.pwd.json...")
101+
write_workflow_json(wg_roundtrip1, roundtrip2_file)
102+
print(f" Exported to {roundtrip2_file}")
103+
104+
# Test 5: Compare files
105+
print("\n[5] Comparing JSON files...")
106+
print("-" * 70)
107+
108+
# Compare main workflows
109+
main_match = compare_json_files(roundtrip1_file, roundtrip2_file)
110+
print(f" roundtrip1 == roundtrip2: {'PASS' if main_match else 'FAIL'}")
111+
112+
# Compare nested workflows
113+
if Path(nested_original).exists() and Path(nested_export).exists():
114+
nested_match = compare_json_files(nested_original, nested_export)
115+
print(f" {nested_original} == {nested_export}: {'PASS' if nested_match else 'FAIL'}")
116+
else:
117+
nested_match = True # If files don't exist, consider it a pass
118+
119+
# Test 6: Load roundtrip2 and verify
120+
print("\n[6] Loading roundtrip2.pwd.json for verification...")
121+
wg_roundtrip2 = load_workflow_json(roundtrip2_file)
122+
print_workflow_info(wg_roundtrip2, "Roundtrip 2 workflow")
123+
124+
# Final verdict
125+
print("\n" + "=" * 70)
126+
if main_match and nested_match:
127+
print("RESULT: ALL TESTS PASSED")
128+
print(" - Workflow structure preserved")
129+
print(" - Input/output values preserved")
130+
print(" - Nested workflow defaults preserved")
131+
print(" - Round-trip is stable and idempotent")
132+
result = 0
133+
else:
134+
print("RESULT: SOME TESTS FAILED")
135+
result = 1
136+
print("=" * 70)
137+
138+
# Cleanup
139+
print("\nCleaning up temporary files...")
140+
for temp_file in [roundtrip1_file, roundtrip2_file, nested_export]:
141+
if Path(temp_file).exists():
142+
Path(temp_file).unlink()
143+
print(f" Removed {temp_file}")
144+
145+
return result
146+
147+
148+
if __name__ == "__main__":
149+
exit(main())
Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,83 @@
1-
from aiida_workgraph import task
2-
from aiida import load_profile
1+
from aiida_workgraph import task, WorkGraph, namespace
2+
from aiida import load_profile, orm
3+
from python_workflow_definition.aiida import write_workflow_json
4+
from workflow import get_prod_and_div as _get_prod_and_div, get_sum as _get_sum, get_square as _get_square
35

46
load_profile()
57

68

7-
@task(outputs=["prod", "div"])
8-
def get_prod_and_div(x, y):
9-
return {"prod": x * y, "div": x / y}
9+
# Wrap the functions with @task decorator
10+
get_prod_and_div = task(outputs=["prod", "div"])(_get_prod_and_div)
11+
get_sum = task(_get_sum)
12+
get_square = task(_get_square)
1013

1114

12-
@task
13-
def get_sum(x, y):
14-
return x + y
15+
# Create nested workflow manually (corresponds to prod_div.json)
16+
nested_wg = WorkGraph(
17+
name="nested_workflow",
18+
inputs=namespace(x=namespace, y=namespace),
19+
outputs=namespace(result=namespace),
20+
)
1521

22+
# Add tasks to nested workflow
23+
t1 = nested_wg.add_task(get_prod_and_div)
24+
t2 = nested_wg.add_task(get_sum)
25+
t3 = nested_wg.add_task(get_square)
1626

17-
@task
18-
def get_square(x):
19-
return x**2
27+
# Connect nested workflow inputs to first task
28+
nested_wg.add_link(nested_wg.inputs.x, t1.inputs.x)
29+
nested_wg.add_link(nested_wg.inputs.y, t1.inputs.y)
2030

31+
# Connect tasks within nested workflow
32+
nested_wg.add_link(t1.outputs.prod, t2.inputs.x)
33+
nested_wg.add_link(t1.outputs.div, t2.inputs.y)
34+
nested_wg.add_link(t2.outputs.result, t3.inputs.x)
2135

22-
@task.graph
23-
def nested_workflow(x, y):
24-
"""Inner workflow from prod_div.json:
25-
- get_prod_and_div(x, y) → prod, div
26-
- get_sum(prod, div) → result
27-
- get_square(result) → result
28-
"""
29-
prod_and_div = get_prod_and_div(x=x, y=y)
30-
sum_result = get_sum(x=prod_and_div.prod, y=prod_and_div.div)
31-
square_result = get_square(x=sum_result.result)
32-
return square_result.result
36+
# Connect nested workflow output
37+
nested_wg.outputs.result = t3.outputs.result
3338

39+
# Set default values for nested workflow inputs
40+
nested_wg.inputs.x.value = orm.Float(1)
41+
nested_wg.inputs.y.value = orm.Float(2)
3442

35-
@task.graph
36-
def main_workflow(a, b, c):
37-
"""Outer workflow from main.pwd.json:
38-
- Pre-processing: get_prod_and_div(a, c) → prod, div
39-
- Nested workflow: nested_workflow(prod, div) → result
40-
- Post-processing: get_sum(result, b) → final_result
41-
"""
42-
# Pre-processing step
43-
preprocessing = get_prod_and_div(x=a, y=c)
4443

45-
# Nested workflow
46-
nested_result = nested_workflow(x=preprocessing.prod, y=preprocessing.div)
44+
# Create main workflow (corresponds to main.pwd.json)
45+
main_wg = WorkGraph(
46+
name="main_workflow",
47+
inputs=namespace(a=namespace, b=namespace, c=namespace),
48+
outputs=namespace(final_result=namespace),
49+
)
4750

48-
# Post-processing step
49-
final_result = get_sum(x=nested_result.result, y=b)
51+
# Add tasks to main workflow
52+
preprocessing = main_wg.add_task(get_prod_and_div)
53+
nested_task = main_wg.add_task(nested_wg) # Add the nested workflow as a task
54+
postprocessing = main_wg.add_task(get_sum)
5055

51-
return final_result.result
56+
# Connect main workflow inputs to preprocessing
57+
main_wg.add_link(main_wg.inputs.a, preprocessing.inputs.x)
58+
main_wg.add_link(main_wg.inputs.c, preprocessing.inputs.y)
5259

60+
# Connect preprocessing to nested workflow
61+
main_wg.add_link(preprocessing.outputs.prod, nested_task.inputs.x)
62+
main_wg.add_link(preprocessing.outputs.div, nested_task.inputs.y)
5363

54-
wg = main_workflow.build(a=3, b=2, c=4)
55-
wg.run()
64+
# Connect nested workflow to postprocessing
65+
main_wg.add_link(nested_task.outputs.result, postprocessing.inputs.x)
66+
main_wg.add_link(main_wg.inputs.b, postprocessing.inputs.y)
67+
68+
# Connect main workflow output
69+
main_wg.outputs.final_result = postprocessing.outputs.result
70+
71+
# Set default values for main workflow inputs
72+
main_wg.inputs.a.value = orm.Float(3)
73+
main_wg.inputs.b.value = orm.Float(2)
74+
main_wg.inputs.c.value = orm.Float(4)
75+
76+
77+
# Export to JSON (will create main_generated.pwd.json and nested_1.json)
78+
print("Exporting workflow to JSON files...")
79+
write_workflow_json(wg=main_wg, file_name="main_generated.pwd.json")
80+
print("✓ Exported to main_generated.pwd.json and nested_1.json")
81+
82+
# Optionally run the workflow
83+
# main_wg.run()

0 commit comments

Comments
 (0)