Skip to content

Commit 0ee967b

Browse files
authored
Merge pull request #4288 from hedaoyuan/fix_bug
Bug fix for get device_context in conv2d op.
2 parents 8c3b8af + ccbb285 commit 0ee967b

File tree

6 files changed

+70
-61
lines changed

6 files changed

+70
-61
lines changed

paddle/operators/gemm_conv2d_op.h

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,6 @@ class GemmConv2DKernel : public framework::OpKernel {
7575
framework::DDim output_matrix_shape = {output_channels,
7676
output_height * output_width};
7777

78-
auto* device_context =
79-
const_cast<platform::DeviceContext*>(context.device_context_);
80-
8178
// convolution operator: im2col + gemm
8279
int in_step = input_channels / groups;
8380
int out_step = output_channels / groups;
@@ -87,14 +84,14 @@ class GemmConv2DKernel : public framework::OpKernel {
8784
for (int g = 0; g < groups; g++) {
8885
// im2col
8986
Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step);
90-
im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1],
91-
device_context);
87+
im2col(context.device_context(), in_slice, col, strides[0], strides[1],
88+
paddings[0], paddings[1]);
9289

9390
// gemm
9491
Tensor out_slice = out_batch.Slice<T>(g * out_step, (g + 1) * out_step);
9592
Tensor filter_slice = filter.Slice<T>(g * out_step, (g + 1) * out_step);
96-
math::matmul<Place, T>(filter_slice, false, col_matrix, false, T(1.0),
97-
&out_slice, T(0.0), device_context);
93+
math::matmul<Place, T>(context.device_context(), filter_slice, false,
94+
col_matrix, false, T(1.0), &out_slice, T(0.0));
9895
}
9996
}
10097
}
@@ -160,9 +157,6 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
160157
filter.numel() / filter.dims()[0]};
161158
filter.Resize(filter_matrix_shape);
162159

163-
auto* device_context =
164-
const_cast<platform::DeviceContext*>(context.device_context_);
165-
166160
// convolution backward input operator: gemm + col2im
167161
// convolution backward weight operator: im2col + gemm
168162
int in_step = input_channels / groups;
@@ -184,14 +178,15 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
184178
out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step);
185179
Tensor filter_slice =
186180
filter.Slice<T>(g * out_step, (g + 1) * out_step);
187-
math::matmul<Place, T>(filter_slice, true, out_grad_slice, false,
188-
T(1.0), &col_matrix, T(0.0), device_context);
181+
math::matmul<Place, T>(context.device_context(), filter_slice, true,
182+
out_grad_slice, false, T(1.0), &col_matrix,
183+
T(0.0));
189184

190185
// col2im
191186
Tensor in_grad_slice =
192187
in_grad_batch.Slice<T>(g * in_step, (g + 1) * in_step);
193-
col2im(in_grad_slice, col, strides[0], strides[1], paddings[0],
194-
paddings[1], device_context);
188+
col2im(context.device_context(), in_grad_slice, col, strides[0],
189+
strides[1], paddings[0], paddings[1]);
195190
}
196191
}
197192
}
@@ -212,15 +207,15 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
212207
Tensor out_grad_slice =
213208
out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step);
214209
Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step);
215-
im2col(in_slice, col, strides[0], strides[1], paddings[0],
216-
paddings[1], device_context);
210+
im2col(context.device_context(), in_slice, col, strides[0],
211+
strides[1], paddings[0], paddings[1]);
217212

218213
// gemm
219214
Tensor filter_grad_slice =
220215
filter_grad_.Slice<T>(g * out_step, (g + 1) * out_step);
221-
math::matmul<Place, T>(out_grad_slice, false, col_matrix, true,
222-
T(1.0), &filter_grad_slice, T(1.0),
223-
device_context);
216+
math::matmul<Place, T>(context.device_context(), out_grad_slice,
217+
false, col_matrix, true, T(1.0),
218+
&filter_grad_slice, T(1.0));
224219
}
225220
}
226221
}

