Skip to content

Commit 9d02313

Browse files
authored
set_value_grad propagate gradients to Input and TensorValue (#34304)
* add set_value_grad op * add unittest. * polish unittest. * polish code. * support cuda kernel * polish code according to CI * polish code. * polish code * remove *.pyc * polish code. * add unittest to improve coverage. * polish code.
1 parent 3429c04 commit 9d02313

File tree

4 files changed

+656
-27
lines changed

4 files changed

+656
-27
lines changed

paddle/fluid/operators/set_value_op.cc

Lines changed: 67 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -157,39 +157,26 @@ class SetValueGradMaker : public framework::SingleGradOpMaker<T> {
157157
protected:
158158
void Apply(GradOpPtr<T> op) const override {
159159
if (this->HasInput("ValueTensor")) {
160-
op->SetType("slice");
161-
op->SetInput("Input", this->OutputGrad("Out"));
160+
op->SetType("set_value_grad");
161+
162+
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
163+
op->SetInput("ValueTensor", this->Input("ValueTensor"));
162164
if (this->HasInput("StartsTensorList")) {
163165
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
164166
}
165167
if (this->HasInput("EndsTensorList")) {
166168
op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
167169
}
170+
if (this->HasInput("StepsTensorList")) {
171+
op->SetInput("StepsTensorList", this->Input("StepsTensorList"));
172+
}
173+
174+
op->SetAttrMap(this->Attrs());
175+
176+
op->SetOutput(framework::GradVarName("ValueTensor"),
177+
this->InputGrad("ValueTensor"));
178+
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));
168179

169-
// convert std::vector<int64_t > to std::vector<int >
170-
std::vector<int64_t> axes_int64 = static_cast<std::vector<int64_t>>(
171-
BOOST_GET_CONST(std::vector<int64_t>, this->GetAttr("axes")));
172-
std::vector<int64_t> starts_int64 = static_cast<std::vector<int64_t>>(
173-
BOOST_GET_CONST(std::vector<int64_t>, this->GetAttr("starts")));
174-
std::vector<int64_t> ends_int64 = static_cast<std::vector<int64_t>>(
175-
BOOST_GET_CONST(std::vector<int64_t>, this->GetAttr("ends")));
176-
std::vector<int64_t> decrease_axes_int64 =
177-
static_cast<std::vector<int64_t>>(BOOST_GET_CONST(
178-
std::vector<int64_t>, this->GetAttr("decrease_axes")));
179-
180-
std::vector<int> axes(axes_int64.begin(), axes_int64.end());
181-
std::vector<int> starts(starts_int64.begin(), starts_int64.end());
182-
std::vector<int> ends(ends_int64.begin(), ends_int64.end());
183-
std::vector<int> decrease_axes(decrease_axes_int64.begin(),
184-
decrease_axes_int64.end());
185-
186-
op->SetAttr("axes", axes);
187-
op->SetAttr("starts", starts);
188-
op->SetAttr("ends", ends);
189-
op->SetAttr("decrease_axis", decrease_axes);
190-
op->SetAttr("infer_flags", std::vector<int>({}));
191-
192-
op->SetOutput("Out", this->InputGrad("ValueTensor"));
193180
} else {
194181
op->SetType("assign");
195182
op->SetInput("X", this->OutputGrad("Out"));
@@ -198,6 +185,50 @@ class SetValueGradMaker : public framework::SingleGradOpMaker<T> {
198185
}
199186
};
200187

188+
class SetValueGrad : public framework::OperatorWithKernel {
189+
public:
190+
using framework::OperatorWithKernel::OperatorWithKernel;
191+
192+
void InferShape(framework::InferShapeContext *ctx) const override {
193+
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
194+
framework::GradVarName("Out"), "set_value_grad");
195+
196+
auto in_dims = ctx->GetInputDim(framework::GradVarName("Out"));
197+
PADDLE_ENFORCE_LT(
198+
in_dims.size(), 7,
199+
platform::errors::InvalidArgument(
200+
"The dimension of set_value_grad operator's input should be less "
201+
"than 7, but received dimension is %d.",
202+
in_dims.size()));
203+
204+
if (ctx->HasOutput(framework::GradVarName("ValueTensor"))) {
205+
ctx->ShareDim("ValueTensor",
206+
/*->*/ framework::GradVarName("ValueTensor"));
207+
ctx->ShareLoD("ValueTensor",
208+
/*->*/ framework::GradVarName("ValueTensor"));
209+
}
210+
}
211+
212+
protected:
213+
framework::OpKernelType GetExpectedKernelType(
214+
const framework::ExecutionContext &ctx) const override {
215+
auto in_tensor = ctx.Input<Tensor>(framework::GradVarName("Out"));
216+
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
217+
ctx, framework::GradVarName("Out")),
218+
in_tensor->place());
219+
}
220+
framework::OpKernelType GetKernelTypeForVar(
221+
const std::string &var_name, const Tensor &tensor,
222+
const framework::OpKernelType &expected_kernel_type) const override {
223+
if (var_name == "StartsTensorList" || var_name == "EndsTensorList" ||
224+
var_name == "StepsTensorList") {
225+
return expected_kernel_type;
226+
}
227+
return framework::OpKernelType(expected_kernel_type.data_type_,
228+
tensor.place(), tensor.layout());
229+
}
230+
};
231+
201232
DECLARE_INPLACE_OP_INFERER(SetValueOpInplaceInferer, {"Input", "Out"});
202233

203234
} // namespace operators
@@ -218,6 +249,16 @@ REGISTER_OP_CPU_KERNEL(
218249
ops::SetValueKernel<plat::CPUDeviceContext, double>,
219250
ops::SetValueKernel<plat::CPUDeviceContext, bool>);
220251

252+
REGISTER_OPERATOR(set_value_grad, ops::SetValueGrad);
253+
254+
REGISTER_OP_CPU_KERNEL(
255+
set_value_grad,
256+
ops::SetValueGradKernel<paddle::platform::CPUDeviceContext, int>,
257+
ops::SetValueGradKernel<plat::CPUDeviceContext, int64_t>,
258+
ops::SetValueGradKernel<plat::CPUDeviceContext, float>,
259+
ops::SetValueGradKernel<plat::CPUDeviceContext, double>,
260+
ops::SetValueGradKernel<plat::CPUDeviceContext, bool>);
261+
221262
REGISTER_OP_VERSION(set_value)
222263
.AddCheckpoint(
223264
R"ROC(

paddle/fluid/operators/set_value_op.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,11 @@ REGISTER_OP_CUDA_KERNEL(
2222
ops::SetValueKernel<paddle::platform::CUDADeviceContext, float>,
2323
ops::SetValueKernel<paddle::platform::CUDADeviceContext, double>,
2424
ops::SetValueKernel<paddle::platform::CUDADeviceContext, bool>);
25+
26+
REGISTER_OP_CUDA_KERNEL(
27+
set_value_grad,
28+
ops::SetValueGradKernel<paddle::platform::CUDADeviceContext, int>,
29+
ops::SetValueGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
30+
ops::SetValueGradKernel<paddle::platform::CUDADeviceContext, float>,
31+
ops::SetValueGradKernel<paddle::platform::CUDADeviceContext, double>,
32+
ops::SetValueGradKernel<paddle::platform::CUDADeviceContext, bool>);

0 commit comments

Comments
 (0)