Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 22 additions & 26 deletions src/qibojit/backends/clifford_operations_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
PARALLEL = True


@njit("Tuple((u1[:], u1[:,:], u1[:,:]))(u1[:,:], u8)", parallel=PARALLEL, cache=True)
@njit(cache=True)
def _get_rxz(symplectic_matrix, nqubits):
return (
symplectic_matrix[:, -1],
Expand All @@ -15,7 +15,7 @@ def _get_rxz(symplectic_matrix, nqubits):
)


@njit("u1[:,:](u1[:,:], u8, u8)", parallel=PARALLEL, cache=True)
@njit(parallel=PARALLEL, cache=True)
def H(symplectic_matrix, q, nqubits):
r, x, z = _get_rxz(symplectic_matrix, nqubits)

Expand All @@ -27,7 +27,7 @@ def H(symplectic_matrix, q, nqubits):
return symplectic_matrix


@njit("u1[:,:](u1[:,:], u8, u8, u8)", parallel=PARALLEL, cache=True)
@njit(parallel=PARALLEL, cache=True)
def CNOT(symplectic_matrix, control_q, target_q, nqubits):
r, x, z = _get_rxz(symplectic_matrix, nqubits)

Expand All @@ -40,7 +40,7 @@ def CNOT(symplectic_matrix, control_q, target_q, nqubits):
return symplectic_matrix


@njit("u1[:,:](u1[:,:], u8, u8, u8)", parallel=PARALLEL, cache=True)
@njit(parallel=PARALLEL, cache=True)
def CZ(symplectic_matrix, control_q, target_q, nqubits):
"""Decomposition --> H-CNOT-H"""
r, x, z = _get_rxz(symplectic_matrix, nqubits)
Expand All @@ -59,7 +59,7 @@ def CZ(symplectic_matrix, control_q, target_q, nqubits):
return symplectic_matrix


@njit("u1[:,:](u1[:,:], u8, u8)", parallel=PARALLEL, cache=True)
@njit(parallel=PARALLEL, cache=True)
def S(symplectic_matrix, q, nqubits):
r, x, z = _get_rxz(symplectic_matrix, nqubits)

Expand All @@ -69,7 +69,7 @@ def S(symplectic_matrix, q, nqubits):
return symplectic_matrix


@njit("u1[:,:](u1[:,:], u8, u8)", parallel=PARALLEL, cache=True)
@njit(parallel=PARALLEL, cache=True)
def Z(symplectic_matrix, q, nqubits):
"""Decomposition --> S-S"""
r, x, z = _get_rxz(symplectic_matrix, nqubits)
Expand All @@ -81,7 +81,7 @@ def Z(symplectic_matrix, q, nqubits):
return symplectic_matrix


@njit("u1[:,:](u1[:,:], u8, u8)", parallel=PARALLEL, cache=True)
@njit(parallel=PARALLEL, cache=True)
def X(symplectic_matrix, q, nqubits):
"""Decomposition --> H-S-S-H"""
r, x, z = _get_rxz(symplectic_matrix, nqubits)
Expand All @@ -93,7 +93,7 @@ def X(symplectic_matrix, q, nqubits):
return symplectic_matrix


@njit("u1[:,:](u1[:,:], u8, u8)", parallel=PARALLEL, cache=True)
@njit(parallel=PARALLEL, cache=True)
def Y(symplectic_matrix, q, nqubits):
"""Decomposition --> S-S-H-S-S-H"""
r, x, z = _get_rxz(symplectic_matrix, nqubits)
Expand All @@ -105,7 +105,7 @@ def Y(symplectic_matrix, q, nqubits):
return symplectic_matrix


@njit("u1[:,:](u1[:,:], u8, u8)", parallel=PARALLEL, cache=True)
@njit(parallel=PARALLEL, cache=True)
def SX(symplectic_matrix, q, nqubits):
"""Decomposition --> H-S-H"""
r, x, z = _get_rxz(symplectic_matrix, nqubits)
Expand All @@ -116,7 +116,7 @@ def SX(symplectic_matrix, q, nqubits):
return symplectic_matrix


@njit("u1[:,:](u1[:,:], u8, u8)", parallel=PARALLEL, cache=True)
@njit(parallel=PARALLEL, cache=True)
def SDG(symplectic_matrix, q, nqubits):
"""Decomposition --> S-S-S"""
r, x, z = _get_rxz(symplectic_matrix, nqubits)
Expand All @@ -127,7 +127,7 @@ def SDG(symplectic_matrix, q, nqubits):
return symplectic_matrix


@njit("u1[:,:](u1[:,:], u8, u8)", parallel=PARALLEL, cache=True)
@njit(parallel=PARALLEL, cache=True)
def SXDG(symplectic_matrix, q, nqubits):
"""Decomposition --> H-S-S-S-H"""
r, x, z = _get_rxz(symplectic_matrix, nqubits)
Expand All @@ -138,7 +138,7 @@ def SXDG(symplectic_matrix, q, nqubits):
return symplectic_matrix


@njit("u1[:,:](u1[:,:], u8, u8)", parallel=PARALLEL, cache=True)
@njit(parallel=PARALLEL, cache=True)
def RY_pi(symplectic_matrix, q, nqubits):
"""Decomposition --> H-S-S"""
r, x, z = _get_rxz(symplectic_matrix, nqubits)
Expand All @@ -151,7 +151,7 @@ def RY_pi(symplectic_matrix, q, nqubits):
return symplectic_matrix


@njit("u1[:,:](u1[:,:], u8, u8)", parallel=PARALLEL, cache=True)
@njit(parallel=PARALLEL, cache=True)
def RY_3pi_2(symplectic_matrix, q, nqubits):
"""Decomposition --> H-S-S"""
r, x, z = _get_rxz(symplectic_matrix, nqubits)
Expand All @@ -164,7 +164,7 @@ def RY_3pi_2(symplectic_matrix, q, nqubits):
return symplectic_matrix


@njit("u1[:,:](u1[:,:], u8, u8, u8)", parallel=PARALLEL, cache=True)
@njit(parallel=PARALLEL, cache=True)
def SWAP(symplectic_matrix, control_q, target_q, nqubits):
"""Decomposition --> CNOT-CNOT-CNOT"""
r, x, z = _get_rxz(symplectic_matrix, nqubits)
Expand Down Expand Up @@ -195,7 +195,7 @@ def SWAP(symplectic_matrix, control_q, target_q, nqubits):
return symplectic_matrix


@njit("u1[:,:](u1[:,:], u8, u8, u8)", parallel=PARALLEL, cache=True)
@njit(parallel=PARALLEL, cache=True)
def iSWAP(symplectic_matrix, control_q, target_q, nqubits):
"""Decomposition --> H-CNOT-CNOT-H-S-S"""
r, x, z = _get_rxz(symplectic_matrix, nqubits)
Expand Down Expand Up @@ -228,7 +228,7 @@ def iSWAP(symplectic_matrix, control_q, target_q, nqubits):
return symplectic_matrix


@njit("u1[:,:](u1[:,:], u8, u8, u8)", parallel=PARALLEL, cache=True)
@njit(parallel=PARALLEL, cache=True)
def CY(symplectic_matrix, control_q, target_q, nqubits):
"""Decomposition --> S-CNOT-SDG"""
r, x, z = _get_rxz(symplectic_matrix, nqubits)
Expand All @@ -254,13 +254,13 @@ def CY(symplectic_matrix, control_q, target_q, nqubits):


# this cannot be cached anymore with numba unfortunately
@njit("i8(i8)", parallel=False, cache=True)
@njit(cache=True)
def _packed_size(n):
"""Returns the size of an array of `n` booleans after packing."""
return int(np.ceil(n / 8))


@njit(["u1[:,:](u1[:,:], i8)", "u1[:,:](b1[:,:], i8)"], parallel=PARALLEL, cache=True)
@njit(parallel=PARALLEL, cache=True)
def _packbits(array, axis):
array = array.astype(np.uint8)
array = np.ascontiguousarray(np.swapaxes(array, axis, -1))
Expand All @@ -281,7 +281,7 @@ def _packbits(array, axis):
return np.swapaxes(out, axis, -1)


@njit("u1[:,:](u1[:,:], i8)", parallel=PARALLEL, cache=True)
@njit(cache=True)
def _pack_for_measurements(state, nqubits):
"""Prepares the state for measurements by packing the rows of the X and Z sections of the symplectic matrix."""
r, x, z = _get_rxz(state, nqubits)
Expand All @@ -290,15 +290,15 @@ def _pack_for_measurements(state, nqubits):
return np.hstack((x, z, r[:, None]))


@njit("u1[:](u1)", parallel=PARALLEL, cache=True)
@njit(cache=True)
def _unpack_byte(byte):
bits = np.empty(8, dtype=np.uint8)
for i in range(8):
bits[i] = (byte >> (7 - i)) & 1
return bits


@njit("u1[:,:](u1[:,:], i8, i8)", parallel=PARALLEL, cache=True)
@njit(cache=True)
def _unpackbits(array, axis, count):
# this is gonnna be used on 2d arrays only
# i.e. portions of the symplectic matrix
Expand Down Expand Up @@ -327,7 +327,7 @@ def _unpackbits(array, axis, count):
return out


@njit("u1[:,:](u1[:,:], i8)", parallel=PARALLEL, cache=True)
@njit(cache=True)
def _unpack_for_measurements(state, nqubits):
"""Unpacks the symplectc matrix that was packed for measurements."""
x = _unpackbits(state[:, : _packed_size(nqubits)], axis=1, count=nqubits)
Expand All @@ -336,10 +336,6 @@ def _unpack_for_measurements(state, nqubits):


@njit(
[
"u1[:,:](u1[:,:], u8[:], u8[:], u8, b1)",
"u1[:,:](u1[:,:], u4[:], u4[:], u4, b1)",
],
parallel=PARALLEL,
cache=True,
fastmath=True,
Expand Down
Loading
Loading