Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fla/ops/chunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def chunk_gated_delta_rule_fwd(
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
# obtain WY representation. u is actually the new v.
A = chunk_scaled_dot_kkt_fwd(
k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
)
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
w, u = recompute_w_u_fwd(
Expand Down
110 changes: 70 additions & 40 deletions vllm/model_executor/layers/fla/ops/chunk_delta_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@

from .index import prepare_chunk_indices, prepare_chunk_offsets
from .op import exp
from .utils import is_nvidia_hopper, use_cuda_graph
from .utils import use_cuda_graph

NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
NUM_WARPS = [2, 4, 8, 16]


@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
Expand All @@ -35,7 +36,7 @@
for num_stages in [2, 3, 4]
for BV in [32, 64]
],
key=["H", "K", "V", "BT", "USE_G"],
key=["H", "K", "V", "BT"],
use_cuda_graph=use_cuda_graph,
)
@triton.jit(do_not_specialize=["T"])
Expand All @@ -45,6 +46,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
w,
v_new,
g,
gk,
h,
h0,
ht,
Expand All @@ -58,6 +60,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
BT: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
SAVE_NEW_VALUE: tl.constexpr,
Expand Down Expand Up @@ -88,12 +91,12 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
b_h4 = tl.zeros([64, BV], dtype=tl.float32)

# calculate offset
h += (boh * H + i_h) * K * V
v += (bos * H + i_h) * V
k += (bos * Hg + i_h // (H // Hg)) * K
w += (bos * H + i_h) * K
h += ((boh * H + i_h) * K * V).to(tl.int64)
v += ((bos * H + i_h) * V).to(tl.int64)
k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64)
w += ((bos * H + i_h) * K).to(tl.int64)
if SAVE_NEW_VALUE:
v_new += (bos * H + i_h) * V
v_new += ((bos * H + i_h) * V).to(tl.int64)
stride_v = H * V
stride_h = H * K * V
stride_k = Hg * K
Expand Down Expand Up @@ -145,92 +148,115 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
)
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))

p_v = tl.make_block_ptr(
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_v_new = (
tl.make_block_ptr(
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
if SAVE_NEW_VALUE
else None
)
b_v_new = tl.zeros([BT, BV], dtype=tl.float32)
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype))
b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
if K > 64:
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype))
b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
if K > 128:
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype))
b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
if K > 192:
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype))
b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
p_v = tl.make_block_ptr(
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v

if SAVE_NEW_VALUE:
p_v_new = tl.make_block_ptr(
p_v = tl.make_block_ptr(
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
tl.store(
p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)
)
tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))

last_idx = min((i_t + 1) * BT, T) - 1
if USE_G:
m_t = (i_t * BT + tl.arange(0, BT)) < T
last_idx = min((i_t + 1) * BT, T) - 1
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
p_g = tl.make_block_ptr(
g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
)
b_g = tl.load(p_g, boundary_check=(0,))
b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None]
b_v = b_v * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None]
b_g_last = exp(b_g_last)
b_h1 = b_h1 * b_g_last
b_h1 *= b_g_last
if K > 64:
b_h2 *= b_g_last
if K > 128:
b_h3 *= b_g_last
if K > 192:
b_h4 *= b_g_last

if USE_GK:
o_k1 = tl.arange(0, 64)
b_gk_last1 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k1,
mask=(o_k1 < K),
other=0.0,
)
b_h1 *= exp(b_gk_last1)[:, None]
if K > 64:
b_h2 = b_h2 * b_g_last
o_k2 = 64 + o_k1
b_gk_last2 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k2,
mask=(o_k2 < K),
other=0.0,
)
b_h2 *= exp(b_gk_last2)[:, None]
if K > 128:
b_h3 = b_h3 * b_g_last
o_k3 = 128 + o_k1
b_gk_last3 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k3,
mask=(o_k3 < K),
other=0.0,
)
b_h3 *= exp(b_gk_last3)[:, None]
if K > 192:
b_h4 = b_h4 * b_g_last
b_v_new = b_v_new.to(k.dtype.element_ty)
o_k4 = 192 + o_k1
b_gk_last4 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k4,
mask=(o_k4 < K),
other=0.0,
)
b_h4 *= exp(b_gk_last4)[:, None]
b_v = b_v.to(k.dtype.element_ty)

