Skip to content

Conversation

@LucasWilkinson
Copy link
Contributor

@LucasWilkinson LucasWilkinson commented Feb 10, 2025

Various improvements to "Groupwise scaling along M" (#2037) namely to address: #2087, context vllm-project/vllm#11868 (comment)

Improvements:

  1. Multiple threads now participating in copy A scales
  2. Predication when copying A scale loads, this means if there is partial M tile (due to the problem shape not being evenly divided by the M tile shape)
  3. More commonly used scale layouts, currently CUTLASS uses a layout like:
(M_TILES, ScaleMsPerTile, K_TILES, L), ordered: (2, 0, 1, 3)

this PR moves to a layout of (i.e. standard M-major):

(M / ScaleGranularityM, K_TILES, L), ordered: (1, 0, 2)

making it much easier to integrate into inference libraries

These improvements were part of vLLMs adoption of this kernel https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp (PR: vllm-project/vllm#11868) and is in current wide scale use. Our goal is to rely on the CUTLASS implementation but that currently not possible given the issues above.

@LucasWilkinson LucasWilkinson changed the title [WIP][Bugfix] Bug fixes for: Groupwise scaling along M for FP8 gemm Improvements: Groupwise scaling along M for FP8 gemm Feb 10, 2025
@LucasWilkinson LucasWilkinson changed the title Improvements: Groupwise scaling along M for FP8 gemm Improvements for: Groupwise scaling along M for FP8 gemm Feb 10, 2025
@LucasWilkinson LucasWilkinson marked this pull request as ready for review February 10, 2025 21:07
@hwu36
Copy link
Collaborator

hwu36 commented Feb 21, 2025

@LucasWilkinson , we upstreamed our change to groupwise scaling kernels. there are some conflicts in this PR that needs to be solved.

Our change is mainly:

Extend groupwise scaling gemm to support both M dimension and N dimension groupwise scaling in FP8 GEMM.
In examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu, two parameters ScaleGranularityM and ScaleGranularityNcontrol the scaling mode:


ScaleGranularityM == 128 and ScaleGranularityN == 128 --> 2Dx2D scaling (block-wise scaling, same as 67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu , 2Dx2D refers to the shape of the scaling factor)

ScaleGranularityM == 1 and ScaleGranularityN == 128 --> 1Dx2D scaling

ScaleGranularityM == 128 and ScaleGranularityN == 1 --> 2Dx1D scaling

ScaleGranularityM == 1 and ScaleGranularityN == 1 --> 1Dx1D scaling

@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/fix-fp8-blockwise branch from db87722 to 7f541db Compare February 25, 2025 06:50
@LucasWilkinson
Copy link
Contributor Author

@LucasWilkinson , we upstreamed our change to groupwise scaling kernels. there are some conflicts in this PR that needs to be solved.

Our change is mainly:

Extend groupwise scaling gemm to support both M dimension and N dimension groupwise scaling in FP8 GEMM.
In examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu, two parameters ScaleGranularityM and ScaleGranularityNcontrol the scaling mode:


ScaleGranularityM == 128 and ScaleGranularityN == 128 --> 2Dx2D scaling (block-wise scaling, same as 67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu , 2Dx2D refers to the shape of the scaling factor)

ScaleGranularityM == 1 and ScaleGranularityN == 128 --> 1Dx2D scaling

ScaleGranularityM == 128 and ScaleGranularityN == 1 --> 2Dx1D scaling

ScaleGranularityM == 1 and ScaleGranularityN == 1 --> 1Dx1D scaling

apologies for the delay the PR has been updated, currently I am still vectorizing the loads of B scales along N (like main) but it might actually makes sense to not do this to enable transposing A and B (since we currently have partial tiles along M this would mean partial tiles along N)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any promblems when transpose A and transpose B?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently this assumes full tiles in N and K so if using this for inference where activations may have partial tiles if you transpose it to Y^T = WX^T it may report not implementable, I think im going to update this since ideally in vLLM we'd like to transpose it to use smaller tensor core instructions, we do lose vectorization on the loads then though

@hwu36
Copy link
Collaborator

hwu36 commented Feb 25, 2025

@manishucsd

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe still using ScalePromotionInterval here, and move size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{} to can_implement check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm im not sure I see ScalePromotionInterval, what would be the motivation to not have this determined at compile time? it seems a bit unnecessarily burdensome on the user to have them set mma_promotion_interval manually

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In anycase moving this as constexpr somewhere on the top will better for readability.

static constexpr int ScalePromotionInterval = size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}) and using that here?

