22// SPDX-License-Identifier: MIT
33
44#include " gemm_utils.hpp"
5- #include " run_gemm_example.inc"
65#include " ck_tile/ops/common.hpp"
76
87template <typename GemmConfig,
@@ -17,9 +16,8 @@ template <typename GemmConfig,
1716 typename ELayout,
1817 typename CDEElementWise,
1918 ck_tile::StreamKReductionStrategy ReductionStrategy>
20- std::tuple<float , ck_tile::index_t > gemm (const ck_tile::StreamKHostArgs& args,
19+ std::tuple<float , ck_tile::index_t > gemm (const ck_tile::reboot:: StreamKHostArgs& args,
2120 const ck_tile::stream_config& s)
22-
2321{
2422 using GemmShape = ck_tile::TileGemmShape<
2523 ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
@@ -29,7 +27,8 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
2927 GemmConfig::PermuteA,
3028 GemmConfig::PermuteB>;
3129
32- using TilePartitioner = ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy>;
30+ using TilePartitioner =
31+ ck_tile::StreamKTilePartitioner_v2<GemmShape, ReductionStrategy, GemmConfig::Persistent>;
3332
3433 using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM ,
3534 GemmConfig::kPadN ,
@@ -78,9 +77,13 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
7877 memory_operation.value ,
7978 GemmConfig::NumWaveGroups>>;
8079
81- using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
80+ using Kernel = ck_tile::reboot:: StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
8281
83- auto kargs = Kernel::MakeKernelArgs (args);
82+ auto kargs = Kernel::MakeKernelArgs (args);
83+ const auto workspace_size = Kernel::GetWorkSpaceSize (kargs);
84+ ck_tile::DeviceMem workspace_data (workspace_size);
85+ workspace_data.SetZero ();
86+ kargs.workspace_ptr = workspace_data.GetDeviceBuffer ();
8487
8588 dim3 grids = Kernel::GridSize (kargs.tile_partitioner );
8689 dim3 blocks = Kernel::BlockSize ();
@@ -101,28 +104,28 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
101104 << std::endl;
102105 }
103106
104- // Function to clear the output C tensor results after each repetition of the kernel
105- auto clear_gemm_output = [&]() {
107+ auto reset_data_buffers = [&]() {
106108 if (ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
109+ {
110+ // Clear the output C tensor results after each repetition of the kernel
107111 hipGetErrorString (hipMemsetAsync (
108112 args.e_ptr , 0 , args.M * args.N * sizeof (CDataType), s.stream_id_ ));
113+ }
114+ else if (ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
115+ {
116+ // Reset sk flags to zero before each repetition of the kernel
117+ workspace_data.SetZero ();
118+ }
109119 };
110120
111- std::function<void ()> preprocess = clear_gemm_output ;
121+ std::function<void ()> preprocess = reset_data_buffers ;
112122
113123 float ave_time = ck_tile::launch_kernel_time_mask (
114124 s,
115125 preprocess,
116126 ck_tile::make_kernel<GemmConfig::kBlockPerCu >(Kernel{}, grids, blocks, 0 , kargs));
117127
118- ck_tile::index_t num_wgs_per_tile = ck_tile::estimate_num_wgs_per_tile<ReductionStrategy>(
119- kargs.tile_partitioner .sk_num_blocks ,
120- // k_iters_per_big_block could be 1, which indicates that all Stream-K workgroups are
121- // big and each does one iteration. Thus, we ensure the value passed in is at least 1 to
122- // avoid division by zero errors.
123- ck_tile::max (kargs.tile_partitioner .k_iters_per_big_block - 1 , 1u ),
124- kargs.tile_partitioner .k_iters_per_tile .get ());
125-
128+ ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner .estimate_num_wgs_per_tile ();
126129 return std::tuple{ave_time, num_wgs_per_tile};
127130 };
128131
@@ -145,6 +148,8 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
145148 }
146149}
147150
151+ #include " run_gemm_example.inc"
152+
148153template <typename GemmConfig, typename TypeConfig>
149154int run_gemm_example_prec_type (std::string a_layout, std::string b_layout, int argc, char * argv[])
150155{
@@ -164,7 +169,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
164169 return 0 ;
165170}
166171
167- template <template <typename PreType> typename GemmConfig>
172+ template <template <typename PreType, bool Persistent_ > typename GemmConfig>
168173int run_gemm_example (int argc, char * argv[])
169174{
170175 auto [result, arg_parser] = create_args (argc, argv);
@@ -174,30 +179,63 @@ int run_gemm_example(int argc, char* argv[])
174179 std::string data_type = arg_parser.get_str (" prec" );
175180 std::string a_layout = arg_parser.get_str (" a_layout" );
176181 std::string b_layout = arg_parser.get_str (" b_layout" );
182+ auto persistent_dp = arg_parser.get_bool (" persistent_dp" );
177183
178184 if (data_type == " bf16" )
179185 {
180186 using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf16_t >;
181- return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t >, TypeConfig>(
182- a_layout, b_layout, argc, argv);
187+ if (persistent_dp)
188+ {
189+ return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t , true >, TypeConfig>(
190+ a_layout, b_layout, argc, argv);
191+ }
192+ else
193+ {
194+ return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t , false >, TypeConfig>(
195+ a_layout, b_layout, argc, argv);
196+ }
183197 }
184198 else if (data_type == " fp16" )
185199 {
186200 using TypeConfig = StreamKGemmTypeConfig<ck_tile::half_t >;
187- return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t >, TypeConfig>(
188- a_layout, b_layout, argc, argv);
201+ if (persistent_dp)
202+ {
203+ return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t , true >, TypeConfig>(
204+ a_layout, b_layout, argc, argv);
205+ }
206+ else
207+ {
208+ return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t , false >, TypeConfig>(
209+ a_layout, b_layout, argc, argv);
210+ }
189211 }
190212 else if (data_type == " fp8" )
191213 {
192214 using TypeConfig = StreamKGemmTypeConfig<ck_tile::fp8_t , ck_tile::fp8_t , ck_tile::half_t >;
193- return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t >, TypeConfig>(
194- a_layout, b_layout, argc, argv);
215+ if (persistent_dp)
216+ {
217+ return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t , true >, TypeConfig>(
218+ a_layout, b_layout, argc, argv);
219+ }
220+ else
221+ {
222+ return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t , false >, TypeConfig>(
223+ a_layout, b_layout, argc, argv);
224+ }
195225 }
196226 else if (data_type == " bf8" )
197227 {
198228 using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf8_t , ck_tile::bf8_t , ck_tile::half_t >;
199- return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t >, TypeConfig>(
200- a_layout, b_layout, argc, argv);
229+ if (persistent_dp)
230+ {
231+ return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t , true >, TypeConfig>(
232+ a_layout, b_layout, argc, argv);
233+ }
234+ else
235+ {
236+ return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t , false >, TypeConfig>(
237+ a_layout, b_layout, argc, argv);
238+ }
201239 }
202240 else
203241 {
0 commit comments