p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h1 += tl.dot(b_k, b_v_new)
b_h1 += tl.dot(b_k, b_v)
if K > 64:
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h2 += tl.dot(b_k, b_v_new)
b_h2 += tl.dot(b_k, b_v)
if K > 128:
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h3 += tl.dot(b_k, b_v_new)
b_h3 += tl.dot(b_k, b_v)
if K > 192:
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h4 += tl.dot(b_k, b_v_new)

b_h4 += tl.dot(b_k, b_v)
# epilogue
if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
Expand All @@ -257,12 +283,15 @@ def chunk_gated_delta_rule_fwd_h(
w: torch.Tensor,
u: torch.Tensor,
g: torch.Tensor | None = None,
gk: torch.Tensor | None = None,
initial_state: torch.Tensor | None = None,
output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
save_new_value: bool = True,
cu_seqlens: torch.LongTensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# This kernel is slightly different from fla to support Q/K with different head numbers.
# In fla, Q/K always have the same head number, so Hg is always equal to H.
B, T, Hg, K, V = *k.shape, u.shape[-1]
H = u.shape[-2]
BT = chunk_size
Expand Down Expand Up @@ -299,6 +328,7 @@ def grid(meta):
w=w,
v_new=v_new,
g=g,
gk=gk,
h=h,
h0=initial_state,
ht=final_state,
Expand Down
24 changes: 11 additions & 13 deletions vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
"USE_G": lambda args: args["g_cumsum"] is not None,
}
)
@triton.autotune(
Expand All @@ -35,7 +35,7 @@
def chunk_scaled_dot_kkt_fwd_kernel(
k,
beta,
g_cumsum,
g,
A,
cu_seqlens,
chunk_indices,
Expand Down Expand Up @@ -85,9 +85,7 @@ def chunk_scaled_dot_kkt_fwd_kernel(
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))

if USE_G:
p_g = tl.make_block_ptr(
g_cumsum + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
)
p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_g_diff = b_g[:, None] - b_g[None, :]
b_A = b_A * exp(b_g_diff)
Expand All @@ -102,8 +100,8 @@ def chunk_scaled_dot_kkt_fwd_kernel(

def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
beta: torch.Tensor,
g_cumsum: torch.Tensor | None = None,
g: torch.Tensor | None = None,
beta: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
Expand All @@ -116,9 +114,8 @@ def chunk_scaled_dot_kkt_fwd(
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
g_cumsum (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`.
Default: None
g (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
Expand All @@ -130,20 +127,21 @@ def chunk_scaled_dot_kkt_fwd(
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""

# This kernel is slightly different from fla to support Q/K with different head numbers.
# In fla, Q/K always have the same head number, so Hg is always equal to H.
B, T, Hg, K = k.shape

H = beta.shape[-1]
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)

A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
k=k,
g=g,
beta=beta,
g_cumsum=g_cumsum,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
Expand Down
22 changes: 18 additions & 4 deletions vllm/model_executor/layers/fla/ops/fused_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
IS_VARLEN: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
IS_KDA: tl.constexpr,
):
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV
Expand Down Expand Up @@ -86,7 +87,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_beta = beta + (bos * HV + i_hv) * V + o_v
else:
p_beta = beta + bos * HV + i_hv
p_g = g + bos * HV + i_hv

if not IS_KDA:
p_g = g + bos * HV + i_hv
else:
p_gk = g + (bos * HV + i_hv) * K + o_k

p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v

mask_k = o_k < K
Expand Down Expand Up @@ -116,14 +122,18 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_g = tl.load(p_g).to(tl.float32)

if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
b_q = b_q * scale
# [BK, BV]
b_h *= exp(b_g)
if not IS_KDA:
b_g = tl.load(p_g).to(tl.float32)
b_h *= exp(b_g)
else:
b_gk = tl.load(p_gk).to(tl.float32)
b_h *= exp(b_gk[:, None])
# [BV]
b_v -= tl.sum(b_h * b_k[:, None], 0)
if IS_BETA_HEADWISE:
Expand Down Expand Up @@ -155,7 +165,10 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
p_k += H * K
p_o += HV * V
p_v += HV * V
p_g += HV
if not IS_KDA:
p_g += HV
else:
p_gk += HV * K
p_beta += HV * (V if IS_BETA_HEADWISE else 1)


Expand Down Expand Up @@ -228,6 +241,7 @@ def fused_recurrent_gated_delta_rule_fwd(
IS_BETA_HEADWISE=beta.ndim == v.ndim,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
INPLACE_FINAL_STATE=inplace_final_state,
IS_KDA=False,
num_warps=num_warps,
num_stages=num_stages,
)
Expand Down
Loading