Skip to content

Commit 3c28697

Browse files
soundOfDestinyzlhwu36
authored
Groupwise scaling along M for FP8 gemm (#2037)
* FP8 groupwise scaling along M * small updates --------- Co-authored-by: zl <[email protected]> Co-authored-by: Haicheng Wu <[email protected]>
1 parent bdd6417 commit 3c28697

File tree

8 files changed

+1418
-48
lines changed

8 files changed

+1418
-48
lines changed

examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ using ArchTag = cutlass::arch::Sm90; // T
123123
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
124124
using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size
125125
using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster
126-
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum;
126+
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<>;
127127
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative;
128128

129129
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;

examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu

Lines changed: 770 additions & 0 deletions
Large diffs are not rendered by default.

examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,8 @@ cutlass_example_add_executable(
3030
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling
3131
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling.cu
3232
)
33+
34+
cutlass_example_add_executable(
35+
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling
36+
67_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling.cu
37+
)

examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/reference/host/gemm_with_groupwise_scaling.h

Lines changed: 507 additions & 0 deletions
Large diffs are not rendered by default.

include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,28 @@ compute_stage_count_or_override(StageCountAutoCarveout<carveout_bytes_> stage_co
8484
return (capacity_bytes - carveout_bytes) / stage_bytes;
8585
}
8686

87+
// Returns the maximum number of smem tiles that can be used with a given smem capacity in gemm of blockwise/groupwise scale.
88+
template<int capacity_bytes_, class ElementA, class ElementB, class ElementBlockScale, class TileShapeMNK, int ScaleMsPerTile, int carveout_bytes_, int alignment = 128>
89+
constexpr int
90+
compute_stage_count_with_blockwise_scale(StageCountAutoCarveout<carveout_bytes_> stage_count) {
91+
constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage);
92+
constexpr auto a_bits = cute::sizeof_bits_v<ElementA>;
93+
constexpr auto b_bits = cute::sizeof_bits_v<ElementB>;
94+
constexpr auto scale_bits = cute::sizeof_bits_v<ElementBlockScale>;
95+
constexpr int stage_bytes_ =
96+
cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
97+
cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) +
98+
cutlass::bits_to_bytes(scale_bits * ScaleMsPerTile) + // scale of tensor A
99+
cutlass::bits_to_bytes(scale_bits * 1); // scale of tensor B
100+
101+
constexpr int stage_bytes = cutlass::round_up(stage_bytes_, alignment) +
102+
static_cast<int>(mainloop_pipeline_bytes);
103+
constexpr int carveout_bytes = cutlass::round_up(carveout_bytes_, alignment);
104+
constexpr int capacity_bytes = capacity_bytes_ / alignment * alignment;
105+
106+
return (capacity_bytes - carveout_bytes) / stage_bytes;
107+
}
108+
87109
// Returns the maximum number of smem tiles that can be used with a given smem capacity (with an optional scale matrix), or overrides with manual count.
88110
template<int capacity_bytes, class ElementA, class ElementB, class ElementScale, class ElementZero, class TileShapeMNK, int stages, int alignment = 128>
89111
constexpr int
@@ -1009,7 +1031,7 @@ template <
10091031
class TileShape_MNK,
10101032
class ClusterShape_MNK,
10111033
class StageCountType,
1012-
class KernelScheduleType
1034+
int ScaleGranularityM_
10131035
>
10141036
struct CollectiveBuilder<
10151037
arch::Sm90,
@@ -1024,12 +1046,12 @@ struct CollectiveBuilder<
10241046
TileShape_MNK,
10251047
ClusterShape_MNK,
10261048
StageCountType,
1027-
KernelScheduleType,
1049+
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_>,
10281050
cute::enable_if_t<
1029-
(cute::is_any_of_v<KernelScheduleType,
1030-
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>) &&
1031-
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
1051+
not detail::is_use_rmem_A<ElementA, GmemLayoutATag, ElementB, GmemLayoutBTag>()>
10321052
> {
1053+
using KernelScheduleType = KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_>;
1054+
10331055
static_assert(is_static<TileShape_MNK>::value);
10341056
static_assert(is_static<ClusterShape_MNK>::value);
10351057
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
@@ -1048,14 +1070,15 @@ struct CollectiveBuilder<
10481070
// For fp32 types, map to tf32 MMA value type
10491071
using ElementAMma = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
10501072
using ElementBMma = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
1073+
using ElementBlockScale = ElementAccumulator;
10511074

10521075
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementAMma, GmemLayoutATag>();
10531076
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementBMma, GmemLayoutBTag>();
10541077

10551078
static constexpr bool IsCooperative = cute::is_any_of_v<KernelScheduleType,
10561079
KernelTmaWarpSpecializedCooperative,
10571080
KernelPtrArrayTmaWarpSpecializedCooperative,
1058-
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>;
1081+
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum<ScaleGranularityM_>>;
10591082
using AtomLayoutMNK = cute::conditional_t<IsCooperative,
10601083
Layout<Shape<_2,_1,_1>>, Layout<Shape<_1,_1,_1>>>;
10611084

@@ -1073,9 +1096,13 @@ struct CollectiveBuilder<
10731096
static constexpr size_t TensorMapStorage = IsArrayOfPointersGemm ? sizeof(cute::TmaDescriptor) * 2 /* for A and B */ : 0;
10741097
static constexpr int KernelSmemCarveout = static_cast<int>(TensorMapStorage);
10751098

1076-
static constexpr int PipelineStages = detail::compute_stage_count_or_override<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
1077-
ElementAMma, ElementBMma, TileShape_MNK>(StageCountType{});
1078-
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>;
1099+
static constexpr int ScaleGranularityM = ScaleGranularityM_ == 0 ? size<0>(TileShape_MNK{}) : ScaleGranularityM_;
1100+
static constexpr int ScaleMsPerTile = size<0>(TileShape_MNK{}) / ScaleGranularityM;
1101+
static_assert((size<0>(TileShape_MNK{}) % ScaleGranularityM) == 0, "FP8 scaling granularity must evenly divide tile shape along M.");
1102+
1103+
static constexpr int PipelineStages = detail::compute_stage_count_with_blockwise_scale<detail::sm90_smem_capacity_bytes - KernelSmemCarveout,
1104+
ElementAMma, ElementBMma, ElementBlockScale, TileShape_MNK, ScaleMsPerTile>(StageCountType{});
1105+
using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedBlockScalingFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType, ScaleGranularityM_>;
10791106

10801107
using SmemCopyAtomA = void;
10811108
using SmemCopyAtomB = void;

include/cutlass/gemm/collective/fp8_accumulation.hpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,22 @@ struct GmmaFP8Accumulation {
7575
}
7676

7777
// `multiply` scale the partial accumulators and `add` to main accumulator (FFMA).
78+
template <
79+
class EngineScale,
80+
class LayoutScale>
7881
CUTLASS_DEVICE
79-
void scale_core(ElementAccumulator const& scale) {
82+
void scale_core(const cute::Tensor<EngineScale, LayoutScale> &scale) {
83+
using TensorScale = cute::Tensor<EngineScale, LayoutScale>;
84+
85+
static_assert(is_static<LayoutScale>::value, "Scale Layout should be static");
86+
static_assert(is_rmem<TensorScale>::value , "Scale tensor must be rmem resident.");
87+
88+
static_assert(LayoutAccum{}.shape() == LayoutScale{}.shape(), "Accumulator and scale must have same shape.");
89+
8090
warpgroup_wait<0>();
8191
CUTLASS_PRAGMA_UNROLL
8292
for (int i = 0; i < size(accum_); ++i) {
83-
accum_(i) += accum_temp_(i) * scale;
93+
accum_(i) += accum_temp_(i) * scale(i);
8494
}
8595
}
8696

@@ -142,8 +152,11 @@ struct GmmaFP8Accumulation {
142152
//
143153

144154
/// scale (multiply_add) the results from the MMA accumulators to main accumulator if needed.
155+
template <
156+
class EngineScale,
157+
class LayoutScale>
145158
CUTLASS_DEVICE
146-
void scale_if_needed(ElementAccumulator const& scale) {
159+
void scale_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
147160
mma_count_ += mma_count_per_mainloop_iteration_;
148161
reset_accum_flag_ = __shfl_sync(0xffffffff, mma_count_ == accum_promotion_interval_, 0);
149162
if (reset_accum_flag_) {
@@ -153,8 +166,11 @@ struct GmmaFP8Accumulation {
153166
}
154167

155168
/// scale (multiply_add) the residue results from the MMA accumulators to main accumulator if needed.
169+
template <
170+
class EngineScale,
171+
class LayoutScale>
156172
CUTLASS_DEVICE
157-
void scale_residue_if_needed(ElementAccumulator const& scale) {
173+
void scale_residue_if_needed(const cute::Tensor<EngineScale, LayoutScale> &scale) {
158174
if (__shfl_sync(0xffffffff, mma_count_ > 0, 0)) {
159175
scale_core(scale);
160176
}

0 commit comments

Comments
 (0)