diff --git a/qbraid_qir/cirq/passes.py b/qbraid_qir/cirq/passes.py index 2592ebf..48c0e5a 100644 --- a/qbraid_qir/cirq/passes.py +++ b/qbraid_qir/cirq/passes.py @@ -13,14 +13,70 @@ """ import itertools -from typing import Iterable +from typing import Iterable, List, Sequence, Type, Union import cirq +from cirq.protocols.decompose_protocol import DecomposeResult from .exceptions import CirqConversionError from .opsets import map_cirq_op_to_pyqir_callable +class QirTargetGateSet(cirq.TwoQubitCompilationTargetGateset): + def __init__( + self, + *, + atol: float = 1e-8, + allow_partial_czs: bool = False, + additional_gates: Sequence[ + Union[Type["cirq.Gate"], "cirq.Gate", "cirq.GateFamily"] + ] = (), + preserve_moment_structure: bool = True, + ) -> None: + super().__init__( + cirq.IdentityGate, + cirq.HPowGate, + cirq.XPowGate, + cirq.YPowGate, + cirq.ZPowGate, + cirq.SWAP, + cirq.CNOT, + cirq.CZ, + cirq.TOFFOLI, + cirq.ResetChannel, + cirq.MeasurementGate, + cirq.PauliMeasurementGate, + *additional_gates, + name="QirTargetGateset", + preserve_moment_structure=preserve_moment_structure, + ) + self.allow_partial_czs = allow_partial_czs + self.atol = atol + + @property + def postprocess_transformers(self) -> List["cirq.TRANSFORMER"]: + return [] + + def _decompose_single_qubit_operation( + self, op: "cirq.Operation", moment_idx: int + ) -> DecomposeResult: + qubit = op.qubits[0] + mat = cirq.unitary(op) + for gate in cirq.single_qubit_matrix_to_gates(mat, self.atol): + yield gate(qubit) + + def _decompose_two_qubit_operation(self, op: "cirq.Operation", _) -> "cirq.OP_TREE": + if not cirq.has_unitary(op): + return NotImplemented + return cirq.two_qubit_matrix_to_cz_operations( + op.qubits[0], + op.qubits[1], + cirq.unitary(op), + allow_partial_czs=self.allow_partial_czs, + atol=self.atol, + ) + + def _decompose_gate_op(operation: cirq.Operation) -> Iterable[cirq.OP_TREE]: """Decomposes a single Cirq gate operation into a sequence of operations that are directly supported by PyQIR. @@ -36,12 +92,10 @@ def _decompose_gate_op(operation: cirq.Operation) -> Iterable[cirq.OP_TREE]: _ = map_cirq_op_to_pyqir_callable(operation) return [operation] except CirqConversionError: - pass - new_ops = cirq.decompose_once(operation, flatten=True, default=[operation]) - if len(new_ops) == 1 and new_ops[0] == operation: - raise CirqConversionError("Couldn't convert circuit to QIR gate set.") - return list(itertools.chain.from_iterable(map(_decompose_gate_op, new_ops))) - + new_ops = cirq.decompose_once(operation, flatten=True, default=[operation]) + if len(new_ops) == 1 and new_ops[0] == operation: + raise CirqConversionError("Couldn't convert circuit to QIR gate set.") + return list(itertools.chain.from_iterable(map(_decompose_gate_op, new_ops))) def _decompose_unsupported_gates(circuit: cirq.Circuit) -> cirq.Circuit: """ @@ -53,21 +107,9 @@ def _decompose_unsupported_gates(circuit: cirq.Circuit) -> cirq.Circuit: Returns: cirq.Circuit: A new circuit with unsupported gates decomposed. """ - new_circuit = cirq.Circuit() - for moment in circuit: - new_ops = [] - for operation in moment: - if isinstance(operation, cirq.GateOperation): - decomposed_ops = list(_decompose_gate_op(operation)) - new_ops.extend(decomposed_ops) - elif isinstance(operation, cirq.ClassicallyControlledOperation): - new_ops.append(operation) - else: - new_ops.append(operation) - - new_circuit.append(new_ops) - return new_circuit + circuit = cirq.optimize_for_target_gateset(circuit, gateset=QirTargetGateSet(), ignore_failures=True, max_num_passes=1) + return circuit def preprocess_circuit(circuit: cirq.Circuit) -> cirq.Circuit: """ diff --git a/qbraid_qir/cirq/visitor.py b/qbraid_qir/cirq/visitor.py index 37695aa..a22567f 100644 --- a/qbraid_qir/cirq/visitor.py +++ b/qbraid_qir/cirq/visitor.py @@ -15,6 +15,7 @@ import logging from abc import ABCMeta, abstractmethod +import numpy as np import cirq import pyqir import pyqir._native @@ -108,6 +109,13 @@ def handle_measurement(pyqir_func): for qubit, result in zip(qubits, results): self._measured_qubits[pyqir.qubit_id(qubit)] = True pyqir_func(self._builder, qubit, result) + + def get_rot_gate_angle(operation: cirq.Operation): + if isinstance(operation.gate, (cirq.ops.XPowGate, cirq.ops.YPowGate, cirq.ops.ZPowGate)): + angle = operation.gate.exponent * np.pi + else: + angle = operation.gate._rads + return angle # dealing with conditional gates if isinstance(operation, cirq.ClassicallyControlledOperation): @@ -121,9 +129,10 @@ def handle_measurement(pyqir_func): # pylint: disable=unnecessary-lambda-assignment if op_str in ["Rx", "Ry", "Rz"]: + angle = get_rot_gate_angle(operation._sub_operation) pyqir_func = lambda: temp_pyqir_func( self._builder, - operation._sub_operation.gate._rads, # type: ignore[union-attr] + angle, # type: ignore[union-attr] *qubits, ) else: @@ -144,11 +153,11 @@ def _branch(conds, pyqir_func): _branch(conditions, pyqir_func) else: pyqir_func, op_str = map_cirq_op_to_pyqir_callable(operation) - if op_str.startswith("measure"): handle_measurement(pyqir_func) - elif op_str in ["Rx", "Ry", "Rz"]: - pyqir_func(self._builder, operation.gate._rads, *qubits) # type: ignore[union-attr] + elif op_str in ["Rx", "Ry", "Rz"]: + angle = get_rot_gate_angle(operation) + pyqir_func(self._builder, angle, *qubits) # type: ignore[union-attr] else: pyqir_func(self._builder, *qubits) diff --git a/tests/cirq_qir/test_cirq_preprocess.py b/tests/cirq_qir/test_cirq_preprocess.py index 4520cc0..c9e20f9 100644 --- a/tests/cirq_qir/test_cirq_preprocess.py +++ b/tests/cirq_qir/test_cirq_preprocess.py @@ -15,13 +15,11 @@ import cirq import numpy as np import pytest +import qbraid from qbraid_qir.cirq.exceptions import CirqConversionError from qbraid_qir.cirq.passes import preprocess_circuit -# pylint: disable=redefined-outer-name - - @pytest.fixture def gridqubit_circuit(): qubits = [cirq.GridQubit(x, 0) for x in range(4)] @@ -40,8 +38,8 @@ def test_convert_gridqubits_to_linequbits(gridqubit_circuit): linequbit_circuit = preprocess_circuit(gridqubit_circuit) for qubit in linequbit_circuit.all_qubits(): assert isinstance(qubit, cirq.LineQubit), "Qubit is not a LineQubit" - assert np.allclose( - linequbit_circuit.unitary(), gridqubit_circuit.unitary() + qbraid.interface.assert_allclose_up_to_global_phase( + linequbit_circuit.unitary(), gridqubit_circuit.unitary(), atol=1e-6 ), "Circuits are not equal" @@ -49,8 +47,8 @@ def test_convert_namedqubits_to_linequbits(namedqubit_circuit): linequbit_circuit = preprocess_circuit(namedqubit_circuit) for qubit in linequbit_circuit.all_qubits(): assert isinstance(qubit, cirq.LineQubit), "Qubit is not a LineQubit" - assert np.allclose( - linequbit_circuit.unitary(), namedqubit_circuit.unitary() + qbraid.interface.assert_allclose_up_to_global_phase( + linequbit_circuit.unitary(), namedqubit_circuit.unitary(), atol=1e-6 ), "Circuits are not equal" @@ -59,12 +57,3 @@ def test_empty_circuit_conversion(): converted_circuit = preprocess_circuit(circuit) assert len(converted_circuit.all_qubits()) == 0, "Converted empty circuit should have no qubits" - -def test_multi_qubit_measurement_error(): - qubits = cirq.LineQubit.range(3) - circuit = cirq.Circuit() - ps = cirq.X(qubits[0]) * cirq.Y(qubits[1]) * cirq.X(qubits[2]) - meas_gates = cirq.measure_single_paulistring(ps) - circuit.append(meas_gates) - with pytest.raises(CirqConversionError): - preprocess_circuit(circuit) diff --git a/tests/cirq_qir/test_cirq_to_qir.py b/tests/cirq_qir/test_cirq_to_qir.py index f9973a1..3a79dc2 100644 --- a/tests/cirq_qir/test_cirq_to_qir.py +++ b/tests/cirq_qir/test_cirq_to_qir.py @@ -18,6 +18,7 @@ import cirq import pyqir import pytest +import numpy as np from qbraid_qir.cirq import CirqConversionError, cirq_to_qir from tests.cirq_qir.fixtures.basic_gates import ( @@ -69,14 +70,6 @@ def test_cirq_qir_conversion_error(): cirq_to_qir(None) -def test_cirq_to_qir_conversion_error(): - """Test raising exception for conversion error.""" - op = cirq.XPowGate(exponent=0.25).controlled().on(cirq.LineQubit(1), cirq.LineQubit(2)) - circuit = cirq.Circuit(op) - with pytest.raises(CirqConversionError): - cirq_to_qir(circuit) - - @pytest.mark.parametrize("circuit_name", single_op_tests) def test_single_qubit_gates(circuit_name, request): qir_op, circuit = request.getfixturevalue(circuit_name) diff --git a/tests/qasm3_qir/converter/test_gates.py b/tests/qasm3_qir/converter/test_gates.py index 1abe394..6d7719e 100644 --- a/tests/qasm3_qir/converter/test_gates.py +++ b/tests/qasm3_qir/converter/test_gates.py @@ -263,18 +263,3 @@ def test_nested_gate_modifiers(): check_single_qubit_gate_op(generated_qir, 2, [1, 1, 1], "z") -def test_unsupported_modifiers(): - # TO DO : add implementations, but till then we have tests - for modifier in ["ctrl", "negctrl"]: - with pytest.raises( - NotImplementedError, - match=r"Controlled modifier gates not yet supported .*", - ): - _ = qasm3_to_qir( - f""" - OPENQASM 3; - include "stdgates.inc"; - qubit[2] q; - {modifier} @ h q[0], q[1]; - """ - )