Skip to content

Commit 0ca88c4

Browse files
[Accuracy diff No.16] Fix accuracy diff for paddle.cumsum、paddle.logcumsumexp API (#74081)
* using kahan * fix test
1 parent 29b4bf4 commit 0ca88c4

File tree

2 files changed

+88
-106
lines changed

2 files changed

+88
-106
lines changed

paddle/phi/kernels/gpu/cum_kernel.cu

Lines changed: 81 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -56,24 +56,6 @@ __global__ void MatrixRowReverse(const T* matrix_data,
5656
}
5757
}
5858

59-
template <typename T, typename Op>
60-
struct BlockPrefixCallbackOp {
61-
// Running prefix
62-
T running_total_;
63-
Op op_;
64-
65-
__device__ BlockPrefixCallbackOp(T running_total, Op op)
66-
: running_total_(running_total), op_(op) {}
67-
68-
// Callback operator to be entered by the first warp of threads in the block.
69-
// tid 0 is responsible for returning a value for seeding the block-wide scan.
70-
__device__ T operator()(T block_aggregate) {
71-
T old_prefix = running_total_;
72-
running_total_ = op_(old_prefix, block_aggregate);
73-
return old_prefix;
74-
}
75-
};
76-
7759
// No bank-conflict transpose
7860
template <typename T, int TILE_DIM, int BLOCK_ROWS>
7961
__global__ void MatrixTranspose(T* odata,
@@ -146,6 +128,73 @@ struct Identity<T, ComplexSum> {
146128
static constexpr T value = {0, 0};
147129
};
148130

131+
template <typename T, typename Op>
132+
struct BlockPrefixCallbackOp {
133+
// Running prefix
134+
T running_total_;
135+
T compensation_;
136+
Op op_;
137+
138+
__device__ BlockPrefixCallbackOp(T identity, Op op)
139+
: running_total_(identity), compensation_(identity), op_(op) {}
140+
141+
// Callback operator to be entered by the first warp of threads in the block.
142+
// tid 0 is responsible for returning a value for seeding the block-wide scan.
143+
__device__ T operator()(T block_aggregate) {
144+
T old_prefix = running_total_;
145+
146+
// Kahan Summation
147+
T y = op_(block_aggregate, static_cast<T>(-compensation_));
148+
T t = op_(running_total_, y);
149+
T y_high = op_(t, static_cast<T>(-running_total_));
150+
compensation_ = op_(y_high, static_cast<T>(-y));
151+
running_total_ = t;
152+
153+
return old_prefix;
154+
}
155+
};
156+
157+
template <typename T>
158+
struct BlockPrefixCallbackOp<T, LogAddExp> {
159+
T max_so_far_;
160+
T scaled_sum_;
161+
T compensation_;
162+
LogAddExp op_;
163+
164+
__device__ BlockPrefixCallbackOp(T identity, LogAddExp op)
165+
: max_so_far_(identity), scaled_sum_(0.0), compensation_(0.0), op_(op) {}
166+
167+
__device__ T operator()(T block_aggregate) {
168+
if (scaled_sum_ == 0.0) {
169+
max_so_far_ = block_aggregate;
170+
scaled_sum_ = 1.0;
171+
compensation_ = 0.0;
172+
return std::numeric_limits<T>::lowest();
173+
}
174+
175+
// Online Scaling
176+
T old_prefix = max_so_far_ + std::log(scaled_sum_);
177+
T m_old = max_so_far_;
178+
T m_new = std::max(m_old, block_aggregate);
179+
180+
if (m_new > m_old) {
181+
T scale = std::exp(m_old - m_new);
182+
scaled_sum_ *= scale;
183+
compensation_ *= scale;
184+
}
185+
186+
// Kahan Summation
187+
T term = std::exp(block_aggregate - m_new);
188+
T y = term - compensation_;
189+
T t = scaled_sum_ + y;
190+
compensation_ = (t - scaled_sum_) - y;
191+
scaled_sum_ = t;
192+
max_so_far_ = m_new;
193+
194+
return old_prefix;
195+
}
196+
};
197+
149198
template <typename T, int BLOCK_THREADS, int ITEMS_PER_THREAD, typename Op>
150199
__global__ void BlockScanKernel(T* d_out,
151200
const T* d_in,
@@ -154,17 +203,17 @@ __global__ void BlockScanKernel(T* d_out,
154203
bool exclusive,
155204
Op op) {
156205
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
206+
using CallbackOp = BlockPrefixCallbackOp<MT, Op>;
157207

158208
// Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
159-
typedef cub::
160-
BlockLoad<MT, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>
161-
BlockLoadT;
162-
typedef cub::BlockStore<MT,
163-
BLOCK_THREADS,
164-
ITEMS_PER_THREAD,
165-
cub::BLOCK_STORE_TRANSPOSE>
166-
BlockStoreT;
167-
typedef cub::BlockScan<MT, BLOCK_THREADS> BlockScanT;
209+
using BlockLoadT = cub::
210+
BlockLoad<MT, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>;
211+
using BlockStoreT = cub::BlockStore<MT,
212+
BLOCK_THREADS,
213+
ITEMS_PER_THREAD,
214+
cub::BLOCK_STORE_TRANSPOSE>;
215+
using BlockScanT = cub::BlockScan<MT, BLOCK_THREADS>;
216+
168217
// Allocate type-safe, repurposable shared memory for collectives
169218
__shared__ union {
170219
typename BlockLoadT::TempStorage load;
@@ -176,24 +225,21 @@ __global__ void BlockScanKernel(T* d_out,
176225
int64_t item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
177226

178227
for (int64_t bx = blockIdx.x; bx < grid_size; bx += gridDim.x) {
179-
BlockPrefixCallbackOp<MT, Op> prefix_op(Identity<MT, Op>::value, op);
228+
CallbackOp prefix_op(Identity<MT, Op>::value, op);
180229

181230
for (int64_t block_offset = 0; block_offset < scan_size;
182231
block_offset += item_per_block) {
183-
int64_t valid_item = (scan_size - block_offset > item_per_block)
184-
? item_per_block
185-
: (scan_size - block_offset);
186-
if (scan_size < item_per_block) {
187-
valid_item = scan_size;
188-
}
232+
int64_t valid_item = std::min(scan_size - block_offset, item_per_block);
189233

190234
int64_t offset = bx * scan_size + block_offset;
191235

192236
MT thread_keys[ITEMS_PER_THREAD];
193237
BlockLoadT(temp_storage.load)
194-
.Load(d_in + offset, thread_keys, valid_item, 0);
238+
.Load(
239+
d_in + offset, thread_keys, valid_item, Identity<MT, Op>::value);
195240

196241
__syncthreads();
242+
197243
if (exclusive) {
198244
BlockScanT(temp_storage.scan)
199245
.ExclusiveScan(thread_keys, thread_keys, op, prefix_op);
@@ -209,63 +255,6 @@ __global__ void BlockScanKernel(T* d_out,
209255
}
210256
}
211257

212-
template <typename Context, typename T>
213-
typename std::enable_if<!std::is_same<T, phi::dtype::float16>::value &&
214-
!std::is_same<T, phi::dtype::bfloat16>::value>::type
215-
ThrustCumsumKernel(const Context& dev_ctx,
216-
const T* in_data,
217-
T* out_data,
218-
int64_t size,
219-
bool reverse,
220-
bool exclusive) {
221-
#ifdef __HIPCC__
222-
const auto& policy = thrust::hip::par.on(dev_ctx.stream());
223-
#else
224-
phi::memory_utils::ThrustAllocator<cudaStream_t> allocator(dev_ctx.GetPlace(),
225-
dev_ctx.stream());
226-
const auto& policy = thrust::cuda::par(allocator).on(dev_ctx.stream());
227-
#endif
228-
if (reverse) {
229-
thrust::reverse_iterator<thrust::device_ptr<const T>> reversed_in(
230-
thrust::device_pointer_cast(in_data) + size);
231-
thrust::reverse_iterator<thrust::device_ptr<T>> reversed_out(
232-
thrust::device_pointer_cast(out_data) + size);
233-
if (exclusive) {
234-
thrust::exclusive_scan(
235-
policy, reversed_in, reversed_in + size, reversed_out);
236-
} else {
237-
thrust::inclusive_scan(
238-
policy, reversed_in, reversed_in + size, reversed_out);
239-
}
240-
} else {
241-
if (exclusive) {
242-
thrust::exclusive_scan(policy, in_data, in_data + size, out_data);
243-
} else {
244-
thrust::inclusive_scan(policy, in_data, in_data + size, out_data);
245-
}
246-
}
247-
248-
return;
249-
}
250-
251-
template <typename Context, typename T>
252-
typename std::enable_if<std::is_same<T, phi::dtype::float16>::value>::type
253-
ThrustCumsumKernel(const Context& dev_ctx,
254-
const phi::dtype::float16* in_data,
255-
phi::dtype::float16* out_data,
256-
int64_t size,
257-
bool reverse,
258-
bool exclusive) {}
259-
260-
template <typename Context, typename T>
261-
typename std::enable_if<std::is_same<T, phi::dtype::bfloat16>::value>::type
262-
ThrustCumsumKernel(const Context& dev_ctx,
263-
const phi::dtype::bfloat16* in_data,
264-
phi::dtype::bfloat16* out_data,
265-
int64_t size,
266-
bool reverse,
267-
bool exclusive) {}
268-
269258
template <typename T, typename Context, typename Op>
270259
void ScanKernel(const Context& dev_ctx,
271260
const DenseTensor& x,
@@ -290,7 +279,6 @@ void ScanKernel(const Context& dev_ctx,
290279
}
291280

292281
auto out_dims = out->dims();
293-
auto size = x.numel();
294282

295283
PADDLE_ENFORCE_EQ(
296284
axis < out_dims.size() && axis >= (0 - out_dims.size()),
@@ -307,22 +295,11 @@ void ScanKernel(const Context& dev_ctx,
307295

308296
const T* in_data = x.data<T>();
309297

310-
// Use thrust for parallel acceleration when the input size is equal to the
311-
// length of the 'axis' dimension.
312-
if (!std::is_same<T, phi::dtype::float16>::value &&
313-
!std::is_same<T, phi::dtype::bfloat16>::value &&
314-
std::is_same<Op, cub::Sum>::value && size == out_dims[axis]) {
315-
ThrustCumsumKernel<Context, T>(
316-
dev_ctx, in_data, out_data, size, reverse, exclusive);
317-
return;
318-
}
319-
320298
size_t height = 1;
321299
size_t width = 1;
322300
for (size_t i = 0; i <= axis; i++) {
323301
height *= out_dims[i];
324302
}
325-
326303
for (size_t i = axis + 1; i < out_dims.size(); i++) {
327304
width *= out_dims[i];
328305
}

test/legacy_test/test_logcumsumexp_op.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,13 @@ def np_logcumsumexp_grad(
7777
exclusive: bool = False,
7878
):
7979
out = np_logcumsumexp(x, axis, flatten, reverse, exclusive)
80-
log_grad_positive = np.where(dout > 0, np.log(dout), np.finfo(x.dtype).min)
81-
log_grad_negative = np.where(dout < 0, np.log(-dout), np.finfo(x.dtype).min)
80+
dout = np.asarray(dout)
81+
pos_mask = dout > 0
82+
neg_mask = dout < 0
83+
log_grad_positive = np.full_like(dout, np.finfo(x.dtype).min)
84+
log_grad_negative = np.full_like(dout, np.finfo(x.dtype).min)
85+
log_grad_positive[pos_mask] = np.log(dout[pos_mask])
86+
log_grad_negative[neg_mask] = np.log(-dout[neg_mask])
8287

8388
output_pos = np.exp(
8489
np_logcumsumexp(

0 commit comments

Comments
 (0)