paddle/operators/math/im2col.cc

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@ template <class T>
2727
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
2828
platform::CPUPlace, T> {
2929
public:
30-
void operator()(const framework::Tensor& im, framework::Tensor& col,
30+
void operator()(const platform::DeviceContext& context,
31+
const framework::Tensor& im, framework::Tensor& col,
3132
int stride_height, int stride_width, int padding_height,
32-
int padding_width, platform::DeviceContext* context) {
33+
int padding_width) {
3334
PADDLE_ENFORCE(im.dims().size() == 3);
3435
PADDLE_ENFORCE(col.dims().size() == 5);
3536

@@ -79,9 +80,9 @@ template <class T>
7980
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
8081
platform::CPUPlace, T> {
8182
public:
82-
void operator()(framework::Tensor& im, const framework::Tensor& col,
83-
int stride_height, int stride_width, int padding_height,
84-
int padding_width, platform::DeviceContext* context) {
83+
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
84+
const framework::Tensor& col, int stride_height,
85+
int stride_width, int padding_height, int padding_width) {
8586
PADDLE_ENFORCE(im.dims().size() == 3);
8687
PADDLE_ENFORCE(col.dims().size() == 5);
8788
int input_channels = im.dims()[0];
@@ -137,9 +138,10 @@ template <class T>
137138
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
138139
platform::CPUPlace, T> {
139140
public:
140-
void operator()(const framework::Tensor& im, framework::Tensor& col,
141+
void operator()(const platform::DeviceContext& context,
142+
const framework::Tensor& im, framework::Tensor& col,
141143
int stride_height, int stride_width, int padding_height,
142-
int padding_width, platform::DeviceContext* context) {
144+
int padding_width) {
143145
PADDLE_ENFORCE(im.dims().size() == 3);
144146
PADDLE_ENFORCE(col.dims().size() == 5);
145147
int input_channels = im.dims()[0];
@@ -197,9 +199,9 @@ template <class T>
197199
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
198200
platform::CPUPlace, T> {
199201
public:
200-
void operator()(framework::Tensor& im, const framework::Tensor& col,
201-
int stride_height, int stride_width, int padding_height,
202-
int padding_width, platform::DeviceContext* context) {
202+
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
203+
const framework::Tensor& col, int stride_height,
204+
int stride_width, int padding_height, int padding_width) {
203205
PADDLE_ENFORCE(im.dims().size() == 3);
204206
PADDLE_ENFORCE(col.dims().size() == 5);
205207
int input_channels = im.dims()[0];

paddle/operators/math/im2col.cu

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,10 @@ template <class T>
6464
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
6565
platform::GPUPlace, T> {
6666
public:
67-
void operator()(const framework::Tensor& im, framework::Tensor& col,
67+
void operator()(const platform::DeviceContext& context,
68+
const framework::Tensor& im, framework::Tensor& col,
6869
int stride_height, int stride_width, int padding_height,
69-
int padding_width, platform::DeviceContext* context) {
70+
int padding_width) {
7071
PADDLE_ENFORCE(im.dims().size() == 3);
7172
PADDLE_ENFORCE(col.dims().size() == 5);
7273

@@ -84,9 +85,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
8485
int block_y = (blocks + 512 - 1) / 512;
8586
dim3 threads(1024, 1);
8687
dim3 grid(block_x, block_y);
87-
im2col<T><<<
88-
grid, threads, 0,
89-
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
88+
im2col<T><<<grid, threads, 0,
89+
reinterpret_cast<const platform::CUDADeviceContext&>(context)
90+
.stream()>>>(
9091
im.data<T>(), num_outputs, input_height, input_width, filter_height,
9192
filter_width, stride_height, stride_width, padding_height,
9293
padding_width, output_height, output_width, col.data<T>());
@@ -149,9 +150,9 @@ template <class T>
149150
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
150151
platform::GPUPlace, T> {
151152
public:
152-
void operator()(framework::Tensor& im, const framework::Tensor& col,
153-
int stride_height, int stride_width, int padding_height,
154-
int padding_width, platform::DeviceContext* context) {
153+
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
154+
const framework::Tensor& col, int stride_height,
155+
int stride_width, int padding_height, int padding_width) {
155156
PADDLE_ENFORCE(im.dims().size() == 3);
156157
PADDLE_ENFORCE(col.dims().size() == 5);
157158

@@ -174,9 +175,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
174175

175176
// To avoid involving atomic operations, we will launch one kernel per
176177
// bottom dimension, and then in the kernel add up the top dimensions.
177-
col2im<T><<<
178-
grid, threads, 0,
179-
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
178+
col2im<T><<<grid, threads, 0,
179+
reinterpret_cast<const platform::CUDADeviceContext&>(context)
180+
.stream()>>>(
180181
num_kernels, col.data<T>(), input_height + 2 * padding_height,
181182
input_width + 2 * padding_width, input_channels, filter_height,
182183
filter_width, stride_height, stride_width, padding_height,
@@ -235,9 +236,10 @@ template <class T>
235236
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
236237
platform::GPUPlace, T> {
237238
public:
238-
void operator()(const framework::Tensor& im, framework::Tensor& col,
239+
void operator()(const platform::DeviceContext& context,
240+
const framework::Tensor& im, framework::Tensor& col,
239241
int stride_height, int stride_width, int padding_height,
240-
int padding_width, platform::DeviceContext* context) {
242+
int padding_width) {
241243
PADDLE_ENFORCE(im.dims().size() == 3);
242244
PADDLE_ENFORCE(col.dims().size() == 5);
243245
int input_channels = im.dims()[0];
@@ -268,9 +270,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
268270
dim3 threads(block_dim_x, block_dim_y,
269271
std::min(block_dim_z, input_channels));
270272
dim3 grid(output_width, output_height);
271-
im2colOCF<T><<<
272-
grid, threads, 0,
273-
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
273+
im2colOCF<T><<<grid, threads, 0,
274+
reinterpret_cast<const platform::CUDADeviceContext&>(context)
275+
.stream()>>>(
274276
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
275277
filter_height, filter_width, stride_height, stride_width,
276278
padding_height, padding_width, output_height, output_width);
@@ -318,9 +320,9 @@ template <class T>
318320
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
319321
platform::GPUPlace, T> {
320322
public:
321-
void operator()(framework::Tensor& im, const framework::Tensor& col,
322-
int stride_height, int stride_width, int padding_height,
323-
int padding_width, platform::DeviceContext* context) {
323+
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
324+
const framework::Tensor& col, int stride_height,
325+
int stride_width, int padding_height, int padding_width) {
324326
PADDLE_ENFORCE(im.dims().size() == 3);
325327
PADDLE_ENFORCE(col.dims().size() == 5);
326328
int input_channels = im.dims()[0];
@@ -351,9 +353,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
351353
dim3 threads(block_dim_x, block_dim_y,
352354
std::min(block_dim_z, input_channels));
353355
dim3 grid(output_width, output_height);
354-
col2imOCF<T><<<
355-
grid, threads, 0,
356-
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
356+
col2imOCF<T><<<grid, threads, 0,
357+
reinterpret_cast<const platform::CUDADeviceContext&>(context)
358+
.stream()>>>(
357359
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
358360
filter_height, filter_width, stride_height, stride_width,
359361
padding_height, padding_width, output_height, output_width);

paddle/operators/math/im2col.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,18 @@ enum class ColFormat { kCFO = 0, kOCF = 1 };
7272
template <ColFormat Format, typename Place, typename T>
7373
class Im2ColFunctor {
7474
public:
75-
void operator()(const framework::Tensor& im, framework::Tensor& col,
75+
void operator()(const platform::DeviceContext& context,
76+
const framework::Tensor& im, framework::Tensor& col,
7677
int stride_height, int stride_width, int padding_height,
77-
int padding_width, platform::DeviceContext* context);
78+
int padding_width);
7879
};
7980

8081
template <ColFormat Format, typename Place, typename T>
8182
class Col2ImFunctor {
8283
public:
83-
void operator()(framework::Tensor& im, const framework::Tensor& col,
84-
int stride_height, int stride_width, int padding_height,
85-
int padding_width, platform::DeviceContext* context);
84+
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
85+
const framework::Tensor& col, int stride_height,
86+
int stride_width, int padding_height, int padding_width);
8687
};
8788

8889
} // namespace math

paddle/operators/math/im2col_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ void testIm2col() {
7878
PADDLE_THROW("no GPU support");
7979
#endif // PADDLE_ONLY_CPU
8080
}
81-
im2col(input, output_cfo, stride, stride, padding, padding, context);
82-
im2col_ocf(input, output_ocf, stride, stride, padding, padding, context);
81+
im2col(*context, input, output_cfo, stride, stride, padding, padding);
82+
im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding);
8383

8484
float* out_cfo_ptr;
8585
if (paddle::platform::is_cpu_place(*place)) {

python/paddle/v2/framework/tests/test_conv2d_op.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,22 @@ def test_check_output(self):
7373
self.check_output()
7474

7575
def test_check_grad(self):
76-
self.check_grad(set(['Input', 'Filter']), 'Output')
76+
self.check_grad(
77+
set(['Input', 'Filter']), 'Output', max_relative_error=0.05)
7778

7879
def test_check_grad_no_filter(self):
79-
self.check_grad(['Input'], 'Output', no_grad_set=set(['Filter']))
80+
self.check_grad(
81+
['Input'],
82+
'Output',
83+
max_relative_error=0.05,
84+
no_grad_set=set(['Filter']))
8085

8186
def test_check_grad_no_input(self):
82-
self.check_grad(['Filter'], 'Output', no_grad_set=set(['Input']))
87+
self.check_grad(
88+
['Filter'],
89+
'Output',
90+
max_relative_error=0.05,
91+
no_grad_set=set(['Input']))
8392

8493
def init_groups(self):
8594
self.groups = 1

0 commit comments

Comments
 (0)