Skip to content

Commit a55574d

Browse files
committed
Manual AiiDA WG creation now seems fixed.
1 parent 5a8a66c commit a55574d

File tree

1 file changed

+100
-88
lines changed
  • python_workflow_definition/src/python_workflow_definition

1 file changed

+100
-88
lines changed

python_workflow_definition/src/python_workflow_definition/aiida.py

Lines changed: 100 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -117,19 +117,19 @@ def write_workflow_json(wg, file_name):
117117

118118
i += 1
119119

120-
# for link in wgdata["links"]:
121-
# if (
122-
# wgdata["tasks"][link["from_node"]]["executor"]["callable_name"]
123-
# == "pickle_node"
124-
# ):
125-
# link["from_socket"] = None
126-
# link["source"] = node_name_mapping[link["from_node"]]
127-
# del link["from_node"]
128-
# link["target"] = node_name_mapping[link["to_node"]]
129-
# del link["to_node"]
130-
# link["sourceHandle"] = link.pop("from_socket")
131-
# link["targetHandle"] = link.pop("to_socket")
132-
# data["edges"].append(link)
120+
for link in wgdata["links"]:
121+
if (
122+
wgdata["tasks"][link["from_node"]]["executor"]["callable_name"]
123+
== "pickle_node"
124+
):
125+
link["from_socket"] = None
126+
link["source"] = node_name_mapping[link["from_node"]]
127+
del link["from_node"]
128+
link["target"] = node_name_mapping[link["to_node"]]
129+
del link["to_node"]
130+
link["sourceHandle"] = link.pop("from_socket")
131+
link["targetHandle"] = link.pop("to_socket")
132+
data["edges"].append(link)
133133

134134
with open(file_name, "w") as f:
135135
# json.dump({"nodes": data[], "edges": edges_new_lst}, f)
@@ -179,6 +179,9 @@ def construct_wg_qe(
179179
plot_energy_volume_curve,
180180
strain_lst,
181181
):
182+
183+
# NOTE: `get_dict` is `get_input_dict`, to compile the input values for the calc tasks
184+
# NOTE: `add_link` must be from outputs to inputs
182185
wg = WorkGraph()
183186

