Skip to content

Commit 5abe410

Browse files
authored
Introduces the new partitioner to implement the reduction StreamK kernel. (#3107)
* Introduces the new partitioner to implement the reduction StreamK kernel * Add more doc text to functions * Add persistent-dp option to streamk example * Update example/ck_tile/40_streamk_gemm/README.md
1 parent 13ba06f commit 5abe410

File tree

8 files changed

+298
-75
lines changed

8 files changed

+298
-75
lines changed

example/ck_tile/40_streamk_gemm/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ args:
2222
-a_layout tensor A data layout (default: R)
2323
-b_layout tensor B data layout (default: C)
2424
-c_layout tensor C data layout (default: R)
25-
-num_sk_blocks number of Stream-K blocks. -1: chosen by algorithm, or user selected (default:-1)
2625
-reduction_strategy strategy for storing results in C tensor. atomic/reduction (default:atomic)
26+
-persistent_dp persistent strategy for data-parallel section. Set to 0 for non-persistent or to 1 for persistent. (default:0)
2727
-stride_a tensor A stride (default:0)
2828
-stride_b tensor B stride (default:0)
2929
-stride_c tensor C stride (default:0)

example/ck_tile/40_streamk_gemm/gemm_utils.hpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ struct GemmConfigBase
1818

1919
static constexpr bool TransposeC = false;
2020
static constexpr bool UseStructuredSparsity = false;
21-
static constexpr bool Persistent = false;
2221

2322
static constexpr int kBlockPerCu = 1;
2423
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
@@ -27,12 +26,12 @@ struct GemmConfigBase
2726
static constexpr bool DoubleSmemBuffer = false;
2827
};
2928

30-
template <typename PrecType>
29+
template <typename PrecType, bool Persistent_>
3130
struct GemmConfigMemoryInterwave : public GemmConfigBase
3231
{
33-
static constexpr ck_tile::index_t M_Tile = 128;
34-
static constexpr ck_tile::index_t N_Tile = 128;
35-
static constexpr ck_tile::index_t K_Tile = 32;
32+
static constexpr ck_tile::index_t M_Tile = 256;
33+
static constexpr ck_tile::index_t N_Tile = 256;
34+
static constexpr ck_tile::index_t K_Tile = 16;
3635

3736
static constexpr ck_tile::index_t M_Warp = 2;
3837
static constexpr ck_tile::index_t N_Warp = 2;
@@ -42,7 +41,8 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
4241
static constexpr ck_tile::index_t N_Warp_Tile = 32;
4342
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
4443

45-
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
44+
static constexpr bool Persistent = Persistent_;
45+
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
4646
};
4747

4848
template <typename ADataType_, typename BDataType_ = ADataType_, typename CDataType_ = ADataType_>
@@ -96,12 +96,12 @@ auto create_args(int argc, char* argv[])
9696
.insert("a_layout", "R", "A tensor data layout - Row by default")
9797
.insert("b_layout", "C", "B tensor data layout - Column by default")
9898
.insert("c_layout", "R", "C tensor data layout - Row by default")
99-
.insert("num_sk_blocks",
100-
"-1",
101-
"number of Stream-K blocks. -1: chosen by algorithm, or user selected")
10299
.insert("reduction_strategy",
103100
"atomic",
104101
"strategy for storing results in C tensor - atomic/reduction")
102+
.insert("persistent_dp",
103+
"0",
104+
"0. Non-persistent data-parallel section, 1 Fully persistent kernel.")
105105
.insert("stride_a", "0", "Tensor A stride")
106106
.insert("stride_b", "0", "Tensor B stride")
107107
.insert("stride_c", "0", "Tensor C stride")

example/ck_tile/40_streamk_gemm/run_gemm_example.inc

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,18 @@ invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
6969
int n_warmup,
7070
int n_repeat,
7171
bool flush_cache,
72-
ck_tile::StreamKReductionStrategy reduction_strategy,
73-
uint32_t num_sk_blocks)
72+
ck_tile::StreamKReductionStrategy reduction_strategy)
7473
{
75-
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
76-
b_k_n_dev_buf.GetDeviceBuffer(),
77-
c_m_n_dev_buf.GetDeviceBuffer(),
78-
M,
79-
N,
80-
K,
81-
stride_A,
82-
stride_B,
83-
stride_C,
84-
reduction_strategy,
85-
num_sk_blocks};
74+
ck_tile::reboot::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
75+
b_k_n_dev_buf.GetDeviceBuffer(),
76+
c_m_n_dev_buf.GetDeviceBuffer(),
77+
M,
78+
N,
79+
K,
80+
stride_A,
81+
stride_B,
82+
stride_C,
83+
reduction_strategy};
8684

