@@ -19,156 +19,110 @@ limitations under the License. */
1919namespace paddle {
2020namespace operators {
2121
22- template <typename DeviceContext, typename T, typename AttrType = T>
22+ inline void GetDims (const framework::DDim& dim, int axis, int * pre , int * n,
23+ int * post ) {
24+ *pre = 1 ;
25+ *post = 1 ;
26+ *n = dim[axis];
27+ for (int i = 0 ; i < axis; ++i) {
28+ (*pre ) *= dim[i];
29+ }
30+ for (int i = axis + 1 ; i < dim.size (); ++i) {
31+ (*post ) *= dim[i];
32+ }
33+ }
34+
35+ template <typename DeviceContext, typename T>
2336class NormKernel : public framework ::OpKernel<T> {
2437 public:
25- void Compute (const framework::ExecutionContext& context) const override {
26- const framework::Tensor* in_x = context.Input <framework::Tensor>(" X" );
27- const framework::Tensor* scale = context.Input <framework::Tensor>(" Scale" );
28- auto * out = context.Output <framework::Tensor>(" Out" );
29- auto epsilon = static_cast <T>(context.Attr <AttrType>(" epsilon" ));
30- out->mutable_data <T>(context.GetPlace ());
31- int batch_size = in_x->dims ()[0 ];
32- int channels = in_x->dims ()[1 ];
33- int height = in_x->dims ()[2 ];
34- int width = in_x->dims ()[3 ];
35- int fea_len = height * width;
36- auto * place =
37- context.template device_context <DeviceContext>().eigen_device ();
38- auto x =
39- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
40- *in_x, framework::make_ddim ({batch_size, fea_len * channels}));
41- // get square
42- framework::Tensor x_square;
43- x_square.mutable_data <T>(in_x->dims (), context.GetPlace ());
44- auto x_square_eigen =
45- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
46- x_square, framework::make_ddim ({batch_size, fea_len * channels}));
47- x_square_eigen.device (*place) = x.square ();
48- auto scale_eigen =
49- framework::EigenVector<T, Eigen::RowMajor, Eigen::DenseIndex>::Flatten (
50- *scale);
51- for (int n = 0 ; n < batch_size; ++n) {
52- framework::Tensor in_x_batch = in_x->Slice (n, n + 1 );
53- auto in_x_batch_eigen =
54- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
55- in_x_batch, framework::make_ddim ({channels, fea_len}));
56- framework::Tensor x_square_batch = x_square.Slice (n, n + 1 );
57- auto x_square_batch_eigen =
58- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
59- x_square_batch, framework::make_ddim ({channels, fea_len}));
60- framework::Tensor out_batch = out->Slice (n, n + 1 );
61- auto out_batch_eigen =
62- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
63- out_batch, framework::make_ddim ({channels, fea_len}));
64- framework::Tensor tmp_tensor;
65- tmp_tensor.mutable_data <T>(framework::make_ddim ({1 , fea_len}),
66- context.GetPlace ());
67- auto tmp = framework::EigenVector<T, Eigen::RowMajor,
68- Eigen::DenseIndex>::Flatten (tmp_tensor);
69- // get colsum and sqrt , inverse
70- auto dim = Eigen::array<int , 1 >({{0 }});
71- tmp.device (*place) = x_square_batch_eigen.sum (dim);
72- tmp.device (*place) = (tmp + epsilon).sqrt ().inverse ();
73- Eigen::array<int , 2 > broadcast_dim_col;
74- broadcast_dim_col[1 ] = 1 ;
75- broadcast_dim_col[0 ] = channels;
76- out_batch_eigen.device (*place) =
77- in_x_batch_eigen * (tmp.broadcast (broadcast_dim_col));
78- Eigen::array<int , 2 > broadcast_dim_row;
79- broadcast_dim_row[1 ] = fea_len;
80- broadcast_dim_row[0 ] = 1 ;
81- out_batch_eigen.device (*place) =
82- out_batch_eigen * (scale_eigen.broadcast (broadcast_dim_row));
83- }
38+ void Compute (const framework::ExecutionContext& ctx) const override {
39+ auto * in_x = ctx.Input <framework::Tensor>(" X" );
40+ auto * out_y = ctx.Output <framework::Tensor>(" Out" );
41+ auto * out_norm = ctx.Output <framework::Tensor>(" Norm" );
42+ out_y->mutable_data <T>(ctx.GetPlace ());
43+ out_norm->mutable_data <T>(ctx.GetPlace ());
44+
45+ auto xdim = in_x->dims ();
46+ auto ndim = out_norm->dims ();
47+ T eps = static_cast <T>(ctx.Attr <float >(" epsilon" ));
48+ int axis = ctx.Attr <int >(" axis" );
49+ if (axis < 0 ) axis = xdim.size () + axis;
50+ int pre , n, post ;
51+ GetDims (xdim, axis, &pre , &n, &post );
52+
53+ auto * place = ctx.template device_context <DeviceContext>().eigen_device ();
54+
55+ Eigen::DSizes<int , 3 > shape (pre , n, post );
56+ Eigen::DSizes<int , 2 > norm_shape (pre , post );
57+
58+ auto x_e = framework::EigenVector<T>::Flatten (*in_x);
59+ auto y_e = framework::EigenVector<T>::Flatten (*out_y);
60+ auto norm_e = framework::EigenVector<T>::Flatten (*out_norm);
61+ auto x = x_e.reshape (shape);
62+ auto y = y_e.reshape (shape);
63+ auto norm = norm_e.reshape (norm_shape);
64+
65+ Eigen::DSizes<int , 1 > rdim (1 );
66+ // y = x / sqrt((sum(x * x) + epsilon))
67+ // norm = sqrt(sum(x * x) + epsilon)
68+ auto sum = x.pow (2 ).sum (rdim) + eps;
69+ norm.device (*place) = sum.sqrt ();
70+ // y = x / norm
71+ Eigen::DSizes<int , 3 > rshape (pre , 1 , post );
72+ Eigen::DSizes<int , 3 > bcast (1 , n, 1 );
73+ y.device (*place) = x / norm.reshape (rshape).broadcast (bcast);
8474 }
8575};
8676template <typename DeviceContext, typename T, typename AttrType = T>
8777class NormGradKernel : public framework ::OpKernel<T> {
8878 public:
89- void Compute (const framework::ExecutionContext& context) const override {
90- const framework::Tensor* in_x = context.Input <framework::Tensor>(" X" );
91- const framework::Tensor* scale = context.Input <framework::Tensor>(" Scale" );
92- const framework::Tensor* out_grad =
93- context.Input <framework::Tensor>(framework::GradVarName (" Out" ));
94- auto epsilon = static_cast <T>(context.Attr <AttrType>(" epsilon" ));
95- framework::Tensor* in_x_grad =
96- context.Output <framework::Tensor>(framework::GradVarName (" X" ));
97- in_x_grad->mutable_data <T>(context.GetPlace ());
98- int batch_size = in_x->dims ()[0 ];
99- int channels = in_x->dims ()[1 ];
100- int height = in_x->dims ()[2 ];
101- int width = in_x->dims ()[3 ];
102- int fea_len = height * width;
103- auto * place =
104- context.template device_context <DeviceContext>().eigen_device ();
105-
106- auto scale_eigen =
107- framework::EigenVector<T, Eigen::RowMajor, Eigen::DenseIndex>::Flatten (
108- *scale);
109- auto x =
110- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
111- *in_x, framework::make_ddim ({batch_size, fea_len * channels}));
112- // get square
113- framework::Tensor x_square;
114- x_square.mutable_data <T>(in_x->dims (), context.GetPlace ());
115- auto x_square_eigen =
116- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
117- x_square, framework::make_ddim ({batch_size, fea_len * channels}));
118- x_square_eigen.device (*place) = x.square ();
119-
120- for (int n = 0 ; n < batch_size; ++n) {
121- framework::Tensor in_x_batch = in_x->Slice (n, n + 1 );
122- auto in_x_batch_eigen =
123- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
124- in_x_batch, framework::make_ddim ({channels, fea_len}));
125- framework::Tensor in_g_batch = in_x_grad->Slice (n, n + 1 );
126- auto in_g_batch_eigen =
127- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
128- in_g_batch, framework::make_ddim ({channels, fea_len}));
129- framework::Tensor x_square_batch = x_square.Slice (n, n + 1 );
130- auto x_square_batch_eigen =
131- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
132- x_square_batch, framework::make_ddim ({channels, fea_len}));
133- framework::Tensor outg_batch = out_grad->Slice (n, n + 1 );
134- auto outg_batch_eigen =
135- framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From (
136- outg_batch, framework::make_ddim ({channels, fea_len}));
137-
138- framework::Tensor tmp_tensor;
139- tmp_tensor.mutable_data <T>(framework::make_ddim ({1 , fea_len}),
140- context.GetPlace ());
141- auto tmp_eigen =
142- framework::EigenVector<T, Eigen::RowMajor,
143- Eigen::DenseIndex>::Flatten (tmp_tensor);
144- auto dim = Eigen::array<int , 1 >({{0 }});
145- tmp_eigen.device (*place) = (in_x_batch_eigen * outg_batch_eigen).sum (dim);
146- framework::Tensor norm_tmp_tensor;
147- norm_tmp_tensor.mutable_data <T>(framework::make_ddim ({1 , fea_len}),
148- context.GetPlace ());
149- auto norm_tmp_eigen =
150- framework::EigenVector<T, Eigen::RowMajor,
151- Eigen::DenseIndex>::Flatten (norm_tmp_tensor);
152- norm_tmp_eigen.device (*place) =
153- (x_square_batch_eigen.sum (dim) + epsilon).sqrt ();
154- Eigen::array<int , 2 > broadcast_dim_col;
155- broadcast_dim_col[1 ] = 1 ;
156- broadcast_dim_col[0 ] = channels;
157- in_g_batch_eigen.device (*place) =
158- in_x_batch_eigen * tmp_eigen.broadcast (broadcast_dim_col);
159- in_g_batch_eigen.device (*place) =
160- in_g_batch_eigen /
161- (norm_tmp_eigen * norm_tmp_eigen).broadcast (broadcast_dim_col);
162- in_g_batch_eigen.device (*place) = outg_batch_eigen - in_g_batch_eigen;
163- // outg_batch_eigen + (in_g_batch_eigen * -1);
164- in_g_batch_eigen.device (*place) =
165- in_g_batch_eigen / norm_tmp_eigen.broadcast (broadcast_dim_col);
166- Eigen::array<int , 2 > broadcast_dim_row;
167- broadcast_dim_row[1 ] = fea_len;
168- broadcast_dim_row[0 ] = 1 ;
169- in_g_batch_eigen.device (*place) =
170- in_g_batch_eigen * (scale_eigen.broadcast (broadcast_dim_row));
171- }
79+ void Compute (const framework::ExecutionContext& ctx) const override {
80+ auto * in_x = ctx.Input <framework::Tensor>(" X" );
81+ auto * in_norm = ctx.Input <framework::Tensor>(" Norm" );
82+ auto * in_dy = ctx.Input <framework::Tensor>(framework::GradVarName (" Out" ));
83+ auto * out_dx = ctx.Output <framework::Tensor>(framework::GradVarName (" X" ));
84+ out_dx->mutable_data <T>(ctx.GetPlace ());
85+
86+ auto xdim = in_x->dims ();
87+ int axis = ctx.Attr <int >(" axis" );
88+ if (axis < 0 ) axis = xdim.size () + axis;
89+ int pre , n, post ;
90+ GetDims (xdim, axis, &pre , &n, &post );
91+
92+ auto * place = ctx.template device_context <DeviceContext>().eigen_device ();
93+
94+ auto x_e = framework::EigenVector<T>::Flatten (*in_x);
95+ auto dy_e = framework::EigenVector<T>::Flatten (*in_dy);
96+ auto norm_e = framework::EigenVector<T>::Flatten (*in_norm);
97+ auto dx_e = framework::EigenVector<T>::Flatten (*out_dx);
98+
99+ Eigen::DSizes<int , 3 > shape (pre , n, post );
100+ Eigen::DSizes<int , 2 > norm_shape (pre , post );
101+ auto x = x_e.reshape (shape);
102+ auto dy = dy_e.reshape (shape);
103+ auto norm = norm_e.reshape (norm_shape);
104+ auto dx = dx_e.reshape (shape);
105+
106+ framework::Tensor rsum;
107+ rsum.mutable_data <T>({pre , post }, ctx.GetPlace ());
108+ auto sum = framework::EigenTensor<T, 2 >::From (rsum);
109+
110+ Eigen::DSizes<int , 1 > rdim (1 );
111+ Eigen::DSizes<int , 3 > bcast (1 , n, 1 );
112+ Eigen::DSizes<int , 3 > rshape (pre , 1 , post );
113+
114+ // dx = ( dy/sqrt(sum(x*x)) ) * [1 - x*sum(x) / (sum(x*x) + e)]
115+ // = [dy - dy * x * sum(x) / (sum(x*x) + e)] / sqrt(sum(x*x))
116+ // = [dy - x * sum(x*dy) / (sum(x*x) + e)] / sqrt(sum(x*x))
117+ // 1. sum = sum(x*dy)
118+ sum.device (*place) = (x * dy).sum (rdim);
119+ // 2. dx = x * sum
120+ dx.device (*place) = sum.reshape (rshape).broadcast (bcast) * x;
121+ // 3. dx / (sum(x*x) + e)
122+ // where, norm.pow(2) = sum(x*x) + e, which is calculated in forward.
123+ dx.device (*place) = dx / norm.pow (2 ).broadcast (bcast);
124+ // 4. [dy - dx] / sqrt(sum(x*x))
125+ dx.device (*place) = (dy - dx) / norm.broadcast (bcast);
172126 }
173127};
174128} // namespace operators
0 commit comments