diff --git a/examples/unroll_example.py b/examples/unroll_example.py index 8d7f375b..b7ae33ad 100755 --- a/examples/unroll_example.py +++ b/examples/unroll_example.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name, cyclic-import """ Script demonstrating how to unroll a QASM 3 program using pyqasm. diff --git a/src/pyqasm/modules/base.py b/src/pyqasm/modules/base.py index 3b21330a..2164705a 100644 --- a/src/pyqasm/modules/base.py +++ b/src/pyqasm/modules/base.py @@ -20,6 +20,7 @@ import functools from abc import ABC, abstractmethod +import re from collections import Counter from copy import deepcopy from typing import Optional @@ -36,6 +37,7 @@ from pyqasm.visitor import QasmVisitor, ScopeManager + def track_user_operation(func): """Decorator to track user operations on a QasmModule.""" @@ -761,3 +763,75 @@ def accept(self, visitor): Args: visitor (QasmVisitor): The visitor to accept """ + + + @abstractmethod + def merge( + self, + other: "QasmModule", + device_qubits: Optional[int] = None, + ) -> "QasmModule": + """Merge this module with another module. + + Implemented by concrete subclasses to avoid version mixing and + import-time cycles. Implementations should ensure both operands + are normalized to the same version prior to merging. + """ + + +def offset_statement_qubits(stmt: qasm3_ast.Statement, offset: int): + """Offset qubit indices for a given statement in-place by ``offset``. + + Handles gates, measurements, resets, and barriers (including slice forms). + """ + if isinstance(stmt, qasm3_ast.QuantumMeasurementStatement): + bit = stmt.measure.qubit + if isinstance(bit, qasm3_ast.IndexedIdentifier): + for group in bit.indices: + for ind in group: + ind.value += offset # type: ignore[attr-defined] + return + + if isinstance(stmt, qasm3_ast.QuantumGate): + for q in stmt.qubits: + for group in q.indices: + for ind in group: + ind.value += offset # type: ignore[attr-defined] + return + + if isinstance(stmt, qasm3_ast.QuantumReset): + q = stmt.qubits + if isinstance(q, qasm3_ast.IndexedIdentifier): + for group in q.indices: + for ind in group: + ind.value += offset # type: ignore[attr-defined] + return + + if isinstance(stmt, qasm3_ast.QuantumBarrier): + qubits = stmt.qubits + if len(qubits) == 0: + return + first = qubits[0] + if isinstance(first, qasm3_ast.IndexedIdentifier): + for group in first.indices: + for ind in group: + ind.value += offset # type: ignore[attr-defined] + elif isinstance(first, qasm3_ast.Identifier): + # Handle forms: __PYQASM_QUBITS__[:E], [S:], [S:E] + name = first.name + if name.startswith("__PYQASM_QUBITS__[") and name.endswith("]"): + slice_str = name[len("__PYQASM_QUBITS__"):] + # Parse slice forms [S:E], [:E], or [S:] + m = re.match(r"\[(?:(\d+)?:(\d+)?)\]", slice_str) + if m: + start_s, end_s = m.group(1), m.group(2) + if start_s is None and end_s is not None: + end_v = int(end_s) + offset + first.name = f"__PYQASM_QUBITS__[:{end_v}]" + elif start_s is not None and end_s is None: + start_v = int(start_s) + offset + first.name = f"__PYQASM_QUBITS__[{start_v}:]" + elif start_s is not None and end_s is not None: + start_v = int(start_s) + offset + end_v = int(end_s) + offset + first.name = f"__PYQASM_QUBITS__[{start_v}:{end_v}]" diff --git a/src/pyqasm/modules/qasm2.py b/src/pyqasm/modules/qasm2.py index f4b0de9d..808ed8a3 100644 --- a/src/pyqasm/modules/qasm2.py +++ b/src/pyqasm/modules/qasm2.py @@ -26,6 +26,7 @@ from pyqasm.exceptions import ValidationError from pyqasm.modules.base import QasmModule from pyqasm.modules.qasm3 import Qasm3Module +from pyqasm.modules.base import offset_statement_qubits class Qasm2Module(QasmModule): @@ -108,3 +109,75 @@ def accept(self, visitor): final_stmt_list = visitor.finalize(unrolled_stmt_list) self.unrolled_ast.statements = final_stmt_list + + def merge(self, other: QasmModule, device_qubits: int | None = None) -> QasmModule: + """Merge two modules and return a QASM2 result without mixing versions. + + - If ``other`` is QASM3, it is merged into this module's semantics, and + any standard gate includes are mapped to ``qelib1.inc``. + - The merged program keeps version "2.0" and prints as QASM2. + """ + if not isinstance(other, QasmModule): + raise TypeError(f"Expected QasmModule instance, got {type(other).__name__}") + + left_mod = self.copy() + right_mod = other.copy() + + # Unroll with qubit consolidation so both sides use __PYQASM_QUBITS__ + unroll_kwargs: dict[str, object] = {"consolidate_qubits": True} + if device_qubits is not None: + unroll_kwargs["device_qubits"] = device_qubits + + left_mod.unroll(**unroll_kwargs) + right_mod.unroll(**unroll_kwargs) + + left_qubits = left_mod.num_qubits + total_qubits = left_qubits + right_mod.num_qubits + + merged_program = Program(statements=[], version="2.0") + + # Unique includes first; map stdgates.inc -> qelib1.inc for QASM2 + include_names: list[str] = [] + for module in (left_mod, right_mod): + for stmt in module.unrolled_ast.statements: + if isinstance(stmt, Include): + fname = stmt.filename + if fname == "stdgates.inc": + fname = "qelib1.inc" + if fname not in include_names: + include_names.append(fname) + for name in include_names: + merged_program.statements.append(Include(filename=name)) + + # Consolidated qubit declaration (converted to qreg on print) + merged_program.statements.append( + qasm3_ast.QubitDeclaration( + size=qasm3_ast.IntegerLiteral(value=total_qubits), + qubit=qasm3_ast.Identifier(name="__PYQASM_QUBITS__"), + ) + ) + + # Append left ops (skip decls and includes) + for stmt in left_mod.unrolled_ast.statements: + if isinstance(stmt, (qasm3_ast.QubitDeclaration, Include)): + continue + merged_program.statements.append(deepcopy(stmt)) + + # Append right ops with index offset + for stmt in right_mod.unrolled_ast.statements: + if isinstance(stmt, (qasm3_ast.QubitDeclaration, Include)): + continue + stmt_copy = deepcopy(stmt) + offset_statement_qubits(stmt_copy, left_qubits) + merged_program.statements.append(stmt_copy) + + merged_module = Qasm2Module( + name=f"{left_mod.name}_merged_{right_mod.name}", + program=merged_program, + ) + merged_module.unrolled_ast = Program(statements=list(merged_program.statements), version="2.0") + merged_module._external_gates = list({*left_mod._external_gates, *right_mod._external_gates}) + merged_module._user_operations = list(left_mod.history) + list(right_mod.history) + merged_module._user_operations.append(f"merge(other={right_mod.name})") + merged_module.validate() + return merged_module diff --git a/src/pyqasm/modules/qasm3.py b/src/pyqasm/modules/qasm3.py index 8ed08d51..f8eaa971 100644 --- a/src/pyqasm/modules/qasm3.py +++ b/src/pyqasm/modules/qasm3.py @@ -16,10 +16,11 @@ Defines a module for handling OpenQASM 3.0 programs. """ +import openqasm3.ast as qasm3_ast from openqasm3.ast import Program from openqasm3.printer import dumps -from pyqasm.modules.base import QasmModule +from pyqasm.modules.base import QasmModule, offset_statement_qubits class Qasm3Module(QasmModule): @@ -52,3 +53,70 @@ def accept(self, visitor): final_stmt_list = visitor.finalize(unrolled_stmt_list) self._unrolled_ast.statements = final_stmt_list + + def merge(self, other: QasmModule, device_qubits: int | None = None) -> QasmModule: + """Merge two modules as OpenQASM 3.0 without mixing versions. + + If ``other`` is QASM2, it will be converted to QASM3 before merging. + The merged program keeps version "3.0". + """ + if not isinstance(other, QasmModule): + raise TypeError(f"Expected QasmModule instance, got {type(other).__name__}") + + # Convert right to QASM3 if it supports conversion; otherwise copy + convert = getattr(other, "to_qasm3", None) + right_mod = convert(as_str=False) if callable(convert) else other.copy() # type: ignore[assignment] + + left_mod = self.copy() + + # Unroll with consolidation so both use __PYQASM_QUBITS__ + unroll_kwargs: dict[str, object] = {"consolidate_qubits": True} + if device_qubits is not None: + unroll_kwargs["device_qubits"] = device_qubits + + left_mod.unroll(**unroll_kwargs) + right_mod.unroll(**unroll_kwargs) + + left_qubits = left_mod.num_qubits + total_qubits = left_qubits + right_mod.num_qubits + + merged_program = Program(statements=[], version="3.0") + + # Unique includes first + include_names: list[str] = [] + for module in (left_mod, right_mod): + for stmt in module.unrolled_ast.statements: + if isinstance(stmt, qasm3_ast.Include) and stmt.filename not in include_names: + include_names.append(stmt.filename) + for name in include_names: + merged_program.statements.append(qasm3_ast.Include(filename=name)) + + # Consolidated qubit declaration + merged_program.statements.append( + qasm3_ast.QubitDeclaration( + size=qasm3_ast.IntegerLiteral(value=total_qubits), + qubit=qasm3_ast.Identifier(name="__PYQASM_QUBITS__"), + ) + ) + + # Append left ops + for stmt in left_mod.unrolled_ast.statements: + if isinstance(stmt, (qasm3_ast.QubitDeclaration, qasm3_ast.Include)): + continue + merged_program.statements.append(stmt) + + # Append right ops with index offset + for stmt in right_mod.unrolled_ast.statements: + if isinstance(stmt, (qasm3_ast.QubitDeclaration, qasm3_ast.Include)): + continue + # right_mod is a copy, so it's safe to modify statements in place + offset_statement_qubits(stmt, left_qubits) + merged_program.statements.append(stmt) + + merged_module = Qasm3Module(name=f"{left_mod.name}_merged_{right_mod.name}", program=merged_program) + merged_module.unrolled_ast = Program(statements=list(merged_program.statements), version="3.0") + merged_module._external_gates = list({*left_mod._external_gates, *right_mod._external_gates}) + merged_module._user_operations = list(left_mod.history) + list(right_mod.history) + merged_module._user_operations.append(f"merge(other={right_mod.name})") + merged_module.validate() + return merged_module diff --git a/tests/qasm3/test_merge.py b/tests/qasm3/test_merge.py new file mode 100644 index 00000000..e99f53b4 --- /dev/null +++ b/tests/qasm3/test_merge.py @@ -0,0 +1,118 @@ +# Copyright 2025 qBraid +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for QasmModule.merge(). +""" + +from pyqasm.entrypoint import loads +from pyqasm.modules import QasmModule + + +def _qasm3(qasm: str) -> QasmModule: + return loads(qasm) + + +def test_merge_basic_gates_and_offsets(): + qasm_a = ( + "OPENQASM 3.0;\n" + "include \"stdgates.inc\";\n" + "qubit[2] q;\n" + "x q[0];\n" + "cx q[0], q[1];\n" + ) + qasm_b = ( + "OPENQASM 3.0;\n" + "include \"stdgates.inc\";\n" + "qubit[3] r;\n" + "h r[0];\n" + "cx r[1], r[2];\n" + ) + + mod_a = _qasm3(qasm_a) + mod_b = _qasm3(qasm_b) + + merged = mod_a.merge(mod_b) + + # Unrolled representation should have a single consolidated qubit declaration of size 5 + text = str(merged) + assert "qubit[5] __PYQASM_QUBITS__;" in text + + lines = [l.strip() for l in text.splitlines() if l.strip()] + # Keep only gate lines for comparison; skip version/includes/declarations + gate_lines = [ + l + for l in lines + if l[0].isalpha() + and not l.startswith("include") + and not l.startswith("OPENQASM") + and not l.startswith("qubit") + ] + assert gate_lines[0].startswith("x __PYQASM_QUBITS__[0]") + assert gate_lines[1].startswith("cx __PYQASM_QUBITS__[0], __PYQASM_QUBITS__[1]") + assert any(l.startswith("h __PYQASM_QUBITS__[2]") for l in gate_lines) + assert any(l.startswith("cx __PYQASM_QUBITS__[3], __PYQASM_QUBITS__[4]") for l in gate_lines) + + +def test_merge_with_measurements_and_barriers(): + # Module A: 1 qubit + classical 1; has barrier and measure + qasm_a = ( + "OPENQASM 3.0;\n" + "include \"stdgates.inc\";\n" + "qubit[1] qa; bit[1] ca;\n" + "h qa[0];\n" + "barrier qa;\n" + "ca[0] = measure qa[0];\n" + ) + # Module B: 2 qubits + classical 2 + qasm_b = ( + "OPENQASM 3.0;\n" + "include \"stdgates.inc\";\n" + "qubit[2] qb; bit[2] cb;\n" + "x qb[1];\n" + "cb[1] = measure qb[1];\n" + ) + + mod_a = _qasm3(qasm_a) + mod_b = _qasm3(qasm_b) + + merged = mod_a.merge(mod_b) + merged_text = str(merged) + + assert "qubit[3] __PYQASM_QUBITS__;" in merged_text + assert "measure __PYQASM_QUBITS__[2];" in merged_text + assert "barrier __PYQASM_QUBITS__" in merged_text + + +def test_merge_qasm2_with_qasm3(): + qasm2 = ( + "OPENQASM 2.0;\n" + "include \"qelib1.inc\";\n" + "qreg q[1];\n" + "h q[0];\n" + ) + qasm3 = ( + "OPENQASM 3.0;\n" + "include \"stdgates.inc\";\n" + "qubit[2] r;\n" + "x r[0];\n" + ) + + mod2 = loads(qasm2) + mod3 = loads(qasm3) + + merged = mod2.merge(mod3) + text = str(merged) + assert "qubit[3] __PYQASM_QUBITS__;" in text + assert "x __PYQASM_QUBITS__[1];" in text