@@ -56,24 +56,6 @@ __global__ void MatrixRowReverse(const T* matrix_data,
5656 }
5757}
5858
59- template <typename T, typename Op>
60- struct BlockPrefixCallbackOp {
61- // Running prefix
62- T running_total_;
63- Op op_;
64-
65- __device__ BlockPrefixCallbackOp (T running_total, Op op)
66- : running_total_(running_total), op_(op) {}
67-
68- // Callback operator to be entered by the first warp of threads in the block.
69- // tid 0 is responsible for returning a value for seeding the block-wide scan.
70- __device__ T operator ()(T block_aggregate) {
71- T old_prefix = running_total_;
72- running_total_ = op_ (old_prefix, block_aggregate);
73- return old_prefix;
74- }
75- };
76-
7759// No bank-conflict transpose
7860template <typename T, int TILE_DIM, int BLOCK_ROWS>
7961__global__ void MatrixTranspose (T* odata,
@@ -146,6 +128,73 @@ struct Identity<T, ComplexSum> {
146128 static constexpr T value = {0 , 0 };
147129};
148130
131+ template <typename T, typename Op>
132+ struct BlockPrefixCallbackOp {
133+ // Running prefix
134+ T running_total_;
135+ T compensation_;
136+ Op op_;
137+
138+ __device__ BlockPrefixCallbackOp (T identity, Op op)
139+ : running_total_(identity), compensation_(identity), op_(op) {}
140+
141+ // Callback operator to be entered by the first warp of threads in the block.
142+ // tid 0 is responsible for returning a value for seeding the block-wide scan.
143+ __device__ T operator ()(T block_aggregate) {
144+ T old_prefix = running_total_;
145+
146+ // Kahan Summation
147+ T y = op_ (block_aggregate, static_cast <T>(-compensation_));
148+ T t = op_ (running_total_, y);
149+ T y_high = op_ (t, static_cast <T>(-running_total_));
150+ compensation_ = op_ (y_high, static_cast <T>(-y));
151+ running_total_ = t;
152+
153+ return old_prefix;
154+ }
155+ };
156+
157+ template <typename T>
158+ struct BlockPrefixCallbackOp <T, LogAddExp> {
159+ T max_so_far_;
160+ T scaled_sum_;
161+ T compensation_;
162+ LogAddExp op_;
163+
164+ __device__ BlockPrefixCallbackOp (T identity, LogAddExp op)
165+ : max_so_far_(identity), scaled_sum_(0.0 ), compensation_(0.0 ), op_(op) {}
166+
167+ __device__ T operator ()(T block_aggregate) {
168+ if (scaled_sum_ == 0.0 ) {
169+ max_so_far_ = block_aggregate;
170+ scaled_sum_ = 1.0 ;
171+ compensation_ = 0.0 ;
172+ return std::numeric_limits<T>::lowest ();
173+ }
174+
175+ // Online Scaling
176+ T old_prefix = max_so_far_ + std::log (scaled_sum_);
177+ T m_old = max_so_far_;
178+ T m_new = std::max (m_old, block_aggregate);
179+
180+ if (m_new > m_old) {
181+ T scale = std::exp (m_old - m_new);
182+ scaled_sum_ *= scale;
183+ compensation_ *= scale;
184+ }
185+
186+ // Kahan Summation
187+ T term = std::exp (block_aggregate - m_new);
188+ T y = term - compensation_;
189+ T t = scaled_sum_ + y;
190+ compensation_ = (t - scaled_sum_) - y;
191+ scaled_sum_ = t;
192+ max_so_far_ = m_new;
193+
194+ return old_prefix;
195+ }
196+ };
197+
149198template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD, typename Op>
150199__global__ void BlockScanKernel (T* d_out,
151200 const T* d_in,
@@ -154,17 +203,17 @@ __global__ void BlockScanKernel(T* d_out,
154203 bool exclusive,
155204 Op op) {
156205 using MT = typename phi::dtype::MPTypeTrait<T>::Type;
206+ using CallbackOp = BlockPrefixCallbackOp<MT, Op>;
157207
158208 // Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
159- typedef cub::
160- BlockLoad<MT, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>
161- BlockLoadT;
162- typedef cub::BlockStore<MT,
163- BLOCK_THREADS,
164- ITEMS_PER_THREAD,
165- cub::BLOCK_STORE_TRANSPOSE>
166- BlockStoreT;
167- typedef cub::BlockScan<MT, BLOCK_THREADS> BlockScanT;
209+ using BlockLoadT = cub::
210+ BlockLoad<MT, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>;
211+ using BlockStoreT = cub::BlockStore<MT,
212+ BLOCK_THREADS,
213+ ITEMS_PER_THREAD,
214+ cub::BLOCK_STORE_TRANSPOSE>;
215+ using BlockScanT = cub::BlockScan<MT, BLOCK_THREADS>;
216+
168217 // Allocate type-safe, repurposable shared memory for collectives
169218 __shared__ union {
170219 typename BlockLoadT::TempStorage load;
@@ -176,24 +225,21 @@ __global__ void BlockScanKernel(T* d_out,
176225 int64_t item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
177226
178227 for (int64_t bx = blockIdx .x ; bx < grid_size; bx += gridDim .x ) {
179- BlockPrefixCallbackOp<MT, Op> prefix_op (Identity<MT, Op>::value, op);
228+ CallbackOp prefix_op (Identity<MT, Op>::value, op);
180229
181230 for (int64_t block_offset = 0 ; block_offset < scan_size;
182231 block_offset += item_per_block) {
183- int64_t valid_item = (scan_size - block_offset > item_per_block)
184- ? item_per_block
185- : (scan_size - block_offset);
186- if (scan_size < item_per_block) {
187- valid_item = scan_size;
188- }
232+ int64_t valid_item = std::min (scan_size - block_offset, item_per_block);
189233
190234 int64_t offset = bx * scan_size + block_offset;
191235
192236 MT thread_keys[ITEMS_PER_THREAD];
193237 BlockLoadT (temp_storage.load )
194- .Load (d_in + offset, thread_keys, valid_item, 0 );
238+ .Load (
239+ d_in + offset, thread_keys, valid_item, Identity<MT, Op>::value);
195240
196241 __syncthreads ();
242+
197243 if (exclusive) {
198244 BlockScanT (temp_storage.scan )
199245 .ExclusiveScan (thread_keys, thread_keys, op, prefix_op);
@@ -209,63 +255,6 @@ __global__ void BlockScanKernel(T* d_out,
209255 }
210256}
211257
212- template <typename Context, typename T>
213- typename std::enable_if<!std::is_same<T, phi::dtype::float16>::value &&
214- !std::is_same<T, phi::dtype::bfloat16>::value>::type
215- ThrustCumsumKernel (const Context& dev_ctx,
216- const T* in_data,
217- T* out_data,
218- int64_t size,
219- bool reverse,
220- bool exclusive) {
221- #ifdef __HIPCC__
222- const auto & policy = thrust::hip::par.on (dev_ctx.stream ());
223- #else
224- phi::memory_utils::ThrustAllocator<cudaStream_t> allocator (dev_ctx.GetPlace (),
225- dev_ctx.stream ());
226- const auto & policy = thrust::cuda::par (allocator).on (dev_ctx.stream ());
227- #endif
228- if (reverse) {
229- thrust::reverse_iterator<thrust::device_ptr<const T>> reversed_in (
230- thrust::device_pointer_cast (in_data) + size);
231- thrust::reverse_iterator<thrust::device_ptr<T>> reversed_out (
232- thrust::device_pointer_cast (out_data) + size);
233- if (exclusive) {
234- thrust::exclusive_scan (
235- policy, reversed_in, reversed_in + size, reversed_out);
236- } else {
237- thrust::inclusive_scan (
238- policy, reversed_in, reversed_in + size, reversed_out);
239- }
240- } else {
241- if (exclusive) {
242- thrust::exclusive_scan (policy, in_data, in_data + size, out_data);
243- } else {
244- thrust::inclusive_scan (policy, in_data, in_data + size, out_data);
245- }
246- }
247-
248- return ;
249- }
250-
251- template <typename Context, typename T>
252- typename std::enable_if<std::is_same<T, phi::dtype::float16>::value>::type
253- ThrustCumsumKernel (const Context& dev_ctx,
254- const phi::dtype::float16* in_data,
255- phi::dtype::float16* out_data,
256- int64_t size,
257- bool reverse,
258- bool exclusive) {}
259-
260- template <typename Context, typename T>
261- typename std::enable_if<std::is_same<T, phi::dtype::bfloat16>::value>::type
262- ThrustCumsumKernel (const Context& dev_ctx,
263- const phi::dtype::bfloat16* in_data,
264- phi::dtype::bfloat16* out_data,
265- int64_t size,
266- bool reverse,
267- bool exclusive) {}
268-
269258template <typename T, typename Context, typename Op>
270259void ScanKernel (const Context& dev_ctx,
271260 const DenseTensor& x,
@@ -290,7 +279,6 @@ void ScanKernel(const Context& dev_ctx,
290279 }
291280
292281 auto out_dims = out->dims ();
293- auto size = x.numel ();
294282
295283 PADDLE_ENFORCE_EQ (
296284 axis < out_dims.size () && axis >= (0 - out_dims.size ()),
@@ -307,22 +295,11 @@ void ScanKernel(const Context& dev_ctx,
307295
308296 const T* in_data = x.data <T>();
309297
310- // Use thrust for parallel acceleration when the input size is equal to the
311- // length of the 'axis' dimension.
312- if (!std::is_same<T, phi::dtype::float16>::value &&
313- !std::is_same<T, phi::dtype::bfloat16>::value &&
314- std::is_same<Op, cub::Sum>::value && size == out_dims[axis]) {
315- ThrustCumsumKernel<Context, T>(
316- dev_ctx, in_data, out_data, size, reverse, exclusive);
317- return ;
318- }
319-
320298 size_t height = 1 ;
321299 size_t width = 1 ;
322300 for (size_t i = 0 ; i <= axis; i++) {
323301 height *= out_dims[i];
324302 }
325-
326303 for (size_t i = axis + 1 ; i < out_dims.size (); i++) {
327304 width *= out_dims[i];
328305 }
0 commit comments