Skip to content

[Perf] FP4 MoE on B200 (latency) #1734

@czhu-cohere

Description

@czhu-cohere

Hello! We are looking into FP4 MoE perf on B200, mainly following this recipe and setup for llama 4 scout.

There's an env variable to control which flashinfer backend to select for the moe kernels:

VLLM_FLASHINFER_MOE_BACKEND=throughput  # default, cutlass backend
VLLM_FLASHINFER_MOE_BACKEND=latency  # trt-llm backend

From our benchmarking (consistent with the naming), at high concurrencies the CUTLASS backend is better, however there is a significant difference in decoding TPOT at lower batch sizes. Concretely for TP=1 BS=1 2048 input/100 output
on llama 4 scout

TRT backend
Mean TPOT (ms):                          7.60 
CUTLASS backend
Mean TPOT (ms):                          12.84

Which is a pretty significant difference. From profiling decoding step it seems most of the extra latency comes from a few ops which could potentially be fixed:

  1. fp4 quantize
    Both moe backends use this op to quantize the activations
void tensorrt_llm::kernels::quantize_with_block_size<(tensorrt_llm::BlockScaleQuantizationType)0, __nv_bfloat16, 16, false>(int, int, int, int, __nv_bfloat16 const*, float const*, unsigned int*, unsigned int*, flashinfer::QuantizationSFLayout)

It seems the diff is that CUTLASS one needs the 128x4 layout which is not performant at low batch sizes, so this op takes ~75µs instead of expected <5µs. Here is a simple bench to repro

# %%
from vllm.utils.flashinfer import fp4_quantize
import torch
from triton.testing import do_bench_cudagraph
from matplotlib import pyplot as plt
# %%
A_scale = torch.randn(16).cuda().float()
bsz = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
times = []
for bs in bsz:
    A = torch.randn(bs, 5120).cuda().to(torch.bfloat16)
    t = do_bench_cudagraph(lambda: fp4_quantize(A, A_scale, is_sf_swizzled_layout=True))
    times.append(t)
# %%
plt.plot(bsz, times)
plt.xscale("log", base=2) 
plt.xticks(bsz, labels=[str(v) for v in bsz])
plt.xlabel("Batch Size")
plt.ylabel("Time (ms)")
plt.title("FP4 Quantization Time vs Batch Size")
plt.grid(True, which="both", linestyle="--", alpha=0.7)
plt.show()
Image
  1. moe prep
    One of the moe prep kernels
void tensorrt_llm::kernels::cutlass_kernels::fusedBuildExpertMapsSortFirstTokenKernel<32, 1, 5>(int const*, int*, int*, long*, long, int, int, int, int)

has ~25µs latency. In another (larger) moe model the kernel with different template args <32, 8, 8> is used which has much lower latency (~4µs).

Other potential improvements based on the profiling:

  • grouped gemm closer to memory bound SoL
  • fuse activation to FC1 epilogue

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions