@@ -235,7 +235,7 @@ __global__ void TransposeNormalKernel(const T* in_ptr,
235235
236236template <typename DeviceContext, typename T>
237237void TransposeNormal<DeviceContext, T>::operator ()(
238- const DeviceContext& context ,
238+ const DeviceContext& dev_ctx ,
239239 const phi::DenseTensor& in,
240240 phi::DenseTensor* out,
241241 const std::vector<int >& axis) {
@@ -246,7 +246,7 @@ void TransposeNormal<DeviceContext, T>::operator()(
246246 auto * out_ptr = out->data <T>();
247247
248248 // copy in_stride, out_stride, axis to gpu device
249- const phi::Place& cuda_place = context .GetPlace ();
249+ const phi::Place& cuda_place = dev_ctx .GetPlace ();
250250 phi::CPUPlace cpu_place = phi::CPUPlace ();
251251 size_t size = 3 * rank * sizeof (int64_t );
252252 auto cpu_buf_holder = phi::memory_utils::Alloc (cpu_place, size);
@@ -259,26 +259,26 @@ void TransposeNormal<DeviceContext, T>::operator()(
259259 cpu_buf[2 * rank + i] = axis[i];
260260 }
261261 memory_utils::Copy (
262- cuda_place, cuda_buf, cpu_place, cpu_buf, size, context .stream ());
262+ cuda_place, cuda_buf, cpu_place, cpu_buf, size, dev_ctx .stream ());
263263 REINTERPRET (const int64_t , in_stride_ptr, cuda_buf);
264264 REINTERPRET (const int64_t , out_stride_ptr, cuda_buf + rank);
265265 REINTERPRET (const int64_t , axis_ptr, cuda_buf + 2 * rank);
266266
267- const int MAX_BLOCK_DIM = context .GetMaxThreadsPerBlock ();
268- const int MAX_GRID_DIM = context .GetMaxPhysicalThreadCount () / MAX_BLOCK_DIM;
267+ const int MAX_BLOCK_DIM = dev_ctx .GetMaxThreadsPerBlock ();
268+ const int MAX_GRID_DIM = dev_ctx .GetMaxPhysicalThreadCount () / MAX_BLOCK_DIM;
269269 int64_t elements = in.numel ();
270270 int block_size = (elements >= MAX_BLOCK_DIM)
271271 ? MAX_BLOCK_DIM
272272 : (1 << static_cast <int >(std::log2 (elements)));
273273 int grid_size = elements / block_size;
274274 grid_size = (grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : grid_size;
275- TransposeNormalKernel<T><<<grid_size, block_size, 0 , context .stream()>>> (
275+ TransposeNormalKernel<T><<<grid_size, block_size, 0 , dev_ctx .stream()>>> (
276276 in_ptr, out_ptr, elements, in_stride_ptr, out_stride_ptr, axis_ptr, rank);
277277}
278278
279279template <typename T>
280280struct TransposeNormal <phi::GPUContext, T> {
281- void operator ()(const phi::GPUContext& context ,
281+ void operator ()(const phi::GPUContext& dev_ctx ,
282282 const DenseTensor& in,
283283 DenseTensor* out,
284284 const std::vector<int >& axis) {
@@ -289,7 +289,7 @@ struct TransposeNormal<phi::GPUContext, T> {
289289 auto * out_ptr = out->data <T>();
290290
291291 // copy in_stride, out_stride, axis to gpu device
292- const phi::Place& cuda_place = context .GetPlace ();
292+ const phi::Place& cuda_place = dev_ctx .GetPlace ();
293293 phi::CPUPlace cpu_place = phi::CPUPlace ();
294294 size_t size = 3 * rank * sizeof (int64_t );
295295 auto cpu_buf_holder = phi::memory_utils::Alloc (cpu_place, size);
@@ -302,22 +302,22 @@ struct TransposeNormal<phi::GPUContext, T> {
302302 cpu_buf[2 * rank + i] = axis[i];
303303 }
304304 memory_utils::Copy (
305- cuda_place, cuda_buf, cpu_place, cpu_buf, size, context .stream ());
305+ cuda_place, cuda_buf, cpu_place, cpu_buf, size, dev_ctx .stream ());
306306 REINTERPRET (const int64_t , in_stride_ptr, cuda_buf);
307307 REINTERPRET (const int64_t , out_stride_ptr, cuda_buf + rank);
308308 REINTERPRET (const int64_t , axis_ptr, cuda_buf + 2 * rank);
309309
310- const int MAX_BLOCK_DIM = context .GetMaxThreadsPerBlock ();
310+ const int MAX_BLOCK_DIM = dev_ctx .GetMaxThreadsPerBlock ();
311311 const int MAX_GRID_DIM =
312- context .GetMaxPhysicalThreadCount () / MAX_BLOCK_DIM;
312+ dev_ctx .GetMaxPhysicalThreadCount () / MAX_BLOCK_DIM;
313313 int64_t elements = in.numel ();
314314 int block_size = (elements >= MAX_BLOCK_DIM)
315315 ? MAX_BLOCK_DIM
316316 : (1 << static_cast <int >(std::log2 (elements)));
317317 int grid_size = elements / block_size;
318318 grid_size = (grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : grid_size;
319319 TransposeNormalKernel<T>
320- <<<grid_size, block_size, 0 , context .stream()>>> (in_ptr,
320+ <<<grid_size, block_size, 0 , dev_ctx .stream()>>> (in_ptr,
321321 out_ptr,
322322 elements,
323323 in_stride_ptr,
@@ -347,30 +347,30 @@ DEFINE_GPU_TRANS_NORMAL(phi::dtype::complex<float>);
347347DEFINE_GPU_TRANS_NORMAL (phi::dtype::complex <double >);
348348
349349struct TensorSetConstantGPU {
350- TensorSetConstantGPU (const phi::DeviceContext& context ,
350+ TensorSetConstantGPU (const phi::DeviceContext& dev_ctx ,
351351 phi::DenseTensor* tensor,
352352 float value)
353- : context_(context ), tensor_(tensor), value_(value) {}
353+ : dev_ctx_(dev_ctx ), tensor_(tensor), value_(value) {}
354354
355355 template <typename T>
356356 void apply () const {
357357 SetConstant<phi::GPUContext, T> functor;
358- functor (reinterpret_cast <const phi::GPUContext&>(context_ ),
358+ functor (reinterpret_cast <const phi::GPUContext&>(dev_ctx_ ),
359359 tensor_,
360360 static_cast <T>(value_));
361361 }
362362
363- const phi::DeviceContext& context_ ;
363+ const phi::DeviceContext& dev_ctx_ ;
364364 phi::DenseTensor* tensor_;
365365 float value_;
366366};
367367
368368template <>
369- void set_constant_with_place<phi::GPUPlace>(const phi::DeviceContext& context ,
369+ void set_constant_with_place<phi::GPUPlace>(const phi::DeviceContext& dev_ctx ,
370370 phi::DenseTensor* tensor,
371371 float value) {
372372 phi::VisitDataType (tensor->dtype (),
373- TensorSetConstantGPU (context , tensor, value));
373+ TensorSetConstantGPU (dev_ctx , tensor, value));
374374}
375375
376376template <typename T>
@@ -386,7 +386,7 @@ __global__ void RowwiseAddKernel(
386386
387387template <typename T>
388388struct RowwiseAdd <phi::GPUContext, T> {
389- void operator ()(const phi::GPUContext& context ,
389+ void operator ()(const phi::GPUContext& dev_ctx ,
390390 const phi::DenseTensor& input,
391391 const phi::DenseTensor& vector,
392392 phi::DenseTensor* output) {
@@ -415,7 +415,7 @@ struct RowwiseAdd<phi::GPUContext, T> {
415415 out_dims_cstr));
416416 int blocks = 512 ;
417417 int grids = (input.numel () + blocks - 1 ) / blocks;
418- RowwiseAddKernel<T><<<grids, blocks, 0 , context .stream()>>> (
418+ RowwiseAddKernel<T><<<grids, blocks, 0 , dev_ctx .stream()>>> (
419419 input.data <T>(),
420420 vector.data <T>(),
421421 output->data <T>(),
0 commit comments