@@ -171,27 +171,29 @@ def construct_wg_simple(add_x_and_y_func, add_x_and_y_and_z_func) -> WorkGraph:
171171
172172
173173def construct_wg_qe (
174- get_dict ,
175- get_list ,
176174 get_bulk_structure ,
177175 calculate_qe ,
178176 generate_structures ,
179177 plot_energy_volume_curve ,
180178 strain_lst ,
181179):
182180
181+ from .shared import get_dict
182+ from .shared import get_list
183+
183184 # NOTE: `get_dict` is `get_input_dict`, to compile the input values for the calc tasks
184185 # NOTE: `add_link` must be from outputs to inputs
185- wg = WorkGraph ()
186+ wg = WorkGraph ("wg-qe" )
186187
187188 get_bulk_structure_task = wg .add_task (
188189 get_bulk_structure ,
189- name = "bulk " ,
190+ name = "get_bulk_structure " ,
190191 register_pickle_by_value = True ,
191192 )
192193
193194 relax_task = wg .add_task (
194195 calculate_qe ,
196+ # ! I don't like the `mini` name...
195197 name = "mini" ,
196198 register_pickle_by_value = True ,
197199 )
@@ -207,7 +209,6 @@ def construct_wg_qe(
207209
208210 scf_qe_tasks = []
209211 for i , strain in enumerate (strain_lst ):
210-
211212 generate_structures_task .add_output ("workgraph.any" , f"s_{ i } " )
212213
213214 scf_qe_task = wg .add_task (
@@ -219,6 +220,7 @@ def construct_wg_qe(
219220
220221 plot_energy_volume_curve_task = wg .add_task (
221222 plot_energy_volume_curve ,
223+ name = "plot_energy_volume_curve" ,
222224 register_pickle_by_value = True ,
223225 )
224226
@@ -248,7 +250,10 @@ def construct_wg_qe(
248250
249251 # ? relax or SCF, or general? -> Should be relax
250252 relax_get_dict_task = wg .add_task (
251- task .pythonjob (outputs = ["dict" ])(get_dict ),
253+ task .pythonjob (
254+ # outputs=["structure", "calculation", "kpts", "pseudopotentials", "smearing"]
255+ # outputs=["dict"]
256+ )(get_dict ),
252257 name = "relax_get_dict" ,
253258 register_pickle_by_value = True ,
254259 )
@@ -262,7 +267,7 @@ def construct_wg_qe(
262267 del pickle_node .TaskCls
263268
264269 pickle_kpts_task = wg .add_task (
265- task .pythonjob (outputs = ["kpts" ])(pickle_node ), name = "kpts_task" , value = [3 , 3 , 3 ]
270+ task .pythonjob (outputs = ["kpts" ])(pickle_node ), name = "kpts_task" , value = [1 , 1 , 1 ] # FIXME: Back to [ 3, 3, 3]
266271 )
267272 del pickle_node .TaskCls
268273
@@ -287,20 +292,28 @@ def construct_wg_qe(
287292
288293 strain_dir_tasks , scf_get_dict_tasks = [], []
289294 for i , strain in enumerate (strain_lst ):
290-
291295 strain_dir = f"strain_{ i } "
292296
293297 strain_dir_task = wg .add_task (
294298 task .pythonjob ()(pickle_node ),
295- name = strain_dir ,
299+ name = f"pickle_ { strain_dir } _dir" ,
296300 value = strain_dir ,
297301 register_pickle_by_value = True ,
298302 )
299303 del pickle_node .TaskCls
300304 strain_dir_tasks .append (strain_dir_task )
301305
302306 scf_get_dict_task = wg .add_task (
303- task .pythonjob (outputs = ["dict" ])(get_dict ),
307+ task .pythonjob (
308+ # outputs=[
309+ # "dict"
310+ # # "structure",
311+ # # "calculation",
312+ # # "kpts",
313+ # # "pseudopotentials",
314+ # # "smearing",
315+ # ]
316+ )(get_dict ),
304317 name = f"get_dict_{ i } " ,
305318 register_pickle_by_value = True ,
306319 )
@@ -330,55 +343,64 @@ def construct_wg_qe(
330343 del get_list .TaskCls
331344
332345 # Add remaining links
333- wg .add_link (pickle_element_task .outputs .element , get_bulk_structure_task .inputs .element )
346+ wg .add_link (
347+ pickle_element_task .outputs .element , get_bulk_structure_task .inputs .element
348+ )
334349 wg .add_link (pickle_a_task .outputs .a , get_bulk_structure_task .inputs .a )
335350 wg .add_link (pickle_cubic_task .outputs .cubic , get_bulk_structure_task .inputs .cubic )
336351
337352 # `.set` rather than `.add_link`, as get_dict takes `**kwargs` as input
338353 relax_get_dict_task .set (
339354 {
340- "structure" : get_bulk_structure_task .outputs .bulk_structure ,
355+ "structure" : get_bulk_structure_task .outputs .structure ,
341356 "calculation" : pickle_calc_type_relax_task .outputs .calc ,
342357 "kpts" : pickle_kpts_task .outputs .kpts ,
343- "pseudopotential " : pickle_pp_task .outputs .pp ,
358+ "pseudopotentials " : pickle_pp_task .outputs .pp ,
344359 "smearing" : pickle_smearing_task .outputs .smearing ,
345360 }
346361 )
347362
348- wg .add_link (relax_get_dict_task .outputs .dict , relax_task .inputs .input_dict )
363+ wg .add_link (relax_get_dict_task .outputs .result , relax_task .inputs .input_dict )
349364 wg .add_link (
350365 pickle_relax_workdir_task .outputs .relax_workdir ,
351366 relax_task .inputs .working_directory ,
352367 )
353368
354369 wg .add_link (relax_task .outputs .structure , generate_structures_task .inputs .structure )
355- wg .add_link (strain_lst_task .outputs .strain_lst , generate_structures_task .inputs .strain_lst )
370+ wg .add_link (
371+ strain_lst_task .outputs .strain_lst , generate_structures_task .inputs .strain_lst
372+ )
356373
357- counter = 0
358- for scf_get_dict_task , scf_qe_task , strain_dir_task in list (
359- zip (scf_get_dict_tasks , scf_qe_tasks , strain_dir_tasks )
374+ for i , (scf_get_dict_task , scf_qe_task , strain_dir_task ) in enumerate (
375+ list (zip (scf_get_dict_tasks , scf_qe_tasks , strain_dir_tasks ))
360376 ):
361- # print(scf_get_dict_task, scf_qe_task, strain_dir_task)
362377 scf_get_dict_task .set (
363378 {
364379 "structure" : generate_structures_task .outputs [f"s_{ i } " ],
365380 "calculation" : pickle_calc_type_scf_task .outputs .calc ,
366381 "kpts" : pickle_kpts_task .outputs .kpts ,
367- "pseudopotential " : pickle_pp_task .outputs .pp ,
382+ "pseudopotentials " : pickle_pp_task .outputs .pp ,
368383 "smearing" : pickle_smearing_task .outputs .smearing ,
369384 }
370385 )
371386
372- wg .add_link (scf_get_dict_task .outputs .dict , scf_qe_task .inputs .input_dict )
373- wg .add_link (strain_dir_task .outputs .result , scf_qe_task .inputs .working_directory )
387+ # import ipdb; ipdb.set_trace()
388+ wg .add_link (scf_get_dict_task .outputs .result , scf_qe_task .inputs .input_dict )
389+ wg .add_link (
390+ strain_dir_task .outputs .result , scf_qe_task .inputs .working_directory
391+ )
374392
375393 # collect energy and volume
376- wg .add_link (scf_qe_task .outputs .energy , wg . tasks . get_energies .inputs .kwargs )
377- wg .add_link (scf_qe_task .outputs .volume , wg . tasks . get_volumes .inputs .kwargs )
394+ wg .add_link (scf_qe_task .outputs .energy , get_energies_task .inputs .kwargs )
395+ wg .add_link (scf_qe_task .outputs .volume , get_volumes_task .inputs .kwargs )
378396
379- counter += 1
380-
381- wg .add_link (get_volumes_task .outputs .volumes , plot_energy_volume_curve_task .inputs .volume_lst )
382- wg .add_link (get_energies_task .outputs .energies , plot_energy_volume_curve_task .inputs .energy_lst )
397+ wg .add_link (
398+ get_volumes_task .outputs .volumes ,
399+ plot_energy_volume_curve_task .inputs .volume_lst ,
400+ )
401+ wg .add_link (
402+ get_energies_task .outputs .energies ,
403+ plot_energy_volume_curve_task .inputs .energy_lst ,
404+ )
383405
384406 return wg
0 commit comments