|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.utils.benchmark as benchmark |
| 5 | +from benchmark_shapes import WEIGHT_SHAPES_MOE |
| 6 | + |
| 7 | +from vllm import _custom_ops as ops |
| 8 | +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config |
| 9 | +from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8, |
| 10 | + fused_experts, |
| 11 | + fused_topk) |
| 12 | +from vllm.utils import FlexibleArgumentParser |
| 13 | + |
| 14 | +DEFAULT_MODELS = [ |
| 15 | + "nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite", |
| 16 | + "ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m" |
| 17 | +] |
| 18 | +DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512] |
| 19 | +DEFAULT_TP_SIZES = [1] |
| 20 | + |
| 21 | +PER_ACT_TOKEN_OPTS = [False] |
| 22 | +PER_OUT_CH_OPTS = [False] |
| 23 | + |
| 24 | + |
| 25 | +def to_fp8(tensor: torch.Tensor): |
| 26 | + finfo = torch.finfo(torch.float8_e4m3fn) |
| 27 | + return torch.round(tensor.clamp( |
| 28 | + min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) |
| 29 | + |
| 30 | + |
| 31 | +def bench_run(results: list[benchmark.Measurement], model: str, |
| 32 | + num_experts: int, topk: int, per_act_token: bool, |
| 33 | + per_out_ch: bool, mkn: tuple[int, int, int]): |
| 34 | + label = "Quant Matmul" |
| 35 | + |
| 36 | + sub_label = ( |
| 37 | + "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, " |
| 38 | + "MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch, |
| 39 | + mkn)) |
| 40 | + |
| 41 | + print(f"Testing: {sub_label}") |
| 42 | + |
| 43 | + (m, k, n) = mkn |
| 44 | + |
| 45 | + dtype = torch.half |
| 46 | + |
| 47 | + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 |
| 48 | + w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10 |
| 49 | + w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10 |
| 50 | + |
| 51 | + _, a_scale = ops.scaled_fp8_quant(a) |
| 52 | + |
| 53 | + w1_q = torch.empty((num_experts, 2 * n, k), |
| 54 | + device="cuda", |
| 55 | + dtype=torch.float8_e4m3fn) |
| 56 | + w2_q = torch.empty((num_experts, k, n), |
| 57 | + device="cuda", |
| 58 | + dtype=torch.float8_e4m3fn) |
| 59 | + w1_scale = torch.empty((num_experts, 1, 1), |
| 60 | + device="cuda", |
| 61 | + dtype=torch.float32) |
| 62 | + w2_scale = torch.empty((num_experts, 1, 1), |
| 63 | + device="cuda", |
| 64 | + dtype=torch.float32) |
| 65 | + |
| 66 | + ab_strides1 = torch.full((num_experts, ), |
| 67 | + k, |
| 68 | + device="cuda", |
| 69 | + dtype=torch.int64) |
| 70 | + c_strides1 = torch.full((num_experts, ), |
| 71 | + 2 * n, |
| 72 | + device="cuda", |
| 73 | + dtype=torch.int64) |
| 74 | + ab_strides2 = torch.full((num_experts, ), |
| 75 | + n, |
| 76 | + device="cuda", |
| 77 | + dtype=torch.int64) |
| 78 | + c_strides2 = torch.full((num_experts, ), |
| 79 | + k, |
| 80 | + device="cuda", |
| 81 | + dtype=torch.int64) |
| 82 | + |
| 83 | + for expert in range(num_experts): |
| 84 | + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) |
| 85 | + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) |
| 86 | + w1_q_notransp = w1_q.clone() |
| 87 | + w2_q_notransp = w2_q.clone() |
| 88 | + w1_q = w1_q.transpose(1, 2) |
| 89 | + w2_q = w2_q.transpose(1, 2) |
| 90 | + |
| 91 | + score = torch.randn((m, num_experts), device="cuda", dtype=dtype) |
| 92 | + |
| 93 | + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) |
| 94 | + |
| 95 | + def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, |
| 96 | + topk_weights: torch.Tensor, topk_ids: torch.Tensor, |
| 97 | + w1_scale: torch.Tensor, w2_scale: torch.Tensor, |
| 98 | + a_scale: torch.Tensor, num_repeats: int): |
| 99 | + for _ in range(num_repeats): |
| 100 | + fused_experts(a, |
| 101 | + w1, |
| 102 | + w2, |
| 103 | + topk_weights, |
| 104 | + topk_ids, |
| 105 | + use_fp8_w8a8=True, |
| 106 | + w1_scale=w1_scale, |
| 107 | + w2_scale=w2_scale, |
| 108 | + a1_scale=a_scale) |
| 109 | + |
| 110 | + def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor, |
| 111 | + w1: torch.Tensor, w2: torch.Tensor, |
| 112 | + w1_scale: torch.Tensor, w2_scale: torch.Tensor, |
| 113 | + topk_weights: torch.Tensor, topk_ids: torch.Tensor, |
| 114 | + ab_strides1: torch.Tensor, c_strides1: torch.Tensor, |
| 115 | + ab_strides2: torch.Tensor, c_strides2: torch.Tensor, |
| 116 | + num_repeats: int): |
| 117 | + for _ in range(num_repeats): |
| 118 | + cutlass_moe_fp8(a, |
| 119 | + w1, |
| 120 | + w2, |
| 121 | + w1_scale, |
| 122 | + w2_scale, |
| 123 | + topk_weights, |
| 124 | + topk_ids, |
| 125 | + ab_strides1, |
| 126 | + c_strides1, |
| 127 | + ab_strides2, |
| 128 | + c_strides2, |
| 129 | + a1_scale=a_scale) |
| 130 | + |
| 131 | + def run_cutlass_from_graph( |
| 132 | + a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, |
| 133 | + w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, |
| 134 | + topk_weights: torch.Tensor, topk_ids: torch.Tensor, |
| 135 | + ab_strides1: torch.Tensor, c_strides1: torch.Tensor, |
| 136 | + ab_strides2: torch.Tensor, c_strides2: torch.Tensor): |
| 137 | + with set_current_vllm_config( |
| 138 | + VllmConfig(parallel_config=ParallelConfig( |
| 139 | + pipeline_parallel_size=1))): |
| 140 | + return cutlass_moe_fp8(a, |
| 141 | + w1_q, |
| 142 | + w2_q, |
| 143 | + w1_scale, |
| 144 | + w2_scale, |
| 145 | + topk_weights, |
| 146 | + topk_ids, |
| 147 | + ab_strides1, |
| 148 | + c_strides1, |
| 149 | + ab_strides2, |
| 150 | + c_strides2, |
| 151 | + a1_scale=a_scale) |
| 152 | + |
| 153 | + def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor, |
| 154 | + w2: torch.Tensor, topk_weights: torch.Tensor, |
| 155 | + topk_ids: torch.Tensor, w1_scale: torch.Tensor, |
| 156 | + w2_scale: torch.Tensor, a_scale: torch.Tensor): |
| 157 | + with set_current_vllm_config( |
| 158 | + VllmConfig(parallel_config=ParallelConfig( |
| 159 | + pipeline_parallel_size=1))): |
| 160 | + return fused_experts(a, |
| 161 | + w1, |
| 162 | + w2, |
| 163 | + topk_weights, |
| 164 | + topk_ids, |
| 165 | + use_fp8_w8a8=True, |
| 166 | + w1_scale=w1_scale, |
| 167 | + w2_scale=w2_scale, |
| 168 | + a1_scale=a_scale) |
| 169 | + |
| 170 | + def replay_graph(graph, num_repeats): |
| 171 | + for _ in range(num_repeats): |
| 172 | + graph.replay() |
| 173 | + torch.cuda.synchronize() |
| 174 | + |
| 175 | + cutlass_stream = torch.cuda.Stream() |
| 176 | + cutlass_graph = torch.cuda.CUDAGraph() |
| 177 | + with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): |
| 178 | + run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, |
| 179 | + topk_weights, topk_ids, ab_strides1, c_strides1, |
| 180 | + ab_strides2, c_strides2) |
| 181 | + torch.cuda.synchronize() |
| 182 | + |
| 183 | + triton_stream = torch.cuda.Stream() |
| 184 | + triton_graph = torch.cuda.CUDAGraph() |
| 185 | + with torch.cuda.graph(triton_graph, stream=triton_stream): |
| 186 | + run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights, |
| 187 | + topk_ids, w1_scale, w2_scale, a_scale) |
| 188 | + torch.cuda.synchronize() |
| 189 | + |
| 190 | + min_run_time = 5 |
| 191 | + num_warmup = 5 |
| 192 | + num_runs = 25 |
| 193 | + |
| 194 | + globals = { |
| 195 | + # Baseline params |
| 196 | + "w1": w1, |
| 197 | + "w2": w2, |
| 198 | + "score": score, |
| 199 | + "topk": topk, |
| 200 | + "w1_q_notransp": w1_q_notransp, |
| 201 | + "w2_q_notransp": w2_q_notransp, |
| 202 | + # Cutlass params |
| 203 | + "a_scale": a_scale, |
| 204 | + "w1_q": w1_q, |
| 205 | + "w2_q": w2_q, |
| 206 | + "w1_scale": w1_scale, |
| 207 | + "w2_scale": w2_scale, |
| 208 | + "ab_strides1": ab_strides1, |
| 209 | + "c_strides1": c_strides1, |
| 210 | + "ab_strides2": ab_strides2, |
| 211 | + "c_strides2": c_strides2, |
| 212 | + # cuda graph params |
| 213 | + "cutlass_graph": cutlass_graph, |
| 214 | + "triton_graph": triton_graph, |
| 215 | + # Gen params |
| 216 | + "a": a, |
| 217 | + "topk_weights": topk_weights, |
| 218 | + "topk_ids": topk_ids, |
| 219 | + "num_runs": num_runs, |
| 220 | + # Kernels |
| 221 | + "run_triton_moe": run_triton_moe, |
| 222 | + "run_cutlass_moe": run_cutlass_moe, |
| 223 | + "replay_graph": replay_graph, |
| 224 | + } |
| 225 | + |
| 226 | + # Warmup |
| 227 | + run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, |
| 228 | + w1_scale, w2_scale, a_scale, num_warmup) |
| 229 | + |
| 230 | + results.append( |
| 231 | + benchmark.Timer( |
| 232 | + stmt= |
| 233 | + "run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 |
| 234 | + globals=globals, |
| 235 | + label=label, |
| 236 | + sub_label=sub_label, |
| 237 | + description="triton_moe", |
| 238 | + ).blocked_autorange(min_run_time=min_run_time)) |
| 239 | + |
| 240 | + # Warmup |
| 241 | + replay_graph(triton_graph, num_warmup) |
| 242 | + |
| 243 | + results.append( |
| 244 | + benchmark.Timer( |
| 245 | + stmt="replay_graph(triton_graph, num_runs)", |
| 246 | + globals=globals, |
| 247 | + label=label, |
| 248 | + sub_label=sub_label, |
| 249 | + description="triton_moe_cuda_graphs", |
| 250 | + ).blocked_autorange(min_run_time=min_run_time)) |
| 251 | + |
| 252 | + # Warmup |
| 253 | + run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, |
| 254 | + topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, |
| 255 | + num_warmup) |
| 256 | + |
| 257 | + results.append( |
| 258 | + benchmark.Timer( |
| 259 | + stmt= |
| 260 | + "run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 |
| 261 | + globals=globals, |
| 262 | + label=label, |
| 263 | + sub_label=sub_label, |
| 264 | + description="grouped_gemm_moe", |
| 265 | + ).blocked_autorange(min_run_time=min_run_time)) |
| 266 | + |
| 267 | + # Warmup |
| 268 | + replay_graph(cutlass_graph, num_warmup) |
| 269 | + |
| 270 | + results.append( |
| 271 | + benchmark.Timer( |
| 272 | + stmt="replay_graph(cutlass_graph, num_runs)", |
| 273 | + globals=globals, |
| 274 | + label=label, |
| 275 | + sub_label=sub_label, |
| 276 | + description="grouped_gemm_moe_cuda_graphs", |
| 277 | + ).blocked_autorange(min_run_time=min_run_time)) |
| 278 | + |
| 279 | + |
| 280 | +def main(args): |
| 281 | + print("Benchmarking models:") |
| 282 | + for i, model in enumerate(args.models): |
| 283 | + print(f"[{i}] {model}") |
| 284 | + |
| 285 | + results: list[benchmark.Measurement] = [] |
| 286 | + |
| 287 | + for model in args.models: |
| 288 | + for tp in args.tp_sizes: |
| 289 | + for layer in WEIGHT_SHAPES_MOE[model]: |
| 290 | + num_experts = layer[0] |
| 291 | + topk = layer[1] |
| 292 | + size_k = layer[2] |
| 293 | + size_n = layer[3] // tp |
| 294 | + |
| 295 | + if len(args.limit_k) > 0 and size_k not in args.limit_k: |
| 296 | + continue |
| 297 | + |
| 298 | + if len(args.limit_n) > 0 and size_n not in args.limit_n: |
| 299 | + continue |
| 300 | + |
| 301 | + for per_act_token in PER_ACT_TOKEN_OPTS: |
| 302 | + for per_out_ch in PER_OUT_CH_OPTS: |
| 303 | + for size_m in DEFAULT_BATCH_SIZES: |
| 304 | + mkn = (size_m, size_k, size_n) |
| 305 | + bench_run(results, model, num_experts, topk, |
| 306 | + per_act_token, per_out_ch, mkn) |
| 307 | + |
| 308 | + compare = benchmark.Compare(results) |
| 309 | + compare.print() |
| 310 | + |
| 311 | + |
| 312 | +if __name__ == "__main__": |
| 313 | + parser = FlexibleArgumentParser( |
| 314 | + description="Benchmark Marlin across specified models/shapes/batches") |
| 315 | + parser.add_argument( |
| 316 | + "--models", |
| 317 | + nargs="+", |
| 318 | + type=str, |
| 319 | + default=DEFAULT_MODELS, |
| 320 | + choices=WEIGHT_SHAPES_MOE.keys(), |
| 321 | + ) |
| 322 | + parser.add_argument("--tp-sizes", |
| 323 | + nargs="+", |
| 324 | + type=int, |
| 325 | + default=DEFAULT_TP_SIZES) |
| 326 | + parser.add_argument("--batch-sizes", |
| 327 | + nargs="+", |
| 328 | + type=int, |
| 329 | + default=DEFAULT_BATCH_SIZES) |
| 330 | + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) |
| 331 | + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) |
| 332 | + parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) |
| 333 | + parser.add_argument("--limit-per-act-token", |
| 334 | + nargs="+", |
| 335 | + type=int, |
| 336 | + default=[]) |
| 337 | + parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) |
| 338 | + |
| 339 | + args = parser.parse_args() |
| 340 | + main(args) |
0 commit comments