Skip to content

Commit 1e83dd2

Browse files
add fp8 gemm.
1 parent 54edfc5 commit 1e83dd2

23 files changed

Lines changed: 2352 additions & 123 deletions

custom_ops/gpu_ops/cutlass_kernels/cutlass_helper.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,59 @@ struct enable_sm90_or_later : Kernel {
4848
}
4949
};
5050

51+
// SM90: covers SM90 (Hopper) only
52+
template <typename Kernel>
53+
struct enable_sm90_only : Kernel {
54+
template <typename... Args>
55+
CUTLASS_DEVICE void operator()(Args &&...args) {
56+
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 900
57+
Kernel::operator()(std::forward<Args>(args)...);
58+
#endif
59+
}
60+
};
61+
62+
// SM100f: covers SM100 (Blackwell GB200) and SM103 (GB10x)
63+
template <typename Kernel>
64+
struct enable_sm100f_only : Kernel {
65+
template <typename... Args>
66+
CUTLASS_DEVICE void operator()(Args &&...args) {
67+
#if defined __CUDA_ARCH__ && (__CUDA_ARCH__ == 1000 || __CUDA_ARCH__ == 1030)
68+
Kernel::operator()(std::forward<Args>(args)...);
69+
#endif
70+
}
71+
};
72+
73+
// SM120: covers SM120 (RTX 5090)
74+
template <typename Kernel>
75+
struct enable_sm120_only : Kernel {
76+
template <typename... Args>
77+
CUTLASS_DEVICE void operator()(Args &&...args) {
78+
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1200
79+
Kernel::operator()(std::forward<Args>(args)...);
80+
#endif
81+
}
82+
};
83+
84+
// SM12x family: covers SM120 (RTX 5090) and SM121 (DGX Spark)
85+
template <typename Kernel>
86+
struct enable_sm120_family : Kernel {
87+
template <typename... Args>
88+
CUTLASS_DEVICE void operator()(Args &&...args) {
89+
#if defined __CUDA_ARCH__ && (__CUDA_ARCH__ >= 1200 && __CUDA_ARCH__ < 1300)
90+
Kernel::operator()(std::forward<Args>(args)...);
91+
#endif
92+
}
93+
};
94+
95+
inline int32_t get_sm_version_num() {
96+
int device = -1;
97+
cudaGetDevice(&device);
98+
int major = 0, minor = 0;
99+
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device);
100+
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device);
101+
return major * 10 + minor;
102+
}
103+
51104
template <paddle::DataType D>
52105
class CutlassDtypeTraits;
53106

