|
| 1 | +import argparse |
| 2 | +import copy |
| 3 | +import itertools |
| 4 | + |
| 5 | +import torch |
| 6 | +import triton |
| 7 | +from sgl_kernel import fp8_blockwise_scaled_mm |
| 8 | +from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm |
| 9 | + |
| 10 | + |
| 11 | +def get_weight_shapes(args): |
| 12 | + models_tps = list(itertools.product(args.models, args.tp_sizes)) |
| 13 | + # NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model. |
| 14 | + # cannot TP |
| 15 | + total = [ |
| 16 | + # (512 + 64, 7168), # this weight is not supported by current kernel |
| 17 | + ((128 + 64) * 128, 7168), |
| 18 | + (128 * (128 + 128), 512), |
| 19 | + (7168, 16384), |
| 20 | + (7168, 18432), |
| 21 | + ] |
| 22 | + # N can TP |
| 23 | + n_tp = [ |
| 24 | + (18432 * 2, 7168), |
| 25 | + ((128 + 64) * 128, 7168), |
| 26 | + (128 * (128 + 128), 512), |
| 27 | + (24576, 1536), |
| 28 | + (4096, 7168), |
| 29 | + ] |
| 30 | + # K can TP |
| 31 | + k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)] |
| 32 | + # only support Deepseek-V3 |
| 33 | + SUPPORT_MODEL = ["deepseek-ai/DeepSeek-V3"] |
| 34 | + |
| 35 | + weight_shapes = [] |
| 36 | + for model, tp_size in models_tps: |
| 37 | + assert model in SUPPORT_MODEL |
| 38 | + for t in total: |
| 39 | + new_t = [t[0], t[1], model] |
| 40 | + weight_shapes.append(new_t) |
| 41 | + for n_t in n_tp: |
| 42 | + new_t = [n_t[0] // tp_size, n_t[1], model] |
| 43 | + weight_shapes.append(new_t) |
| 44 | + for k_t in k_tp: |
| 45 | + new_t = [k_t[0], k_t[1] // tp_size, model] |
| 46 | + weight_shapes.append(new_t) |
| 47 | + return weight_shapes |
| 48 | + |
| 49 | + |
| 50 | +def cdiv(a: int, b: int) -> int: |
| 51 | + """Ceiling division.""" |
| 52 | + return -(a // -b) |
| 53 | + |
| 54 | + |
| 55 | +def scale_shape(shape, group_shape): |
| 56 | + assert len(shape) == len(group_shape) |
| 57 | + return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape))) |
| 58 | + |
| 59 | + |
| 60 | +@triton.testing.perf_report( |
| 61 | + triton.testing.Benchmark( |
| 62 | + x_names=["batch_size"], |
| 63 | + x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048], |
| 64 | + x_log=False, |
| 65 | + line_arg="provider", |
| 66 | + line_vals=["vllm", "sgl-kernel"], |
| 67 | + line_names=["vllm fp8 blockwise gemm", "sgl-kernel fp8 blockwise gemm"], |
| 68 | + styles=[("blue", "-"), ("orange", "-")], |
| 69 | + ylabel="GB/s", |
| 70 | + plot_name="fp8 blockwise scaled matmul", |
| 71 | + args={}, |
| 72 | + ) |
| 73 | +) |
| 74 | +def benchmark(batch_size, provider, N, K): |
| 75 | + M = batch_size |
| 76 | + fp8_info = torch.finfo(torch.float8_e4m3fn) |
| 77 | + fp8_max, fp8_min = fp8_info.max, fp8_info.min |
| 78 | + |
| 79 | + a_fp32 = (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max |
| 80 | + a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) |
| 81 | + |
| 82 | + b_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max |
| 83 | + b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t() |
| 84 | + |
| 85 | + scale_a_group_shape = (1, 128) |
| 86 | + scale_b_group_shape = (128, 128) |
| 87 | + scale_a_shape = scale_shape(a_fp8.shape, scale_a_group_shape) |
| 88 | + scale_b_shape = scale_shape(b_fp8.shape, scale_b_group_shape) |
| 89 | + |
| 90 | + scale_a = torch.randn(scale_a_shape, device="cuda", dtype=torch.float32) |
| 91 | + scale_b = torch.randn(scale_b_shape, device="cuda", dtype=torch.float32) |
| 92 | + scale_a = scale_a.t().contiguous().t() |
| 93 | + scale_b = scale_b.t().contiguous().t() |
| 94 | + |
| 95 | + quantiles = [0.5, 0.2, 0.8] |
| 96 | + if provider == "sgl-kernel": |
| 97 | + ms, min_ms, max_ms = triton.testing.do_bench( |
| 98 | + lambda: fp8_blockwise_scaled_mm( |
| 99 | + a_fp8, b_fp8, scale_a, scale_b, torch.float16 |
| 100 | + ), |
| 101 | + quantiles=quantiles, |
| 102 | + ) |
| 103 | + if provider == "vllm": |
| 104 | + ms, min_ms, max_ms = triton.testing.do_bench( |
| 105 | + lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16), |
| 106 | + quantiles=quantiles, |
| 107 | + ) |
| 108 | + gbps = ( |
| 109 | + lambda ms: ( |
| 110 | + (2 * M * N * K - M * N) * a_fp8.element_size() |
| 111 | + + (3 * M * N) * scale_a.element_size() |
| 112 | + ) |
| 113 | + * 1e-9 |
| 114 | + / (ms * 1e-3) |
| 115 | + ) |
| 116 | + return gbps(ms), gbps(max_ms), gbps(min_ms) |
| 117 | + |
| 118 | + |
| 119 | +if __name__ == "__main__": |
| 120 | + parser = argparse.ArgumentParser() |
| 121 | + parser.add_argument( |
| 122 | + "--models", |
| 123 | + nargs="+", |
| 124 | + type=str, |
| 125 | + default=["deepseek-ai/DeepSeek-V3"], |
| 126 | + help="List of models to benchmark", |
| 127 | + ) |
| 128 | + parser.add_argument( |
| 129 | + "--tp-sizes", |
| 130 | + nargs="+", |
| 131 | + type=int, |
| 132 | + default=[1], |
| 133 | + help="List of tensor parallel sizes", |
| 134 | + ) |
| 135 | + args = parser.parse_args() |
| 136 | + |
| 137 | + NK_model_names = get_weight_shapes(args) |
| 138 | + for N, K, model_name in NK_model_names: |
| 139 | + print(f"{model_name} N={N} K={K}: ") |
| 140 | + benchmark.run( |
| 141 | + print_data=True, |
| 142 | + show_plots=True, |
| 143 | + save_path="bench_fp8_blockwise_res", |
| 144 | + N=N, |
| 145 | + K=K, |
| 146 | + ) |
| 147 | + |
| 148 | + print("Benchmark finished!") |
0 commit comments