Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,4 @@ jobflow_to_aiida_qe.json
aiida_to_jobflow_qe.json
pyiron_base_to_aiida_simple.json
pyiron_base_to_jobflow_qe.json

**/*.h5
2 changes: 1 addition & 1 deletion binder/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies:
- pyiron_base =0.12.0
- pyiron_workflow =0.13.0
- pygraphviz =1.14
- aiida-workgraph =0.5.2
- aiida-workgraph =0.7.4
- plumpy =0.25.0
- conda_subprocess =0.0.7
- networkx =3.5
Expand Down
5,800 changes: 5,799 additions & 1 deletion example_workflows/quantum_espresso/aiida.ipynb

Large diffs are not rendered by default.

5,748 changes: 5,747 additions & 1 deletion example_workflows/quantum_espresso/jobflow.ipynb

Large diffs are not rendered by default.

4,994 changes: 4,993 additions & 1 deletion example_workflows/quantum_espresso/pyiron_base.ipynb

Large diffs are not rendered by default.

6,016 changes: 6,015 additions & 1 deletion example_workflows/quantum_espresso/pyiron_workflow.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [

[project.optional-dependencies]
aiida = [
"aiida-workgraph>=0.5.1,<=0.5.2",
"aiida-workgraph>=0.5.1,<=0.7.4",
]
jobflow = [
"jobflow>=0.1.18,<=0.2.0",
Expand Down
54 changes: 30 additions & 24 deletions src/python_workflow_definition/aiida.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@

from aiida import orm
from aiida_pythonjob.data.serializer import general_serializer
from aiida_workgraph import WorkGraph, task
from aiida_workgraph import WorkGraph, task, Task, namespace
from aiida_workgraph.socket import TaskSocketNamespace

from dataclasses import replace
from node_graph.node_spec import SchemaSource
from python_workflow_definition.models import PythonWorkflowDefinitionWorkflow
from python_workflow_definition.shared import (
convert_nodes_list_to_dict,
update_node_names,
remove_result,
set_result_node,
NODES_LABEL,
EDGES_LABEL,
Expand All @@ -24,11 +24,8 @@


def load_workflow_json(file_name: str) -> WorkGraph:
data = remove_result(
workflow_dict=PythonWorkflowDefinitionWorkflow.load_json_file(
file_name=file_name
)
)

data = PythonWorkflowDefinitionWorkflow.load_json_file(file_name=file_name)

wg = WorkGraph()
task_name_mapping = {}
Expand All @@ -40,24 +37,28 @@ def load_workflow_json(file_name: str) -> WorkGraph:
p, m = identifier.rsplit(".", 1)
mod = import_module(p)
func = getattr(mod, m)
wg.add_task(func)
# Remove the default result output, because we will add the outputs later from the data in the link
del wg.tasks[-1].outputs["result"]
task_name_mapping[id] = wg.tasks[-1]
decorated_func = task(outputs=namespace())(func)
new_task = wg.add_task(decorated_func)
new_task.spec = replace(new_task.spec, schema_source=SchemaSource.EMBEDDED)
task_name_mapping[id] = new_task
else:
# data task
data_node = general_serializer(identifier)
task_name_mapping[id] = data_node

# add links
for link in data[EDGES_LABEL]:
# TODO: continue here
to_task = task_name_mapping[str(link[TARGET_LABEL])]
# if the input is not exit, it means we pass the data into to the kwargs
# in this case, we add the input socket
if link[TARGET_PORT_LABEL] not in to_task.inputs:
to_socket = to_task.add_input("workgraph.any", name=link[TARGET_PORT_LABEL])
else:
to_socket = to_task.inputs[link[TARGET_PORT_LABEL]]
if isinstance(to_task, Task):
if link[TARGET_PORT_LABEL] not in to_task.inputs:
to_socket = to_task.add_input_spec(
"workgraph.any", name=link[TARGET_PORT_LABEL]
)
else:
to_socket = to_task.inputs[link[TARGET_PORT_LABEL]]
from_task = task_name_mapping[str(link[SOURCE_LABEL])]
if isinstance(from_task, orm.Data):
to_socket.value = from_task
Expand All @@ -69,16 +70,14 @@ def load_workflow_json(file_name: str) -> WorkGraph:
# we add it here, and assume the output exit
if link[SOURCE_PORT_LABEL] not in from_task.outputs:
# if str(link["sourcePort"]) not in from_task.outputs:
from_socket = from_task.add_output(
from_socket = from_task.add_output_spec(
"workgraph.any",
name=link[SOURCE_PORT_LABEL],
# name=str(link["sourcePort"]),
metadata={"is_function_output": True},
)
else:
from_socket = from_task.outputs[link[SOURCE_PORT_LABEL]]

wg.add_link(from_socket, to_socket)
if isinstance(to_task, Task):
wg.add_link(from_socket, to_socket)
except Exception as e:
traceback.print_exc()
print("Failed to link", link, "with error:", e)
Expand All @@ -90,12 +89,18 @@ def write_workflow_json(wg: WorkGraph, file_name: str) -> dict:
node_name_mapping = {}
data_node_name_mapping = {}
i = 0
GRAPH_LEVEL_NAMES = ["graph_inputs", "graph_outputs", "graph_ctx"]

for node in wg.tasks:
executor = node.get_executor()

if node.name in GRAPH_LEVEL_NAMES:
continue

node_name_mapping[node.name] = i

callable_name = executor["callable_name"]
callable_name = f"{executor['module_path']}.{callable_name}"
executor = node.get_executor()
callable_name = f"{executor.module_path}.{executor.callable_name}"

data[NODES_LABEL].append({"id": i, "type": "function", "value": callable_name})
i += 1

Expand Down Expand Up @@ -141,6 +146,7 @@ def write_workflow_json(wg: WorkGraph, file_name: str) -> dict:
SOURCE_PORT_LABEL: None,
}
)

data[VERSION_LABEL] = VERSION_NUMBER
PythonWorkflowDefinitionWorkflow(
**set_result_node(workflow_dict=update_node_names(workflow_dict=data))
Expand Down
Loading