diff --git a/src/qibo/gates/gates.py b/src/qibo/gates/gates.py index 16c4b97dee..cb145f571b 100644 --- a/src/qibo/gates/gates.py +++ b/src/qibo/gates/gates.py @@ -2413,11 +2413,7 @@ def _dagger(self) -> "Gate": def decompose(self, *free, use_toffolis: bool = True) -> List[Gate]: """Decomposition of RBS gate according to `ArXiv:2109.09685 `_.""" - from qibo.transpiler.decompositions import ( # pylint: disable=C0415 - standard_decompositions, - ) - - return standard_decompositions(self) + return _controlled_decompose(self, *free, use_toffolis=use_toffolis) class ECR(Gate): @@ -2461,11 +2457,7 @@ def decompose(self, *free, use_toffolis: bool = True) -> List[Gate]: \\textup{ECR} = e^{i 7 \\pi / 4} \\, S(q_{0}) \\, \\sqrt{X}(q_{1}) \\, \\textup{CNOT}(q_{0}, q_{1}) \\, X(q_{0}) """ - from qibo.transpiler.decompositions import ( # pylint: disable=C0415 - standard_decompositions, - ) - - return standard_decompositions(self) + return _controlled_decompose(self, *free, use_toffolis=use_toffolis) class TOFFOLI(Gate): @@ -2826,3 +2818,39 @@ def _check_engine(array): return torch return np + + +def _controlled_decompose(self, *free, use_toffolis: bool = True) -> List[Gate]: + """Decompose non-conjugation gates component-wise.""" + from qibo.transpiler.decompositions import ( # pylint: disable=C0415 + standard_decompositions, + ) + + if self.is_controlled_by: + control_qubits = self.control_qubits + self.is_controlled_by = False + self.control_qubits = () + decomp = standard_decompositions(self) + c_decomps1 = [] + c_decomps2 = [] + while len(decomp) > 1: + g1 = decomp[0].dagger() + g2 = decomp[-1] + if g1.to_json() != g2.to_json(): + break + + c_decomps1.append(decomp[0]) + c_decomps2.append(decomp[-1]) + decomp = decomp[1:-1] + + for i, gate in enumerate(decomp): + gate_control_qubits = gate.control_qubits + gate.is_controlled_by = False + gate.control_qubits = () + c_gate = gate.controlled_by(*control_qubits, *gate_control_qubits) + decomp[i] = c_gate + self.is_controlled_by = True + decomp = [*c_decomps1, *decomp, *c_decomps2[::-1]] + return decomp + + return standard_decompositions(self)