Skip to content

Conversation

@varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Jul 18, 2025

Purpose

Tweak the num_warps and NUM_STAGES (num pipeline stages for prefetching) values of the kernel.

Local micro-benchmark numbers:
main:

Benchmark: E=256, T=2048, H=7168, group_size=128, repeat=200
tokens=4: 	quant_silu_mul 0.030ms 	
tokens=8: 	quant_silu_mul 0.056ms 	
tokens=16: 	quant_silu_mul 0.106ms 	
tokens=32: 	quant_silu_mul 0.204ms 	
tokens=64: 	quant_silu_mul 0.402ms 	
tokens=128: 	quant_silu_mul 0.799ms 	
tokens=256: 	quant_silu_mul 1.579ms 	
tokens=384: 	quant_silu_mul 2.366ms 	
tokens=512: 	quant_silu_mul 3.148ms 	
tokens=1024: 	quant_silu_mul 6.272ms 	
tokens=2048: 	quant_silu_mul 12.522ms 	

This PR:

Benchmark: E=256, T=2048, H=7168, group_size=128, repeat=200
tokens=4: 	quant_silu_mul 0.017ms 	
tokens=8: 	quant_silu_mul 0.032ms 	
tokens=16: 	quant_silu_mul 0.057ms 	
tokens=32: 	quant_silu_mul 0.108ms 	
tokens=64: 	quant_silu_mul 0.211ms 	
tokens=128: 	quant_silu_mul 0.417ms 	
tokens=256: 	quant_silu_mul 0.830ms 	
tokens=384: 	quant_silu_mul 1.234ms 	
tokens=512: 	quant_silu_mul 1.639ms 	
tokens=1024: 	quant_silu_mul 3.254ms 	
tokens=2048: 	quant_silu_mul 6.514ms 	

Note: micro-benchmarking script from https://github.com/tlrmchlsmth/ptgq_fp8

E2E Perf

server command : VLLM_ALL2ALL_BACKEND="deepep_low_latency" VLLM_USE_DEEP_GEMM=1 canhazgpu run -g 2 -- vllm serve Qwen/Qwen3-30B-A3B-FP8 --trust-remote-code --enable-expert-parallel --data-parallel-size 2 --port 9010 --no-enable-prefix-caching
benchmark command : python3 ./benchmarks/benchmark_serving.py --model Qwen/Qwen3-30B-A3B-FP8 --dataset-name sharegpt --port 9010 --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json
Methodology: Start the server and execute the benchmark command 3 times. Report the best Total Token Throughput numbers.

main:

============ Serving Benchmark Result ============
Successful requests:                 	1000 	 
Benchmark duration (s):              	32.44	 
Total input tokens:                  	217393    
Total generated tokens:              	201847    
Request throughput (req/s):          	30.83	 
Output token throughput (tok/s):     	6222.53   
Total Token throughput (tok/s):      	12924.31  
---------------Time to First Token----------------
Mean TTFT (ms):                      	6470.31   
Median TTFT (ms):                    	6734.54   
P99 TTFT (ms):                       	12538.94  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                      	192.93    
Median TPOT (ms):                    	76.87	 
P99 TPOT (ms):                       	773.24    
---------------Inter-token Latency----------------
Mean ITL (ms):                       	61.06	 
Median ITL (ms):                     	35.02	 
P99 ITL (ms):                        	778.17    
==================================================

This PR:

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  30.64     
Total input tokens:                      217393    
Total generated tokens:                  201847    
Request throughput (req/s):              32.64     
Output token throughput (tok/s):         6587.82   
Total Token throughput (tok/s):          13683.03  
---------------Time to First Token----------------
Mean TTFT (ms):                          6416.49   
Median TTFT (ms):                        6604.24   
P99 TTFT (ms):                           11718.61  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          174.51    
Median TPOT (ms):                        66.36     
P99 TPOT (ms):                           776.26    
---------------Inter-token Latency----------------
Mean ITL (ms):                           54.63     
Median ITL (ms):                         27.40     
P99 ITL (ms):                            779.23    
==================================================

Test Plan

local testing : pytest -s tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py

e2e testing :
server command : VLLM_ALL2ALL_BACKEND="deepep_low_latency" VLLM_USE_DEEP_GEMM=1 canhazgpu run -g 2 -- vllm serve Qwen/Qwen3-30B-A3B-FP8 --trust-remote-code --enable-expert-parallel --data-parallel-size 2 --port 9010 --no-enable-prefix-caching
lm_eval command : lm_eval --model local-completions --tasks gsm8k --model_args model=Qwen/Qwen3-30B-A3B-FP8,base_url=http://127.0.0.1:9010/v1/completions,num_concurrent=30,max_retries=3 --limit 100

Test Result

tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py test passes locally

lm_eval output :

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.84|±  |0.0368|
|     |       |strict-match    |     5|exact_match|↑  | 0.95|±  |0.0219|

(Optional) Documentation Update

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces performance optimizations to the silu_mul_fp8_quant_deep_gemm Triton kernel. The changes involve switching from a manual while loop to tl.range to enable software pipelining, and tuning the num_warps and NUM_STAGES parameters.

The code modifications are correct and follow Triton best practices for performance. The provided micro-benchmarks demonstrate a significant performance improvement, which validates the tuning choices. The changes are well-contained and improve the efficiency of the kernel as intended. I have no further comments.

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 18, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) July 18, 2025 16:36
@simon-mo simon-mo disabled auto-merge July 19, 2025 06:09
@simon-mo simon-mo merged commit dcc6cfb into vllm-project:main Jul 19, 2025
78 of 79 checks passed
x22x22 pushed a commit to x22x22/vllm that referenced this pull request Aug 5, 2025
…kernel (vllm-project#21193)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: x22x22 <[email protected]>
Pradyun92 pushed a commit to Pradyun92/vllm that referenced this pull request Aug 6, 2025
…kernel (vllm-project#21193)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
…kernel (vllm-project#21193)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
…kernel (vllm-project#21193)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
…kernel (vllm-project#21193)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Paul Pak <[email protected]>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
…kernel (vllm-project#21193)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Diego-Castan <[email protected]>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 27, 2025
…kernel (vllm-project#21193)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants