Skip to content

Conversation

@soundOfDestiny
Copy link
Contributor

Background (copied from #1932)

As we adopt narrower datatypes, traditional scaling methods struggle to maintain accuracy, particularly with 8-bit floating-point types (e.g., e5m2_t, e4m3_t). The typical GEMM operation uses tensorwise scaling with $D = alpha * (A @ B) + beta * C$, but narrower datatypes necessitate more finer-grained scaling techniques. Before we dive deep into groupwise scaling below is a glossary of various scaling methods:

  1. Tensorwise Scaling: Uses a single scaling factor per tensor, applied in the epilogue.
  2. Rowwise Scaling: Uses a row vector for scaling, with dimensions Mx1 for operand A and 1xN for operand B, avoiding the scaling along the reduction dimension. This can also be handled in the epilogue with EpilogueVisitorTree.
  3. Blockwise Scaling (Blockwise Scaling for FP8 #1932): Introduces a 2D scaling tensor, assigning one scaling value per CTA Block. Since this scaling involves the reduction dimension (M, N, K), it must be applied during the mainloop, impacting performance. Blockwise Scaling for FP8 #1932 implements blockwise scaling for CUTLASS F8 GEMM, staging scaling tensors via shared memory, and preparing for future support of groupwise scaling.
  4. Groupwise Scaling (along M in A tensor, this PR): Uses a 2D scaling tensor with multiple scaling values per CTA Block. Scaling granularity is independent of CTA Block configuration, allowing greater flexibility for future implementations.

Summary

As #1932 adds blockwise scaling strategy, this PR is a patch based on #1932 and adds groupwise scaling strategy along M in A tensor. Scaling granularity along M is made independent of CTA Block configuration, however, scaling granularities along N and K are still blockwise (i.e. one scaling value per CTA Block).

This PR restricts scaling granularity along M to a factor of TILE_SHAPE_M in CTA Block configuration, while one can set the GEMM scaling granularity along M to exactly TILE_SHAPE_M (i.e. fallback to blockwise scaling strategy) and call repeat_interleave method on input tensor ScaleA to simulate the situation that scaling granularity is multiplies of TILE_SHAPE_M.

Groupwise Scaling

In this implementation, we load scaling tensors with more elements than #1932 to shared memory since there might be various scaling along M per CTA Block. However, each thread only needs to load at most 2 scale values for A tensor and exactly one scale value for B tensor from shared memory to registers per iteration because WGMMA accumulators of each thread involve only 2 rows in result tensor.

Performance

I haven't observed a performance degradation compared with #1932
blockwise scaling

./64_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling 
  Disposition: Passed
  Problem Size: 1024x512x1024x1
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0112583 ms
  GFLOPS: 95373.3

groupwise scaling (this PR, setting scaling granularity along M to 64)

./64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling 
  Disposition: Passed
  Problem Size: 1024x512x1024x1
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0112435 ms
  GFLOPS: 95499.3

@zhyncs
Copy link

zhyncs commented Jan 17, 2025

Hi @hwu36 This PR is from the DeepSeek Team. Could you help review and merge it? The SGLang team wants to implement block-wise FP8 using CUTLASS for DeepSeek V3. This PR is essential for us. Thanks!

@ll2088
Copy link

ll2088 commented Jan 21, 2025

Hi @hwu36 This PR is from the DeepSeek Team. Could you help review and merge it? The SGLang team wants to implement block-wise FP8 using CUTLASS for DeepSeek V3. This PR is essential for us. Thanks!

Hi @zhyncs zh This PR looks like a example demo,Has the integration with SGLang been done? Could you post a PR about the integration code with SGLang?

@zhyncs
Copy link

zhyncs commented Jan 21, 2025

@ll2088
Our current open source version https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py has been referenced and adapted by other projects, including vLLM and LightLLM.
The version developed based on CUTLASS is currently based on this branch. https://github.com/soundOfDestiny/cutlass/tree/f8_groupwise_scaling_pr_branch.
We hope the official CUTLASS will review and merge this PR soon so we can use the official version. Currently, v3.7.0 includes block-wise but not per-token-per-128-channel support.

@ll2088
Copy link

ll2088 commented Jan 21, 2025

@ll2088 Our current open source version https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py has been referenced and adapted by other projects, including vLLM and LightLLM. The version developed based on CUTLASS is currently based on this branch. https://github.com/soundOfDestiny/cutlass/tree/f8_groupwise_scaling_pr_branch. We hope the official CUTLASS will review and merge this PR soon so we can use the official version. Currently, v3.7.0 includes block-wise but not per-token-per-128-channel support.

The version developed based on CUTLASS in SGLang, Does it PRed? Could you post it here?

@zhyncs
Copy link

zhyncs commented Jan 21, 2025

Not yet.

@soundOfDestiny soundOfDestiny force-pushed the f8_groupwise_scaling_pr_branch branch from 9d997ce to a08ef31 Compare January 21, 2025 06:57
@ll2088
Copy link

ll2088 commented Jan 21, 2025

image @soundOfDestiny using TileShape = Shape<_1,_128,_128>; why does it not work? compile problem occurs.

And why does ScaleMsPerTile = 128 not work? @soundOfDestiny

@soundOfDestiny soundOfDestiny force-pushed the f8_groupwise_scaling_pr_branch branch from a08ef31 to 0c08d7c Compare January 21, 2025 14:40
@soundOfDestiny
Copy link
Contributor Author

ad5c27dc5369702a20ba7d80c218083a 51f4cfba6b99089f4beac0af8b411f8e @zhyncs ScaleMsPerTile=128 is not supported here, the shared memory is not enough.

/workspace/applied-ai/kernels/cuda/cutlass_gemm/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h:336 Setting smem size to 234496 /workspace/applied-ai/kernels/cuda/cutlass_gemm/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h:343 cudaFuncSetAttribute() returned error: invalid argument Got cutlass error: Error Internal at: 673

The issue of incorrect calculation of shared memory size has appeared since #1932.
It has been fixed in latest commit.

@soundOfDestiny soundOfDestiny force-pushed the f8_groupwise_scaling_pr_branch branch from 0c08d7c to df73dd0 Compare January 21, 2025 14:50
@soundOfDestiny soundOfDestiny force-pushed the f8_groupwise_scaling_pr_branch branch from df73dd0 to 3197c81 Compare January 21, 2025 14:57
@hwu36 hwu36 merged commit 3c28697 into NVIDIA:main Jan 31, 2025
@hwu36
Copy link
Collaborator

hwu36 commented Jan 31, 2025

cuda 12.9 will improve the performance of blockscale/groupscale kernels.

@soundOfDestiny soundOfDestiny deleted the f8_groupwise_scaling_pr_branch branch February 1, 2025 01:45
@Hongbosherlock
Copy link

Hongbosherlock commented Feb 3, 2025

Hi @soundOfDestiny and @hwu36
I've noticed that GEMM is slower when K is large, such as m128-n1536-k7168
Is it possible to use Stream-k or Split-K to accelerate this Groupwise-GEMM ? I'm unsure if this PR supports it.
Will be very grateful for your help!

@hwu36
Copy link
Collaborator

hwu36 commented Feb 3, 2025

Hi @soundOfDestiny and @hwu36
I've noticed that GEMM is slower when K is large, such as m128-n1536-k7168
Is it possible to use Stream-k or Split-K to accelerate this Groupwise-GEMM ? I'm unsure if this PR supports it.
Will be very grateful for your help!

@jackkosaian

@Hongbosherlock
Copy link

Hi @soundOfDestiny and @hwu36
I've noticed that GEMM is slower when K is large, such as m128-n1536-k7168
Is it possible to use Stream-k or Split-K to accelerate this Groupwise-GEMM ? I'm unsure if this PR supports it.
Will be very grateful for your help!

@jackkosaian

hi @jackkosaian

I'm currently working on optimizing this Groupwise-GEMM performance for the Hopper architecture using CUTLASS 3.x and exploring the split-K technique. I've reviewed previous issues related to split-K (#702 (comment),
#1586) and tried understanding its principles.

I initially attempted to implement split-K by directly modifying the code here:

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
    Shape<int,int,int,int>, // Indicates ProblemShape
    CollectiveMainloopWithBlockWiseScaling,
    CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

I tried replacing GemmUniversal with GemmSplitKParallel
However, this approach resulted in errors
image

I'm trying to use kernel::GemmUniversal with mode GemmUniversalMode::kGemmSplitKParallel now, but I'm not sure if it will work for this gemm.
I'm curious if there are more elegant or recommended ways to achieve split-K for this Groupwise-GEMM in CUTLASS 3.x on Hopper.

@soundOfDestiny
Copy link
Contributor Author

soundOfDestiny commented Feb 3, 2025

Hi @soundOfDestiny and @hwu36
I've noticed that GEMM is slower when K is large, such as m128-n1536-k7168
Is it possible to use Stream-k or Split-K to accelerate this Groupwise-GEMM ? I'm unsure if this PR supports it.
Will be very grateful for your help!

@jackkosaian

hi @jackkosaian

I'm currently working on optimizing this Groupwise-GEMM performance for the Hopper architecture using CUTLASS 3.x and exploring the split-K technique. I've reviewed previous issues related to split-K (#702 (comment), #1586) and tried understanding its principles.

I initially attempted to implement split-K by directly modifying the code here:

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
    Shape<int,int,int,int>, // Indicates ProblemShape
    CollectiveMainloopWithBlockWiseScaling,
    CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

I tried replacing GemmUniversal with GemmSplitKParallel However, this approach resulted in errors image

I'm trying to use kernel::GemmUniversal with mode GemmUniversalMode::kGemmSplitKParallel now, but I'm not sure if it will work for this gemm. I'm curious if there are more elegant or recommended ways to achieve split-K for this Groupwise-GEMM in CUTLASS 3.x on Hopper.

the definition of GemmSplitKParallel
is

template <
  typename Mma_,                  ///! Threadblock-scoped matrix multiply-accumulate 
  typename Epilogue_,             ///! Epilogue
  typename ThreadblockSwizzle_    ///! Threadblock swizzling function
>

the second template argument should be epilogue, rather than CollectiveMainloopWithBlockWiseScaling in your example

@Hongbosherlock
Copy link

hi @soundOfDestiny and @jackkosaian
thanks for your reply, I have described my issue in this thread : #2075

@Maximilianxu
Copy link

cuda 12.9 will improve the performance of blockscale/groupscale kernels.

In cuBLAS?

@ginowu
Copy link

ginowu commented Feb 5, 2025

cuda 12.9 will improve the performance of blockscale/groupscale kernels.

@hwu36 BTW, would newer version of Transformer-Engine support generating this kind of group-wise scaling factors? As current TE only supports generating per-tensor scales, thanks!

@hwu36
Copy link
Collaborator

hwu36 commented Feb 5, 2025

cuda 12.9 will improve the performance of blockscale/groupscale kernels.

@hwu36 BTW, would newer version of Transformer-Engine support generating this kind of group-wise scaling factors? As current TE only supports generating per-tensor scales, thanks!

Sorry, I don't know.

@yizhang2077
Copy link

yizhang2077 commented Feb 7, 2025

hi @soundOfDestiny, I try to use groupwise scaling to implement per-token-per-128-channel and blockwise, but it can not work, I describe my issue #2087, will be grateful for your help, thanks!

sijialouintel added a commit to sijialouintel/cutlass that referenced this pull request Feb 12, 2025
* Handle MNK Sm90{Row, Col}Reduction problem shapes (NVIDIA#1803)

* add is_last_tile

* Improve sm90 mixed dtype kernel (NVIDIA#1883)

* Add GMMA shape m64n40k16 (NVIDIA#1864)

* Add all supported GMMA shapes (NVIDIA#1890)

* add maximum support (NVIDIA#1833)

* fix typo (NVIDIA#1853)

* fix by adding public (NVIDIA#1753)

* added mapping for bf16 to torch::kBFloat16 (NVIDIA#1843)

Co-authored-by: Haicheng Wu <[email protected]>

* Fix README (NVIDIA#1658)

* Fix README

* Improve README

---------

Co-authored-by: Haicheng Wu <[email protected]>

* Adjusting code indentation (NVIDIA#1639)

* Include of regular_tile_iterator.h fixed for NVRTC (NVIDIA#1765)

* Include of regular_tile_iterator.h fixed for NVRTC

* More include fixed for NVRTC

* Update gemm_f16n_f16t_f32t_tensor_op_f32_sm80.cu with include "cutlass/gemm/device/gemm_universal.h" (NVIDIA#1569)

fix compile with `cmake .. -DCUTLASS_ENABLE_TESTS=ON -DCUTLASS_TEST_LEVEL=2`

* remove redundant hardcoded packing configs in mixed dtype gemm (NVIDIA#1894)

Co-authored-by: Siyuan Fu <[email protected]>

* fix wrong A/BLayout in MMA_Traits for binary mma and append other MMA_Traits support  (NVIDIA#1856)

* fix wrong A/BLayout in  MMA_Traits<SM80_16x8x256_S32U1U1S32_TN_XORPOPC> and append support for  m8n8k128, m16n8k128  mma.and.popc in MMA_Traits instantiation

* add "print" template for  subbyte_reference<T>

* Add a print for the uint{x}b_t type. (NVIDIA#1871)

* Refactor some GroupedGEMM logic (NVIDIA#1899)

* feat: support kFactor 8 used in mma tensor op tile iterator (NVIDIA#1512)

* Update publications (NVIDIA#1912)

* remove restriction of stride == kernel in nhwc_pooling (NVIDIA#1896)

* fix undefined in device code error (NVIDIA#1880)

* Fix the racing condition of mixed-input gemm when writing the registers (NVIDIA#1931)

* move two warpgroup_wait

* merge main

---------

Co-authored-by: Siyuan Fu <[email protected]>

* Fix `cutlass` python library with cuda `12.6.2.post1` (NVIDIA#1942)

* Fix `cutlass` python library with cuda `12.6.2.post1`

Previously we had this error:
```
  File "/storage/home/cutlass/python/cutlass/backend/operation.py", line 39, in <listcomp>
    _version_splits = [int(x) for x in __version__.split("rc")[0].split(".")]
                       ^^^^^^
ValueError: invalid literal for int() with base 10: 'post1'
```

* Update sm90_utils.py

* Update generator.py

* Update python/cutlass_library/generator.py

Co-authored-by: Jack Kosaian <[email protected]>

* Update python/cutlass_library/sm90_utils.py

Co-authored-by: Jack Kosaian <[email protected]>

---------

Co-authored-by: Jack Kosaian <[email protected]>

* add {uint4, uint2, int2} => {fp16, bf16} conversion (NVIDIA#1966)

* Improve mixed dtype GEMM (NVIDIA#1972)

* update

* fix a typo

* fix a typo that fails the compiling when ElementScale is not the same as MmaType (NVIDIA#1977)

* Fix CuTe README Typo (NVIDIA#1951)

* Fix Typo (NVIDIA#1962)

* 3.6.0 update (NVIDIA#2005)

* 3.6.0 update

* doc and swap stuff

---------

Co-authored-by: yuzhai <[email protected]>
Co-authored-by: Haicheng Wu <[email protected]>

* Update CHANGELOG.md

* Update 0x_gemm_tutorial.md (NVIDIA#1982)

Shouldn't this be BLK_M, BLK_**K**, k

* fix bug: arch/mma_sm60.h Mma<2,2,1> calculate wrong (NVIDIA#1989)

* fix mem fence (NVIDIA#2030)

Co-authored-by: yuzhai <[email protected]>

* Add half->int8 saturate conversion to promise valid range (NVIDIA#1983)

* Add half->int8 saturate conversion to promise valid range

* add gpu only macro

---------

Co-authored-by: Haicheng Wu <[email protected]>

* Add vector-types back to platform.h (NVIDIA#2026)

* Fix typo in library_defaults.py (NVIDIA#2024)

* Fix Typos (NVIDIA#2021)

* Fix Typo

* Fix Typo

* Add Line Break (NVIDIA#2020)

* Blockwise Scaling for FP8 (NVIDIA#1932)

* F8 Blockwise Scaling

* two more NumProducerThreadEvents

---------

Co-authored-by: Haicheng Wu <[email protected]>

* fix assertion in integer_subbytes.h (NVIDIA#1961)

* CUTLASS 3.7 (NVIDIA#2045)

* CUTLASS 3.7

* clean up changelog

---------

Co-authored-by: yuzhai <[email protected]>
Co-authored-by: Haicheng Wu <[email protected]>

* update 3.7 docs (NVIDIA#2051)

* update docs

* update docs

* update docs

* update docs

* update docs

---------

Co-authored-by: yuzhai <[email protected]>

* CUTLASS 3.8 Release (NVIDIA#2059)

* CUTLASS 3.8 Release

* update

* Update README.md

* Revert "Update README.md"

This reverts commit b353e36.

* update

* update

---------

Co-authored-by: Haicheng Wu <[email protected]>
Co-authored-by: Haicheng Wu <[email protected]>

* fix cuda 12.6 issues (NVIDIA#2066)

* fix a readme broken link (NVIDIA#2069)

* Update README.md

* Groupwise scaling along M for FP8 gemm (NVIDIA#2037)

* FP8 groupwise scaling along M

* small updates

---------

Co-authored-by: zl <[email protected]>
Co-authored-by: Haicheng Wu <[email protected]>

* bugfix generic-k code in top-k with softmax (NVIDIA#1993)

* bugfix generic-k code in top-k with softmax

* Update include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp

Co-authored-by: Ali Hassani <[email protected]>

* Update examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu

Co-authored-by: Ali Hassani <[email protected]>

---------

Co-authored-by: Ali Hassani <[email protected]>

* [EVT] Add support for Row/Col broadcast PtrArray (NVIDIA#2033)

* Add group support to EVT row/col broadcast.

* small modifications

---------

Co-authored-by: Haicheng Wu <[email protected]>

* v3.8.0 update (NVIDIA#2082)

* 3.8 update

* fix Markus' name

---------

Co-authored-by: yuzhai <[email protected]>

* [WA] Fix compiling errors

---------

Co-authored-by: Saagar Jha <[email protected]>
Co-authored-by: Haicheng Wu <[email protected]>
Co-authored-by: Sergey Klevtsov <[email protected]>
Co-authored-by: Tri Dao <[email protected]>
Co-authored-by: Xinyu Yang <[email protected]>
Co-authored-by: sijialou <[email protected]>
Co-authored-by: Bogumil Sapinski Mobica <[email protected]>
Co-authored-by: Haicheng Wu <[email protected]>
Co-authored-by: Lei Mao <[email protected]>
Co-authored-by: 103yiran <[email protected]>
Co-authored-by: MaxAkaAltmer <[email protected]>
Co-authored-by: 侯奇 <[email protected]>
Co-authored-by: Lain <[email protected]>
Co-authored-by: Siyuan Fu <[email protected]>
Co-authored-by: Caleb_Du <[email protected]>
Co-authored-by: LiYu Lu <[email protected]>
Co-authored-by: azhurkevich <[email protected]>
Co-authored-by: chenwei <[email protected]>
Co-authored-by: Wenlei Bao <[email protected]>
Co-authored-by: LiuQiang <[email protected]>
Co-authored-by: dan_the_3rd <[email protected]>
Co-authored-by: Jack Kosaian <[email protected]>
Co-authored-by: Yujia Zhai <[email protected]>
Co-authored-by: yuzhai <[email protected]>
Co-authored-by: Andrew O'Neill <[email protected]>
Co-authored-by: Dongxu.Wang <[email protected]>
Co-authored-by: ZZK <[email protected]>
Co-authored-by: Driss Guessous <[email protected]>
Co-authored-by: ZincCat <[email protected]>
Co-authored-by: Manish Gupta <[email protected]>
Co-authored-by: bobliao <[email protected]>
Co-authored-by: mihir-awatramani <[email protected]>
Co-authored-by: Liang <[email protected]>
Co-authored-by: zl <[email protected]>
Co-authored-by: Tadej Ciglarič <[email protected]>
Co-authored-by: Ali Hassani <[email protected]>
Co-authored-by: Josh Fromm <[email protected]>
hgl71964 pushed a commit to hgl71964/cutlass that referenced this pull request Feb 21, 2025
* FP8 groupwise scaling along M

* small updates

---------

Co-authored-by: zl <[email protected]>
Co-authored-by: Haicheng Wu <[email protected]>
@mnicely
Copy link
Collaborator

mnicely commented May 2, 2025

@soundOfDestiny would you mind testing performance with CUTLASS v3.9.1 + CUDA 12.9?

@soundOfDestiny
Copy link
Contributor Author

@soundOfDestiny would you mind testing performance with CUTLASS v3.9.1 + CUDA 12.9?

blockwise scaling

./67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling
Result MSE: 1.08661e-06, MRE: 5.4241, greatest error: 0.0193539
Aux MSE: 2.16905e-06, MRE: 11.8225, greatest error: 0.019928
Disposition: Passed
Problem Size: 1024x512x1024x1
Rasterization: Heuristic with a maximum CTA swizzle of 1
Avg runtime: 0.0106264 ms
GFLOPS: 101044

groupwise scaling (setting scaling granularity along M to 64, which is the config in the description of this PR)

./67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling
Problem Size: 1024x512x1024x1
Tile shape (M, N, K): _128, _128, _128
ScaleGranularityM: 64 (ScaleMsPerTile: 2)
ScaleGranularityN: 128 (ScaleNsPerTile: 1)
Running...
Result MSE: 1.13043e-06, MRE: 5.48901, greatest error: 0.0216789
Aux MSE: 2.27379e-06, MRE: 11.9547, greatest error: 0.0216789
Disposition: Passed
Rasterization: Heuristic with a maximum CTA swizzle of 1
Avg runtime: 0.0110664 ms
GFLOPS: 97027.5

groupwise scaling (setting scaling granularity along M to 1)

./67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling
Problem Size: 1024x512x1024x1
Tile shape (M, N, K): _128, _128, _128
ScaleGranularityM: 1 (ScaleMsPerTile: 128)
ScaleGranularityN: 128 (ScaleNsPerTile: 1)
Running...
Result MSE: 1.26768e-06, MRE: 5.84238, greatest error: 0.0191631
Aux MSE: 2.52203e-06, MRE: 12.6349, greatest error: 0.0191631
Disposition: Passed
Rasterization: Heuristic with a maximum CTA swizzle of 1
Avg runtime: 0.0119117 ms
GFLOPS: 90141.7

btw, there has been a PR to improve performance of groupwise scaling: #2095

andralex pushed a commit to andralex/cutlass that referenced this pull request Jun 14, 2025
* FP8 groupwise scaling along M

* small updates

---------

Co-authored-by: zl <[email protected]>
Co-authored-by: Haicheng Wu <[email protected]>
Albresky pushed a commit to Albresky/cutlass that referenced this pull request Oct 11, 2025
* FP8 groupwise scaling along M

* small updates

---------

Co-authored-by: zl <[email protected]>
Co-authored-by: Haicheng Wu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants