Skip to content

Commit 02fda71

Browse files
committed
refine sgd-op
1 parent bb58a47 commit 02fda71

File tree

3 files changed

+94
-83
lines changed

3 files changed

+94
-83
lines changed

paddle/operators/sgd_op.cc

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -61,43 +61,9 @@ This operator implements one step of the stochastic gradient descent algorithm.
6161
}
6262
};
6363

64-
template <typename T>
65-
struct SparseSGDFunctor<platform::CPUDeviceContext, T> {
66-
void operator()(const platform::CPUDeviceContext& context,
67-
const framework::SelectedRows& input,
68-
const framework::Tensor& learning_rate,
69-
framework::Tensor* output) {
70-
auto in_height = input.height();
71-
auto out_dims = output->dims();
72-
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
73-
74-
auto& in_value = input.value();
75-
auto& in_rows = input.rows();
76-
77-
int64_t in_row_numel = in_value.numel() / in_rows.size();
78-
PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height);
79-
80-
auto* in_data = in_value.data<T>();
81-
auto* out_data = output->data<T>();
82-
auto* lr = learning_rate.data<T>();
83-
84-
for (size_t i = 0; i < in_rows.size(); i++) {
85-
for (int64_t j = 0; j < in_row_numel; j++) {
86-
out_data[in_rows[i] * in_row_numel + j] -=
87-
lr[0] * in_data[i * in_row_numel + j];
88-
}
89-
}
90-
}
91-
};
92-
93-
template struct SparseSGDFunctor<platform::CPUDeviceContext, float>;
94-
template struct SparseSGDFunctor<platform::CPUDeviceContext, double>;
95-
9664
} // namespace operators
9765
} // namespace paddle
9866

9967
namespace ops = paddle::operators;
10068
REGISTER_OP_WITHOUT_GRADIENT(sgd, ops::SGDOp, ops::SGDOpMaker);
101-
REGISTER_OP_CPU_KERNEL(
102-
sgd, ops::SGDOpKernel<paddle::platform::CPUDeviceContext, float>,
103-
ops::SGDOpKernel<paddle::platform::CPUDeviceContext, double>);
69+
REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel<float>, ops::SGDOpKernel<double>);

paddle/operators/sgd_op.cu

Lines changed: 69 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,19 @@ namespace paddle {
2020
namespace operators {
2121

2222
namespace {
23+
24+
template <typename T>
25+
__global__ void SGDKernel(const T* g, const T* p, const T* learning_rate,
26+
const int num, T* p_out) {
27+
T lr = learning_rate[0];
28+
int grid_size = blockDim.x * gridDim.x;
29+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += grid_size) {
30+
T g_data = g[i];
31+
T p_data = p[i];
32+
p_out[i] = p_data - lr * g_data;
33+
}
34+
}
35+
2336
template <typename T, int block_size>
2437
__global__ void SparseSGDFunctorKernel(const T* selected_rows,
2538
const int64_t* rows,
@@ -41,40 +54,65 @@ __global__ void SparseSGDFunctorKernel(const T* selected_rows,
4154
} // namespace
4255

4356
template <typename T>
44-
struct SparseSGDFunctor<platform::CUDADeviceContext, T> {
45-
void operator()(const platform::CUDADeviceContext& context,
46-
const framework::SelectedRows& input,
47-
const framework::Tensor& learning_rate,
48-
framework::Tensor* output) {
49-
auto in_height = input.height();
50-
auto out_dims = output->dims();
51-
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
52-
53-
auto& in_value = input.value();
54-
auto& in_rows = input.rows();
55-
56-
int64_t in_row_numel = in_value.numel() / in_rows.size();
57-
PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height);
58-
59-
auto* in_data = in_value.data<T>();
60-
auto* out_data = output->data<T>();
61-
62-
const int block_size = 256;
63-
dim3 threads(block_size, 1);
64-
dim3 grid(1, in_rows.size());
65-
SparseSGDFunctorKernel<T, 256><<<grid, threads, 0, context.stream()>>>(
66-
in_data, in_rows.data(), learning_rate.data<T>(), out_data,
67-
in_row_numel);
57+
class SGDOpCUDAKernel : public framework::OpKernel<T> {
58+
public:
59+
void Compute(const framework::ExecutionContext& ctx) const override {
60+
auto* param = ctx.Input<framework::Tensor>("Param");
61+
auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
62+
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
63+
64+
auto* grad_var = ctx.InputVar("Grad");
65+
// Actually, all tensors are LoDTensor except SelectedRows.
66+
if (grad_var->IsType<framework::LoDTensor>()) {
67+
param_out->mutable_data<T>(ctx.GetPlace());
68+
auto* grad = ctx.Input<framework::Tensor>("Grad");
69+
auto* grad_data = grad->data<T>();
70+
auto* param_data = param->data<T>();
71+
auto* param_out_data = param_out->data<T>();
72+
73+
int block = 512;
74+
int grid = (param->numel() + block - 1) / block;
75+
76+
SGDKernel<T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
77+
grad_data, param_data, learning_rate->data<T>(), param->numel(),
78+
param_out_data);
79+
80+
} else if (grad_var->IsType<framework::SelectedRows>()) {
81+
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
82+
// This manual optimization brings difficulty to track data dependency.
83+
// It's better to find a more elegant solution.
84+
PADDLE_ENFORCE_EQ(param, param_out);
85+
auto* grad = ctx.Input<framework::SelectedRows>("Grad");
86+
87+
auto in_height = grad->height();
88+
auto out_dims = param_out->dims();
89+
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
90+
91+
auto& in_value = grad->value();
92+
auto& in_rows = grad->rows();
93+
94+
int64_t in_row_numel = in_value.numel() / in_rows.size();
95+
PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height);
96+
97+
auto* in_data = in_value.data<T>();
98+
auto* out_data = param_out->data<T>();
99+
100+
const int block_size = 256;
101+
dim3 threads(block_size, 1);
102+
dim3 grid(1, in_rows.size());
103+
SparseSGDFunctorKernel<
104+
T, 256><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>(
105+
in_data, in_rows.data(), learning_rate->data<T>(), out_data,
106+
in_row_numel);
107+
108+
} else {
109+
PADDLE_THROW("Unsupported Variable Type of Grad");
110+
}
68111
}
69112
};
70-
71-
template struct SparseSGDFunctor<platform::CUDADeviceContext, float>;
72-
template struct SparseSGDFunctor<platform::CUDADeviceContext, double>;
73-
74113
} // namespace operators
75114
} // namespace paddle
76115