custom_ops/gpu_ops/cutlass_kernels/w8a8/c3x/scaled_mm.cuh

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ namespace fastdeploy {
3434

3535
template <typename ElementAB_,
3636
typename ElementD_,
37-
template <typename, typename, typename>
38-
typename Epilogue_,
37+
template <typename, typename, typename> typename Epilogue_,
3938
typename TileShape,
4039
typename ClusterShape,
4140
typename KernelSchedule,
@@ -57,7 +56,8 @@ struct cutlass_3x_gemm {
5756
// These are the minimum alignments needed for the kernels to compile
5857
static constexpr int AlignmentAB =
5958
128 / cutlass::sizeof_bits<ElementAB>::value;
60-
static constexpr int AlignmentCD = 4;
59+
static constexpr int AlignmentCD =
60+
128 / cutlass::sizeof_bits<ElementD>::value;
6161

6262
using CollectiveEpilogue =
6363
typename cutlass::epilogue::collective::CollectiveBuilder<
@@ -104,8 +104,7 @@ struct cutlass_3x_gemm {
104104

105105
template <typename ElementAB_,
106106
typename ElementD_,
107-
template <typename, typename, typename>
108-
typename Epilogue_,
107+
template <typename, typename, typename> typename Epilogue_,
109108
typename TileShape,
110109
typename ClusterShape,
111110
typename KernelSchedule,
@@ -180,11 +179,88 @@ struct cutlass_3x_gemm_sm100 {
180179
sizeof(typename CollectiveEpilogue::SharedStorage))>,
181180
KernelSchedule>::CollectiveOp;
182181

183-
using GemmKernel =
182+
using GemmKernel = enable_sm100f_only<
184183
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>,
185184
CollectiveMainloop,
186185
CollectiveEpilogue,
187-
void>;
186+
void>>;
187+
};
188+
189+
template <typename ElementAB_,
190+
typename ElementD_,
191+
template <typename, typename, typename> typename Epilogue_,
192+
typename TileShape,
193+
typename ClusterShape,
194+
typename KernelSchedule,
195+
typename EpilogueSchedule>
196+
struct cutlass_3x_gemm_sm120 {
197+
using ElementAB = ElementAB_;
198+
using LayoutA = cutlass::layout::RowMajor;
199+
static constexpr int AlignmentA =
200+
128 / cutlass::sizeof_bits<ElementAB>::value;
201+
202+
using LayoutB = cutlass::layout::ColumnMajor;
203+
static constexpr int AlignmentB =
204+
128 / cutlass::sizeof_bits<ElementAB>::value;
205+
206+
using ElementC = void;
207+
using LayoutC = cutlass::layout::RowMajor;
208+
static constexpr int AlignmentC =
209+
128 / cutlass::sizeof_bits<ElementD_>::value;
210+
211+
using ElementD = ElementD_;
212+
using LayoutD = cutlass::layout::RowMajor;
213+
static constexpr int AlignmentD = AlignmentC;
214+
215+
using ElementAcc = typename std::
216+
conditional<std::is_same_v<ElementAB, int8_t>, int32_t, float>::type;
217+
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
218+
219+
using ElementAccumulator = float;
220+
using ElementCompute = float;
221+
222+
using EVTCompute = typename Epilogue::EVTCompute;
223+
224+
using CollectiveEpilogue =
225+
typename cutlass::epilogue::collective::CollectiveBuilder<
226+
cutlass::arch::Sm120,
227+
cutlass::arch::OpClassTensorOp,
228+
TileShape,
229+
ClusterShape,
230+
cutlass::epilogue::collective::EpilogueTileAuto,
231+
ElementAccumulator,
232+
ElementCompute,
233+
ElementC,
234+
LayoutC,
235+
AlignmentC,
236+
ElementD,
237+
LayoutD,
238+
AlignmentD,
239+
EpilogueSchedule,
240+
EVTCompute>::CollectiveOp;
241+
242+
using CollectiveMainloop =
243+
typename cutlass::gemm::collective::CollectiveBuilder<
244+
cutlass::arch::Sm120,
245+
cutlass::arch::OpClassTensorOp,
246+
ElementAB,
247+
LayoutA,
248+
AlignmentA,
249+
ElementAB,
250+
LayoutB,
251+
AlignmentB,
252+
ElementAccumulator,
253+
TileShape,
254+
ClusterShape,
255+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
256+
sizeof(typename CollectiveEpilogue::SharedStorage))>,
257+
KernelSchedule>::CollectiveOp;
258+
259+
using GemmKernel = enable_sm120_only<
260+
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>,
261+
CollectiveMainloop,
262+
CollectiveEpilogue,
263+
void>>;
188264
};
189265

190266
} // namespace fastdeploy
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// adapted from:
2+
// https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
3+
4+
#include "scaled_mm_kernels.hpp"
5+
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
6+
7+
namespace fastdeploy {
8+
9+
void cutlass_scaled_mm_blockwise_sm100_fp8(paddle::Tensor &out,
10+
paddle::Tensor const &a,
11+
paddle::Tensor const &b,
12+
paddle::Tensor const &a_scales,
13+
paddle::Tensor const &b_scales) {
14+
if (out.dtype() == paddle::DataType::BFLOAT16) {
15+
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
16+
out, a, b, a_scales, b_scales);
17+
} else {
18+
PD_CHECK(out.dtype() == paddle::DataType::FLOAT16);
19+
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
20+
out, a, b, a_scales, b_scales);
21+
}
22+
}
23+
24+
} // namespace fastdeploy

0 commit comments

Comments
 (0)