@@ -117,19 +117,19 @@ def write_workflow_json(wg, file_name):
117117
118118 i += 1
119119
120- # for link in wgdata["links"]:
121- # if (
122- # wgdata["tasks"][link["from_node"]]["executor"]["callable_name"]
123- # == "pickle_node"
124- # ):
125- # link["from_socket"] = None
126- # link["source"] = node_name_mapping[link["from_node"]]
127- # del link["from_node"]
128- # link["target"] = node_name_mapping[link["to_node"]]
129- # del link["to_node"]
130- # link["sourceHandle"] = link.pop("from_socket")
131- # link["targetHandle"] = link.pop("to_socket")
132- # data["edges"].append(link)
120+ for link in wgdata ["links" ]:
121+ if (
122+ wgdata ["tasks" ][link ["from_node" ]]["executor" ]["callable_name" ]
123+ == "pickle_node"
124+ ):
125+ link ["from_socket" ] = None
126+ link ["source" ] = node_name_mapping [link ["from_node" ]]
127+ del link ["from_node" ]
128+ link ["target" ] = node_name_mapping [link ["to_node" ]]
129+ del link ["to_node" ]
130+ link ["sourceHandle" ] = link .pop ("from_socket" )
131+ link ["targetHandle" ] = link .pop ("to_socket" )
132+ data ["edges" ].append (link )
133133
134134 with open (file_name , "w" ) as f :
135135 # json.dump({"nodes": data[], "edges": edges_new_lst}, f)
@@ -179,6 +179,9 @@ def construct_wg_qe(
179179 plot_energy_volume_curve ,
180180 strain_lst ,
181181):
182+
183+ # NOTE: `get_dict` is `get_input_dict`, to compile the input values for the calc tasks
184+ # NOTE: `add_link` must be from outputs to inputs
182185 wg = WorkGraph ()
183186
184187 get_bulk_structure_task = wg .add_task (
@@ -190,68 +193,32 @@ def construct_wg_qe(
190193 relax_task = wg .add_task (
191194 calculate_qe ,
192195 name = "mini" ,
193- # input_dict=wg.tasks.get_dict.outputs.result,
194- # working_directory=wg.tasks.relax_workdir.outputs.result,
195196 register_pickle_by_value = True ,
196197 )
197198
198199 generate_structures_task = wg .add_task (
199200 generate_structures ,
200201 name = "generate_structures" ,
201- # structure=wg.tasks.mini.outputs.structure,
202- # strain_lst=wg.tasks.strain_lst.outputs.result,
203202 register_pickle_by_value = True ,
204203 )
205204
206205 # here we add the structure outputs based on the number of strains
207206 del wg .tasks .generate_structures .outputs ["result" ]
208207
209- qe_tasks = []
208+ scf_qe_tasks = []
210209 for i , strain in enumerate (strain_lst ):
211210
212211 generate_structures_task .add_output ("workgraph.any" , f"s_{ i } " )
213212
214- # Possible remove `get_dict_task` here to have it later
215- # get_dict_task = wg.add_task(
216- # get_dict,
217- # name=f"get_dict_{i}",
218- # # calculation=wg.tasks.calculation_scf.outputs.result,
219- # # structure=wg.tasks.generate_structures.outputs[f"s_{i}"],
220- # # kpts=wg.tasks.kpts.outputs.result,
221- # # pseudopotentials=wg.tasks.pseudopotentials.outputs.result,
222- # # smearing=wg.tasks.smearing.outputs.result,
223- # register_pickle_by_value=True,
224- # )
225-
226- strain_dir = f"strain_{ i } "
227-
228- # strain_dir_task = wg.add_task(
229- # task.pythonjob(outputs=[strain_dir])(pickle_node),
230- # name=strain_dir,
231- # value=strain_dir,
232- # register_pickle_by_value=True,
233- # )
234- # del pickle_node.TaskCls
235-
236- # import ipdb; ipdb.set_trace()
237- qe_task = wg .add_task (
213+ scf_qe_task = wg .add_task (
238214 calculate_qe ,
239215 name = f"qe_{ i } " ,
240- # input_dict=get_dict_task.outputs.result,
241- # working_directory=strain_dir_task.outputs[strain_dir],
242216 register_pickle_by_value = True ,
243217 )
244- qe_tasks .append (qe_task )
245-
246- # collect energy and volume
247- # TODO: Maybe put this outside, in a separate for-loop to again try to fix the order
248- # wg.add_link(qe_task.outputs.energy, wg.tasks.get_energies.inputs.kwargs)
249- # wg.add_link(qe_task.outputs.volume, wg.tasks.get_volumes.inputs.kwargs)
218+ scf_qe_tasks .append (scf_qe_task )
250219
251220 plot_energy_volume_curve_task = wg .add_task (
252221 plot_energy_volume_curve ,
253- # volume_lst=wg.tasks.get_volumes.outputs.result,
254- # energy_lst=wg.tasks.get_energies.outputs.result,
255222 register_pickle_by_value = True ,
256223 )
257224
@@ -262,7 +229,9 @@ def construct_wg_qe(
262229 )
263230 del pickle_node .TaskCls
264231
265- pickle_a_task = wg .add_task (task .pythonjob (outputs = ["a" ])(pickle_node ), name = "pickle_a" , value = 4.05 )
232+ pickle_a_task = wg .add_task (
233+ task .pythonjob (outputs = ["a" ])(pickle_node ), name = "pickle_a" , value = 4.05
234+ )
266235 del pickle_node .TaskCls
267236
268237 pickle_cubic_task = wg .add_task (
@@ -277,34 +246,36 @@ def construct_wg_qe(
277246 )
278247 del pickle_node .TaskCls
279248
280- # ? relax or SCF, or general?
281- get_dict_task = wg .add_task (
282- get_dict ,
283- name = "get_dict" ,
284- structure = wg .tasks .bulk .outputs .result ,
285- # calculation=wg.tasks.calculation.outputs.result,
286- # kpts=wg.tasks.kpts.outputs.result,
287- # pseudopotentials=wg.tasks.pseudopotentials.outputs.result,
288- # smearing=wg.tasks.smearing.outputs.result,
249+ # ? relax or SCF, or general? -> Should be relax
250+ relax_get_dict_task = wg .add_task (
251+ task .pythonjob (outputs = ["dict" ])(get_dict ),
252+ name = "relax_get_dict" ,
289253 register_pickle_by_value = True ,
290254 )
255+ del get_dict .TaskCls
291256
292- pp_task = wg .add_task (
293- task .pythonjob (pickle_node ),
257+ pickle_pp_task = wg .add_task (
258+ task .pythonjob (outputs = [ "pp" ])( pickle_node ),
294259 name = "pseudopotentials" ,
295260 value = {"Al" : "Al.pbe-n-kjpaw_psl.1.0.0.UPF" },
296261 )
297262 del pickle_node .TaskCls
298263
299- kpts_task = wg .add_task (
264+ pickle_kpts_task = wg .add_task (
300265 task .pythonjob (outputs = ["kpts" ])(pickle_node ), name = "kpts_task" , value = [3 , 3 , 3 ]
301266 )
302267 del pickle_node .TaskCls
303268
304- vc_relax_task = wg .add_task (task .pythonjob ()(pickle_node ), name = "calculation" , value = "vc-relax" )
269+ pickle_calc_type_relax_task = wg .add_task (
270+ task .pythonjob (outputs = ["calc" ])(pickle_node ),
271+ name = "calc_type_relax" ,
272+ value = "vc-relax" ,
273+ )
305274 del pickle_node .TaskCls
306-
307- smearing_task = wg .add_task (task .pythonjob ()(pickle_node ), name = "smearing" , value = 0.02 )
275+
276+ pickle_smearing_task = wg .add_task (
277+ task .pythonjob (outputs = ["smearing" ])(pickle_node ), name = "smearing" , value = 0.02
278+ )
308279 del pickle_node .TaskCls
309280
310281 strain_lst_task = wg .add_task (
@@ -314,41 +285,36 @@ def construct_wg_qe(
314285 )
315286 del pickle_node .TaskCls
316287
317- strain_dir_tasks = []
288+ strain_dir_tasks , scf_get_dict_tasks = [], []
318289 for i , strain in enumerate (strain_lst ):
319290
320291 strain_dir = f"strain_{ i } "
321292
322293 strain_dir_task = wg .add_task (
323- task .pythonjob (outputs = [ strain_dir ] )(pickle_node ),
294+ task .pythonjob ()(pickle_node ),
324295 name = strain_dir ,
325296 value = strain_dir ,
326297 register_pickle_by_value = True ,
327298 )
328299 del pickle_node .TaskCls
329300 strain_dir_tasks .append (strain_dir_task )
330301
331- # Possible remove `get_dict_task` here to have it later
332- get_dict_task = wg .add_task (
333- get_dict ,
302+ scf_get_dict_task = wg .add_task (
303+ task .pythonjob (outputs = ["dict" ])(get_dict ),
334304 name = f"get_dict_{ i } " ,
335- # calculation=wg.tasks.calculation_scf.outputs.result,
336- # structure=wg.tasks.generate_structures.outputs[f"s_{i}"],
337- # kpts=wg.tasks.kpts.outputs.result,
338- # pseudopotentials=wg.tasks.pseudopotentials.outputs.result,
339- # smearing=wg.tasks.smearing.outputs.result,
340305 register_pickle_by_value = True ,
341306 )
307+ del get_dict .TaskCls
308+ scf_get_dict_tasks .append (scf_get_dict_task )
342309
343310 if i == 0 :
344- scf_task = wg .add_task (task .pythonjob (pickle_node ), name = "calculation_scf" , value = "scf" )
311+ pickle_calc_type_scf_task = wg .add_task (
312+ task .pythonjob (outputs = ["calc" ])(pickle_node ),
313+ name = "calc_type_scf" ,
314+ value = "scf" ,
315+ )
345316 del pickle_node .TaskCls
346317
347- # collect energy and volume
348- # TODO: Maybe put this outside, in a separate for-loop to again try to fix the order
349- # wg.add_link(qe_task.outputs.energy, wg.tasks.get_energies.inputs.kwargs)
350- # wg.add_link(qe_task.outputs.volume, wg.tasks.get_volumes.inputs.kwargs)
351-
352318 get_energies_task = wg .add_task (
353319 task .pythonjob (outputs = ["energies" ])(get_list ),
354320 name = "get_energies" ,
@@ -364,9 +330,55 @@ def construct_wg_qe(
364330 del get_list .TaskCls
365331
366332 # Add remaining links
367- wg .add_link (get_bulk_structure_task .inputs .element , element = pickle_element_task .outputs .element )
368- # ,
369- # a=wg.tasks.pickle_a.outputs.a,
370- # cubic=wg.tasks.pickle_cubic.outputs.result,
333+ wg .add_link (pickle_element_task .outputs .element , get_bulk_structure_task .inputs .element )
334+ wg .add_link (pickle_a_task .outputs .a , get_bulk_structure_task .inputs .a )
335+ wg .add_link (pickle_cubic_task .outputs .cubic , get_bulk_structure_task .inputs .cubic )
336+
337+ # `.set` rather than `.add_link`, as get_dict takes `**kwargs` as input
338+ relax_get_dict_task .set (
339+ {
340+ "structure" : get_bulk_structure_task .outputs .bulk_structure ,
341+ "calculation" : pickle_calc_type_relax_task .outputs .calc ,
342+ "kpts" : pickle_kpts_task .outputs .kpts ,
343+ "pseudopotential" : pickle_pp_task .outputs .pp ,
344+ "smearing" : pickle_smearing_task .outputs .smearing ,
345+ }
346+ )
347+
348+ wg .add_link (relax_get_dict_task .outputs .dict , relax_task .inputs .input_dict )
349+ wg .add_link (
350+ pickle_relax_workdir_task .outputs .relax_workdir ,
351+ relax_task .inputs .working_directory ,
352+ )
353+
354+ wg .add_link (relax_task .outputs .structure , generate_structures_task .inputs .structure )
355+ wg .add_link (strain_lst_task .outputs .strain_lst , generate_structures_task .inputs .strain_lst )
356+
357+ counter = 0
358+ for scf_get_dict_task , scf_qe_task , strain_dir_task in list (
359+ zip (scf_get_dict_tasks , scf_qe_tasks , strain_dir_tasks )
360+ ):
361+ # print(scf_get_dict_task, scf_qe_task, strain_dir_task)
362+ scf_get_dict_task .set (
363+ {
364+ "structure" : generate_structures_task .outputs [f"s_{ i } " ],
365+ "calculation" : pickle_calc_type_scf_task .outputs .calc ,
366+ "kpts" : pickle_kpts_task .outputs .kpts ,
367+ "pseudopotential" : pickle_pp_task .outputs .pp ,
368+ "smearing" : pickle_smearing_task .outputs .smearing ,
369+ }
370+ )
371+
372+ wg .add_link (scf_get_dict_task .outputs .dict , scf_qe_task .inputs .input_dict )
373+ wg .add_link (strain_dir_task .outputs .result , scf_qe_task .inputs .working_directory )
374+
375+ # collect energy and volume
376+ wg .add_link (scf_qe_task .outputs .energy , wg .tasks .get_energies .inputs .kwargs )
377+ wg .add_link (scf_qe_task .outputs .volume , wg .tasks .get_volumes .inputs .kwargs )
378+
379+ counter += 1
380+
381+ wg .add_link (get_volumes_task .outputs .volumes , plot_energy_volume_curve_task .inputs .volume_lst )
382+ wg .add_link (get_energies_task .outputs .energies , plot_energy_volume_curve_task .inputs .energy_lst )
371383
372384 return wg
0 commit comments