@hwu36, So this will be 4 for TileShapeK = 128 and InstructionShape = 32, which is the original case, for TileShape = 64 this will be 2. Will that be not supported?

@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/fix-fp8-blockwise branch from edd90be to 2a9256f Compare February 26, 2025 05:36
Copy link
Contributor

@manishucsd manishucsd Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this restriction only for M and not for N? dim-M usually maps to batch count while dim-N will be model_dimension, a nice multiple of 2? correct?

If this is A_row * B_col groupwise GEMM, it is sometimes required that we do transposed and swap creating an underlying GEMM to be B_row * A_row, swapping M <-> N. This is typically helpful for (a.) mixed-input BF16*F8 which doesn't apply here (b.) M is small say 64, we can swap and transpose to run a better tile. I have seen that to give more performance for small M.

Does vectorizing scale_copy_b vs not-vectorizing give any performance improvements? If not, I would suggest that we be symmetric for this kernel in M and N to allow user to apply swap and transpose trick to this kernel.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was mostly just trying to keep it as close to the original as possible to minimize the chances of perf regressions, but I agree this is much less confusing. And I think we will want to transpose in vLLM in order to use smaller instructions for smaller batch sizes.

Copy link
Contributor Author

@LucasWilkinson LucasWilkinson Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pushed an update that enables partial tiles in N

Comment on lines +486 to +492
Copy link
Contributor

@manishucsd manishucsd Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make sure that this copy_if is issued by only 32 threads? The thread layout of shape 32 (created above) won't be tiled over entire tile by make_tiled_copy, just confirm please using simple printf

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ran

      if ((!blockIdx.x && !blockIdx.y && !blockIdx.z)) printf("%d ", threadIdx.x);
      if (thread0()) printf("\n");
      pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);

and got:

...
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 
...

I think we should be good 👍

Comment on lines +366 to +372
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should TMA related tensor constructions be in lane_predicate as before, no need for all the threads to construct this even in this implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure, I didn't think this was a big deal since if you look at the 3.6.0 diff with improved the mixed input GEMM (we were told 3.6 had perf improvements for mixed input) in include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp you can see that it was updated to have all the threads compute the TMA tensors, not sure what the recommended approach is, or if this particular change had any impact. Would some guidance!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In anycase moving this as constexpr somewhere on the top will better for readability.

static constexpr int ScalePromotionInterval = size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}) and using that here?

@hwu36, So this will be 4 for TileShapeK = 128 and InstructionShape = 32, which is the original case, for TileShape = 64 this will be 2. Will that be not supported?

Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/fix-fp8-blockwise branch from 1dc4ebd to 460b938 Compare February 26, 2025 23:07
@LucasWilkinson
Copy link
Contributor Author

H100

This PR:

Basic split-K GEMM kernel
Running: 
  Problem Size: 1024x512x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 1 (ScaleMsPerTile: 128)
  ScaleGranularityN: 1 (ScaleNsPerTile: 128)
  Disposition: Passed
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0159489 ms
  GFLOPS: 67323.9

Running: 
  Problem Size: 1024x512x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 1 (ScaleMsPerTile: 128)
  ScaleGranularityN: 128 (ScaleNsPerTile: 1)
  Disposition: Passed
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0147137 ms
  GFLOPS: 72975.5

Running: 
  Problem Size: 1024x512x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 128 (ScaleMsPerTile: 1)
  ScaleGranularityN: 1 (ScaleNsPerTile: 128)
  Disposition: Passed
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.015305 ms
  GFLOPS: 70156.5

Running: 
  Problem Size: 1024x512x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 128 (ScaleMsPerTile: 1)
  ScaleGranularityN: 128 (ScaleNsPerTile: 1)
  Disposition: Passed
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.014349 ms
  GFLOPS: 74830.6


StreamK GEMM kernel
Running: 
  Problem Size: 1024x512x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 1 (ScaleMsPerTile: 128)
  ScaleGranularityN: 1 (ScaleNsPerTile: 128)
  Disposition: Passed
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0176837 ms
  GFLOPS: 60719.1

Running: 
  Problem Size: 1024x512x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 1 (ScaleMsPerTile: 128)
  ScaleGranularityN: 128 (ScaleNsPerTile: 1)
  Disposition: Passed
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0159611 ms
  GFLOPS: 67272.3

Running: 
  Problem Size: 1024x512x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 128 (ScaleMsPerTile: 1)
  ScaleGranularityN: 1 (ScaleNsPerTile: 128)
  Disposition: Passed
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0170682 ms
  GFLOPS: 62909.1

