diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index cef53b183cef..99d8d3eee0df 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -86,6 +86,9 @@ def benchmark_config( (num_experts, 2 * shard_intermediate_size), dtype=torch.float32 ) w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) + if use_deep_gemm: + # we use the default block shape for deepgemm + block_quant_shape = [128, 128] if use_fp8_w8a8: if block_quant_shape: block_n, block_k = block_quant_shape[0], block_quant_shape[1] diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index e67ce0545318..253d2984aa9d 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -1,13 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# fmt: off -# ruff: noqa: E501 import time -# Import DeepGEMM functions -import deep_gemm import torch -from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor +from deep_gemm import fp8_gemm_nt +from deep_gemm.testing.numeric import calc_diff +from deep_gemm.utils.math import ceil_div, per_block_cast_to_fp8, per_token_cast_to_fp8 # Import vLLM functions from vllm import _custom_ops as ops @@ -18,107 +16,84 @@ from vllm.triton_utils import triton -# Copied from -# https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L9 -def per_token_cast_to_fp8( - x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - """Convert tensor to FP8 format with per-token scaling.""" - assert x.dim() == 2 and x.size(1) % 128 == 0 - m, n = x.shape - x_view = x.view(m, -1, 128) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return (x_view * (448.0 / x_amax.unsqueeze(2))).to( - torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) - - # Copied from # https://github.com/deepseek-ai/DeepGEMM/blob/78cacf70d41d15d688bd493ebc85845f7f2a3d5d/tests/test_core.py#L17 -def per_block_cast_to_fp8( - x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def per_block_cast_to_fp8_vllm(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Convert tensor to FP8 format with per-block scaling.""" assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), - dtype=x.dtype, - device=x.device) + x_padded = torch.zeros( + (ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device + ) x_padded[:m, :n] = x x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - return x_scaled.view_as(x_padded)[:m, :n].contiguous(), ( - x_amax / 448.0).view(x_view.size(0), x_view.size(2)) - - -def benchmark_shape(m: int, - n: int, - k: int, - warmup: int = 100, - repeat: int = 10000, - verbose: bool = False) -> dict: + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view( + x_view.size(0), x_view.size(2) + ) + + +def benchmark_shape( + m: int, + n: int, + k: int, + warmup: int = 100, + repeat: int = 10000, + verbose: bool = False, +) -> dict: """Benchmark all implementations for a specific (m, n, k) shape.""" if verbose: print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===") - # Create test tensors - A = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) - B = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) - - # Reference result in BF16 + A = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) + B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) torch.cuda.synchronize() C_ref = A @ B.t() # Pre-quantize B for all implementations # (weights can be pre-quantized offline) B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B) - B_vllm, B_scale_vllm = per_block_cast_to_fp8(B) + B_vllm, B_scale_vllm = per_block_cast_to_fp8_vllm(B) # Block size configuration block_size = [128, 128] # Pre-quantize A for all implementations A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) - A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm) - C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + C_deepgemm = ( + torch.empty((n, m), device="cuda", dtype=torch.bfloat16).t().contiguous() + ) A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( - A, block_size[1], column_major_scales=True) + A, block_size[1], column_major_scales=True + ) - # === DeepGEMM Implementation === def deepgemm_gemm(): - # A quantization is inside the loop as it depends on activations - # A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) - # A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8( - # A, block_size[1]) - # A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm) - # C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) - deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm), - (B_deepgemm, B_scale_deepgemm), - C_deepgemm) + fp8_gemm_nt( + (A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm + ) return C_deepgemm - # === vLLM Triton Implementation === def vllm_triton_gemm(): - # A quantization is inside the loop as it depends on activations - # A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) - return w8a8_block_fp8_matmul(A_vllm, - B_vllm, - A_scale_vllm, - B_scale_vllm, - block_size, - output_dtype=torch.bfloat16) - - # === vLLM CUTLASS Implementation === + return w8a8_block_fp8_matmul( + A_vllm, + B_vllm, + A_scale_vllm, + B_scale_vllm, + block_size, + output_dtype=torch.bfloat16, + ) + def vllm_cutlass_gemm(): - # A quantization is inside the loop as it depends on activations - # A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( - # A, block_size[1], column_major_scales=True) - return ops.cutlass_scaled_mm(A_vllm_cutlass, - B_vllm.T, - scale_a=A_scale_vllm_cutlass, - scale_b=B_scale_vllm.T, - out_dtype=torch.bfloat16) - - # Run correctness check first + return ops.cutlass_scaled_mm( + A_vllm_cutlass, + B_vllm.T, + scale_a=A_scale_vllm_cutlass, + scale_b=B_scale_vllm.T, + out_dtype=torch.bfloat16, + ) + if verbose: print("Running correctness check...") C_deepgemm = deepgemm_gemm() @@ -133,26 +108,22 @@ def vllm_cutlass_gemm(): print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}") print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}") print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}") - print("vLLM Triton vs DeepGEMM difference: " - f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}") - print("vLLM CUTLASS vs DeepGEMM difference: " - f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}") + print( + "vLLM Triton vs DeepGEMM difference: " + f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}" + ) + print( + "vLLM CUTLASS vs DeepGEMM difference: " + f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}" + ) - # Benchmark implementations implementations = { "DeepGEMM": deepgemm_gemm, "vLLM Triton": vllm_triton_gemm, - "vLLM CUTLASS": vllm_cutlass_gemm + "vLLM CUTLASS": vllm_cutlass_gemm, } - benchmark_results = { - "shape": { - "m": m, - "n": n, - "k": k - }, - "implementations": {} - } + benchmark_results = {"shape": {"m": m, "n": n, "k": k}, "implementations": {}} for name, func in implementations.items(): # Warmup @@ -180,38 +151,36 @@ def vllm_cutlass_gemm(): "tflops": tflops, "gb_s": gb_s, "diff": { - "DeepGEMM": - 0.0 if name == "DeepGEMM" else calc_diff(func(), C_deepgemm), - "Reference": - deepgemm_diff if name == "DeepGEMM" else - (vllm_triton_diff - if name == "vLLM Triton" else vllm_cutlass_diff) - } + "DeepGEMM": 0.0 + if name == "DeepGEMM" + else calc_diff(func(), C_deepgemm), + "Reference": deepgemm_diff + if name == "DeepGEMM" + else (vllm_triton_diff if name == "vLLM Triton" else vllm_cutlass_diff), + }, } if verbose: - print( - f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s" - ) + print(f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s") # Calculate speedups baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"] for name, data in benchmark_results["implementations"].items(): if name != "DeepGEMM": speedup = baseline / data["time_ms"] - benchmark_results["implementations"][name][ - "speedup_vs_deepgemm"] = speedup + benchmark_results["implementations"][name]["speedup_vs_deepgemm"] = speedup if verbose: - print(f"DeepGEMM is {1/speedup:.2f}x " - f"{'faster' if 1/speedup > 1 else 'slower'} than {name}") + print( + f"DeepGEMM is {1 / speedup:.2f}x " + f"{'faster' if 1 / speedup > 1 else 'slower'} than {name}" + ) - vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"][ - "time_ms"] - vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"][ - "time_ms"] + vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"]["time_ms"] + vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"]["time_ms"] cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time - benchmark_results["implementations"]["vLLM CUTLASS"][ - "speedup_vs_triton"] = cutlass_vs_triton + benchmark_results["implementations"]["vLLM CUTLASS"]["speedup_vs_triton"] = ( + cutlass_vs_triton + ) if verbose: print( f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x " @@ -223,8 +192,7 @@ def vllm_cutlass_gemm(): def format_table_row(values, widths): """Format a row with specified column widths.""" - return "| " + " | ".join(f"{val:{w}}" - for val, w in zip(values, widths)) + " |" + return "| " + " | ".join(f"{val:{w}}" for val, w in zip(values, widths)) + " |" def print_table(headers, rows, title=None): @@ -232,16 +200,12 @@ def print_table(headers, rows, title=None): if title: print(f"\n{title}") - # Calculate column widths based on headers and data widths = [ max(len(str(h)), max(len(str(row[i])) for row in rows)) for i, h in enumerate(headers) ] - # Create separator line separator = "+-" + "-+-".join("-" * w for w in widths) + "-+" - - # Print table print(separator) print(format_table_row(headers, widths)) print(separator) @@ -259,44 +223,22 @@ def run_benchmarks(verbose: bool = False): """Run benchmarks for a set of common shapes.""" print("===== STARTING FP8 GEMM BENCHMARK =====") - # Make sure we're using the GPU if not torch.cuda.is_available(): print("CUDA not available! Tests require GPU.") return - # Print system information print(f"PyTorch version: {torch.__version__}") print(f"CUDA version: {torch.version.cuda}") print(f"Triton version: {triton.__version__}") print(f"Using device: {torch.cuda.get_device_name()}") - # Enable TF32 for better performance torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True - # Set seeds for reproducibility torch.manual_seed(42) torch.cuda.manual_seed(42) # Define benchmark shapes (m, n, k) - shapes = [ - (8, 4096, 7168), - (8, 7168, 18432), - (8, 18432, 7168), - (64, 4096, 7168), - (64, 7168, 18432), - (64, 18432, 7168), - (64, 24576, 1536), - (64, 32768, 512), - (64, 7168, 16384), - (128, 4096, 7168), - (128, 7168, 18432), - (128, 18432, 7168), - (1024, 4096, 7168), - (1024, 18432, 7168), - (2048, 4096, 7168), - (4096, 4096, 7168), - ] shapes = [ # (64, 2112, 7168), (64, 24576, 1536), @@ -323,7 +265,6 @@ def run_benchmarks(verbose: bool = False): result = benchmark_shape(m, n, k, verbose=verbose) all_results.append(result) - # Print results in a nicely formatted table print("\n===== PERFORMANCE COMPARISON =====") # Print DeepGEMM table @@ -332,38 +273,50 @@ def run_benchmarks(verbose: bool = False): for result in all_results: shape = result["shape"] impl_data = result["implementations"]["DeepGEMM"] - deepgemm_rows.append([ - shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", - f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}" - ]) + deepgemm_rows.append( + [ + shape["m"], + shape["n"], + shape["k"], + f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", + f"{impl_data['gb_s']:.1f}", + ] + ) - print_table(deepgemm_headers, - deepgemm_rows, - title="DeepGEMM Implementation:") + print_table(deepgemm_headers, deepgemm_rows, title="DeepGEMM Implementation:") # Print vLLM Triton table - triton_headers = [ - "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM" - ] + triton_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"] triton_rows = [] for result in all_results: shape = result["shape"] impl_data = result["implementations"]["vLLM Triton"] speedup = impl_data.get("speedup_vs_deepgemm", 1.0) - triton_rows.append([ - shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", - f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", - format_speedup(speedup) - ]) + triton_rows.append( + [ + shape["m"], + shape["n"], + shape["k"], + f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", + f"{impl_data['gb_s']:.1f}", + format_speedup(speedup), + ] + ) - print_table(triton_headers, - triton_rows, - title="vLLM Triton Implementation:") + print_table(triton_headers, triton_rows, title="vLLM Triton Implementation:") # Print vLLM CUTLASS table cutlass_headers = [ - "m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM", - "vs Triton" + "m", + "n", + "k", + "Time (μs)", + "TFLOPS", + "GB/s", + "vs DeepGEMM", + "vs Triton", ] cutlass_rows = [] for result in all_results: @@ -371,28 +324,27 @@ def run_benchmarks(verbose: bool = False): impl_data = result["implementations"]["vLLM CUTLASS"] vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0) vs_triton = impl_data.get("speedup_vs_triton", 1.0) - cutlass_rows.append([ - shape["m"], shape["n"], shape["k"], f"{impl_data['time_us']:.1f}", - f"{impl_data['tflops']:.1f}", f"{impl_data['gb_s']:.1f}", - format_speedup(vs_deepgemm), - format_speedup(vs_triton) - ]) + cutlass_rows.append( + [ + shape["m"], + shape["n"], + shape["k"], + f"{impl_data['time_us']:.1f}", + f"{impl_data['tflops']:.1f}", + f"{impl_data['gb_s']:.1f}", + format_speedup(vs_deepgemm), + format_speedup(vs_triton), + ] + ) - print_table(cutlass_headers, - cutlass_rows, - title="vLLM CUTLASS Implementation:") + print_table(cutlass_headers, cutlass_rows, title="vLLM CUTLASS Implementation:") # Calculate and print averages print("\n===== AVERAGE PERFORMANCE =====") implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"] avg_metrics = { - impl: { - "tflops": 0, - "gb_s": 0, - "time_ms": 0 - } - for impl in implementations + impl: {"tflops": 0, "gb_s": 0, "time_ms": 0} for impl in implementations } for result in all_results: @@ -410,9 +362,9 @@ def run_benchmarks(verbose: bool = False): avg_tflops = avg_metrics[impl]["tflops"] / num_shapes avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes avg_time = avg_metrics[impl]["time_ms"] / num_shapes - avg_rows.append([ - impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}" - ]) + avg_rows.append( + [impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"] + ) print_table(avg_headers, avg_rows) @@ -420,21 +372,19 @@ def run_benchmarks(verbose: bool = False): avg_speedups = { "DeepGEMM vs vLLM Triton": 0, "DeepGEMM vs vLLM CUTLASS": 0, - "vLLM CUTLASS vs vLLM Triton": 0 + "vLLM CUTLASS vs vLLM Triton": 0, } for result in all_results: deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"] vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"] - vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"][ - "time_ms"] + vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"]["time_ms"] - avg_speedups[ - "DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time - avg_speedups[ - "DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time - avg_speedups[ - "vLLM CUTLASS vs vLLM Triton"] += vllm_triton_time / vllm_cutlass_time + avg_speedups["DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time + avg_speedups["DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time + avg_speedups["vLLM CUTLASS vs vLLM Triton"] += ( + vllm_triton_time / vllm_cutlass_time + ) print("\n===== AVERAGE SPEEDUPS =====") speedup_headers = ["Comparison", "Speedup"] @@ -446,14 +396,12 @@ def run_benchmarks(verbose: bool = False): print_table(speedup_headers, speedup_rows) - # Average accuracy comparison print("\n===== ACCURACY COMPARISON =====") avg_diff = {impl: 0 for impl in implementations} for result in all_results: for impl in implementations: - avg_diff[impl] += result["implementations"][impl]["diff"][ - "Reference"] + avg_diff[impl] += result["implementations"][impl]["diff"]["Reference"] diff_headers = ["Implementation", "Avg Diff vs Reference"] diff_rows = [] diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 2d7cf39a8cca..b418a22a48ec 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -66,25 +66,6 @@ def next_power_of_2(x): return 2**math.ceil(math.log2(x)) -def per_block_cast_to_fp8( - x: torch.Tensor, - block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros( - (deep_gemm.ceil_div(m, 128) * 128, - deep_gemm.ceil_div(n, block_size_n) * block_size_n), - dtype=x.dtype, - device=x.device) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() - scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) - return x_scaled_sub, scales - - def make_block_quant_fp8_weights( e: int, n: int, @@ -125,8 +106,8 @@ def make_block_quant_fp8_weights( assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] for i in range(e): - w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) - w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + w1[i], w1_s[i] = deep_gemm.utils.math.per_block_cast_to_fp8(w1_bf16[i]) + w2[i], w2_s[i] = deep_gemm.utils.math.per_block_cast_to_fp8(w2_bf16[i]) return w1, w2, w1_s, w2_s diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index eec59573792d..ca9f1d39af5e 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -18,7 +18,8 @@ from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8, w8a8_block_fp8_matmul) + get_col_major_tma_aligned_tensor, per_token_group_quant_fp8, + w8a8_block_fp8_matmul) from vllm.platforms import current_platform dg_available = False @@ -263,25 +264,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): assert rel_diff < 0.03 -def per_block_cast_to_fp8( - x: torch.Tensor, - block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 - m, n = x.shape - x_padded = torch.zeros( - (deep_gemm.ceil_div(m, 128) * 128, - deep_gemm.ceil_div(n, block_size_n) * block_size_n), - dtype=x.dtype, - device=x.device) - x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) - x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) - x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) - x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() - scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) - return x_scaled_sub, scales - - @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) @@ -299,10 +281,8 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max - _, block_k = block_size[0], block_size[1] - - A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_k) - B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32) + A_fp8, As_fp8 = deep_gemm.utils.math.per_token_cast_to_fp8(A_fp32) + B_fp8, Bs_fp8 = deep_gemm.utils.math.per_block_cast_to_fp8(B_fp32) As = As_fp8.to(torch.float32) Bs = Bs_fp8.to(torch.float32) @@ -310,15 +290,12 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): ref_out = native_w8a8_block_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) - # Transpose earlier so that the testing will not trigger transposing kernels - As_fp8 = deep_gemm.get_col_major_tma_aligned_tensor(As_fp8) - out = torch.zeros((M, N), device='cuda', dtype=out_dtype) assert As_fp8.shape == (M, (K + 127) // 128), f"{As_fp8.shape} != {(M, (K + 127) // 128)}" - deep_gemm.gemm_fp8_fp8_bf16_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out) + deep_gemm.fp8_gemm_nt((A_fp8, As_fp8), (B_fp8, Bs_fp8), out) rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / @@ -382,16 +359,16 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, dtype=torch.bfloat16, device=a.device) - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous((a_q, a_s), (w1, w1_s), - inter_out, m_indices) + deep_gemm.m_grouped_fp8_gemm_nt_contiguous((a_q, a_s), (w1, w1_s), + inter_out, m_indices) act_out = SiluAndMul().forward_native(inter_out) act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device) - deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (act_out_q, act_out_s), (w2, w2_s), out, m_indices) + deep_gemm.m_grouped_fp8_gemm_nt_contiguous((act_out_q, act_out_s), + (w2, w2_s), out, m_indices) final_out = _moe_unpermute(out, inv_perm, topk, K, topk_weight) @@ -441,15 +418,15 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): w1_s = torch.empty((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) w2_s = torch.empty((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) - w1_s = deep_gemm.get_col_major_tma_aligned_tensor(w1_s).contiguous() - w2_s = deep_gemm.get_col_major_tma_aligned_tensor(w2_s).contiguous() + w1_s = get_col_major_tma_aligned_tensor(w1_s).contiguous() + w2_s = get_col_major_tma_aligned_tensor(w2_s).contiguous() assert w1_s.shape == (E, (2 * N + 127) // 128, (K + 127) // 128) assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] for i in range(E): - w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) - w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + w1[i], w1_s[i] = deep_gemm.utils.math.per_block_cast_to_fp8(w1_bf16[i]) + w2[i], w2_s[i] = deep_gemm.utils.math.per_block_cast_to_fp8(w2_bf16[i]) # Set the context to avoid lots of warning spam. with set_current_vllm_config(vllm_config): @@ -460,14 +437,10 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) - topk_weights, topk_ids, token_expert_indices = fused_topk( - a, score.float(), topk, False) + topk_weights, topk_ids, _ = fused_topk(a, score.float(), topk, False) out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) - #print(f"{out.sum()=}") - #print(f"{ref_out.sum()=}") - rel_diff = (torch.mean( torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) / torch.mean(torch.abs(ref_out.to(torch.float32)))) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 70836879d17c..fd313b828266 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -266,19 +266,16 @@ def apply( # for the M expectation of each batch, correctly setting this value # may lead to better performance. expected_m = max_num_tokens - - dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a1q, a1q_scale), - (w1, w1_scale), - out=workspace1, - masked_m=expert_num_tokens, - expected_m=expected_m) + dg.fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale), + out=workspace1, + masked_m=expert_num_tokens, + expected_m=expected_m) assert expert_num_tokens is not None a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1, expert_num_tokens) - dg.m_grouped_gemm_fp8_fp8_bf16_nt_masked((a2q, a2q_scale), - (w2, w2_scale), - out=output, - masked_m=expert_num_tokens, - expected_m=expected_m) + dg.fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), + out=output, + masked_m=expert_num_tokens, + expected_m=expected_m) diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index b4473b907381..f349d2802de1 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -144,8 +144,8 @@ def apply( (M_sum, N // 2)) mm2_out = _resize_cache(workspace2, (M_sum, K)) - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids) + dg.m_grouped_fp8_gemm_nt_contiguous((a1q, a1q_scale), (w1, w1_scale), + mm1_out, expert_ids) self.activation(activation, act_out, mm1_out.view(-1, N)) @@ -154,9 +154,8 @@ def apply( self.block_shape[1], column_major_scales=True, out_q=quant_out) - - dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( - (a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids) + dg.m_grouped_fp8_gemm_nt_contiguous((a2q, a2q_scale), (w2, w2_scale), + mm2_out, expert_ids) torch.index_select(mm2_out, 0, inv_perm, out=output) diff --git a/vllm/model_executor/layers/quantization/deepgemm.py b/vllm/model_executor/layers/quantization/deepgemm.py index 1d40f4915a1b..304d9af9c921 100644 --- a/vllm/model_executor/layers/quantization/deepgemm.py +++ b/vllm/model_executor/layers/quantization/deepgemm.py @@ -58,7 +58,7 @@ def w8a8_block_fp8_matmul_deepgemm( output_dtype) # Deepgemm only supports output tensor type as bfloat16 assert C.dtype == torch.bfloat16 - deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) + deep_gemm.fp8_gemm_nt((A, As), (B, Bs), C) return C diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 754650ebeffb..a4ba2783a0a9 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -114,6 +114,10 @@ def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor): and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 def apply_w8a8_block_fp8_linear( @@ -158,9 +162,6 @@ def apply_w8a8_block_fp8_linear( if current_platform.is_cuda(): if current_platform.has_device_capability(100): - def ceil_div(x: int, y: int) -> int: - return (x + y - 1) // y - use_cutlass = cutlass_block_fp8_supported and ( ceil_div(weight.shape[0], 128) == weight_scale.shape[0] and ceil_div(weight.shape[1], 128) == weight_scale.shape[1]) @@ -655,3 +656,67 @@ def grid(META): ) return C + + +# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947 +# TODO(wentao): remove this function when DeepGEMM exposes this function +def get_tma_aligned_size(x: int, element_size: int) -> int: + """ + Global memory address of TMA must be 16-byte aligned. + Since we use column-major layout for the LHS scaling tensor, + the M-axis of the LHS scaling tensor needs to be padded to a multiple of + 16 bytes. + + Arguments: + x: original M-axis shape of the LHS scaling tensor. + element_size: element size of the LHS scaling tensor. + + Returns: + M-axis shape of the LHS scaling tensor after padding. + """ + tma_alignment_bytes = 16 + assert tma_alignment_bytes % element_size == 0 + alignment = tma_alignment_bytes // element_size + return ceil_div(x, alignment) * alignment + + +# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947 +# TODO(wentao): remove this function when DeepGEMM exposes this function +def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor: + """ + Returns TMA-aligned transposed format of the input tensor. `torch.transpose` + will be called if necessary. + If the input tensor is already column-major layout and 16-byte aligned along + the M axis (thus meets the requirement of LHS scaling tensor in + DeepGEMM), this function will do nothing. + + Arguments: + x: usually the LHS scaling tensor in GEMM. + + Returns: + The LHS scaling tensor of TMA-aligned transposed format. + """ + # NOTES: for the extreme performance, you may rewrite/fuse this function in + # CUDA + assert x.dim() in (2, 3) + remove_dim = False + m, n = x.shape[-2], x.shape[-1] + aligned_m = get_tma_aligned_size(m, x.element_size()) + if x.dim() == 2: + if x.stride(0) == 1 and x.stride(1) == aligned_m: + return x + x, remove_dim = x.unsqueeze(0), True + + b = x.shape[0] + + # The last kernel gives a column-major TMA aligned layout + if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride( + 2) == aligned_m: + return x.squeeze(0) if remove_dim else x + + # Normal layout requires transposing + aligned_x = torch.transpose( + torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2) + aligned_x[:, :m, :] = x + aligned_x = aligned_x[:, :m, :] + return aligned_x.squeeze(0) if remove_dim else aligned_x