From 3c0cf818a41d081c7b14228295a612a2808a7c8b Mon Sep 17 00:00:00 2001 From: Julian Geiger Date: Fri, 2 May 2025 12:20:24 +0200 Subject: [PATCH] Modify write and load AiiDA methods to work without data nodes --- example_workflows/arithmetic/workflow.json | 17 +- .../src/python_workflow_definition/aiida.py | 235 ++++++++++-------- 2 files changed, 144 insertions(+), 108 deletions(-) diff --git a/example_workflows/arithmetic/workflow.json b/example_workflows/arithmetic/workflow.json index e68cfc7..2429bca 100644 --- a/example_workflows/arithmetic/workflow.json +++ b/example_workflows/arithmetic/workflow.json @@ -1,14 +1,13 @@ { "nodes": [ - {"id": 0, "function": "workflow.get_prod_and_div"}, - {"id": 1, "function": "workflow.get_sum"}, - {"id": 2, "value": 1}, - {"id": 3, "value": 2} + { "id": 0, "function": "workflow.get_prod_and_div" }, + { "id": 1, "function": "workflow.get_sum" } ], "edges": [ - {"target": 0, "targetPort": "x", "source": 2, "sourcePort": null}, - {"target": 0, "targetPort": "y", "source": 3, "sourcePort": null}, - {"target": 1, "targetPort": "x", "source": 0, "sourcePort": "prod"}, - {"target": 1, "targetPort": "y", "source": 0, "sourcePort": "div"} + { "target": 0, "targetPort": "x", "source": null, "sourcePort": null }, + { "target": 0, "targetPort": "y", "source": null, "sourcePort": null }, + { "target": 1, "targetPort": "x", "source": 0, "sourcePort": "prod" }, + { "target": 1, "targetPort": "y", "source": 0, "sourcePort": "div" }, + { "target": null, "targetPort": null, "source": 1, "sourcePort": "result" } ] -} +} \ No newline at end of file diff --git a/python_workflow_definition/src/python_workflow_definition/aiida.py b/python_workflow_definition/src/python_workflow_definition/aiida.py index c3fbf12..1054513 100644 --- a/python_workflow_definition/src/python_workflow_definition/aiida.py +++ b/python_workflow_definition/src/python_workflow_definition/aiida.py @@ -1,11 +1,6 @@ from importlib import import_module import json -import traceback -from aiida import orm -from aiida_pythonjob.data.serializer import general_serializer -from aiida_workgraph import WorkGraph, task -from aiida_workgraph.socket import TaskSocketNamespace from python_workflow_definition.shared import ( convert_nodes_list_to_dict, @@ -18,121 +13,163 @@ ) -def load_workflow_json(file_name: str) -> WorkGraph: +# -> WorkGraph +def load_workflow_json(file_name: str): + """ + Load a workflow JSON file and convert it to a WorkGraph. + + Args: + file_name: Path to the JSON workflow file + + Returns: + A populated WorkGraph object + """ with open(file_name) as f: data = json.load(f) + from aiida_workgraph import WorkGraph wg = WorkGraph() task_name_mapping = {} - for id, identifier in convert_nodes_list_to_dict( - nodes_list=data[NODES_LABEL] - ).items(): + # Create all tasks first + for task_id, identifier in convert_nodes_list_to_dict(data[NODES_LABEL]).items(): if isinstance(identifier, str) and "." in identifier: - p, m = identifier.rsplit(".", 1) - mod = import_module(p) - func = getattr(mod, m) + # Import the function dynamically + package_path, function_name = identifier.rsplit(".", 1) + module = import_module(package_path) + func = getattr(module, function_name) + + # Add task and prepare for linking wg.add_task(func) - # Remove the default result output, because we will add the outputs later from the data in the link - del wg.tasks[-1].outputs["result"] - task_name_mapping[id] = wg.tasks[-1] - else: - # data task - data_node = general_serializer(identifier) - task_name_mapping[id] = data_node - - # add links + current_task = wg.tasks[-1] + + # Remove default output as we'll add custom outputs later + del current_task.outputs["result"] + task_name_mapping[task_id] = current_task + + # Add all connections between tasks for link in data[EDGES_LABEL]: - to_task = task_name_mapping[str(link[TARGET_LABEL])] - # if the input is not exit, it means we pass the data into to the kwargs - # in this case, we add the input socket - if link[TARGET_PORT_LABEL] not in to_task.inputs: - to_socket = to_task.add_input("workgraph.any", name=link[TARGET_PORT_LABEL]) - else: - to_socket = to_task.inputs[link[TARGET_PORT_LABEL]] - from_task = task_name_mapping[str(link[SOURCE_LABEL])] - if isinstance(from_task, orm.Data): - to_socket.value = from_task - else: - try: - if link[SOURCE_PORT_LABEL] is None: - link[SOURCE_PORT_LABEL] = "result" - # because we are not define the outputs explicitly during the pythonjob creation - # we add it here, and assume the output exit - if link[SOURCE_PORT_LABEL] not in from_task.outputs: - # if str(link["sourcePort"]) not in from_task.outputs: - from_socket = from_task.add_output( - "workgraph.any", - name=link[SOURCE_PORT_LABEL], - # name=str(link["sourcePort"]), - metadata={"is_function_output": True}, - ) - else: - from_socket = from_task.outputs[link[SOURCE_PORT_LABEL]] - - wg.add_link(from_socket, to_socket) - except Exception as e: - traceback.print_exc() - print("Failed to link", link, "with error:", e) + source_id = link[SOURCE_LABEL] + target_id = link[TARGET_LABEL] + source_port = link[SOURCE_PORT_LABEL] + target_port = link[TARGET_PORT_LABEL] + + # Handle task-to-task connections + if source_id is not None and target_id is not None: + from_task = task_name_mapping[str(source_id)] + to_task = task_name_mapping[str(target_id)] + + # Create output socket on source task + from_socket = from_task.add_output("workgraph.any", name=source_port, metadata={"is_function_output": True}) + + # Create or get input socket on target task + if target_port not in to_task.inputs: + to_socket = to_task.add_input("workgraph.any", name=target_port) + else: + to_socket = to_task.inputs[target_port] + + # Connect the tasks + wg.add_link(from_socket, to_socket) + + # Handle dangling outputs (no target) + elif source_id is not None and target_id is None: + from_task = task_name_mapping[str(source_id)] + from_task.add_output("workgraph.any", name=source_port, metadata={"is_function_output": True}) + return wg -def write_workflow_json(wg: WorkGraph, file_name: str) -> dict: +def write_workflow_json(wg: "WorkGraph", file_name: str) -> dict: + """ + Write a WorkGraph to a JSON file. + + Args: + wg: WorkGraph object to serialize + file_name: Path where the JSON file will be written + + Returns: + Dictionary representation of the serialized workflow + """ data = {NODES_LABEL: [], EDGES_LABEL: []} node_name_mapping = {} - data_node_name_mapping = {} - i = 0 - for node in wg.tasks: - executor = node.get_executor() + + # Add all process nodes first + for i, node in enumerate(wg.tasks): + # Store node index for later reference node_name_mapping[node.name] = i - callable_name = executor["callable_name"] - callable_name = f"{executor['module_path']}.{callable_name}" + # Get executor info and build full callable name + executor = node.get_executor() + callable_name = f"{executor['module_path']}.{executor['callable_name']}" + + # Add node to data structure data[NODES_LABEL].append({"id": i, "function": callable_name}) - i += 1 + # Create edges from WorkGraph links for link in wg.links: link_data = link.to_dict() - # if the from socket is the default result, we set it to None - if link_data["from_socket"] == "result": - link_data["from_socket"] = None - link_data[TARGET_LABEL] = node_name_mapping[link_data.pop("to_node")] - link_data[TARGET_PORT_LABEL] = link_data.pop("to_socket") - link_data[SOURCE_LABEL] = node_name_mapping[link_data.pop("from_node")] - link_data[SOURCE_PORT_LABEL] = link_data.pop("from_socket") - data[EDGES_LABEL].append(link_data) - - for node in wg.tasks: - for input in node.inputs: - # assume namespace is not used as input - if isinstance(input, TaskSocketNamespace): - continue - if isinstance(input.value, orm.Data): - if input.value.uuid not in data_node_name_mapping: - if isinstance(input.value, orm.List): - raw_value = input.value.get_list() - elif isinstance(input.value, orm.Dict): - raw_value = input.value.get_dict() - # unknow reason, there is a key "node_type" in the dict - raw_value.pop("node_type", None) - else: - raw_value = input.value.value - data[NODES_LABEL].append({"id": i, "value": raw_value}) - input_node_name = i - data_node_name_mapping[input.value.uuid] = input_node_name - i += 1 - else: - input_node_name = data_node_name_mapping[input.value.uuid] - data[EDGES_LABEL].append( - { - TARGET_LABEL: node_name_mapping[node.name], - TARGET_PORT_LABEL: input._name, - SOURCE_LABEL: input_node_name, - SOURCE_PORT_LABEL: None, - } - ) + + # Handle default result case + from_socket = link_data.pop("from_socket") + if from_socket == "result": + from_socket = None + + # Convert to expected format + edge = { + SOURCE_LABEL: node_name_mapping[link_data.pop("from_node")], + SOURCE_PORT_LABEL: from_socket, + TARGET_LABEL: node_name_mapping[link_data.pop("to_node")], + TARGET_PORT_LABEL: link_data.pop("to_socket"), + } + + data[EDGES_LABEL].append(edge) + + # Define sockets to ignore when adding workflow I/O connections + wg_input_sockets = { + "_wait", + "metadata", + "function_data", + "process_label", + "function_inputs", + "deserializers", + "serializers", + } + wg_output_sockets = {"_wait", "_outputs", "exit_code"} + + # Handle first node's inputs (external inputs to workflow) + first_node = wg.tasks[0] + first_node_id = node_name_mapping[first_node.name] + + for input_name in first_node.inputs._sockets: + if input_name not in wg_input_sockets: + data[EDGES_LABEL].append( + { + SOURCE_LABEL: None, + SOURCE_PORT_LABEL: None, + TARGET_LABEL: first_node_id, + TARGET_PORT_LABEL: input_name, + } + ) + + # Handle last node's outputs (workflow outputs) + last_node = wg.tasks[-1] + last_node_id = node_name_mapping[last_node.name] + + for output_name in last_node.outputs._sockets: + if output_name not in wg_output_sockets: + data[EDGES_LABEL].append( + { + SOURCE_LABEL: last_node_id, + SOURCE_PORT_LABEL: output_name, + TARGET_LABEL: None, + TARGET_PORT_LABEL: None, + } + ) + + # Write the data to file + import ipdb; ipdb.set_trace() + with open(file_name, "w") as f: - # json.dump({"nodes": data[], "edges": edges_new_lst}, f) json.dump(data, f, indent=2) return data