77116
namespace ops = paddle::operators;
78-
REGISTER_OP_CUDA_KERNEL(
79-
sgd, ops::SGDOpKernel<paddle::platform::CUDADeviceContext, float>,
80-
ops::SGDOpKernel<paddle::platform::CUDADeviceContext, double>);
117+
REGISTER_OP_CUDA_KERNEL(sgd, ops::SGDOpCUDAKernel<float>,
118+
ops::SGDOpCUDAKernel<double>);

paddle/operators/sgd_op.h

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,7 @@ limitations under the License. */
2020
namespace paddle {
2121
namespace operators {
2222

23-
template <typename DeviceContext, typename T>
24-
struct SparseSGDFunctor {
25-
void operator()(const DeviceContext& context,
26-
const framework::SelectedRows& input,
27-
const framework::Tensor& learning_rate,
28-
framework::Tensor* output);
29-
};
30-
31-
template <typename DeviceContext, typename T>
23+
template <typename T>
3224
class SGDOpKernel : public framework::OpKernel<T> {
3325
public:
3426
void Compute(const framework::ExecutionContext& ctx) const override {
@@ -45,21 +37,36 @@ class SGDOpKernel : public framework::OpKernel<T> {
4537
auto p = framework::EigenVector<T>::Flatten(*param);
4638
auto g = framework::EigenVector<T>::Flatten(*grad);
4739
auto o = framework::EigenVector<T>::Flatten(*param_out);
48-
auto lr = framework::EigenVector<T>::Flatten(*learning_rate);
49-
auto& place =
50-
*ctx.template device_context<DeviceContext>().eigen_device();
40+
auto* lr = learning_rate->data<T>();
5141

52-
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
53-
o.device(place) = p - lr.broadcast(grad_dsize) * g;
42+
o = p - lr[0] * g;
5443
} else if (grad_var->IsType<framework::SelectedRows>()) {
5544
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
5645
// This manual optimization brings difficulty to track data dependency.
5746
// It's better to find a more elegant solution.
5847
PADDLE_ENFORCE_EQ(param, param_out);
5948
auto* grad = ctx.Input<framework::SelectedRows>("Grad");
60-
SparseSGDFunctor<DeviceContext, T> functor;
61-
functor(ctx.template device_context<DeviceContext>(), *grad,
62-
*learning_rate, param_out);
49+
50+
auto in_height = grad->height();
51+
auto out_dims = param_out->dims();
52+
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
53+
54+
auto& in_value = grad->value();
55+
auto& in_rows = grad->rows();
56+
57+
int64_t in_row_numel = in_value.numel() / in_rows.size();
58+
PADDLE_ENFORCE_EQ(in_row_numel, param_out->numel() / in_height);
59+
60+
auto* in_data = in_value.data<T>();
61+
auto* out_data = param_out->data<T>();
62+
auto* lr = learning_rate->data<T>();
63+
64+
for (size_t i = 0; i < in_rows.size(); i++) {
65+
for (int64_t j = 0; j < in_row_numel; j++) {
66+
out_data[in_rows[i] * in_row_numel + j] -=
67+
lr[0] * in_data[i * in_row_numel + j];
68+
}
69+
}
6370
} else {
6471
PADDLE_THROW("Unsupported Variable Type of Grad");
6572
}

0 commit comments

Comments
 (0)