Skip to content

Conversation

@LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Sep 10, 2024

Layerwise profiler for see how much time is spent on CUDA (GPU kernels) for each module/layer

Example of how to run a profile

python examples/offline_profile.py --model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 --prompt-len 512 --json Llama31-8b-FP8 --max-num-batched-tokens 8196 --enforce-eager

Then there are some utilities for looking at the profile breakdown, e.g. to get a summary table of the prefill phase you can run:

$ python tools/profiler/print_layerwise_table.py --json-trace Llama31-8b-FP8.json --phase prefill --table summary
name                                                                             | cuda_time_us | pct_cuda_... | invocations    
================================================================================================================================
LlamaForCausalLM                                                                 |     31788.89 |        97.33 |            1.00
|- LlamaModel                                                                    |     31788.89 |        97.33 |            1.00
|-- VocabParallelEmbedding(weight=bfloat16[128256, 4096])                        |        59.20 |         0.18 |            1.00
|--- void at::native::(anonymous namespace)::indexSelectLargeIndex<c10::BFloa... |        59.20 |         0.18 |            1.00
|-- LlamaDecoderLayer                                                            |     31709.72 |        97.09 |           32.00
|--- RMSNorm(weight=bfloat16[4096])                                              |      1336.33 |         4.09 |           64.00
|---- void vllm::rms_norm_kernel<c10::BFloat16>(c10::BFloat16*, c10::BFloat16... |        26.56 |         0.08 |            1.00
|---- std::enable_if<(((8)>(0)))&&vllm::_typeConvert<c10::BFloat16>::exists, ... |      1309.77 |         4.01 |           63.00
|--- LlamaAttention                                                              |      8511.74 |        26.06 |           32.00
|---- QKVParallelLinear(weight=float8_e4m3fn[4096, 6144], weight_scale=float3... |      3014.17 |         9.23 |           32.00
|----- void vllm::scaled_fp8_quant_kernel<c10::BFloat16>(c10::Float8_e4m3fn*,... |       472.03 |         1.45 |           32.00
|----- void cutlass::device_kernel<(anonymous namespace)::cutlass_3x_gemm<cut... |      2542.14 |         7.78 |           32.00
|---- Llama3RotaryEmbedding                                                      |       892.66 |         2.73 |           32.00
|----- void vllm::rotary_embedding_kernel<c10::BFloat16, true>(long const*, c... |       892.66 |         2.73 |           32.00
|---- Attention                                                                  |      2454.87 |         7.52 |           32.00
|----- void vllm::reshape_and_cache_flash_kernel<__nv_bfloat16, __nv_bfloat16... |       343.77 |         1.05 |           32.00
|----- void flash_fwd_splitkv_kernel<Flash_fwd_kernel_traits<128, 64, 128, 4,... |      1756.16 |         5.38 |           32.00
|----- Memcpy DtoD (Device -> Device)                                            |       354.94 |         1.09 |           32.00
|---- RowParallelLinear(weight=float8_e4m3fn[4096, 4096], weight_scale=float3... |      2150.05 |         6.58 |           32.00
|----- void vllm::scaled_fp8_quant_kernel<c10::BFloat16>(c10::Float8_e4m3fn*,... |       474.97 |         1.45 |           32.00
|----- void cutlass::device_kernel<(anonymous namespace)::cutlass_3x_gemm<cut... |      1675.08 |         5.13 |           32.00
|--- LlamaMLP                                                                    |     21861.65 |        66.94 |           32.00
|---- MergedColumnParallelLinear(weight=float8_e4m3fn[4096, 28672], weight_sc... |     12077.86 |        36.98 |           32.00
|----- void vllm::scaled_fp8_quant_kernel<c10::BFloat16>(c10::Float8_e4m3fn*,... |       483.19 |         1.48 |           32.00
|----- void cutlass::device_kernel<(anonymous namespace)::cutlass_3x_gemm<cut... |     11594.67 |        35.50 |           32.00
|---- SiluAndMul                                                                 |      2882.36 |         8.83 |           32.00
|----- void vllm::act_and_mul_kernel<c10::BFloat16, &(c10::BFloat16 vllm::sil... |      2882.36 |         8.83 |           32.00
|---- RowParallelLinear(weight=float8_e4m3fn[14336, 4096], weight_scale=float... |      6901.43 |        21.13 |           32.00
|----- void vllm::scaled_fp8_quant_kernel<c10::BFloat16>(c10::Float8_e4m3fn*,... |      1268.66 |         3.88 |           32.00
|----- void cutlass::device_kernel<(anonymous namespace)::cutlass_3x_gemm<cut... |      5632.78 |        17.25 |           32.00
|-- RMSNorm(weight=bfloat16[4096])                                               |        19.97 |         0.06 |            1.00
|--- std::enable_if<(((8)>(0)))&&vllm::_typeConvert<c10::BFloat16>::exists, v... |        19.97 |         0.06 |            1.00
LogitsProcessor                                                                  |       360.95 |         1.11 |            1.00
|- void at::native::(anonymous namespace)::indexSelectSmallIndex<c10::BFloat1... |         3.81 |         0.01 |            1.00
|- Memset (Device)                                                               |         1.12 |         0.00 |            1.00
|- sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x128x64_warpgroupsize1x... |       356.03 |         1.09 |            1.00
Sampler                                                                          |       510.01 |         1.56 |            1.00
|- Memcpy HtoD (Pinned -> Device)                                                |        16.67 |         0.05 |            7.00
|- void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kerne... |         4.48 |         0.01 |            1.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl<at... |         4.86 |         0.01 |            1.00
|- at::native::(anonymous namespace)::fill_index_and_segment_kernel(int2*, in... |         3.33 |         0.01 |            1.00
|- Memset (Device)                                                               |        11.94 |         0.04 |            9.00
|- void at_cuda_detail::cub::DeviceRadixSortHistogramKernel<at_cuda_detail::c... |         6.14 |         0.02 |            1.00
|- void at_cuda_detail::cub::DeviceRadixSortExclusiveSumKernel<at_cuda_detail... |         1.89 |         0.01 |            1.00
|- void at_cuda_detail::cub::DeviceRadixSortOnesweepKernel<at_cuda_detail::cu... |        61.31 |         0.19 |            4.00
|- void at_cuda_detail::cub::DeviceRadixSortHistogramKernel<at_cuda_detail::c... |         3.20 |         0.01 |            1.00
|- void at_cuda_detail::cub::DeviceRadixSortExclusiveSumKernel<at_cuda_detail... |         1.54 |         0.00 |            1.00
|- void at_cuda_detail::cub::DeviceRadixSortOnesweepKernel<at_cuda_detail::cu... |        11.10 |         0.03 |            1.00
|- void at::native::(anonymous namespace)::sort_postprocess_kernel<float>(flo... |         6.79 |         0.02 |            1.00
|- Memcpy DtoD (Device -> Device)                                                |         2.75 |         0.01 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kerne... |         1.82 |         0.01 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctorO... |         1.50 |         0.00 |            1.00
|- void at::native::_scatter_gather_elementwise_kernel<128, 4, at::native::_c... |         7.68 |         0.02 |            2.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         3.78 |         0.01 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::(anonymous n... |         4.22 |         0.01 |            2.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<4, float, floa... |        89.31 |         0.27 |            2.00
|- void at::native::tensor_kernel_scan_innermost_dim<float, std::plus<float> ... |       169.34 |         0.52 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctorO... |         1.54 |         0.00 |            1.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl<at... |         4.90 |         0.01 |            1.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         1.82 |         0.01 |            1.00
|- void (anonymous namespace)::elementwise_kernel_with_index<int, at::native:... |         2.46 |         0.01 |            1.00
|- void at::native::_scatter_gather_elementwise_kernel<128, 4, at::native::_c... |        10.53 |         0.03 |            1.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<4, float, floa... |        28.00 |         0.09 |            1.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl<at... |         2.30 |         0.01 |            1.00
|- void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_ke... |         5.02 |         0.02 |            1.00
|- void at::native::(anonymous namespace)::distribution_elementwise_grid_stri... |         4.54 |         0.01 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::BinaryFuncto... |         3.23 |         0.01 |            1.00
|- void at::native::reduce_kernel<512, 1, at::native::ReduceOp<float, at::nat... |        28.96 |         0.09 |            1.00
|- Memcpy DtoH (Device -> Pageable)                                              |         3.04 |         0.01 |            1.00

Or to view it as a graph you can run:

python tools/profiler/visualize_layerwise_profile.py  --json-trace Llama31-8b-FP8.json --output-directory profile_breakdown --plot-metric pct_cuda_time

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@LucasWilkinson LucasWilkinson force-pushed the varun/main-with-profiler branch from 3516c46 to d857de9 Compare September 16, 2024 16:14
@LucasWilkinson LucasWilkinson changed the title [WIP, misc] CUDA Time Layerwise Profiler [misc] CUDA Time Layerwise Profiler Sep 16, 2024
@LucasWilkinson LucasWilkinson marked this pull request as ready for review September 16, 2024 16:15
@LucasWilkinson LucasWilkinson force-pushed the varun/main-with-profiler branch 2 times, most recently from 1a0844e to 52aafcf Compare September 17, 2024 15:03
@LucasWilkinson LucasWilkinson force-pushed the varun/main-with-profiler branch from 52aafcf to e03bedb Compare October 7, 2024 04:28
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! Since the script is pretty technically involved and relies on exact attributes to exist, could you add a simple e2e test to run in CI so we can know if torch updates break it?

@LucasWilkinson
Copy link
Collaborator Author

LucasWilkinson commented Oct 7, 2024

could you add a simple e2e test to run in CI so we can know if torch updates break it?

what's the easiest way to do this? just add a pytest test or just invoke offline_profile somehow? is there instructions on how to register something with buildkite or all pytest folders already automatically run?

@mgoin added examples test

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM and works well!

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 7, 2024
@LucasWilkinson LucasWilkinson force-pushed the varun/main-with-profiler branch from e88f143 to 97647e1 Compare October 8, 2024 16:27
Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Came in here because I wanted to use offline_profile.py

I added some minor comments, but LGTM

@LucasWilkinson LucasWilkinson force-pushed the varun/main-with-profiler branch from 84c4e81 to c1a5507 Compare October 16, 2024 19:36
@mgoin mgoin merged commit 9d30a05 into vllm-project:main Oct 17, 2024
76 checks passed
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: Alvant <[email protected]>
garg-amit pushed a commit to garg-amit/vllm that referenced this pull request Oct 28, 2024
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: Amit Garg <[email protected]>
FerdinandZhong pushed a commit to FerdinandZhong/vllm that referenced this pull request Oct 29, 2024
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: qishuai <[email protected]>
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: Sumit Dubey <[email protected]>
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Signed-off-by: LeiWang1999 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants