Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_triton_block_scaled_mm,
w8a8_block_fp8_matmul,
)
from vllm.utils import FlexibleArgumentParser, cdiv

Expand Down Expand Up @@ -158,7 +158,7 @@ def bench_fp8(
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
),
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm(
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul(
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
),
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
get_col_major_tma_aligned_tensor,
per_token_group_quant_fp8,
w8a8_triton_block_scaled_mm,
w8a8_block_fp8_matmul,
)
from vllm.triton_utils import triton
from vllm.utils.deep_gemm import calc_diff, fp8_gemm_nt, per_block_cast_to_fp8
Expand Down Expand Up @@ -59,7 +59,7 @@ def deepgemm_gemm():

# === vLLM Triton Implementation ===
def vllm_triton_gemm():
return w8a8_triton_block_scaled_mm(A_vllm,
return w8a8_block_fp8_matmul(A_vllm,
B_vllm,
A_scale_vllm,
B_scale_vllm,
Expand Down
5 changes: 2 additions & 3 deletions tests/kernels/quantization/test_block_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm.config import VllmConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm, get_col_major_tma_aligned_tensor,
per_token_group_quant_fp8, w8a8_triton_block_scaled_mm)
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8
Expand Down Expand Up @@ -90,8 +90,7 @@ def test_w8a8_block_fp8_matmul(M, N, K, block_size, out_dtype, seed):

ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
out = w8a8_triton_block_scaled_mm(A_fp8, B_fp8, As, Bs, block_size,
out_dtype)
out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype)

rel_diff = (torch.mean(
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
Expand Down
26 changes: 7 additions & 19 deletions tests/kernels/quantization/test_fp8_quant_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@
(8, 513, 64), # Non-divisible (native only)
])
@pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("use_ue8m0", [True, False])
@torch.inference_mode()
def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
group_size: int, seed: int,
use_ue8m0: bool) -> None:
group_size: int, seed: int) -> None:
"""Test QuantFP8 group quantization with various configurations.

Tests both CUDA and native implementations, column-major scales,
Expand All @@ -40,8 +38,7 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False,
use_ue8m0=use_ue8m0)
column_major_scales=False)

# 1. Test native implementation (always available)
x_quant_native, scales_native = quant_op.forward_native(x.clone())
Expand All @@ -51,15 +48,9 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,
# 2. Test column-major scales configuration
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True,
use_ue8m0=use_ue8m0)
column_major_scales=True)
_, scales_col = quant_op_col.forward_native(x.clone())
assert scales_col.shape == (batch_size, expected_num_groups)
assert scales_col.stride(0) == 1
assert scales_col.stride(1) == batch_size

# Test column-major scales consistency
assert torch.allclose(scales_col, scales_native, rtol=1e-9, atol=1e-8)
assert scales_col.shape == (expected_num_groups, batch_size)

# 3. Test CUDA implementation (only for divisible dimensions)
if is_divisible:
Expand All @@ -77,9 +68,8 @@ def test_quantfp8_group_functionality(batch_size: int, hidden_dim: int,


@pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("use_ue8m0", [True, False])
@torch.inference_mode()
def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
def test_quantfp8_group_multidimensional(seed: int) -> None:
current_platform.seed_everything(seed)

group_size = 64
Expand All @@ -92,8 +82,7 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
group_shape = GroupShape(1, group_size)
quant_op = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=False,
use_ue8m0=use_ue8m0)
column_major_scales=False)

x_quant, scales = quant_op.forward_native(x_3d.clone())
assert x_quant.shape == x_3d.shape
Expand All @@ -102,8 +91,7 @@ def test_quantfp8_group_multidimensional(seed: int, use_ue8m0: bool) -> None:
# Test column_major_scales with multi-dim
quant_op_col = QuantFP8(static=False,
group_shape=group_shape,
column_major_scales=True,
use_ue8m0=use_ue8m0)
column_major_scales=True)
_, scales_col = quant_op_col.forward_native(x_3d.clone())
assert scales_col.shape == (batch1, hidden_dim // group_size, batch2)

Expand Down
30 changes: 30 additions & 0 deletions tests/model_executor/test_enabled_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from vllm.model_executor.layers.layernorm import (RMSNorm,
dispatch_rocm_rmsnorm_func,
fused_add_rms_norm, rms_norm)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul)
from vllm.platforms import current_platform

RMS_NORM_SUPPORTED_DTYPES = [torch.float16, torch.bfloat16]
Expand Down Expand Up @@ -109,6 +111,34 @@ def test_enabled_ops_invalid(env: str):
RMSNorm(1024).enabled()


@pytest.mark.skipif(
not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(),
reason="AITER is a feature exclusive for ROCm and FP8_FNUZ")
@pytest.mark.parametrize("use_cutlass", [True, False])
@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"])
def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str,
use_rocm_aiter_gemm_w8a8_blockscale: str,
monkeypatch):

monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR",
use_rocm_aiter_gemm_w8a8_blockscale)

use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool(
int(use_rocm_aiter_gemm_w8a8_blockscale)))
block_scale_func = dispatch_w8a8_blockscale_func(
use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported)
if use_cutlass:
assert block_scale_func == cutlass_scaled_mm
elif current_platform.is_rocm() and int(use_rocm_aiter) and int(
use_rocm_aiter_gemm_w8a8_blockscale):
assert block_scale_func == (
torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale)
else:
assert block_scale_func == w8a8_block_fp8_matmul


@pytest.mark.parametrize("use_rocm_aiter", ["0", "1"])
def test_topk_dispatch(use_rocm_aiter: str, monkeypatch):
monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter)
Expand Down
35 changes: 0 additions & 35 deletions tests/quantization/test_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
CompressedTensorsW4A16Fp4, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Expand Down Expand Up @@ -745,35 +742,3 @@ def test_compressed_tensors_transforms_perplexity(vllm_runner, model, prompt,
perplexity = llm.generate_prompt_perplexity([prompt])[0]
print(perplexity)
assert perplexity <= exp_perplexity


def test_compressed_tensors_fp8_block_enabled(vllm_runner):
model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK"
with vllm_runner(model_path) as llm:

fp8_dtype = current_platform.fp8_dtype()

def check_model(model):
layer = model.model.layers[0]

qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
assert isinstance(qkv_proj.scheme.w8a8_block_fp8_linear,
W8A8BlockFp8LinearOp)

assert qkv_proj.weight.dtype is fp8_dtype
assert qkv_proj.weight_scale.dtype is torch.float32
assert len(qkv_proj.weight.shape) == 2
assert len(qkv_proj.weight_scale.shape) == 2

input_quant_op = \
qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op
assert isinstance(input_quant_op, QuantFP8)
assert input_quant_op._forward_method == input_quant_op.forward_cuda

llm.apply_model(check_model)

output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
17 changes: 0 additions & 17 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,23 +545,6 @@ def __post_init__(self):
# local attention.
self.scheduler_config.disable_hybrid_kv_cache_manager = True

def has_blocked_weights():
if self.quant_config is not None:
if hasattr(self.quant_config, "weight_block_size"):
return self.quant_config.weight_block_size is not None
elif hasattr(self.quant_config, "has_blocked_weights"):
return self.quant_config.has_blocked_weights()
return False

# Enable quant_fp8 CUDA ops (TODO disable in follow up)
# On H100 the CUDA kernel is faster than
# native implementation
# https://github.com/vllm-project/vllm/issues/25094
if has_blocked_weights():
custom_ops = self.compilation_config.custom_ops
if "none" not in custom_ops and "-quant_fp8" not in custom_ops:
custom_ops.append("+quant_fp8")

def update_sizes_for_sequence_parallelism(self,
possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,14 +644,6 @@ def get_cache_scale(self, name: str) -> Optional[str]:
# If no matches, return None
return None

def has_blocked_weights(self) -> bool:
for scheme in self.target_scheme_map.values():
weight_quant = scheme.get("weights")
if (weight_quant is not None
and weight_quant.strategy == QuantizationStrategy.BLOCK):
return True
return False

@staticmethod
def supports_cutlass_24(
weight_quant: Optional[QuantizationArgs],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp, check_aiter_fp8_linear_support,
apply_fp8_block_linear, check_aiter_fp8_linear_support,
create_fp8_input_scale, create_fp8_scale_parameter,
create_fp8_weight_parameter, maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy, process_fp8_weight_channel_strategy,
Expand Down Expand Up @@ -41,30 +41,16 @@ def __init__(self, weight_quant: QuantizationArgs,
self.strategy = weight_quant.strategy
self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme
self.act_q_group_shape = GroupShape.PER_TENSOR \
if is_static_input_scheme else GroupShape.PER_TOKEN
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_q_group_shape)

self.weight_block_size = self.weight_quant.block_structure
if self.weight_block_size is not None:
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
else:
self.act_q_group_shape = GroupShape.PER_TENSOR \
if is_static_input_scheme else GroupShape.PER_TOKEN

self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = check_aiter_fp8_linear_support()

if self.weight_block_size is not None:
assert not self.is_static_input_scheme
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_q_group_shape)

@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
Expand Down Expand Up @@ -155,14 +141,13 @@ def apply_weights(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:

if self.weight_block_size is not None:
return self.w8a8_block_fp8_linear.apply(
if layer.weight_block_size is not None:
return apply_fp8_block_linear(
layer,
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported)

return self.fp8_linear.apply(input=x,
weight=layer.weight,
Expand Down
10 changes: 5 additions & 5 deletions vllm/model_executor/layers/quantization/deepgemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def prepare_block_fp8_matmul_inputs(
return M, N, K, C


def w8a8_deepgemm_block_scaled_mm(
def w8a8_block_fp8_matmul_deepgemm(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Expand All @@ -58,7 +58,7 @@ def w8a8_deepgemm_block_scaled_mm(
return C


def w8a8_deepgemm_block_scaled_mm_fake(
def w8a8_block_fp8_matmul_deepgemm_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Expand All @@ -72,7 +72,7 @@ def w8a8_deepgemm_block_scaled_mm_fake(


direct_register_custom_op(
op_name="w8a8_deepgemm_block_scaled_mm",
op_func=w8a8_deepgemm_block_scaled_mm,
fake_impl=w8a8_deepgemm_block_scaled_mm_fake,
op_name="w8a8_block_fp8_matmul_deepgemm",
op_func=w8a8_block_fp8_matmul_deepgemm,
fake_impl=w8a8_block_fp8_matmul_deepgemm_fake,
)
Loading