Skip to content

Commit 7ad13fb

Browse files
authored
Merge pull request #4876 from QiJune/sgd_op_sparse_kernel
add sparse update kernel for sgd operator
2 parents c93596d + f968145 commit 7ad13fb

File tree

6 files changed

+224
-29
lines changed

6 files changed

+224
-29
lines changed

paddle/operators/sgd_op.cc

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel {
2121
public:
2222
using framework::OperatorWithKernel::OperatorWithKernel;
2323

24-
void InferShape(framework::InferShapeContext *ctx) const override {
24+
void InferShape(framework::InferShapeContext* ctx) const override {
2525
PADDLE_ENFORCE(ctx->HasInput("Param"),
2626
"Input(Param) of SGDOp should not be null.");
2727
PADDLE_ENFORCE(ctx->HasInput("Grad"),
@@ -35,15 +35,15 @@ class SGDOp : public framework::OperatorWithKernel {
3535
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
3636
"Learning rate should have 1 element");
3737
auto param_dim = ctx->GetInputDim("Param");
38-
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"),
39-
"Two input of SGD Op's dimension must be same.");
38+
// TODO(qijun): check dimensions of Param and Grad at complie
39+
// and run time.
4040
ctx->SetOutputDim("ParamOut", param_dim);
4141
}
4242
};
4343

4444
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
4545
public:
46-
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
46+
SGDOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
4747
: OpProtoAndCheckerMaker(proto, op_checker) {
4848
AddInput("Param", "Input parameter");
4949
AddInput("LearningRate", "Learning rate of SGD");
@@ -58,6 +58,38 @@ param_out = param - learning_rate * grad;
5858
)DOC");
5959
}
6060
};
61+
62+
template <typename T>
63+
struct SparseSGDFunctor<platform::CPUPlace, T> {
64+
void operator()(const platform::DeviceContext& context,
65+
const framework::SelectedRows& input,
66+
const framework::Tensor& learning_rate,
67+
framework::Tensor* output) {
68+
auto in_height = input.height();
69+
auto out_dims = output->dims();
70+
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
71+
72+
auto& in_value = input.value();
73+
auto& in_rows = input.rows();
74+
75+
int64_t in_row_numel = in_value.numel() / in_rows.size();
76+
PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height);
77+
78+
auto* in_data = in_value.data<T>();
79+
auto* out_data = output->data<T>();
80+
auto* lr = learning_rate.data<T>();
81+
82+
for (size_t i = 0; i < in_rows.size(); i++) {
83+
for (int64_t j = 0; j < in_row_numel; j++) {
84+
out_data[in_rows[i] * in_row_numel + j] -=
85+
lr[0] * in_data[i * in_row_numel + j];
86+
}
87+
}
88+
}
89+
};
90+
91+
template struct SparseSGDFunctor<platform::CPUPlace, float>;
92+
6193
} // namespace operators
6294
} // namespace paddle
6395

paddle/operators/sgd_op.cu

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,66 @@
1414

1515
#define EIGEN_USE_GPU
1616
#include "paddle/operators/sgd_op.h"
17+
#include "paddle/platform/cuda_helper.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
namespace {
23+
template <typename T>
24+
__global__ void SparseSGDFunctorKernel(const T* selected_rows,
25+
const int64_t* rows,
26+
const T* learning_rate, T* tensor_out,
27+
int64_t row_numel, int block_size) {
28+
const int ty = blockIdx.y;
29+
int tid = threadIdx.x;
30+
31+
selected_rows += ty * row_numel;
32+
tensor_out += rows[ty] * row_numel;
33+
34+
for (int index = tid; index < row_numel; index += block_size) {
35+
// Since index in rows of SelectedRows can be duplicate, we have to use
36+
// Atomic Operation to avoid concurrent write error.
37+
paddle::platform::CudaAtomicAdd(
38+
tensor_out + index, -1.0 * learning_rate[0] * selected_rows[index]);
39+
}
40+
}
41+
} // namespace
42+
43+
template <typename T>
44+
struct SparseSGDFunctor<platform::GPUPlace, T> {
45+
void operator()(const platform::DeviceContext& context,
46+
const framework::SelectedRows& input,
47+
const framework::Tensor& learning_rate,
48+
framework::Tensor* output) {
49+
auto in_height = input.height();
50+
auto out_dims = output->dims();
51+
PADDLE_ENFORCE_EQ(in_height, out_dims[0]);
52+
53+
auto& in_value = input.value();
54+
auto& in_rows = input.rows();
55+
56+
int64_t in_row_numel = in_value.numel() / in_rows.size();
57+
PADDLE_ENFORCE_EQ(in_row_numel, output->numel() / in_height);
58+
59+
auto* in_data = in_value.data<T>();
60+
auto* out_data = output->data<T>();
61+
62+
int block_size = 256;
63+
dim3 threads(block_size, 1);
64+
dim3 grid(1, in_rows.size());
65+
SparseSGDFunctorKernel<
66+
T><<<grid, threads, 0,
67+
reinterpret_cast<const platform::CUDADeviceContext&>(context)
68+
.stream()>>>(in_data, in_rows.data(), learning_rate.data<T>(),
69+
out_data, in_row_numel, block_size);
70+
}
71+
};
72+
73+
template struct SparseSGDFunctor<platform::GPUPlace, float>;
74+
75+
} // namespace operators
76+
} // namespace paddle
1777