184187
get_bulk_structure_task = wg.add_task(
@@ -190,68 +193,32 @@ def construct_wg_qe(
190193
relax_task = wg.add_task(
191194
calculate_qe,
192195
name="mini",
193-
# input_dict=wg.tasks.get_dict.outputs.result,
194-
# working_directory=wg.tasks.relax_workdir.outputs.result,
195196
register_pickle_by_value=True,
196197
)
197198

198199
generate_structures_task = wg.add_task(
199200
generate_structures,
200201
name="generate_structures",
201-
# structure=wg.tasks.mini.outputs.structure,
202-
# strain_lst=wg.tasks.strain_lst.outputs.result,
203202
register_pickle_by_value=True,
204203
)
205204

206205
# here we add the structure outputs based on the number of strains
207206
del wg.tasks.generate_structures.outputs["result"]
208207

209-
qe_tasks = []
208+
scf_qe_tasks = []
210209
for i, strain in enumerate(strain_lst):
211210

212211
generate_structures_task.add_output("workgraph.any", f"s_{i}")
213212

214-
# Possible remove `get_dict_task` here to have it later
215-
# get_dict_task = wg.add_task(
216-
# get_dict,
217-
# name=f"get_dict_{i}",
218-
# # calculation=wg.tasks.calculation_scf.outputs.result,
219-
# # structure=wg.tasks.generate_structures.outputs[f"s_{i}"],
220-
# # kpts=wg.tasks.kpts.outputs.result,
221-
# # pseudopotentials=wg.tasks.pseudopotentials.outputs.result,
222-
# # smearing=wg.tasks.smearing.outputs.result,
223-
# register_pickle_by_value=True,
224-
# )
225-
226-
strain_dir = f"strain_{i}"
227-
228-
# strain_dir_task = wg.add_task(
229-
# task.pythonjob(outputs=[strain_dir])(pickle_node),
230-
# name=strain_dir,
231-
# value=strain_dir,
232-
# register_pickle_by_value=True,
233-
# )
234-
# del pickle_node.TaskCls
235-
236-
# import ipdb; ipdb.set_trace()
237-
qe_task = wg.add_task(
213+
scf_qe_task = wg.add_task(
238214
calculate_qe,
239215
name=f"qe_{i}",
240-
# input_dict=get_dict_task.outputs.result,
241-
# working_directory=strain_dir_task.outputs[strain_dir],
242216
register_pickle_by_value=True,
243217
)
244-
qe_tasks.append(qe_task)
245-
246-
# collect energy and volume
247-
# TODO: Maybe put this outside, in a separate for-loop to again try to fix the order
248-
# wg.add_link(qe_task.outputs.energy, wg.tasks.get_energies.inputs.kwargs)
249-
# wg.add_link(qe_task.outputs.volume, wg.tasks.get_volumes.inputs.kwargs)
218+
scf_qe_tasks.append(scf_qe_task)
250219

251220
plot_energy_volume_curve_task = wg.add_task(
252221
plot_energy_volume_curve,
253-
# volume_lst=wg.tasks.get_volumes.outputs.result,
254-
# energy_lst=wg.tasks.get_energies.outputs.result,
255222
register_pickle_by_value=True,
256223
)
257224

@@ -262,7 +229,9 @@ def construct_wg_qe(
262229
)
263230
del pickle_node.TaskCls
264231

265-
pickle_a_task = wg.add_task(task.pythonjob(outputs=["a"])(pickle_node), name="pickle_a", value=4.05)
232+
pickle_a_task = wg.add_task(
233+
task.pythonjob(outputs=["a"])(pickle_node), name="pickle_a", value=4.05
234+
)
266235
del pickle_node.TaskCls
267236

268237
pickle_cubic_task = wg.add_task(
@@ -277,34 +246,36 @@ def construct_wg_qe(
277246
)
278247
del pickle_node.TaskCls
279248

280-
# ? relax or SCF, or general?
281-
get_dict_task = wg.add_task(
282-
get_dict,
283-
name="get_dict",
284-
structure=wg.tasks.bulk.outputs.result,
285-
# calculation=wg.tasks.calculation.outputs.result,
286-
# kpts=wg.tasks.kpts.outputs.result,
287-
# pseudopotentials=wg.tasks.pseudopotentials.outputs.result,
288-
# smearing=wg.tasks.smearing.outputs.result,
249+
# ? relax or SCF, or general? -> Should be relax
250+
relax_get_dict_task = wg.add_task(
251+
task.pythonjob(outputs=["dict"])(get_dict),
252+
name="relax_get_dict",
289253
register_pickle_by_value=True,
290254
)
255+
del get_dict.TaskCls
291256

292-
pp_task = wg.add_task(
293-
task.pythonjob(pickle_node),
257+
pickle_pp_task = wg.add_task(
258+
task.pythonjob(outputs=["pp"])(pickle_node),
294259
name="pseudopotentials",
295260
value={"Al": "Al.pbe-n-kjpaw_psl.1.0.0.UPF"},
296261
)
297262
del pickle_node.TaskCls
298263

299-
kpts_task = wg.add_task(
264+
pickle_kpts_task = wg.add_task(
300265
task.pythonjob(outputs=["kpts"])(pickle_node), name="kpts_task", value=[3, 3, 3]
301266
)
302267
del pickle_node.TaskCls
303268

304-
vc_relax_task = wg.add_task(task.pythonjob()(pickle_node), name="calculation", value="vc-relax")
269+
pickle_calc_type_relax_task = wg.add_task(
270+
task.pythonjob(outputs=["calc"])(pickle_node),
271+
name="calc_type_relax",
272+
value="vc-relax",
273+
)
305274
del pickle_node.TaskCls
306-
307-
smearing_task = wg.add_task(task.pythonjob()(pickle_node), name="smearing", value=0.02)
275+
276+
pickle_smearing_task = wg.add_task(
277+
task.pythonjob(outputs=["smearing"])(pickle_node), name="smearing", value=0.02
278+
)
308279
del pickle_node.TaskCls
309280

310281
strain_lst_task = wg.add_task(
@@ -314,41 +285,36 @@ def construct_wg_qe(
314285
)
315286
del pickle_node.TaskCls
316287

317-
strain_dir_tasks = []
288+
strain_dir_tasks, scf_get_dict_tasks = [], []
318289
for i, strain in enumerate(strain_lst):
319290

320291
strain_dir = f"strain_{i}"
321292

322293
strain_dir_task = wg.add_task(
323-
task.pythonjob(outputs=[strain_dir])(pickle_node),
294+
task.pythonjob()(pickle_node),
324295
name=strain_dir,
325296
value=strain_dir,
326297
register_pickle_by_value=True,
327298
)
328299
del pickle_node.TaskCls
329300
strain_dir_tasks.append(strain_dir_task)
330301

331-
# Possible remove `get_dict_task` here to have it later
332-
get_dict_task = wg.add_task(
333-
get_dict,
302+
scf_get_dict_task = wg.add_task(
303+
task.pythonjob(outputs=["dict"])(get_dict),
334304
name=f"get_dict_{i}",
335-
# calculation=wg.tasks.calculation_scf.outputs.result,
336-
# structure=wg.tasks.generate_structures.outputs[f"s_{i}"],
337-
# kpts=wg.tasks.kpts.outputs.result,
338-
# pseudopotentials=wg.tasks.pseudopotentials.outputs.result,
339-
# smearing=wg.tasks.smearing.outputs.result,
340305
register_pickle_by_value=True,
341306
)
307+
del get_dict.TaskCls
308+
scf_get_dict_tasks.append(scf_get_dict_task)
342309

343310
if i == 0:
344-
scf_task = wg.add_task(task.pythonjob(pickle_node), name="calculation_scf", value="scf")
311+
pickle_calc_type_scf_task = wg.add_task(
312+
task.pythonjob(outputs=["calc"])(pickle_node),
313+
name="calc_type_scf",
314+
value="scf",
315+
)
345316
del pickle_node.TaskCls
346317

347-
# collect energy and volume
348-
# TODO: Maybe put this outside, in a separate for-loop to again try to fix the order
349-
# wg.add_link(qe_task.outputs.energy, wg.tasks.get_energies.inputs.kwargs)
350-
# wg.add_link(qe_task.outputs.volume, wg.tasks.get_volumes.inputs.kwargs)
351-
352318
get_energies_task = wg.add_task(
353319
task.pythonjob(outputs=["energies"])(get_list),
354320
name="get_energies",
@@ -364,9 +330,55 @@ def construct_wg_qe(
364330
del get_list.TaskCls
365331

366332
# Add remaining links
367-
wg.add_link(get_bulk_structure_task.inputs.element, element=pickle_element_task.outputs.element)
368-
# ,
369-
# a=wg.tasks.pickle_a.outputs.a,
370-
# cubic=wg.tasks.pickle_cubic.outputs.result,
333+
wg.add_link(pickle_element_task.outputs.element, get_bulk_structure_task.inputs.element)
334+
wg.add_link(pickle_a_task.outputs.a, get_bulk_structure_task.inputs.a)
335+
wg.add_link(pickle_cubic_task.outputs.cubic, get_bulk_structure_task.inputs.cubic)
336+
337+
# `.set` rather than `.add_link`, as get_dict takes `**kwargs` as input
338+
relax_get_dict_task.set(
339+
{
340+
"structure": get_bulk_structure_task.outputs.bulk_structure,
341+
"calculation": pickle_calc_type_relax_task.outputs.calc,
342+
"kpts": pickle_kpts_task.outputs.kpts,
343+
"pseudopotential": pickle_pp_task.outputs.pp,
344+
"smearing": pickle_smearing_task.outputs.smearing,
345+
}
346+
)
347+
348+
wg.add_link(relax_get_dict_task.outputs.dict, relax_task.inputs.input_dict)
349+
wg.add_link(
350+
pickle_relax_workdir_task.outputs.relax_workdir,
351+
relax_task.inputs.working_directory,
352+
)
353+
354+
wg.add_link(relax_task.outputs.structure, generate_structures_task.inputs.structure)
355+
wg.add_link(strain_lst_task.outputs.strain_lst, generate_structures_task.inputs.strain_lst)
356+
357+
counter = 0
358+
for scf_get_dict_task, scf_qe_task, strain_dir_task in list(
359+
zip(scf_get_dict_tasks, scf_qe_tasks, strain_dir_tasks)
360+
):
361+
# print(scf_get_dict_task, scf_qe_task, strain_dir_task)
362+
scf_get_dict_task.set(
363+
{
364+
"structure": generate_structures_task.outputs[f"s_{i}"],
365+
"calculation": pickle_calc_type_scf_task.outputs.calc,
366+
"kpts": pickle_kpts_task.outputs.kpts,
367+
"pseudopotential": pickle_pp_task.outputs.pp,
368+
"smearing": pickle_smearing_task.outputs.smearing,
369+
}
370+
)
371+
372+
wg.add_link(scf_get_dict_task.outputs.dict, scf_qe_task.inputs.input_dict)
373+
wg.add_link(strain_dir_task.outputs.result, scf_qe_task.inputs.working_directory)
374+
375+
# collect energy and volume
376+
wg.add_link(scf_qe_task.outputs.energy, wg.tasks.get_energies.inputs.kwargs)
377+
wg.add_link(scf_qe_task.outputs.volume, wg.tasks.get_volumes.inputs.kwargs)
378+
379+
counter += 1
380+
381+
wg.add_link(get_volumes_task.outputs.volumes, plot_energy_volume_curve_task.inputs.volume_lst)
382+
wg.add_link(get_energies_task.outputs.energies, plot_energy_volume_curve_task.inputs.energy_lst)
371383

372384
return wg

0 commit comments

Comments
 (0)