Skip to content
Merged
4 changes: 4 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,11 @@ __global__ void indexer_k_quant_and_cache_kernel(
#ifndef USE_ROCM
__syncwarp();
#endif
#if defined(__gfx942__)
float scale = fmaxf(amax, 1e-4) / 224.0f;
#else
float scale = fmaxf(amax, 1e-4) / 448.0f;
#endif
if (use_ue8m0) {
scale = exp2f(ceilf(log2f(scale)));
}
Expand Down
210 changes: 210 additions & 0 deletions vllm/attention/ops/rocm_aiter_mla_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
from functools import lru_cache

import torch

from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)


# Take from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L84
def fp8_mqa_logits_torch(
q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor],
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
"""Compute FP8 MQA logits for a single sequence without KV paging.

Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.

Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
kv, scale = kv
seq_len_kv = kv.shape[0]
k = kv.to(torch.bfloat16)
q = q.to(torch.bfloat16)

mask_lo = (
torch.arange(0, seq_len_kv, device="cuda")[None, :] >= cu_seqlen_ks[:, None]
)
mask_hi = (
torch.arange(0, seq_len_kv, device="cuda")[None, :] < cu_seqlen_ke[:, None]
)
mask = mask_lo & mask_hi

score = torch.einsum("mhd,nd->hmn", q, k).float() * scale
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float("-inf"))

return logits


def rocm_fp8_mqa_logits(
q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor],
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
) -> torch.Tensor:
"""Compute FP8 MQA logits for a single sequence without KV paging.

Args:
q: Query tensor of shape [M, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or
[N, 1]) with dtype `torch.float32`.
weights: weights of shape [M, H], dtype `torch.float32`.
cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
shape [M], dtype int32.
cu_seqlen_ke: End indices (exclusive) for valid K per query position,
shape [M], dtype int32.

Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""

# TODO(ganyi): Temporarily workaround, will remove the module check and reference
# path after aiter merge this kernel into main
@lru_cache
def has_mqa_logits_module():
return importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None

if rocm_aiter_ops.is_enabled() and has_mqa_logits_module():
from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits

kv, scale = kv
return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
else:
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)


# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_attention.py#L156
def fp8_paged_mqa_logits_torch(
q: torch.Tensor,
kv_cache: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
max_model_len: int,
):
from vllm.utils.math_utils import cdiv

fp8_dtype = current_platform.fp8_dtype()
batch_size, next_n, _, dim = q.size()
kv_cache, scale = kv_cache[..., :dim], kv_cache[..., dim:]
scale = scale.contiguous().view(torch.float)
q = q.float()
kv_cache = kv_cache.view(fp8_dtype).float() * scale
num_block, block_size, _, dim = kv_cache.size()
logits = torch.full(
[batch_size * next_n, max_model_len],
float("-inf"),
device=q.device,
dtype=torch.float32,
)
context_lens = context_lens.tolist()
for i in range(batch_size):
context_len = context_lens[i]
q_offsets = torch.arange(context_len - next_n, context_len, device="cuda")
weight_slice = (
weights[i * next_n : (i + 1) * next_n, :].transpose(0, 1).contiguous()
)
for block_rk in range(cdiv(context_len, block_size)):
block_idx = block_tables[i][block_rk]
qx, kx = q[i], kv_cache[block_idx]
k_offsets = torch.arange(
block_rk * block_size, (block_rk + 1) * block_size, device="cuda"
)
mask = (k_offsets[None, :] < context_len) & (
k_offsets[None, :] <= q_offsets[:, None]
)
s = torch.where(
mask[None, :, :],
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
logits.dtype
),
float("-inf"),
)
s = torch.relu(s) * weight_slice[..., None]
s = s.sum(dim=0)
logits[
i * next_n : (i + 1) * next_n,
block_rk * block_size : (block_rk + 1) * block_size,
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float("-inf"))
return logits


def rocm_fp8_paged_mqa_logits(
q_fp8: torch.Tensor,
kv_cache_fp8: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
schedule_metadata: torch.Tensor,
max_model_len: int,
) -> torch.Tensor:
"""Compute FP8 MQA logits using paged KV-cache.

Args:
q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
`torch.float8_e4m3fn` by caller.
kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
[num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
4 bytes per (block,pos) store the `float` dequant scale.
weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
context_lens: Tensor of shape [B], dtype int32; effective context length
for each batch element.
block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
block indices to physical blocks in the paged cache.
schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
used to distribute work across SMs.
max_model_len: Maximum sequence length used to size the logits output.

Returns:
Logits tensor of shape [B * next_n, max_model_len], dtype
`torch.float32`.
"""

if rocm_aiter_ops.is_enabled():
from aiter.ops.triton.pa_mqa_logits import deepgemm_fp8_paged_mqa_logits_stage1

