Skip to content

Commit 9a5585b

Browse files
committed
Current WG working version
1 parent a55574d commit 9a5585b

File tree

3 files changed

+71
-43
lines changed

3 files changed

+71
-43
lines changed

python_workflow_definition/src/python_workflow_definition/aiida.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -171,27 +171,29 @@ def construct_wg_simple(add_x_and_y_func, add_x_and_y_and_z_func) -> WorkGraph:
171171

172172

173173
def construct_wg_qe(
174-
get_dict,
175-
get_list,
176174
get_bulk_structure,
177175
calculate_qe,
178176
generate_structures,
179177
plot_energy_volume_curve,
180178
strain_lst,
181179
):
182180

181+
from .shared import get_dict
182+
from .shared import get_list
183+
183184
# NOTE: `get_dict` is `get_input_dict`, to compile the input values for the calc tasks
184185
# NOTE: `add_link` must be from outputs to inputs
185-
wg = WorkGraph()
186+
wg = WorkGraph("wg-qe")
186187

187188
get_bulk_structure_task = wg.add_task(
188189
get_bulk_structure,
189-
name="bulk",
190+
name="get_bulk_structure",
190191
register_pickle_by_value=True,
191192
)
192193

193194
relax_task = wg.add_task(
194195
calculate_qe,
196+
# ! I don't like the `mini` name...
195197
name="mini",
196198
register_pickle_by_value=True,
197199
)
@@ -207,7 +209,6 @@ def construct_wg_qe(
207209

208210
scf_qe_tasks = []
209211
for i, strain in enumerate(strain_lst):
210-
211212
generate_structures_task.add_output("workgraph.any", f"s_{i}")
212213

213214
scf_qe_task = wg.add_task(
@@ -219,6 +220,7 @@ def construct_wg_qe(
219220

220221
plot_energy_volume_curve_task = wg.add_task(
221222
plot_energy_volume_curve,
223+
name="plot_energy_volume_curve",
222224
register_pickle_by_value=True,
223225
)
224226

@@ -248,7 +250,10 @@ def construct_wg_qe(
248250

249251
# ? relax or SCF, or general? -> Should be relax
250252
relax_get_dict_task = wg.add_task(
251-
task.pythonjob(outputs=["dict"])(get_dict),
253+
task.pythonjob(
254+
# outputs=["structure", "calculation", "kpts", "pseudopotentials", "smearing"]
255+
# outputs=["dict"]
256+
)(get_dict),
252257
name="relax_get_dict",
253258
register_pickle_by_value=True,
254259
)
@@ -262,7 +267,7 @@ def construct_wg_qe(
262267
del pickle_node.TaskCls
263268

264269
pickle_kpts_task = wg.add_task(
265-
task.pythonjob(outputs=["kpts"])(pickle_node), name="kpts_task", value=[3, 3, 3]
270+
task.pythonjob(outputs=["kpts"])(pickle_node), name="kpts_task", value=[1, 1, 1] # FIXME: Back to [3, 3, 3]
266271
)
267272
del pickle_node.TaskCls
268273

@@ -287,20 +292,28 @@ def construct_wg_qe(
287292

288293
strain_dir_tasks, scf_get_dict_tasks = [], []
289294
for i, strain in enumerate(strain_lst):
290-
291295
strain_dir = f"strain_{i}"
292296

293297
strain_dir_task = wg.add_task(
294298
task.pythonjob()(pickle_node),
295-
name=strain_dir,
299+
name=f"pickle_{strain_dir}_dir",
296300
value=strain_dir,
297301
register_pickle_by_value=True,
298302
)
299303
del pickle_node.TaskCls
300304
strain_dir_tasks.append(strain_dir_task)
301305

302306
scf_get_dict_task = wg.add_task(
303-
task.pythonjob(outputs=["dict"])(get_dict),
307+
task.pythonjob(
308+
# outputs=[
309+
# "dict"
310+
# # "structure",
311+
# # "calculation",
312+
# # "kpts",
313+
# # "pseudopotentials",
314+
# # "smearing",
315+
# ]
316+
)(get_dict),
304317
name=f"get_dict_{i}",
305318
register_pickle_by_value=True,
306319
)
@@ -330,55 +343,64 @@ def construct_wg_qe(
330343
del get_list.TaskCls
331344

332345
# Add remaining links
333-
wg.add_link(pickle_element_task.outputs.element, get_bulk_structure_task.inputs.element)
346+
wg.add_link(
347+
pickle_element_task.outputs.element, get_bulk_structure_task.inputs.element
348+
)
334349
wg.add_link(pickle_a_task.outputs.a, get_bulk_structure_task.inputs.a)
335350
wg.add_link(pickle_cubic_task.outputs.cubic, get_bulk_structure_task.inputs.cubic)
336351

337352
# `.set` rather than `.add_link`, as get_dict takes `**kwargs` as input
338353
relax_get_dict_task.set(
339354
{
340-
"structure": get_bulk_structure_task.outputs.bulk_structure,
355+
"structure": get_bulk_structure_task.outputs.structure,
341356
"calculation": pickle_calc_type_relax_task.outputs.calc,
342357
"kpts": pickle_kpts_task.outputs.kpts,
343-
"pseudopotential": pickle_pp_task.outputs.pp,
358+
"pseudopotentials": pickle_pp_task.outputs.pp,
344359
"smearing": pickle_smearing_task.outputs.smearing,
345360
}
346361
)
347362

348-
wg.add_link(relax_get_dict_task.outputs.dict, relax_task.inputs.input_dict)
363+
wg.add_link(relax_get_dict_task.outputs.result, relax_task.inputs.input_dict)
349364
wg.add_link(
350365
pickle_relax_workdir_task.outputs.relax_workdir,
351366
relax_task.inputs.working_directory,
352367
)
353368

354369
wg.add_link(relax_task.outputs.structure, generate_structures_task.inputs.structure)
355-
wg.add_link(strain_lst_task.outputs.strain_lst, generate_structures_task.inputs.strain_lst)
370+
wg.add_link(
371+
strain_lst_task.outputs.strain_lst, generate_structures_task.inputs.strain_lst
372+
)
356373

357-
counter = 0
358-
for scf_get_dict_task, scf_qe_task, strain_dir_task in list(
359-
zip(scf_get_dict_tasks, scf_qe_tasks, strain_dir_tasks)
374+
for i, (scf_get_dict_task, scf_qe_task, strain_dir_task) in enumerate(
375+
list(zip(scf_get_dict_tasks, scf_qe_tasks, strain_dir_tasks))
360376
):
361-
# print(scf_get_dict_task, scf_qe_task, strain_dir_task)
362377
scf_get_dict_task.set(
363378
{
364379
"structure": generate_structures_task.outputs[f"s_{i}"],
365380
"calculation": pickle_calc_type_scf_task.outputs.calc,
366381
"kpts": pickle_kpts_task.outputs.kpts,
367-
"pseudopotential": pickle_pp_task.outputs.pp,
382+
"pseudopotentials": pickle_pp_task.outputs.pp,
368383
"smearing": pickle_smearing_task.outputs.smearing,
369384
}
370385
)
371386

372-
wg.add_link(scf_get_dict_task.outputs.dict, scf_qe_task.inputs.input_dict)
373-
wg.add_link(strain_dir_task.outputs.result, scf_qe_task.inputs.working_directory)
387+
# import ipdb; ipdb.set_trace()
388+
wg.add_link(scf_get_dict_task.outputs.result, scf_qe_task.inputs.input_dict)
389+
wg.add_link(
390+
strain_dir_task.outputs.result, scf_qe_task.inputs.working_directory
391+
)
374392

375393
# collect energy and volume
376-
wg.add_link(scf_qe_task.outputs.energy, wg.tasks.get_energies.inputs.kwargs)
377-
wg.add_link(scf_qe_task.outputs.volume, wg.tasks.get_volumes.inputs.kwargs)
394+
wg.add_link(scf_qe_task.outputs.energy, get_energies_task.inputs.kwargs)
395+
wg.add_link(scf_qe_task.outputs.volume, get_volumes_task.inputs.kwargs)
378396

379-
counter += 1
380-
381-
wg.add_link(get_volumes_task.outputs.volumes, plot_energy_volume_curve_task.inputs.volume_lst)
382-
wg.add_link(get_energies_task.outputs.energies, plot_energy_volume_curve_task.inputs.energy_lst)
397+
wg.add_link(
398+
get_volumes_task.outputs.volumes,
399+
plot_energy_volume_curve_task.inputs.volume_lst,
400+
)
401+
wg.add_link(
402+
get_energies_task.outputs.energies,
403+
plot_energy_volume_curve_task.inputs.energy_lst,
404+
)
383405

384406
return wg

python_workflow_definition/src/python_workflow_definition/shared.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
def get_dict(**kwargs):
2+
# NOTE: In WG, this will automatically be wrapped in a dict with the `result` key
23
return {k: v for k, v in kwargs.items()}
4+
# return {'dict': {k: v for k, v in kwargs.items()}}
35

46

57
def get_list(**kwargs):

quantum_espresso_workflow.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,24 @@
1212
def write_input(input_dict, working_directory="."):
1313
filename = os.path.join(working_directory, "input.pwi")
1414
os.makedirs(working_directory, exist_ok=True)
15-
write(
16-
filename=filename,
17-
images=Atoms(**input_dict["structure"]),
18-
Crystal=True,
19-
kpts=input_dict["kpts"],
20-
input_data={
21-
"calculation": input_dict["calculation"],
22-
"occupations": "smearing",
23-
"degauss": input_dict["smearing"],
24-
},
25-
pseudopotentials=input_dict["pseudopotentials"],
26-
tstress=True,
27-
tprnfor=True,
28-
)
15+
try:
16+
write(
17+
filename=filename,
18+
images=Atoms(**input_dict["structure"]),
19+
Crystal=True,
20+
kpts=input_dict["kpts"],
21+
input_data={
22+
"calculation": input_dict["calculation"],
23+
"occupations": "smearing",
24+
"degauss": input_dict["smearing"],
25+
},
26+
pseudopotentials=input_dict["pseudopotentials"],
27+
tstress=True,
28+
tprnfor=True,
29+
)
30+
except KeyError:
31+
print('INPUT_DICT')
32+
print(input_dict)
2933

3034

3135
def collect_output(working_directory="."):
@@ -82,7 +86,7 @@ def get_bulk_structure(element, a, cubic):
8286
a=a,
8387
cubic=cubic,
8488
)
85-
return atoms_to_json_dict(atoms=ase_atoms)
89+
return {'structure': atoms_to_json_dict(atoms=ase_atoms)}
8690

8791

8892
def atoms_to_json_dict(atoms):

0 commit comments

Comments
 (0)