@@ -20,6 +20,19 @@ namespace paddle {
2020namespace operators {
2121
2222namespace {
23+
24+ template <typename T>
25+ __global__ void SGDKernel (const T* g, const T* p, const T* learning_rate,
26+ const int num, T* p_out) {
27+ T lr = learning_rate[0 ];
28+ int grid_size = blockDim .x * gridDim .x ;
29+ for (int i = blockIdx .x * blockDim .x + threadIdx .x ; i < num; i += grid_size) {
30+ T g_data = g[i];
31+ T p_data = p[i];
32+ p_out[i] = p_data - lr * g_data;
33+ }
34+ }
35+
2336template <typename T, int block_size>
2437__global__ void SparseSGDFunctorKernel (const T* selected_rows,
2538 const int64_t * rows,
@@ -41,40 +54,65 @@ __global__ void SparseSGDFunctorKernel(const T* selected_rows,
4154} // namespace
4255
4356template <typename T>
44- struct SparseSGDFunctor <platform::CUDADeviceContext, T> {
45- void operator ()(const platform::CUDADeviceContext& context,
46- const framework::SelectedRows& input,
47- const framework::Tensor& learning_rate,
48- framework::Tensor* output) {
49- auto in_height = input.height ();
50- auto out_dims = output->dims ();
51- PADDLE_ENFORCE_EQ (in_height, out_dims[0 ]);
52-
53- auto & in_value = input.value ();
54- auto & in_rows = input.rows ();
55-
56- int64_t in_row_numel = in_value.numel () / in_rows.size ();
57- PADDLE_ENFORCE_EQ (in_row_numel, output->numel () / in_height);
58-
59- auto * in_data = in_value.data <T>();
60- auto * out_data = output->data <T>();
61-
62- const int block_size = 256 ;
63- dim3 threads (block_size, 1 );
64- dim3 grid (1 , in_rows.size ());
65- SparseSGDFunctorKernel<T, 256 ><<<grid, threads, 0 , context.stream()>>> (
66- in_data, in_rows.data (), learning_rate.data <T>(), out_data,
67- in_row_numel);
57+ class SGDOpCUDAKernel : public framework ::OpKernel<T> {
58+ public:
59+ void Compute (const framework::ExecutionContext& ctx) const override {
60+ auto * param = ctx.Input <framework::Tensor>(" Param" );
61+ auto * param_out = ctx.Output <framework::Tensor>(" ParamOut" );
62+ auto * learning_rate = ctx.Input <framework::Tensor>(" LearningRate" );
63+
64+ auto * grad_var = ctx.InputVar (" Grad" );
65+ // Actually, all tensors are LoDTensor except SelectedRows.
66+ if (grad_var->IsType <framework::LoDTensor>()) {
67+ param_out->mutable_data <T>(ctx.GetPlace ());
68+ auto * grad = ctx.Input <framework::Tensor>(" Grad" );
69+ auto * grad_data = grad->data <T>();
70+ auto * param_data = param->data <T>();
71+ auto * param_out_data = param_out->data <T>();
72+
73+ int block = 512 ;
74+ int grid = (param->numel () + block - 1 ) / block;
75+
76+ SGDKernel<T><<<grid, block, 0 , ctx.cuda_device_context().stream()>>> (
77+ grad_data, param_data, learning_rate->data <T>(), param->numel (),
78+ param_out_data);
79+
80+ } else if (grad_var->IsType <framework::SelectedRows>()) {
81+ // TODO(qijun): In Sparse SGD operator, in-place update is enforced.
82+ // This manual optimization brings difficulty to track data dependency.
83+ // It's better to find a more elegant solution.
84+ PADDLE_ENFORCE_EQ (param, param_out);
85+ auto * grad = ctx.Input <framework::SelectedRows>(" Grad" );
86+
87+ auto in_height = grad->height ();
88+ auto out_dims = param_out->dims ();
89+ PADDLE_ENFORCE_EQ (in_height, out_dims[0 ]);
90+
91+ auto & in_value = grad->value ();
92+ auto & in_rows = grad->rows ();
93+
94+ int64_t in_row_numel = in_value.numel () / in_rows.size ();
95+ PADDLE_ENFORCE_EQ (in_row_numel, param_out->numel () / in_height);
96+
97+ auto * in_data = in_value.data <T>();
98+ auto * out_data = param_out->data <T>();
99+
100+ const int block_size = 256 ;
101+ dim3 threads (block_size, 1 );
102+ dim3 grid (1 , in_rows.size ());
103+ SparseSGDFunctorKernel<
104+ T, 256 ><<<grid, threads, 0 , ctx.cuda_device_context().stream()>>> (
105+ in_data, in_rows.data (), learning_rate->data <T>(), out_data,
106+ in_row_numel);
107+
108+ } else {
109+ PADDLE_THROW (" Unsupported Variable Type of Grad" );
110+ }
68111 }
69112};
70-
71- template struct SparseSGDFunctor <platform::CUDADeviceContext, float >;
72- template struct SparseSGDFunctor <platform::CUDADeviceContext, double >;
73-
74113} // namespace operators
75114} // namespace paddle
76115
77116namespace ops = paddle::operators;
78- REGISTER_OP_CUDA_KERNEL (
79- sgd, ops::SGDOpKernel<paddle::platform::CUDADeviceContext, float >,
80- ops::SGDOpKernel<paddle::platform::CUDADeviceContext, double >);
117+ REGISTER_OP_CUDA_KERNEL (sgd, ops::SGDOpCUDAKernel<float >,
118+ ops::SGDOpCUDAKernel<double >);
0 commit comments