@@ -32,219 +32,5 @@ template <typename T, size_t D, int MajorType = Eigen::RowMajor,
3232 typename IndexType = Eigen::DenseIndex>
3333using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
3434
35- template <typename DeviceContext, typename T>
36- class ExpandAsV2Kernel : public framework ::OpKernel<T> {
37- public:
38- void Compute (const framework::ExecutionContext& context) const override {
39- auto rank = context.Input <Tensor>(" X" )->dims ().size ();
40- auto target_shape = context.Attr <std::vector<int >>(" target_shape" );
41- auto target_rank = target_shape.size ();
42- PADDLE_ENFORCE_GE (target_rank, rank,
43- platform::errors::InvalidArgument (
44- " The rank (%d) of the input 'target_tensor' for "
45- " expand_as_v2 op must be greater than or equal to "
46- " the rank (%d) of the input 'x'." ,
47- target_rank, rank));
48- PADDLE_ENFORCE_GE (rank, 1 , platform::errors::InvalidArgument (
49- " The rank (%d) of the input 'x' for "
50- " expand_as_v2 op must be positive." ,
51- rank));
52- PADDLE_ENFORCE_LE (target_rank, MAX_RANK_SUPPORTED,
53- platform::errors::InvalidArgument (
54- " The rank (%d) of the input 'target_tensor' for "
55- " expand_as_v2 op must be less than or equal to %d." ,
56- target_rank, MAX_RANK_SUPPORTED));
57-
58- switch (target_rank) {
59- case 1 :
60- ExpandAs<1 >(context);
61- break ;
62- case 2 :
63- ExpandAs<2 >(context);
64- break ;
65- case 3 :
66- ExpandAs<3 >(context);
67- break ;
68- case 4 :
69- ExpandAs<4 >(context);
70- break ;
71- case 5 :
72- ExpandAs<5 >(context);
73- break ;
74- case 6 :
75- ExpandAs<6 >(context);
76- break ;
77- }
78- }
79-
80- protected:
81- template <int Rank>
82- void ExpandAs (const framework::ExecutionContext& context) const {
83- auto * in0 = context.Input <Tensor>(" X" );
84- auto in_dims = in0->dims ();
85- auto target_shape = context.Attr <std::vector<int >>(" target_shape" );
86- auto vec_in_dims = phi::vectorize<int >(in_dims);
87- auto diff = target_shape.size () - vec_in_dims.size ();
88- vec_in_dims.insert (vec_in_dims.begin (), diff, 1 );
89- std::vector<int > repeat_times (vec_in_dims.size ());
90- for (size_t i = 0 ; i < vec_in_dims.size (); ++i) {
91- PADDLE_ENFORCE_NE (target_shape[i], 0 ,
92- platform::errors::InvalidArgument (
93- " The value of target shape cannot be zero." ));
94- if (i < diff) {
95- PADDLE_ENFORCE_GT (
96- target_shape[i], 0 ,
97- platform::errors::InvalidArgument (
98- " The expanded size (%d) for non-existing dimensions must be "
99- " positive for expand_as_v2 op." ,
100- target_shape[i]));
101- repeat_times[i] = target_shape[i];
102- } else if (target_shape[i] > 0 ) {
103- if (vec_in_dims[i] != 1 ) {
104- PADDLE_ENFORCE_EQ (
105- vec_in_dims[i], target_shape[i],
106- platform::errors::InvalidArgument (
107- " The value (%d) of the non-singleton dimension does not match"
108- " the corresponding value (%d) in shape for expand_as_v2 op." ,
109- vec_in_dims[i], target_shape[i]));
110- repeat_times[i] = 1 ;
111- } else {
112- repeat_times[i] = target_shape[i];
113- }
114- } else {
115- PADDLE_ENFORCE_EQ (
116- target_shape[i], -1 ,
117- platform::errors::InvalidArgument (
118- " When the value in shape is negative for expand_as_v2 op, "
119- " only -1 is supported, but the value received is %d." ,
120- target_shape[i]));
121- repeat_times[i] = 1 ;
122- }
123- }
124- auto * out0 = context.Output <Tensor>(" Out" );
125- Eigen::DSizes<Eigen::DenseIndex, Rank> bcast_dims;
126- for (size_t i = 0 ; i < repeat_times.size (); ++i) {
127- bcast_dims[i] = repeat_times[i];
128- }
129-
130- framework::DDim new_in_dims = phi::make_ddim (vec_in_dims);
131- framework::DDim out_dims = phi::make_ddim (target_shape);
132-
133- out0->Resize (out_dims);
134- auto x = EigenTensor<T, Rank>::From (*in0, new_in_dims);
135- out0->mutable_data <T>(context.GetPlace ());
136- auto y = EigenTensor<T, Rank>::From (*out0, out_dims);
137- auto & place =
138- *context.template device_context <DeviceContext>().eigen_device ();
139- EigenBroadcast<std::decay_t <decltype (place)>, T, Rank>::Eval (place, y, x,
140- bcast_dims);
141- }
142- };
143-
144- template <typename DeviceContext, typename T>
145- class ExpandAsV2GradKernel : public framework ::OpKernel<T> {
146- public:
147- void Compute (const framework::ExecutionContext& context) const override {
148- auto * in0 = context.Input <Tensor>(" X" );
149- auto target_shape = context.Attr <std::vector<int >>(" target_shape" );
150- auto x_dims = in0->dims ();
151- auto vec_in_dims = phi::vectorize<int >(x_dims);
152- auto diff = target_shape.size () - vec_in_dims.size ();
153- vec_in_dims.insert (vec_in_dims.begin (), diff, 1 );
154- std::vector<int > repeat_times (vec_in_dims.size ());
155- for (size_t i = 0 ; i < vec_in_dims.size (); ++i) {
156- repeat_times[i] = target_shape[i] / vec_in_dims[i];
157- }
158- std::vector<int > reshape_dims_vec;
159- std::vector<int > reduce_dims_vec;
160- for (size_t i = 0 ; i < repeat_times.size (); ++i) {
161- reduce_dims_vec.push_back (reshape_dims_vec.size ());
162- reshape_dims_vec.push_back (repeat_times[i]);
163- reshape_dims_vec.push_back (vec_in_dims[i]);
164- }
165-
166- int dims = reduce_dims_vec.size ();
167- bool just_copy = true ;
168- for (size_t i = 0 ; i < repeat_times.size (); i++) {
169- if (repeat_times[i] != 1 ) {
170- just_copy = false ;
171- break ;
172- }
173- }
174- // no need reduce, just copy
175- if (just_copy) {
176- auto * in0 = context.Input <Tensor>(framework::GradVarName (" Out" ));
177- auto * out0 = context.Output <Tensor>(framework::GradVarName (" X" ));
178- out0->mutable_data <T>(context.GetPlace ());
179- framework::TensorCopy (*in0, context.GetPlace (), context.device_context (),
180- out0);
181- } else {
182- PADDLE_ENFORCE_GE (dims, 1 ,
183- platform::errors::InvalidArgument (
184- " The rank of the input 'Out@GRAD' for "
185- " expand_as_v2_grad op must be greater than or "
186- " equal to 1, but the value received is %d." ,
187- dims));
188- PADDLE_ENFORCE_LE (dims, MAX_RANK_SUPPORTED,
189- platform::errors::InvalidArgument (
190- " The rank of the input 'Out@GRAD' for "
191- " expand_as_v2_grad op must be less than or equal "
192- " to %d, but the value received is %d." ,
193- MAX_RANK_SUPPORTED, dims));
194- switch (dims) {
195- case 1 :
196- ExpandAsBackward<1 >(context, reshape_dims_vec, reduce_dims_vec);
197- break ;
198- case 2 :
199- ExpandAsBackward<2 >(context, reshape_dims_vec, reduce_dims_vec);
200- break ;
201- case 3 :
202- ExpandAsBackward<3 >(context, reshape_dims_vec, reduce_dims_vec);
203- break ;
204- case 4 :
205- ExpandAsBackward<4 >(context, reshape_dims_vec, reduce_dims_vec);
206- break ;
207- case 5 :
208- ExpandAsBackward<5 >(context, reshape_dims_vec, reduce_dims_vec);
209- break ;
210- case 6 :
211- ExpandAsBackward<6 >(context, reshape_dims_vec, reduce_dims_vec);
212- break ;
213- default :
214- PADDLE_THROW (platform::errors::InvalidArgument (
215- " Only support tensor with rank being between 1 and 6. But "
216- " received tensor's rank = %d." ,
217- dims));
218- }
219- }
220- }
221-
222- protected:
223- template <int Dims>
224- void ExpandAsBackward (const framework::ExecutionContext& context,
225- const std::vector<int >& reshape_dims_vec,
226- const std::vector<int >& reduce_dims_vec) const {
227- size_t reshape_size = reshape_dims_vec.size ();
228- size_t reduce_size = reduce_dims_vec.size ();
229- auto * in0 = context.Input <Tensor>(framework::GradVarName (" Out" ));
230- auto * out0 = context.Output <Tensor>(framework::GradVarName (" X" ));
231- out0->mutable_data <T>(context.GetPlace ());
232- auto x_grad = EigenVector<T>::Flatten (*out0);
233- Eigen::DSizes<Eigen::DenseIndex, Dims * 2 > reshape_dims;
234- for (size_t i = 0 ; i < reshape_size; ++i) {
235- reshape_dims[i] = reshape_dims_vec[i];
236- }
237- Eigen::DSizes<Eigen::DenseIndex, Dims> reduce_dims;
238- for (size_t i = 0 ; i < reduce_size; ++i) {
239- reduce_dims[i] = reduce_dims_vec[i];
240- }
241- auto out_grad = EigenVector<T>::Flatten (*in0);
242- auto & place =
243- *context.template device_context <DeviceContext>().eigen_device ();
244- EigenBroadcastGrad<std::decay_t <decltype (place)>, T, Dims>::Eval (
245- place, x_grad, out_grad, reduce_dims, reshape_dims);
246- }
247- };
248-
24935} // namespace operators
25036} // namespace paddle
0 commit comments