@@ -15,31 +15,53 @@ limitations under the License. */
1515#pragma once
1616#include " paddle/framework/eigen.h"
1717#include " paddle/framework/op_registry.h"
18+ #include " paddle/framework/selected_rows.h"
1819
1920namespace paddle {
2021namespace operators {
2122
23+ template <typename Place, typename T>
24+ struct SparseSGDFunctor {
25+ void operator ()(const platform::DeviceContext& context,
26+ const framework::SelectedRows& input,
27+ const framework::Tensor& learning_rate,
28+ framework::Tensor* output);
29+ };
30+
2231template <typename Place, typename T>
2332class SGDOpKernel : public framework ::OpKernel<T> {
2433 public:
2534 void Compute (const framework::ExecutionContext& ctx) const override {
26- auto param = ctx.Input <framework::Tensor>(" Param" );
27- auto grad = ctx.Input <framework::Tensor>(" Grad" );
28- auto param_out = ctx.Output <framework::Tensor>(" ParamOut" );
29- auto learning_rate = ctx.Input <framework::Tensor>(" LearningRate" );
35+ auto * param = ctx.Input <framework::Tensor>(" Param" );
36+ auto * param_out = ctx.Output <framework::Tensor>(" ParamOut" );
37+ auto * learning_rate = ctx.Input <framework::Tensor>(" LearningRate" );
3038
31- param_out->mutable_data <T>(ctx.GetPlace ());
39+ auto * grad_var = ctx.InputVar (" Grad" );
40+ // Actually, all tensors are LoDTensor except SelectedRows.
41+ if (grad_var->IsType <framework::LoDTensor>()) {
42+ param_out->mutable_data <T>(ctx.GetPlace ());
43+ auto * grad = ctx.Input <framework::Tensor>(" Grad" );
3244
33- auto p = framework::EigenVector<T>::Flatten (*param);
34- auto g = framework::EigenVector<T>::Flatten (*grad);
35- auto o = framework::EigenVector<T>::Flatten (*param_out);
36- auto lr = framework::EigenVector<T>::Flatten (*learning_rate);
37- auto place = ctx.GetEigenDevice <Place>();
45+ auto p = framework::EigenVector<T>::Flatten (*param);
46+ auto g = framework::EigenVector<T>::Flatten (*grad);
47+ auto o = framework::EigenVector<T>::Flatten (*param_out);
48+ auto lr = framework::EigenVector<T>::Flatten (*learning_rate);
49+ auto place = ctx.GetEigenDevice <Place>();
3850
39- Eigen::DSizes<int , 1 > grad_dsize (grad->numel ());
40- o.device (place) = p - lr.broadcast (grad_dsize) * g;
51+ Eigen::DSizes<int , 1 > grad_dsize (grad->numel ());
52+ o.device (place) = p - lr.broadcast (grad_dsize) * g;
53+ } else if (grad_var->IsType <framework::SelectedRows>()) {
54+ // TODO(qijun): In Sparse SGD operator, in-place update is enforced.
55+ // This manual optimization brings difficulty to track data dependency.
56+ // It's better to find a more elegant solution.
57+ PADDLE_ENFORCE_EQ (param, param_out);
58+ auto * grad = ctx.Input <framework::SelectedRows>(" Grad" );
59+ SparseSGDFunctor<Place, T> functor;
60+ functor (ctx.device_context (), *grad, *learning_rate, param_out);
61+ } else {
62+ PADDLE_THROW (" Unsupported Variable Type of Grad" );
63+ }
4164 }
4265};
43-
4466} // namespace operators
4567} // namespace paddle
0 commit comments