batch_size, next_n, heads, _ = q_fp8.shape
out_qk = torch.full(
(heads, batch_size * next_n, max_model_len),
float("-inf"),
device="cuda",
dtype=torch.float32,
)
deepgemm_fp8_paged_mqa_logits_stage1(
q_fp8,
kv_cache_fp8,
weights,
out_qk,
context_lens,
block_tables,
max_model_len,
)
return out_qk.sum(dim=0)
else:
return fp8_paged_mqa_logits_torch(
q_fp8, kv_cache_fp8, weights, context_lens, block_tables, max_model_len
)
22 changes: 18 additions & 4 deletions vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,7 @@ def sparse_attn_indexer(
) -> torch.Tensor:
# careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata
fp8_dtype = current_platform.fp8_dtype()
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
return sparse_attn_indexer_fake(
Expand Down Expand Up @@ -630,7 +631,7 @@ def sparse_attn_indexer(
k_fp8 = torch.empty(
[chunk.total_seq_lens, head_dim],
device=k.device,
dtype=torch.float8_e4m3fn,
dtype=fp8_dtype,
)
k_scale = torch.empty(
[chunk.total_seq_lens, 4],
Expand All @@ -644,7 +645,12 @@ def sparse_attn_indexer(
chunk.block_table,
chunk.cu_seq_lens,
)
logits = fp8_mqa_logits(
fp8_mqa_logits_func = fp8_mqa_logits
if current_platform.is_rocm():
from vllm.attention.ops.rocm_aiter_mla_sparse import rocm_fp8_mqa_logits

fp8_mqa_logits_func = rocm_fp8_mqa_logits
logits = fp8_mqa_logits_func(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32)),
weights[chunk.token_start : chunk.token_end],
Expand Down Expand Up @@ -689,7 +695,14 @@ def sparse_attn_indexer(
next_n = padded_q_fp8_decode_tokens.shape[1]
assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n
logits = fp8_paged_mqa_logits(
fp8_paged_mqa_logits_func = fp8_paged_mqa_logits
if current_platform.is_rocm():
from vllm.attention.ops.rocm_aiter_mla_sparse import (
rocm_fp8_paged_mqa_logits,
)

fp8_paged_mqa_logits_func = rocm_fp8_paged_mqa_logits
logits = fp8_paged_mqa_logits_func(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
Expand Down Expand Up @@ -746,7 +759,8 @@ def sparse_attn_indexer_fake(
_flattened_kv = torch.empty(
[total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
)
_k_fp8 = _flattened_kv[..., :head_dim].view(torch.float8_e4m3fn).contiguous()
fp8_dtype = current_platform.fp8_dtype()
_k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous()
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
return topk_indices_buffer

Expand Down
13 changes: 12 additions & 1 deletion vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,18 @@ def get_attn_backend_cls(
from vllm.attention.backends.registry import AttentionBackendEnum

if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on ROCm.")
if kv_cache_dtype.startswith("fp8"):
raise ValueError(
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
)
assert block_size == 1, (
"Sparse MLA backend on ROCm only supports block size 1 for now."
)
logger.info_once("Using Sparse MLA backend on V1 engine.")
return (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ganyi1996ppo what do you think about adding an assertion of block-size 1 when we are using this backend?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, please take a look.

"vllm.v1.attention.backends.mla.rocm_aiter_mla_sparse."
"ROCMAiterMLASparseBackend"
)

if use_mla:
if selected_backend is None:
Expand Down
5 changes: 3 additions & 2 deletions vllm/utils/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ def _align(x: int, y: int) -> int:
def per_block_cast_to_fp8(
x: torch.Tensor, block_size: list[int] = DEFAULT_BLOCK_SIZE, use_ue8m0: bool = False
) -> tuple[torch.Tensor, torch.Tensor]:
fp8_dtype = current_platform.fp8_dtype()
assert x.dim() == 2
m, n = x.shape
block_m, block_n = block_size
Expand All @@ -334,9 +335,9 @@ def per_block_cast_to_fp8(
x_padded[:m, :n] = x
x_view = x_padded.view(-1, block_m, x_padded.size(1) // block_n, block_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
sf = x_amax / 224.0 if current_platform.is_fp8_fnuz() else x_amax / 448.0
sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
x_scaled = (x_view * (1.0 / sf)).to(fp8_dtype)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(2)
)
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mla/flashmla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _convert_req_index_to_global_index_kernel(
inblock_off = tok % BLOCK_SIZE

# Guard block_table access
valid_block = block_id < max_num_blocks_per_req
valid_block = (block_id < max_num_blocks_per_req) & (block_id >= 0)
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
base = tl.load(bt_ptr, mask=valid_block, other=0)

Expand Down
15 changes: 9 additions & 6 deletions vllm/v1/attention/backends/mla/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata, is_deep_gemm_supported
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
Expand All @@ -23,7 +24,9 @@


class DeepseekV32IndexerBackend(AttentionBackend):
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [
1 if current_platform.is_rocm() else 64
]

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
Expand Down Expand Up @@ -328,10 +331,10 @@ def build(
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()

seq_lens = common_attn_metadata.seq_lens[:num_decodes]

self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
if is_deep_gemm_supported():
self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
seq_lens, self.kv_cache_spec.block_size, self.num_sms
)
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...],
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
Expand Down
Loading