Running: 
  Problem Size: 1024x512x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 128 (ScaleMsPerTile: 1)
  ScaleGranularityN: 128 (ScaleNsPerTile: 1)
  Disposition: Passed
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0157621 ms
  GFLOPS: 68121.7

Main:


Basic split-K GEMM kernel
  Disposition: Passed
  Problem Size: 1024x512x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 1 (ScaleMsPerTile: 128)
  ScaleGranularityN: 1 (ScaleNsPerTile: 128)
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0177991 ms
  GFLOPS: 60325.7

  Disposition: Passed
  Problem Size: 1024x512x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 1 (ScaleMsPerTile: 128)
  ScaleGranularityN: 128 (ScaleNsPerTile: 1)
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0151617 ms
  GFLOPS: 70819.4

  Disposition: Passed
  Problem Size: 1024x512x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 128 (ScaleMsPerTile: 1)
  ScaleGranularityN: 1 (ScaleNsPerTile: 128)
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0158391 ms
  GFLOPS: 67790.6

  Disposition: Passed
  Problem Size: 1024x512x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 128 (ScaleMsPerTile: 1)
  ScaleGranularityN: 128 (ScaleNsPerTile: 1)
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0148361 ms
  GFLOPS: 72373.6


StreamK GEMM kernel
  Disposition: Passed
  Problem Size: 1024x512x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 1 (ScaleMsPerTile: 128)
  ScaleGranularityN: 1 (ScaleNsPerTile: 128)
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0185101 ms
  GFLOPS: 58008.4

  Disposition: Passed
  Problem Size: 1024x512x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 1 (ScaleMsPerTile: 128)
  ScaleGranularityN: 128 (ScaleNsPerTile: 1)
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0160838 ms
  GFLOPS: 66759

  Disposition: Passed
  Problem Size: 1024x512x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 128 (ScaleMsPerTile: 1)
  ScaleGranularityN: 1 (ScaleNsPerTile: 128)
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0169781 ms
  GFLOPS: 63242.8

  Disposition: Passed
  Problem Size: 1024x512x1024x1
  Tile shape (M, N, K): _128, _128, _128
  ScaleGranularityM: 128 (ScaleMsPerTile: 1)
  ScaleGranularityN: 128 (ScaleNsPerTile: 1)
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0156655 ms
  GFLOPS: 68541.8

@hwu36 hwu36 merged commit df18f5e into NVIDIA:main Feb 28, 2025
}

if (options.k % size<2>(TileShape{}) != 0) {
std::cout << "Skippig (k size: " << options.k << " less then TileShape[2]: " << size<2>(TileShape{}) << "):" << std::endl;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a repeated typo?

@qijiaxing
Copy link

Hi @LucasWilkinson ,
if I want to pass pytoch tensors to do a 1D2D GEMM, should be the tensor and scale tensor shape like this?

A: [M, K]
B: [N, K]
A scale: [K/128, M]
B scale: [K/128, N/128]

andralex pushed a commit to andralex/cutlass that referenced this pull request Jun 14, 2025
* fix blockwise fp8 kernels

Signed-off-by: Lucas Wilkinson <[email protected]>

* wip, < 128 not working

Signed-off-by: Lucas Wilkinson <[email protected]>

* fix < 128

Signed-off-by: Lucas Wilkinson <[email protected]>

* reduce diff

Signed-off-by: Lucas Wilkinson <[email protected]>

* review comments

Signed-off-by: Lucas Wilkinson <[email protected]>

* support partial n blocks

Signed-off-by: Lucas Wilkinson <[email protected]>

* fix build errors

Signed-off-by: Lucas Wilkinson <[email protected]>

---------

Signed-off-by: Lucas Wilkinson <[email protected]>
Albresky pushed a commit to Albresky/cutlass that referenced this pull request Oct 11, 2025
* fix blockwise fp8 kernels

Signed-off-by: Lucas Wilkinson <[email protected]>

* wip, < 128 not working

Signed-off-by: Lucas Wilkinson <[email protected]>

* fix < 128

Signed-off-by: Lucas Wilkinson <[email protected]>

* reduce diff

Signed-off-by: Lucas Wilkinson <[email protected]>

* review comments

Signed-off-by: Lucas Wilkinson <[email protected]>

* support partial n blocks

Signed-off-by: Lucas Wilkinson <[email protected]>

* fix build errors

Signed-off-by: Lucas Wilkinson <[email protected]>

---------

Signed-off-by: Lucas Wilkinson <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants