Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 22 additions & 194 deletions src/qibojit/custom_operators/gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from numba import njit, prange


@njit("int64(int64, int32[:])", cache=True)
@njit(cache=True)
def multicontrol_index(g, qubits):
i = 0
i += g
Expand All @@ -12,16 +12,7 @@ def multicontrol_index(g, qubits):
return i


@njit(
[
"float32[:](float32[:], float32[:,:], int32, int32)",
"float64[:](float64[:], float64[:,:], int64, int64)",
"complex64[:](complex64[:], complex64[:,:], int64, int64)",
"complex128[:](complex128[:], complex128[:,:], int64, int64)",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def apply_gate_kernel(state, gate, nstates, m):
tk = 1 << m
for g in prange(nstates): # pylint: disable=not-an-iterable
Expand All @@ -34,16 +25,7 @@ def apply_gate_kernel(state, gate, nstates, m):
return state


@njit(
[
"float32[:](float32[:], float32[:,:], int32[:], int32, int32)",
"float64[:](float64[:], float64[:,:], int32[:], int64, int64)",
"complex64[:](complex64[:], complex64[:,:], int32[:], int64, int64)",
"complex128[:](complex128[:], complex128[:,:], int32[:], int64, int64)",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def multicontrol_apply_gate_kernel(state, gate, qubits, nstates, m):
tk = 1 << m
for g in prange(nstates): # pylint: disable=not-an-iterable
Expand All @@ -56,16 +38,7 @@ def multicontrol_apply_gate_kernel(state, gate, qubits, nstates, m):
return state


@njit(
[
"float32[:](float32[:], optional(float32[:,:]), int32, int32)",
"float64[:](float64[:], optional(float64[:,:]), int64, int64)",
"complex64[:](complex64[:], optional(complex64[:,:]), int64, int64)",
"complex128[:](complex128[:], optional(complex128[:,:]), int64, int64)",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def apply_x_kernel(state, gate, nstates, m):
tk = 1 << m
for g in prange(nstates): # pylint: disable=not-an-iterable
Expand All @@ -75,16 +48,7 @@ def apply_x_kernel(state, gate, nstates, m):
return state


@njit(
[
"float32[:](float32[:], optional(float32[:,:]), int32[:], int32, int32)",
"float64[:](float64[:], optional(float64[:,:]), int32[:], int64, int64)",
"complex64[:](complex64[:], optional(complex64[:,:]), int32[:], int64, int64)",
"complex128[:](complex128[:], optional(complex128[:,:]), int32[:], int64, int64)",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def multicontrol_apply_x_kernel(state, gate, qubits, nstates, m):
tk = 1 << m
for g in prange(nstates): # pylint: disable=not-an-iterable
Expand All @@ -94,14 +58,7 @@ def multicontrol_apply_x_kernel(state, gate, qubits, nstates, m):
return state


@njit(
[
"complex64[:](complex64[:], optional(complex64[:,:]), int64, int64)",
"complex128[:](complex128[:], optional(complex128[:,:]), int64, int64)",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def apply_y_kernel(state, gate, nstates, m):
tk = 1 << m
for g in prange(nstates): # pylint: disable=not-an-iterable
Expand All @@ -111,14 +68,7 @@ def apply_y_kernel(state, gate, nstates, m):
return state


@njit(
[
"complex64[:](complex64[:], optional(complex64[:,:]), int32[:], int64, int64)",
"complex128[:](complex128[:], optional(complex128[:,:]), int32[:], int64, int64)",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def multicontrol_apply_y_kernel(state, gate, qubits, nstates, m):
tk = 1 << m
for g in prange(nstates): # pylint: disable=not-an-iterable
Expand All @@ -128,16 +78,7 @@ def multicontrol_apply_y_kernel(state, gate, qubits, nstates, m):
return state


@njit(
[
"float32[:](float32[:], optional(float32[:,:]), int32, int32)",
"float64[:](float64[:], optional(float64[:,:]), int64, int64)",
"complex64[:](complex64[:], optional(complex64[:,:]), int64, int64)",
"complex128[:](complex128[:], optional(complex128[:,:]), int64, int64)",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def apply_z_kernel(state, gate, nstates, m):
tk = 1 << m
for g in prange(nstates): # pylint: disable=not-an-iterable
Expand All @@ -146,16 +87,7 @@ def apply_z_kernel(state, gate, nstates, m):
return state


@njit(
[
"float32[:](float32[:], optional(float32[:,:]), int32[:], int32, int32)",
"float64[:](float64[:], optional(float64[:,:]), int32[:], int64, int64)",
"complex64[:](complex64[:], optional(complex64[:,:]), int32[:], int64, int64)",
"complex128[:](complex128[:], optional(complex128[:,:]), int32[:], int64, int64)",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def multicontrol_apply_z_kernel(state, gate, qubits, nstates, m):
tk = 1 << m
for g in prange(nstates): # pylint: disable=not-an-iterable
Expand All @@ -164,16 +96,7 @@ def multicontrol_apply_z_kernel(state, gate, qubits, nstates, m):
return state


@njit(
[
"float32[:](float32[:], float32, int32, int32)",
"float64[:](float64[:], float64, int64, int64)",
"complex64[:](complex64[:], complex64, int64, int64)",
"complex128[:](complex128[:], complex128, int64, int64)",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def apply_z_pow_kernel(state, gate, nstates, m):
tk = 1 << m
for g in prange(nstates): # pylint: disable=not-an-iterable
Expand All @@ -182,16 +105,7 @@ def apply_z_pow_kernel(state, gate, nstates, m):
return state


@njit(
[
"float32[:](float32[:], float32, int32[:], int32, int32)",
"float64[:](float64[:], float64, int32[:], int64, int64)",
"complex64[:](complex64[:], complex64, int32[:], int64, int64)",
"complex128[:](complex128[:], complex128, int32[:], int64, int64)",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def multicontrol_apply_z_pow_kernel(state, gate, qubits, nstates, m):
tk = 1 << m
for g in prange(nstates): # pylint: disable=not-an-iterable
Expand All @@ -200,16 +114,7 @@ def multicontrol_apply_z_pow_kernel(state, gate, qubits, nstates, m):
return state


@njit(
[
"float32[:](float32[:], float32[:,:], int32, int32, int32, boolean)",
"float64[:](float64[:], float64[:,:], int64, int64, int64, boolean)",
"complex64[:](complex64[:], complex64[:,:], int64, int64, int64, boolean)",
"complex128[:](complex128[:], complex128[:,:], int64, int64, int64, boolean)",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def apply_two_qubit_gate_kernel(state, gate, nstates, m1, m2, swap_targets=False):
tk1, tk2 = 1 << m1, 1 << m2
uk1, uk2 = tk1, tk2
Expand Down Expand Up @@ -248,16 +153,7 @@ def apply_two_qubit_gate_kernel(state, gate, nstates, m1, m2, swap_targets=False
return state


@njit(
[
"float32[:](float32[:], float32[:,:], int32[:], int32, int32, int32, boolean)",
"float64[:](float64[:], float64[:,:], int32[:], int64, int64, int64, boolean)",
"complex64[:](complex64[:], complex64[:,:], int32[:], int64, int64, int64, boolean)",
"complex128[:](complex128[:], complex128[:,:], int32[:], int64, int64, int64, boolean)",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def multicontrol_apply_two_qubit_gate_kernel(
state, gate, qubits, nstates, m1, m2, swap_targets=False
):
Expand Down Expand Up @@ -297,16 +193,7 @@ def multicontrol_apply_two_qubit_gate_kernel(
return state


@njit(
[
"float32[:](float32[:], optional(float32[:,:]), int32, int32, int32, boolean)",
"float64[:](float64[:], optional(float64[:,:]), int64, int64, int64, boolean)",
"complex64[:](complex64[:], optional(complex64[:,:]), int64, int64, int64, boolean)",
"complex128[:](complex128[:], optional(complex128[:,:]), int64, int64, int64, boolean)",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def apply_swap_kernel(state, gate, nstates, m1, m2, swap_targets=False):
tk1, tk2 = 1 << m1, 1 << m2
for g in prange(nstates): # pylint: disable=not-an-iterable
Expand All @@ -317,16 +204,7 @@ def apply_swap_kernel(state, gate, nstates, m1, m2, swap_targets=False):
return state


@njit(
[
"float32[:](float32[:], optional(float32[:,:]), int32[:], int32, int32, int32, boolean)",
"float64[:](float64[:], optional(float64[:,:]), int32[:], int64, int64, int64, boolean)",
"complex64[:](complex64[:], optional(complex64[:,:]), int32[:], int64, int64, int64, boolean)",
"complex128[:](complex128[:], optional(complex128[:,:]), int32[:], int64, int64, int64, boolean)",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def multicontrol_apply_swap_kernel(
state, gate, qubits, nstates, m1, m2, swap_targets=False
):
Expand All @@ -338,14 +216,7 @@ def multicontrol_apply_swap_kernel(
return state


@njit(
[
"complex64[:](complex64[:], complex64[:], int64, int64, int64, boolean)",
"complex128[:](complex128[:], complex128[:], int64, int64, int64, boolean)",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def apply_fsim_kernel(state, gate, nstates, m1, m2, swap_targets=False):
tk1, tk2 = 1 << m1, 1 << m2
uk1, uk2 = tk1, tk2
Expand All @@ -364,14 +235,7 @@ def apply_fsim_kernel(state, gate, nstates, m1, m2, swap_targets=False):
return state


@njit(
[
"complex64[:](complex64[:], complex64[:], int32[:], int64, int64, int64, boolean)",
"complex128[:](complex128[:], complex128[:], int32[:], int64, int64, int64, boolean)",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def multicontrol_apply_fsim_kernel(
state, gate, qubits, nstates, m1, m2, swap_targets=False
):
Expand All @@ -390,24 +254,15 @@ def multicontrol_apply_fsim_kernel(
return state


@njit(["int32(int32, int32[:])", "int64(int64, int64[:])"], cache=True)
@njit(cache=True)
def multitarget_index(i, targets):
t = 0
for u, v in enumerate(targets):
t += ((i >> u) & 1) * v
return t


@njit(
[
"float32[:](float32[:], float32[:,:], int32[:], int32, int32[:])",
"float64[:](float64[:], float64[:,:], int32[:], int64, int64[:])",
"complex64[:](complex64[:], complex64[:,:], int32[:], int64, int64[:])",
"complex128[:](complex128[:], complex128[:,:], int32[:], int64, int64[:])",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def apply_three_qubit_gate_kernel(state, gate, qubits, nstates, targets):
for g in prange(nstates): # pylint: disable=not-an-iterable
ig = multicontrol_index(g, qubits)
Expand All @@ -434,16 +289,7 @@ def apply_three_qubit_gate_kernel(state, gate, qubits, nstates, targets):
return state


@njit(
[
"float32[:](float32[:], float32[:,:], int32[:], int32, int32[:])",
"float64[:](float64[:], float64[:,:], int32[:], int64, int64[:])",
"complex64[:](complex64[:], complex64[:,:], int32[:], int64, int64[:])",
"complex128[:](complex128[:], complex128[:,:], int32[:], int64, int64[:])",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def apply_four_qubit_gate_kernel(state, gate, qubits, nstates, targets):
for g in prange(nstates): # pylint: disable=not-an-iterable
ig = multicontrol_index(g, qubits)
Expand Down Expand Up @@ -486,16 +332,7 @@ def apply_four_qubit_gate_kernel(state, gate, qubits, nstates, targets):
return state


@njit(
[
"float32[:](float32[:], float32[:,:], int32[:], int32, int32[:])",
"float64[:](float64[:], float64[:,:], int32[:], int64, int64[:])",
"complex64[:](complex64[:], complex64[:,:], int32[:], int64, int64[:])",
"complex128[:](complex128[:], complex128[:,:], int32[:], int64, int64[:])",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def apply_five_qubit_gate_kernel(state, gate, qubits, nstates, targets):
for g in prange(nstates): # pylint: disable=not-an-iterable
ig = multicontrol_index(g, qubits)
Expand Down Expand Up @@ -572,16 +409,7 @@ def apply_five_qubit_gate_kernel(state, gate, qubits, nstates, targets):
return state


@njit(
[
"float32[:](float32[:], float32[:,::1], int32[:], int32, int32[:])",
"float64[:](float64[:], float64[:,::1], int32[:], int64, int64[:])",
"complex64[:](complex64[:], complex64[:,::1], int32[:], int64, int64[:])",
"complex128[:](complex128[:], complex128[:,::1], int32[:], int64, int64[:])",
],
parallel=True,
cache=True,
)
@njit(parallel=True, cache=True)
def apply_multi_qubit_gate_kernel(state, gate, qubits, nstates, targets):
nsubstates = 1 << len(targets)
for g in prange(nstates): # pylint: disable=not-an-iterable
Expand Down
Loading
Loading