1878
namespace ops = paddle::operators;
1979
REGISTER_OP_GPU_KERNEL(sgd,

paddle/operators/sgd_op.h

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,31 +15,53 @@ limitations under the License. */
1515
#pragma once
1616
#include "paddle/framework/eigen.h"
1717
#include "paddle/framework/op_registry.h"
18+
#include "paddle/framework/selected_rows.h"
1819

1920
namespace paddle {
2021
namespace operators {
2122

23+
template <typename Place, typename T>
24+
struct SparseSGDFunctor {
25+
void operator()(const platform::DeviceContext& context,
26+
const framework::SelectedRows& input,
27+
const framework::Tensor& learning_rate,
28+
framework::Tensor* output);
29+
};
30+
2231
template <typename Place, typename T>
2332
class SGDOpKernel : public framework::OpKernel<T> {
2433
public:
2534
void Compute(const framework::ExecutionContext& ctx) const override {
26-
auto param = ctx.Input<framework::Tensor>("Param");
27-
auto grad = ctx.Input<framework::Tensor>("Grad");
28-
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
29-
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
35+
auto* param = ctx.Input<framework::Tensor>("Param");
36+
auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
37+
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
3038

31-
param_out->mutable_data<T>(ctx.GetPlace());
39+
auto* grad_var = ctx.InputVar("Grad");
40+
// Actually, all tensors are LoDTensor except SelectedRows.
41+
if (grad_var->IsType<framework::LoDTensor>()) {
42+
param_out->mutable_data<T>(ctx.GetPlace());
43+
auto* grad = ctx.Input<framework::Tensor>("Grad");
3244

33-
auto p = framework::EigenVector<T>::Flatten(*param);
34-
auto g = framework::EigenVector<T>::Flatten(*grad);
35-
auto o = framework::EigenVector<T>::Flatten(*param_out);
36-
auto lr = framework::EigenVector<T>::Flatten(*learning_rate);
37-
auto place = ctx.GetEigenDevice<Place>();
45+
auto p = framework::EigenVector<T>::Flatten(*param);
46+
auto g = framework::EigenVector<T>::Flatten(*grad);
47+
auto o = framework::EigenVector<T>::Flatten(*param_out);
48+
auto lr = framework::EigenVector<T>::Flatten(*learning_rate);
49+
auto place = ctx.GetEigenDevice<Place>();
3850

39-
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
40-
o.device(place) = p - lr.broadcast(grad_dsize) * g;
51+
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
52+
o.device(place) = p - lr.broadcast(grad_dsize) * g;
53+
} else if (grad_var->IsType<framework::SelectedRows>()) {
54+
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
55+
// This manual optimization brings difficulty to track data dependency.
56+
// It's better to find a more elegant solution.
57+
PADDLE_ENFORCE_EQ(param, param_out);
58+
auto* grad = ctx.Input<framework::SelectedRows>("Grad");
59+
SparseSGDFunctor<Place, T> functor;
60+
functor(ctx.device_context(), *grad, *learning_rate, param_out);
61+
} else {
62+
PADDLE_THROW("Unsupported Variable Type of Grad");
63+
}
4164
}
4265
};
43-
4466
} // namespace operators
4567
} // namespace paddle

paddle/pybind/pybind.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,15 @@ PYBIND11_PLUGIN(core) {
154154
py::return_value_policy::reference)
155155
.def("set_height", &SelectedRows::set_height)
156156
.def("height", &SelectedRows::height)
157-
.def("set_rows", &SelectedRows::set_rows)
157+
.def("set_rows",
158+
[](SelectedRows &self, std::vector<int64_t> rows) {
159+
#ifndef PADDLE_WITH_CUDA
160+
self.set_rows(rows);
161+
#else
162+
Vector<int64_t> new_rows(rows);
163+
self.set_rows(new_rows);
164+
#endif
165+
})
158166
.def("rows", [](SelectedRows &self) {
159167
#ifndef PADDLE_WITH_CUDA
160168
return self.rows();
@@ -187,6 +195,11 @@ All parameter, weight, gradient are variables in Paddle.
187195
return self.GetMutable<LoDTensor>();
188196
},
189197
py::return_value_policy::reference)
198+
.def("get_selected_rows",
199+
[](Variable &self) -> SelectedRows * {
200+
return self.GetMutable<SelectedRows>();
201+
},
202+
py::return_value_policy::reference)
190203
.def("get_net",
191204
[](Variable &self) -> operators::NetOp * {
192205
return self.GetMutable<operators::NetOp>();

python/paddle/v2/framework/tests/test_selected_rows.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,30 @@ def test_selected_rows(self):
88
place = core.CPUPlace()
99
height = 10
1010
rows = [0, 4, 7]
11-
row_numel = 10
12-
selcted_rows = core.SelectedRows(rows, row_numel)
13-
np_array = np.ones((len(rows), height)).astype("float32")
11+
row_numel = 12
12+
selected_rows = core.SelectedRows(rows, height)
13+
np_array = np.ones((len(rows), row_numel)).astype("float32")
1414
np_array[0, 0] = 2.0
1515
np_array[2, 8] = 4.0
16-
tensor = selcted_rows.get_tensor()
16+
tensor = selected_rows.get_tensor()
1717
tensor.set(np_array, place)
1818

1919
# compare rows
20-
self.assertEqual(0, selcted_rows.rows()[0])
21-
self.assertEqual(4, selcted_rows.rows()[1])
22-
self.assertEqual(7, selcted_rows.rows()[2])
20+
self.assertEqual(0, selected_rows.rows()[0])
21+
self.assertEqual(4, selected_rows.rows()[1])
22+
self.assertEqual(7, selected_rows.rows()[2])
2323

2424
# compare height
25-
self.assertEqual(10, selcted_rows.height())
25+
self.assertEqual(10, selected_rows.height())
2626

2727
# compare tensor
2828
self.assertAlmostEqual(2.0,
29-
selcted_rows.get_tensor().get_float_element(0))
29+
selected_rows.get_tensor().get_float_element(0))
3030
self.assertAlmostEqual(1.0,
31-
selcted_rows.get_tensor().get_float_element(1))
31+
selected_rows.get_tensor().get_float_element(1))
3232
self.assertAlmostEqual(
33-
4.0, selcted_rows.get_tensor().get_float_element(2 * row_numel + 8))
33+
4.0,
34+
selected_rows.get_tensor().get_float_element(2 * row_numel + 8))
3435

