@@ -34,8 +34,7 @@ namespace fastdeploy {
3434
3535template <typename ElementAB_,
3636 typename ElementD_,
37- template <typename , typename , typename >
38- typename Epilogue_,
37+ template <typename , typename , typename > typename Epilogue_,
3938 typename TileShape,
4039 typename ClusterShape,
4140 typename KernelSchedule,
@@ -57,7 +56,8 @@ struct cutlass_3x_gemm {
5756 // These are the minimum alignments needed for the kernels to compile
5857 static constexpr int AlignmentAB =
5958 128 / cutlass::sizeof_bits<ElementAB>::value;
60- static constexpr int AlignmentCD = 4 ;
59+ static constexpr int AlignmentCD =
60+ 128 / cutlass::sizeof_bits<ElementD>::value;
6161
6262 using CollectiveEpilogue =
6363 typename cutlass::epilogue::collective::CollectiveBuilder<
@@ -104,8 +104,7 @@ struct cutlass_3x_gemm {
104104
105105template <typename ElementAB_,
106106 typename ElementD_,
107- template <typename , typename , typename >
108- typename Epilogue_,
107+ template <typename , typename , typename > typename Epilogue_,
109108 typename TileShape,
110109 typename ClusterShape,
111110 typename KernelSchedule,
@@ -180,11 +179,88 @@ struct cutlass_3x_gemm_sm100 {
180179 sizeof (typename CollectiveEpilogue::SharedStorage))>,
181180 KernelSchedule>::CollectiveOp;
182181
183- using GemmKernel =
182+ using GemmKernel = enable_sm100f_only<
184183 cutlass::gemm::kernel::GemmUniversal<Shape<int , int , int , int >,
185184 CollectiveMainloop,
186185 CollectiveEpilogue,
187- void >;
186+ void >>;
187+ };
188+
189+ template <typename ElementAB_,
190+ typename ElementD_,
191+ template <typename , typename , typename > typename Epilogue_,
192+ typename TileShape,
193+ typename ClusterShape,
194+ typename KernelSchedule,
195+ typename EpilogueSchedule>
196+ struct cutlass_3x_gemm_sm120 {
197+ using ElementAB = ElementAB_;
198+ using LayoutA = cutlass::layout::RowMajor;
199+ static constexpr int AlignmentA =
200+ 128 / cutlass::sizeof_bits<ElementAB>::value;
201+
202+ using LayoutB = cutlass::layout::ColumnMajor;
203+ static constexpr int AlignmentB =
204+ 128 / cutlass::sizeof_bits<ElementAB>::value;
205+
206+ using ElementC = void ;
207+ using LayoutC = cutlass::layout::RowMajor;
208+ static constexpr int AlignmentC =
209+ 128 / cutlass::sizeof_bits<ElementD_>::value;
210+
211+ using ElementD = ElementD_;
212+ using LayoutD = cutlass::layout::RowMajor;
213+ static constexpr int AlignmentD = AlignmentC;
214+
215+ using ElementAcc = typename std::
216+ conditional<std::is_same_v<ElementAB, int8_t >, int32_t , float >::type;
217+ using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
218+
219+ using ElementAccumulator = float ;
220+ using ElementCompute = float ;
221+
222+ using EVTCompute = typename Epilogue::EVTCompute;
223+
224+ using CollectiveEpilogue =
225+ typename cutlass::epilogue::collective::CollectiveBuilder<
226+ cutlass::arch::Sm120,
227+ cutlass::arch::OpClassTensorOp,
228+ TileShape,
229+ ClusterShape,
230+ cutlass::epilogue::collective::EpilogueTileAuto,
231+ ElementAccumulator,
232+ ElementCompute,
233+ ElementC,
234+ LayoutC,
235+ AlignmentC,
236+ ElementD,
237+ LayoutD,
238+ AlignmentD,
239+ EpilogueSchedule,
240+ EVTCompute>::CollectiveOp;
241+
242+ using CollectiveMainloop =
243+ typename cutlass::gemm::collective::CollectiveBuilder<
244+ cutlass::arch::Sm120,
245+ cutlass::arch::OpClassTensorOp,
246+ ElementAB,
247+ LayoutA,
248+ AlignmentA,
249+ ElementAB,
250+ LayoutB,
251+ AlignmentB,
252+ ElementAccumulator,
253+ TileShape,
254+ ClusterShape,
255+ cutlass::gemm::collective::StageCountAutoCarveout<static_cast <int >(
256+ sizeof (typename CollectiveEpilogue::SharedStorage))>,
257+ KernelSchedule>::CollectiveOp;
258+
259+ using GemmKernel = enable_sm120_only<
260+ cutlass::gemm::kernel::GemmUniversal<Shape<int , int , int , int >,
261+ CollectiveMainloop,
262+ CollectiveEpilogue,
263+ void >>;
188264};
189265
190266} // namespace fastdeploy
0 commit comments