-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Improvements for: Groupwise scaling along M for FP8 gemm #2095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements for: Groupwise scaling along M for FP8 gemm #2095
Conversation
|
@LucasWilkinson , we upstreamed our change to groupwise scaling kernels. there are some conflicts in this PR that needs to be solved. Our change is mainly: |
db87722 to
7f541db
Compare
apologies for the delay the PR has been updated, currently I am still vectorizing the loads of B scales along N (like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any promblems when transpose A and transpose B?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
currently this assumes full tiles in N and K so if using this for inference where activations may have partial tiles if you transpose it to Y^T = WX^T it may report not implementable, I think im going to update this since ideally in vLLM we'd like to transpose it to use smaller tensor core instructions, we do lose vectorization on the loads then though
include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp
Outdated
Show resolved
Hide resolved
...p8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe still using ScalePromotionInterval here, and move size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{} to can_implement check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm im not sure I see ScalePromotionInterval, what would be the motivation to not have this determined at compile time? it seems a bit unnecessarily burdensome on the user to have them set mma_promotion_interval manually
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In anycase moving this as constexpr somewhere on the top will better for readability.
static constexpr int ScalePromotionInterval = size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}) and using that here?
@hwu36, So this will be 4 for TileShapeK = 128 and InstructionShape = 32, which is the original case, for TileShape = 64 this will be 2. Will that be not supported?
edd90be to
2a9256f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this restriction only for M and not for N? dim-M usually maps to batch count while dim-N will be model_dimension, a nice multiple of 2? correct?
If this is A_row * B_col groupwise GEMM, it is sometimes required that we do transposed and swap creating an underlying GEMM to be B_row * A_row, swapping M <-> N. This is typically helpful for (a.) mixed-input BF16*F8 which doesn't apply here (b.) M is small say 64, we can swap and transpose to run a better tile. I have seen that to give more performance for small M.
Does vectorizing scale_copy_b vs not-vectorizing give any performance improvements? If not, I would suggest that we be symmetric for this kernel in M and N to allow user to apply swap and transpose trick to this kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was mostly just trying to keep it as close to the original as possible to minimize the chances of perf regressions, but I agree this is much less confusing. And I think we will want to transpose in vLLM in order to use smaller instructions for smaller batch sizes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pushed an update that enables partial tiles in N
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you make sure that this copy_if is issued by only 32 threads? The thread layout of shape 32 (created above) won't be tiled over entire tile by make_tiled_copy, just confirm please using simple printf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ran
if ((!blockIdx.x && !blockIdx.y && !blockIdx.z)) printf("%d ", threadIdx.x);
if (thread0()) printf("\n");
pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive_noinc);
and got:
...
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
...
I think we should be good 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should TMA related tensor constructions be in lane_predicate as before, no need for all the threads to construct this even in this implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Im not sure, I didn't think this was a big deal since if you look at the 3.6.0 diff with improved the mixed input GEMM (we were told 3.6 had perf improvements for mixed input) in include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp you can see that it was updated to have all the threads compute the TMA tensors, not sure what the recommended approach is, or if this particular change had any impact. Would some guidance!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In anycase moving this as constexpr somewhere on the top will better for readability.
static constexpr int ScalePromotionInterval = size<2>(TileShape{}) / size<2>(typename TiledMma::AtomShape_MNK{}) and using that here?
@hwu36, So this will be 4 for TileShapeK = 128 and InstructionShape = 32, which is the original case, for TileShape = 64 this will be 2. Will that be not supported?
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
1dc4ebd to
460b938
Compare
|
H100 This PR: Main: |
| } | ||
|
|
||
| if (options.k % size<2>(TileShape{}) != 0) { | ||
| std::cout << "Skippig (k size: " << options.k << " less then TileShape[2]: " << size<2>(TileShape{}) << "):" << std::endl; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like a repeated typo?
|
Hi @LucasWilkinson , |
* fix blockwise fp8 kernels Signed-off-by: Lucas Wilkinson <[email protected]> * wip, < 128 not working Signed-off-by: Lucas Wilkinson <[email protected]> * fix < 128 Signed-off-by: Lucas Wilkinson <[email protected]> * reduce diff Signed-off-by: Lucas Wilkinson <[email protected]> * review comments Signed-off-by: Lucas Wilkinson <[email protected]> * support partial n blocks Signed-off-by: Lucas Wilkinson <[email protected]> * fix build errors Signed-off-by: Lucas Wilkinson <[email protected]> --------- Signed-off-by: Lucas Wilkinson <[email protected]>
* fix blockwise fp8 kernels Signed-off-by: Lucas Wilkinson <[email protected]> * wip, < 128 not working Signed-off-by: Lucas Wilkinson <[email protected]> * fix < 128 Signed-off-by: Lucas Wilkinson <[email protected]> * reduce diff Signed-off-by: Lucas Wilkinson <[email protected]> * review comments Signed-off-by: Lucas Wilkinson <[email protected]> * support partial n blocks Signed-off-by: Lucas Wilkinson <[email protected]> * fix build errors Signed-off-by: Lucas Wilkinson <[email protected]> --------- Signed-off-by: Lucas Wilkinson <[email protected]>
Various improvements to "Groupwise scaling along M" (#2037) namely to address: #2087, context vllm-project/vllm#11868 (comment)
Improvements:
this PR moves to a layout of (i.e. standard M-major):
making it much easier to integrate into inference libraries
These improvements were part of vLLMs adoption of this kernel https://github.com/vllm-project/vllm/blob/v0.7.1/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp (PR: vllm-project/vllm#11868) and is in current wide scale use. Our goal is to rely on the CUTLASS implementation but that currently not possible given the issues above.