-
Notifications
You must be signed in to change notification settings - Fork 590
Description
Hello! We are looking into FP4 MoE perf on B200, mainly following this recipe and setup for llama 4 scout.
There's an env variable to control which flashinfer backend to select for the moe kernels:
VLLM_FLASHINFER_MOE_BACKEND=throughput # default, cutlass backend
VLLM_FLASHINFER_MOE_BACKEND=latency # trt-llm backend
From our benchmarking (consistent with the naming), at high concurrencies the CUTLASS backend is better, however there is a significant difference in decoding TPOT at lower batch sizes. Concretely for TP=1 BS=1 2048 input/100 output
on llama 4 scout
TRT backend
Mean TPOT (ms): 7.60
CUTLASS backend
Mean TPOT (ms): 12.84
Which is a pretty significant difference. From profiling decoding step it seems most of the extra latency comes from a few ops which could potentially be fixed:
- fp4 quantize
Both moe backends use this op to quantize the activations
void tensorrt_llm::kernels::quantize_with_block_size<(tensorrt_llm::BlockScaleQuantizationType)0, __nv_bfloat16, 16, false>(int, int, int, int, __nv_bfloat16 const*, float const*, unsigned int*, unsigned int*, flashinfer::QuantizationSFLayout)
It seems the diff is that CUTLASS one needs the 128x4 layout which is not performant at low batch sizes, so this op takes ~75µs instead of expected <5µs. Here is a simple bench to repro
# %%
from vllm.utils.flashinfer import fp4_quantize
import torch
from triton.testing import do_bench_cudagraph
from matplotlib import pyplot as plt
# %%
A_scale = torch.randn(16).cuda().float()
bsz = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
times = []
for bs in bsz:
A = torch.randn(bs, 5120).cuda().to(torch.bfloat16)
t = do_bench_cudagraph(lambda: fp4_quantize(A, A_scale, is_sf_swizzled_layout=True))
times.append(t)
# %%
plt.plot(bsz, times)
plt.xscale("log", base=2)
plt.xticks(bsz, labels=[str(v) for v in bsz])
plt.xlabel("Batch Size")
plt.ylabel("Time (ms)")
plt.title("FP4 Quantization Time vs Batch Size")
plt.grid(True, which="both", linestyle="--", alpha=0.7)
plt.show()
- moe prep
One of the moe prep kernels
void tensorrt_llm::kernels::cutlass_kernels::fusedBuildExpertMapsSortFirstTokenKernel<32, 1, 5>(int const*, int*, int*, long*, long, int, int, int, int)
has ~25µs latency. In another (larger) moe model the kernel with different template args <32, 8, 8> is used which has much lower latency (~4µs).
Other potential improvements based on the profiling:
- grouped gemm closer to memory bound SoL
- fuse activation to FC1 epilogue