What is your question?
I'm trying to apply StreamK or SplitK to a hopper warp specialized GEMM.
you can see the full code here and the Gemm is declared in this way:
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
TileShape, ClusterShape,
EpilogueTileType,
ElementAccumulator, ElementCompute,
ElementC, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD,
EpilogueSchedule,
FusionOperation
>::CollectiveOp;
using CollectiveMainloopWithBlockWiseScaling = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
>,
KernelSchedule
>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloopWithBlockWiseScaling,
CollectiveEpilogue
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
Firstly, I tried just replacing GemmUniversal with GemmSplitKParallel to apply SplitK. It didn't work.
Now I'm trying to apply StreamK/SplitK to the GEMM with the other way based on this example
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>,
CollectiveOp,
EpilogueOp,
cutlass::gemm::StreamKScheduler // <--- Change needed to enable the stream-K scheduler
>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
But I'm not sure how to use it next.
Does this method support FP8 GEMM on Hopper? Which version of CUDA should I use?
@jackkosaian Will be very grateful for your help!
What is your question?
I'm trying to apply
StreamKorSplitKto a hopper warp specialized GEMM.you can see the full code here and the
Gemmis declared in this way:Firstly, I tried just replacing
GemmUniversalwithGemmSplitKParallelto applySplitK. It didn't work.Now I'm trying to apply
StreamK/SplitKto the GEMM with the other way based on this exampleBut I'm not sure how to use it next.
Does this method support FP8 GEMM on Hopper? Which version of CUDA should I use?
@jackkosaian Will be very grateful for your help!