3536

3637
if __name__ == "__main__":

python/paddle/v2/framework/tests/test_sgd_op.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import unittest
22
import numpy as np
3+
import paddle.v2.framework.core as core
4+
from paddle.v2.framework.op import Operator
35
from op_test import OpTest
46

57

@@ -17,5 +19,70 @@ def test_check_output(self):
1719
self.check_output()
1820

1921

22+
class TestSparseSGDOp(unittest.TestCase):
23+
def check_with_place(self, place):
24+
scope = core.Scope()
25+
26+
# create and initialize Grad Variable
27+
height = 10
28+
rows = [0, 4, 7]
29+
row_numel = 12
30+
31+
grad_selected_rows = scope.var('Grad').get_selected_rows()
32+
grad_selected_rows.set_height(height)
33+
grad_selected_rows.set_rows(rows)
34+
np_array = np.ones((len(rows), row_numel)).astype("float32")
35+
np_array[0, 0] = 2.0
36+
np_array[2, 8] = 4.0
37+
38+
grad_tensor = grad_selected_rows.get_tensor()
39+
grad_tensor.set(np_array, place)
40+
41+
# create and initialize Param Variable
42+
param = scope.var('Param').get_tensor()
43+
param_array = np.full((height, row_numel), 5.0).astype("float32")
44+
param.set(param_array, place)
45+
46+
# create and initialize LeraningRate Variable
47+
lr = scope.var('LearningRate').get_tensor()
48+
lr_array = np.full((1), 2.0).astype("float32")
49+
lr.set(lr_array, place)
50+
51+
# create and run sgd operator
52+
sgd_op = Operator(
53+
"sgd",
54+
Param='Param',
55+
Grad='Grad',
56+
ParamOut='Param',
57+
LearningRate='LearningRate')
58+
ctx = core.DeviceContext.create(place)
59+
sgd_op.run(scope, ctx)
60+
61+
# get and compare result
62+
result_array = np.array(param)
63+
64+
# rows[0] = 0, 5.0 - 2.0 * 2.0
65+
self.assertAlmostEqual(1.0, result_array[rows[0], 0])
66+
# rows[0] = 0, 5.0 - 2.0 * 1.0
67+
self.assertAlmostEqual(3.0, result_array[rows[0], 2])
68+
# 5.0 - 2.0 * 0.0
69+
self.assertAlmostEqual(5.0, result_array[1, 0])
70+
# rows[1] = 4, 5.0 - 2.0 * 1.0
71+
self.assertAlmostEqual(3.0, result_array[rows[1], 10])
72+
# 5.0 - 2.0 * 0.0
73+
self.assertAlmostEqual(5.0, result_array[5, 8])
74+
# rows[2] = 7, 5.0 - 2.0 * 1.0
75+
self.assertAlmostEqual(3.0, result_array[rows[2], 1])
76+
# rows[2] = 7, 5.0 - 2.0 * 4.0
77+
self.assertAlmostEqual(-3.0, result_array[rows[2], 8])
78+
79+
def test_sparse_sgd(self):
80+
places = [core.CPUPlace()]
81+
if core.is_compile_gpu():
82+
places.append(core.GPUPlace(0))
83+
for place in places:
84+
self.check_with_place(place)
85+
86+
2087
if __name__ == "__main__":
2188
unittest.main()

0 commit comments

Comments
 (0)