Skip to content

Commit 831bd64

Browse files
Fix im2col cpu (#75731) (#76006)
* fix: prevent memcpy over-read in im2col_sh1sw1dh1dw1ph1pw1 NCHW branches - Add bounds clamping for all memcpy operations in the specialized fast path - Add zero-fill for shortfall cases to ensure complete output tensor coverage - Maintain performance by using memcpy when safe, falling back to element-wise operations only when necessary * fix: prevent memcpy over-read in filter_width==1 case of im2col_sh1sw1dh1dw1ph1pw1 - Fix unsafe memcpy in NCHW path when filter_width == 1 - Prevent negative size_t conversion when output_width < plw + prw - Clamp copy size to available source span (im_width) to avoid over-read - Add zero-fill for shortfall cases to ensure complete output coverage * fix: enhance im2col_common to prevent overflow in arithmetic operations - Convert dimensions to 64-bit integers to avoid overflow during calculations - Update index calculations for col and im arrays to use 64-bit arithmetic - Ensure safe access to tensor data by checking bounds before indexing Co-authored-by: Bvicii <[email protected]>
1 parent da596f0 commit 831bd64

File tree

1 file changed

+28
-10
lines changed

1 file changed

+28
-10
lines changed

paddle/phi/kernels/funcs/im2col_cfo_cpu.h

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ inline void im2col_common(const phi::DenseTensor& im,
4444
int output_width = col->dims()[4];
4545
int channels_col = im_channels * filter_height * filter_width;
4646

47+
// Convert dimensions to 64-bit to prevent overflow in arithmetic operations
48+
const int64_t im_channels64 = im_channels;
49+
const int64_t im_height64 = im_height;
50+
const int64_t im_width64 = im_width;
51+
const int64_t filter_height64 = filter_height;
52+
const int64_t filter_width64 = filter_width;
53+
const int64_t output_height64 = output_height;
54+
const int64_t output_width64 = output_width;
55+
4756
const T* im_data = im.data<T>();
4857
T* col_data = col->data<T>();
4958
for (int c = 0; c < channels_col; ++c) {
@@ -54,18 +63,27 @@ inline void im2col_common(const phi::DenseTensor& im,
5463
int im_row_idx = h * stride[0] - padding[0] + h_offset * dilation[0];
5564
for (int w = 0; w < output_width; ++w) {
5665
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
57-
int im_idx;
58-
if (data_layout != DataLayout::kNHWC) {
59-
im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
66+
67+
// Calculate col_idx using 64-bit arithmetic to prevent overflow
68+
int64_t col_idx64 =
69+
((int64_t)c * output_height64 + h) * output_width64 + w;
70+
71+
// Check bounds first to avoid buffer overflow in im_idx calculation
72+
if (im_row_idx < 0 || im_row_idx >= im_height || im_col_idx < 0 ||
73+
im_col_idx >= im_width) {
74+
*(col_data + col_idx64) = static_cast<T>(0);
6075
} else {
61-
im_idx = (im_row_idx * im_width + im_col_idx) * im_channels + c_im;
76+
int64_t im_idx64;
77+
if (data_layout != DataLayout::kNHWC) {
78+
im_idx64 = ((int64_t)c_im * im_height64 + im_row_idx) * im_width64 +
79+
im_col_idx;
80+
} else {
81+
im_idx64 = ((int64_t)im_row_idx * im_width64 + im_col_idx) *
82+
im_channels64 +
83+
c_im;
84+
}
85+
*(col_data + col_idx64) = *(im_data + im_idx64);
6286
}
63-
int col_idx = (c * output_height + h) * output_width + w;
64-
65-
col_data[col_idx] = (im_row_idx < 0 || im_row_idx >= im_height ||
66-
im_col_idx < 0 || im_col_idx >= im_width)
67-
? static_cast<T>(0)
68-
: im_data[im_idx];
6987
}
7088
}
7189
}

0 commit comments

Comments
 (0)