Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 15 additions & 20 deletions paddle/fluid/operators/set_value_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,31 +151,26 @@ class SetValueGradMaker : public framework::SingleGradOpMaker<T> {

protected:
void Apply(GradOpPtr<T> op) const override {
if (this->HasInput("ValueTensor")) {
op->SetType("set_value_grad");
op->SetType("set_value_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该讨论,如果有ValueTensor调用set_value_grad,没有ValueTensor调用set_value_with_scalar_grad

if (this->HasInput("StartsTensorList")) {
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
}
if (this->HasInput("EndsTensorList")) {
op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
}
if (this->HasInput("StepsTensorList")) {
op->SetInput("StepsTensorList", this->Input("StepsTensorList"));
}

op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("ValueTensor", this->Input("ValueTensor"));
if (this->HasInput("StartsTensorList")) {
op->SetInput("StartsTensorList", this->Input("StartsTensorList"));
}
if (this->HasInput("EndsTensorList")) {
op->SetInput("EndsTensorList", this->Input("EndsTensorList"));
}
if (this->HasInput("StepsTensorList")) {
op->SetInput("StepsTensorList", this->Input("StepsTensorList"));
}
op->SetAttrMap(this->Attrs());

op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));

if (this->HasInput("ValueTensor")) {
op->SetInput("ValueTensor", this->Input("ValueTensor"));
op->SetOutput(framework::GradVarName("ValueTensor"),
this->InputGrad("ValueTensor"));
op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input"));

} else {
op->SetType("assign");
op->SetInput("X", this->OutputGrad("Out"));
op->SetOutput("Out", this->InputGrad("Input"));
}
}
};
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -612,14 +612,14 @@

- backward_op : set_value_grad
forward : set_value (Tensor x, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes, int64_t[] shape, Scalar[] values) -> Tensor(out)
args : (Tensor out_grad)
args : (Tensor out_grad, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes)
output : Tensor(x_grad)
infer_meta:
func: UnchangedInferMeta
param: [out_grad]
kernel:
func: assign
param: [out_grad]
func: set_value_with_scalar_grad
param: [out_grad, starts, ends, steps, axes, decrease_axes, none_axes]

- backward_op : set_value_with_tensor_grad
forward: set_value_with_tensor (Tensor x, Tensor values, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes) -> Tensor(out)
Expand Down
17 changes: 17 additions & 0 deletions paddle/phi/kernels/cpu/set_value_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,20 @@ PD_REGISTER_KERNEL(set_value_grad,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(set_value_with_scalar_grad,
CPU,
ALL_LAYOUT,
phi::SetValueWithScalarGradKernel,
float,
double,
int,
int64_t,
bool,
int16_t,
uint8_t,
int8_t,
phi::dtype::bfloat16,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
17 changes: 17 additions & 0 deletions paddle/phi/kernels/gpu/set_value_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,20 @@ PD_REGISTER_KERNEL(set_value_grad,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(set_value_with_scalar_grad,
GPU,
ALL_LAYOUT,
phi::SetValueWithScalarGradKernel,
float,
double,
int,
int64_t,
bool,
int16_t,
uint8_t,
int8_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
93 changes: 93 additions & 0 deletions paddle/phi/kernels/impl/set_value_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,4 +341,97 @@ void SetValueGradKernel(const Context& dev_ctx,
}
}

template <typename T, typename Context>
void SetValueWithScalarGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const IntArray& starts,
const IntArray& ends,
const IntArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* x_grad) {
const int rank = out_grad.dims().size();

switch (rank) {
case 1:
SetValueGradImpl<T, Context, 1>(dev_ctx,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是直接调SetValueGradKernel就可以了,value_grad传入nullptr。

out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 2:
SetValueGradImpl<T, Context, 2>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 3:
SetValueGradImpl<T, Context, 3>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 4:
SetValueGradImpl<T, Context, 4>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 5:
SetValueGradImpl<T, Context, 5>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 6:
SetValueGradImpl<T, Context, 6>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"The rank of set_value_with_scalar_grad's input should be less than "
"7, but "
"received %d.",
rank));
}
}
} // namespace phi
10 changes: 10 additions & 0 deletions paddle/phi/kernels/set_value_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,14 @@ void SetValueGradKernel(const Context& dev_ctx,
DenseTensor* x_grad,
DenseTensor* value_grad);

template <typename T, typename Context>
void SetValueWithScalarGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const IntArray& starts,
const IntArray& ends,
const IntArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* x_grad);
} // namespace phi
103 changes: 103 additions & 0 deletions paddle/phi/kernels/xpu/set_value_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,100 @@ void SetValueGradKernel(const Context& dev_ctx,
}
}

template <typename T, typename Context>
void SetValueWithScalarGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const IntArray& starts,
const IntArray& ends,
const IntArray& steps,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& decrease_axes,
const std::vector<int64_t>& none_axes,
DenseTensor* x_grad) {
const int rank = out_grad.dims().size();

switch (rank) {
case 1:
SetValueGradImpl<T, Context, 1>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 2:
SetValueGradImpl<T, Context, 2>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 3:
SetValueGradImpl<T, Context, 3>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 4:
SetValueGradImpl<T, Context, 4>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 5:
SetValueGradImpl<T, Context, 5>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
case 6:
SetValueGradImpl<T, Context, 6>(dev_ctx,
out_grad,
starts,
ends,
steps,
axes,
decrease_axes,
none_axes,
x_grad,
nullptr);
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"The rank of set_value_with_scalar_grad's input should be less than "
"7, but "
"received %d.",
rank));
}
}

} // namespace phi

PD_REGISTER_KERNEL(set_value_grad,
Expand All @@ -407,3 +501,12 @@ PD_REGISTER_KERNEL(set_value_grad,
phi::dtype::float16,
int,
int64_t) {}

PD_REGISTER_KERNEL(set_value_with_scalar_grad,
XPU,
ALL_LAYOUT,
phi::SetValueWithScalarGradKernel,
float,
phi::dtype::float16,
int,
int64_t) {}
2 changes: 2 additions & 0 deletions test/legacy_test/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,7 @@ def test_setitem(self):
np.testing.assert_allclose(out[1, 2, 3, 4], np.array(10))
self.assertEqual(x.grad.shape, [2, 3, 4, 5])
x_grad_expected = np.ones((2, 3, 4, 5)) * 2
x_grad_expected[1, 2, 3, 4] = 0
np.testing.assert_allclose(x.grad, x_grad_expected)

# case2: 0-D Tensor indice in some axis
Expand All @@ -847,6 +848,7 @@ def test_setitem(self):
self.assertEqual(out.shape, x.shape)
np.testing.assert_allclose(out[1, 1], np.ones((4, 5)) * 0.5)
x_grad_expected = np.ones((2, 3, 4, 5))
x_grad_expected[1, 1] = 0
np.testing.assert_allclose(x.grad, x_grad_expected)

# case3:0-D Tensor indice in some axis, value is a Tensor
Expand Down
1 change: 1 addition & 0 deletions test/xpu/test_zero_dim_tensor_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ def test_setitem(self):
np.testing.assert_allclose(out[1, 2, 3, 4], np.array(10))
self.assertEqual(x.grad.shape, [2, 3, 4, 5])
x_grad_expected = np.ones((2, 3, 4, 5)) * 2
x_grad_expected[1, 2, 3, 4] = 0
np.testing.assert_allclose(x.grad, x_grad_expected)

# case2: 0-D Tensor indice in some axis
Expand Down