Skip to content

Commit 993ede7

Browse files
Feiyu Chanpiotrekobi
authored andcommitted
roll_op: support Tensor as input for shifts (PaddlePaddle#36727)
1 parent 06ac1c3 commit 993ede7

File tree

5 files changed

+105
-22
lines changed

5 files changed

+105
-22
lines changed

paddle/fluid/operators/roll_op.cc

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,23 @@ class RollOp : public framework::OperatorWithKernel {
4040
auto dims = ctx->Attrs().Get<std::vector<int64_t>>("axis");
4141
auto shifts = ctx->Attrs().Get<std::vector<int64_t>>("shifts");
4242

43-
if (dims.size() != 0) {
44-
PADDLE_ENFORCE_EQ(dims.size(), shifts.size(),
45-
platform::errors::InvalidArgument(
46-
"When dims.size() != 0, dims.size() "
47-
"should be equal to "
48-
"shifts.size(). But received "
49-
"dims.size() = %d, shifts.size() = %d",
50-
dims.size(), shifts.size()));
51-
} else {
52-
PADDLE_ENFORCE_EQ(shifts.size(), 1,
53-
platform::errors::InvalidArgument(
54-
"When dims.size() == 0, shifts.size() "
55-
"should be equal to 1, But received "
56-
"shifts.size() = %d",
57-
shifts.size()));
43+
if (!ctx->HasInput("ShiftsTensor")) {
44+
if (dims.size() != 0) {
45+
PADDLE_ENFORCE_EQ(dims.size(), shifts.size(),
46+
platform::errors::InvalidArgument(
47+
"When dims.size() != 0, dims.size() "
48+
"should be equal to "
49+
"shifts.size(). But received "
50+
"dims.size() = %d, shifts.size() = %d",
51+
dims.size(), shifts.size()));
52+
} else {
53+
PADDLE_ENFORCE_EQ(shifts.size(), 1,
54+
platform::errors::InvalidArgument(
55+
"When dims.size() == 0, shifts.size() "
56+
"should be equal to 1, But received "
57+
"shifts.size() = %d",
58+
shifts.size()));
59+
}
5860
}
5961

6062
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
@@ -105,6 +107,10 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker {
105107
"The number of places by which the elements "
106108
"of the tensor are shifted.")
107109
.SetDefault({});
110+
AddInput("ShiftsTensor",
111+
"The number of places by which the elements of the tensor "
112+
"are shifted.")
113+
.AsDispensable();
108114
AddAttr<std::vector<int64_t>>(
109115
"axis",
110116
"Axis along which to roll. It must have the same size "
@@ -129,6 +135,9 @@ class RollGradMaker : public framework::SingleGradOpMaker<T> {
129135
void Apply(GradOpPtr<T> op) const override {
130136
op->SetType("roll_grad");
131137
op->SetInput("X", this->Input("X"));
138+
if (this->HasInput("ShiftsTensor")) {
139+
op->SetInput("ShiftsTensor", this->Input("ShiftsTensor"));
140+
}
132141
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
133142
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
134143
op->SetAttrMap(this->Attrs());

paddle/fluid/operators/roll_op.cu

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ class RollKernel<platform::CUDADeviceContext, T>
5959
auto* in = context.Input<LoDTensor>("X");
6060
auto* out = context.Output<LoDTensor>("Out");
6161
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
62+
if (context.HasInput("ShiftsTensor")) {
63+
const auto* shifts_tensor =
64+
context.Input<framework::Tensor>("ShiftsTensor");
65+
PADDLE_ENFORCE_EQ(
66+
shifts_tensor->dims().size(), 1,
67+
platform::errors::InvalidArgument(
68+
"The rank of ShiftsTensor is expected to be 1, got %s",
69+
shifts_tensor->dims().size()));
70+
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
71+
}
6272
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
6373

6474
auto* in_data = in->data<T>();
@@ -134,6 +144,16 @@ class RollGradKernel<platform::CUDADeviceContext, T>
134144
auto* in = context.Input<LoDTensor>(framework::GradVarName("Out"));
135145
auto* out = context.Output<LoDTensor>(framework::GradVarName("X"));
136146
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
147+
if (context.HasInput("ShiftsTensor")) {
148+
const auto* shifts_tensor =
149+
context.Input<framework::Tensor>("ShiftsTensor");
150+
PADDLE_ENFORCE_EQ(
151+
shifts_tensor->dims().size(), 1,
152+
platform::errors::InvalidArgument(
153+
"The rank of ShiftsTensor is expected to be 1, got %s",
154+
shifts_tensor->dims().size()));
155+
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
156+
}
137157
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
138158

139159
auto* in_data = in->data<T>();

paddle/fluid/operators/roll_op.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include <memory>
1717
#include <vector>
1818
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/operators/utils.h"
20+
#include "paddle/fluid/platform/enforce.h"
1921

2022
namespace paddle {
2123
namespace operators {
@@ -85,6 +87,16 @@ class RollKernel : public framework::OpKernel<T> {
8587
auto& input = input_var->Get<LoDTensor>();
8688
auto* output = output_var->GetMutable<LoDTensor>();
8789
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
90+
if (context.HasInput("ShiftsTensor")) {
91+
const auto* shifts_tensor =
92+
context.Input<framework::Tensor>("ShiftsTensor");
93+
PADDLE_ENFORCE_EQ(
94+
shifts_tensor->dims().size(), 1,
95+
platform::errors::InvalidArgument(
96+
"The rank of ShiftsTensor is expected to be 1, got %s",
97+
shifts_tensor->dims().size()));
98+
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
99+
}
88100
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
89101

90102
std::vector<T> out_vec;
@@ -123,6 +135,11 @@ class RollGradKernel : public framework::OpKernel<T> {
123135
auto& input = input_var->Get<LoDTensor>();
124136
auto* output = output_var->GetMutable<LoDTensor>();
125137
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
138+
if (context.HasInput("ShiftsTensor")) {
139+
const auto* shifts_tensor =
140+
context.Input<framework::Tensor>("ShiftsTensor");
141+
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
142+
}
126143
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");
127144

128145
std::vector<T> out_vec;

python/paddle/fluid/tests/unittests/test_roll_op.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,34 @@ def test_axis_out_range():
122122

123123
self.assertRaises(ValueError, test_axis_out_range)
124124

125+
def test_shifts_as_tensor_dygraph(self):
126+
with fluid.dygraph.guard():
127+
x = paddle.arange(9).reshape([3, 3])
128+
shape = paddle.shape(x)
129+
shifts = shape // 2
130+
axes = [0, 1]
131+
out = paddle.roll(x, shifts=shifts, axis=axes).numpy()
132+
expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]])
133+
self.assertTrue(np.allclose(out, expected_out))
134+
135+
def test_shifts_as_tensor_static(self):
136+
with program_guard(Program(), Program()):
137+
x = paddle.arange(9).reshape([3, 3]).astype('float32')
138+
shape = paddle.shape(x)
139+
shifts = shape // 2
140+
axes = [0, 1]
141+
out = paddle.roll(x, shifts=shifts, axis=axes)
142+
expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]])
143+
144+
exe = fluid.Executor(fluid.CPUPlace())
145+
[out_np] = exe.run(fetch_list=[out])
146+
self.assertTrue(np.allclose(out_np, expected_out))
147+
148+
if paddle.is_compiled_with_cuda():
149+
exe = fluid.Executor(fluid.CPUPlace())
150+
[out_np] = exe.run(fetch_list=[out])
151+
self.assertTrue(np.allclose(out_np, expected_out))
152+
125153

126154
if __name__ == "__main__":
127155
unittest.main()

python/paddle/tensor/manipulation.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -696,15 +696,24 @@ def roll(x, shifts, axis=None, name=None):
696696

697697
helper = LayerHelper("roll", **locals())
698698
check_type(axis, 'axis', (list, tuple), 'roll')
699-
check_type(shifts, 'shifts', (list, tuple), 'roll')
699+
700700
out = helper.create_variable_for_type_inference(x.dtype)
701701

702-
helper.append_op(
703-
type='roll',
704-
inputs={'X': x},
705-
outputs={'Out': out},
706-
attrs={'axis': axis,
707-
'shifts': shifts})
702+
if isinstance(shifts, Variable):
703+
helper.append_op(
704+
type='roll',
705+
inputs={'X': x,
706+
"ShiftsTensor": shifts},
707+
outputs={'Out': out},
708+
attrs={'axis': axis})
709+
else:
710+
check_type(shifts, 'shifts', (list, tuple), 'roll')
711+
helper.append_op(
712+
type='roll',
713+
inputs={'X': x},
714+
outputs={'Out': out},
715+
attrs={'axis': axis,
716+
'shifts': shifts})
708717
return out
709718

710719

0 commit comments

Comments
 (0)