8785
std::tuple<float, ck_tile::index_t> ave_time_and_batch;
8886

@@ -197,7 +195,6 @@ int run_gemm_example_with_layouts(int argc,
197195

198196
ck_tile::StreamKReductionStrategy reduction_strategy =
199197
get_reduction_strategy_value(arg_parser.get_str("reduction_strategy"));
200-
uint32_t num_sk_blocks = static_cast<uint32_t>(arg_parser.get_int("num_sk_blocks"));
201198

202199
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
203200
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
@@ -261,8 +258,7 @@ int run_gemm_example_with_layouts(int argc,
261258
n_warmup,
262259
n_repeat,
263260
flush_cache,
264-
reduction_strategy,
265-
num_sk_blocks);
261+
reduction_strategy);
266262

267263
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
268264

@@ -279,10 +275,10 @@ int run_gemm_example_with_layouts(int argc,
279275
<< " B_Type=" << DataTypeTraits<BDataType>::name
280276
<< " C_Type=" << DataTypeTraits<CDataType>::name
281277
<< " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " "
282-
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
283-
<< std::endl;
278+
<< " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << ave_time
279+
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
284280

285-
bool pass = true;
281+
bool pass = false;
286282

287283
// Memory on host to store gpu reference result
288284
ck_tile::HostTensor<CDataType> c_m_n_ref(

example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp

Lines changed: 64 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// SPDX-License-Identifier: MIT
33

44
#include "gemm_utils.hpp"
5-
#include "run_gemm_example.inc"
65
#include "ck_tile/ops/common.hpp"
76

87
template <typename GemmConfig,
@@ -17,9 +16,8 @@ template <typename GemmConfig,
1716
typename ELayout,
1817
typename CDEElementWise,
1918
ck_tile::StreamKReductionStrategy ReductionStrategy>
20-
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
19+
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::reboot::StreamKHostArgs& args,
2120
const ck_tile::stream_config& s)
22-
2321
{
2422
using GemmShape = ck_tile::TileGemmShape<
2523
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
@@ -29,7 +27,8 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
2927
GemmConfig::PermuteA,
3028
GemmConfig::PermuteB>;
3129

32-
using TilePartitioner = ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy>;
30+
using TilePartitioner =
31+
ck_tile::StreamKTilePartitioner_v2<GemmShape, ReductionStrategy, GemmConfig::Persistent>;
3332

3433
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
3534
GemmConfig::kPadN,
@@ -78,9 +77,13 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
7877
memory_operation.value,
7978
GemmConfig::NumWaveGroups>>;
8079

81-
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
80+
using Kernel = ck_tile::reboot::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
8281

83-
auto kargs = Kernel::MakeKernelArgs(args);
82+
auto kargs = Kernel::MakeKernelArgs(args);
83+
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
84+
ck_tile::DeviceMem workspace_data(workspace_size);
85+
workspace_data.SetZero();
86+
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
8487

8588
dim3 grids = Kernel::GridSize(kargs.tile_partitioner);
8689
dim3 blocks = Kernel::BlockSize();
@@ -101,28 +104,28 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
101104
<< std::endl;
102105
}
103106

104-
// Function to clear the output C tensor results after each repetition of the kernel
105-
auto clear_gemm_output = [&]() {
107+
auto reset_data_buffers = [&]() {
106108
if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
109+
{
110+
// Clear the output C tensor results after each repetition of the kernel
107111
hipGetErrorString(hipMemsetAsync(
108112
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
113+
}
114+
else if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
115+
{
116+
// Reset sk flags to zero before each repetition of the kernel
117+
workspace_data.SetZero();
118+
}
109119
};
110120

111-
std::function<void()> preprocess = clear_gemm_output;
121+
std::function<void()> preprocess = reset_data_buffers;
112122

113123
float ave_time = ck_tile::launch_kernel_time_mask(
114124
s,
115125
preprocess,
116126
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
117127

118-
ck_tile::index_t num_wgs_per_tile = ck_tile::estimate_num_wgs_per_tile<ReductionStrategy>(
119-
kargs.tile_partitioner.sk_num_blocks,
120-
// k_iters_per_big_block could be 1, which indicates that all Stream-K workgroups are
121-
// big and each does one iteration. Thus, we ensure the value passed in is at least 1 to
122-
// avoid division by zero errors.
123-
ck_tile::max(kargs.tile_partitioner.k_iters_per_big_block - 1, 1u),
124-
kargs.tile_partitioner.k_iters_per_tile.get());
125-
128+
ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
126129
return std::tuple{ave_time, num_wgs_per_tile};
127130
};
128131

@@ -145,6 +148,8 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
145148
}
146149
}
147150

151+
#include "run_gemm_example.inc"
152+
148153
template <typename GemmConfig, typename TypeConfig>
149154
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
150155
{
@@ -164,7 +169,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
164169
return 0;
165170
}
166171

167-
template <template <typename PreType> typename GemmConfig>
172+
template <template <typename PreType, bool Persistent_> typename GemmConfig>
168173
int run_gemm_example(int argc, char* argv[])
169174
{
170175
auto [result, arg_parser] = create_args(argc, argv);
@@ -174,30 +179,63 @@ int run_gemm_example(int argc, char* argv[])
174179
std::string data_type = arg_parser.get_str("prec");
175180
std::string a_layout = arg_parser.get_str("a_layout");
176181
std::string b_layout = arg_parser.get_str("b_layout");
182+
auto persistent_dp = arg_parser.get_bool("persistent_dp");
177183

178184
if(data_type == "bf16")
179185
{
180186
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf16_t>;
181-
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, TypeConfig>(
182-
a_layout, b_layout, argc, argv);
187+
if(persistent_dp)
188+
{
189+
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, true>, TypeConfig>(
190+
a_layout, b_layout, argc, argv);
191+
}
192+
else
193+
{
194+
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, false>, TypeConfig>(
195+
a_layout, b_layout, argc, argv);
196+
}
183197
}
184198
else if(data_type == "fp16")
185199
{
186200
using TypeConfig = StreamKGemmTypeConfig<ck_tile::half_t>;
187-
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, TypeConfig>(
188-
a_layout, b_layout, argc, argv);
201+
if(persistent_dp)
202+
{
203+
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, true>, TypeConfig>(
204+
a_layout, b_layout, argc, argv);
205+
}
206+
else
207+
{
208+
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, false>, TypeConfig>(
209+
a_layout, b_layout, argc, argv);
210+
}
189211
}
190212
else if(data_type == "fp8")
191213
{
192214
using TypeConfig = StreamKGemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
193-
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig>(
194-
a_layout, b_layout, argc, argv);
215+
if(persistent_dp)
216+
{
217+
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t, true>, TypeConfig>(
218+
a_layout, b_layout, argc, argv);
219+
}
220+
else
221+
{
222+
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t, false>, TypeConfig>(
223+
a_layout, b_layout, argc, argv);
224+
}
195225
}
196226
else if(data_type == "bf8")
197227
{
198228
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
199-
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig>(
200-
a_layout, b_layout, argc, argv);
229+
if(persistent_dp)
230+
{
231+
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t, true>, TypeConfig>(
232+
a_layout, b_layout, argc, argv);
233+
}
234+
else
235+
{
236+
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t, false>, TypeConfig>(
237+
a_layout, b_layout, argc, argv);
238+
}
201239
}
202240
else
203241
{

include/ck_tile/host/kernel_launch.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ CK_TILE_HOST double timing_loop_impl(TimerType timer,
110110
{
111111
for(int i = 0; i < s.cold_niters_; i++)
112112
{
113+
if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
114+
{
115+
preprocess();
116+
}
113117
callables_func();
114118
}
115119
// Only profile preprocess if it's provided

0 commit comments

Comments
 (0)