@@ -157,39 +157,26 @@ class SetValueGradMaker : public framework::SingleGradOpMaker<T> {
157157 protected:
158158 void Apply (GradOpPtr<T> op) const override {
159159 if (this ->HasInput (" ValueTensor" )) {
160- op->SetType (" slice" );
161- op->SetInput (" Input" , this ->OutputGrad (" Out" ));
160+ op->SetType (" set_value_grad" );
161+
162+ op->SetInput (framework::GradVarName (" Out" ), this ->OutputGrad (" Out" ));
163+ op->SetInput (" ValueTensor" , this ->Input (" ValueTensor" ));
162164 if (this ->HasInput (" StartsTensorList" )) {
163165 op->SetInput (" StartsTensorList" , this ->Input (" StartsTensorList" ));
164166 }
165167 if (this ->HasInput (" EndsTensorList" )) {
166168 op->SetInput (" EndsTensorList" , this ->Input (" EndsTensorList" ));
167169 }
170+ if (this ->HasInput (" StepsTensorList" )) {
171+ op->SetInput (" StepsTensorList" , this ->Input (" StepsTensorList" ));
172+ }
173+
174+ op->SetAttrMap (this ->Attrs ());
175+
176+ op->SetOutput (framework::GradVarName (" ValueTensor" ),
177+ this ->InputGrad (" ValueTensor" ));
178+ op->SetOutput (framework::GradVarName (" Input" ), this ->InputGrad (" Input" ));
168179
169- // convert std::vector<int64_t > to std::vector<int >
170- std::vector<int64_t > axes_int64 = static_cast <std::vector<int64_t >>(
171- BOOST_GET_CONST (std::vector<int64_t >, this ->GetAttr (" axes" )));
172- std::vector<int64_t > starts_int64 = static_cast <std::vector<int64_t >>(
173- BOOST_GET_CONST (std::vector<int64_t >, this ->GetAttr (" starts" )));
174- std::vector<int64_t > ends_int64 = static_cast <std::vector<int64_t >>(
175- BOOST_GET_CONST (std::vector<int64_t >, this ->GetAttr (" ends" )));
176- std::vector<int64_t > decrease_axes_int64 =
177- static_cast <std::vector<int64_t >>(BOOST_GET_CONST (
178- std::vector<int64_t >, this ->GetAttr (" decrease_axes" )));
179-
180- std::vector<int > axes (axes_int64.begin (), axes_int64.end ());
181- std::vector<int > starts (starts_int64.begin (), starts_int64.end ());
182- std::vector<int > ends (ends_int64.begin (), ends_int64.end ());
183- std::vector<int > decrease_axes (decrease_axes_int64.begin (),
184- decrease_axes_int64.end ());
185-
186- op->SetAttr (" axes" , axes);
187- op->SetAttr (" starts" , starts);
188- op->SetAttr (" ends" , ends);
189- op->SetAttr (" decrease_axis" , decrease_axes);
190- op->SetAttr (" infer_flags" , std::vector<int >({}));
191-
192- op->SetOutput (" Out" , this ->InputGrad (" ValueTensor" ));
193180 } else {
194181 op->SetType (" assign" );
195182 op->SetInput (" X" , this ->OutputGrad (" Out" ));
@@ -198,6 +185,50 @@ class SetValueGradMaker : public framework::SingleGradOpMaker<T> {
198185 }
199186};
200187
188+ class SetValueGrad : public framework ::OperatorWithKernel {
189+ public:
190+ using framework::OperatorWithKernel::OperatorWithKernel;
191+
192+ void InferShape (framework::InferShapeContext *ctx) const override {
193+ OP_INOUT_CHECK (ctx->HasInput (framework::GradVarName (" Out" )), " Input" ,
194+ framework::GradVarName (" Out" ), " set_value_grad" );
195+
196+ auto in_dims = ctx->GetInputDim (framework::GradVarName (" Out" ));
197+ PADDLE_ENFORCE_LT (
198+ in_dims.size (), 7 ,
199+ platform::errors::InvalidArgument (
200+ " The dimension of set_value_grad operator's input should be less "
201+ " than 7, but received dimension is %d." ,
202+ in_dims.size ()));
203+
204+ if (ctx->HasOutput (framework::GradVarName (" ValueTensor" ))) {
205+ ctx->ShareDim (" ValueTensor" ,
206+ /* ->*/ framework::GradVarName (" ValueTensor" ));
207+ ctx->ShareLoD (" ValueTensor" ,
208+ /* ->*/ framework::GradVarName (" ValueTensor" ));
209+ }
210+ }
211+
212+ protected:
213+ framework::OpKernelType GetExpectedKernelType (
214+ const framework::ExecutionContext &ctx) const override {
215+ auto in_tensor = ctx.Input <Tensor>(framework::GradVarName (" Out" ));
216+ return framework::OpKernelType (OperatorWithKernel::IndicateVarDataType (
217+ ctx, framework::GradVarName (" Out" )),
218+ in_tensor->place ());
219+ }
220+ framework::OpKernelType GetKernelTypeForVar (
221+ const std::string &var_name, const Tensor &tensor,
222+ const framework::OpKernelType &expected_kernel_type) const override {
223+ if (var_name == " StartsTensorList" || var_name == " EndsTensorList" ||
224+ var_name == " StepsTensorList" ) {
225+ return expected_kernel_type;
226+ }
227+ return framework::OpKernelType (expected_kernel_type.data_type_ ,
228+ tensor.place (), tensor.layout ());
229+ }
230+ };
231+
201232DECLARE_INPLACE_OP_INFERER (SetValueOpInplaceInferer, {" Input" , " Out" });
202233
203234} // namespace operators
@@ -218,6 +249,16 @@ REGISTER_OP_CPU_KERNEL(
218249 ops::SetValueKernel<plat::CPUDeviceContext, double >,
219250 ops::SetValueKernel<plat::CPUDeviceContext, bool >);
220251
252+ REGISTER_OPERATOR (set_value_grad, ops::SetValueGrad);
253+
254+ REGISTER_OP_CPU_KERNEL (
255+ set_value_grad,
256+ ops::SetValueGradKernel<paddle::platform::CPUDeviceContext, int >,
257+ ops::SetValueGradKernel<plat::CPUDeviceContext, int64_t >,
258+ ops::SetValueGradKernel<plat::CPUDeviceContext, float >,
259+ ops::SetValueGradKernel<plat::CPUDeviceContext, double >,
260+ ops::SetValueGradKernel<plat::CPUDeviceContext, bool >);
261+
221262REGISTER_OP_VERSION (set_value)
222263 .AddCheckpoint(
223264 R"ROC(
0 commit comments