Skip to content

Commit 4860feb

Browse files
Thunderbrookchen-zhiyu
authored andcommitted
add xpu slice op (PaddlePaddle#27349)
* add xpu slice op test=xpu * add slice xpu op test=xpu * code style test=kunlun * style test=kunlun * format test=kunlun
1 parent 828e343 commit 4860feb

File tree

2 files changed

+374
-0
lines changed

2 files changed

+374
-0
lines changed
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifdef PADDLE_WITH_XPU
16+
17+
#include <algorithm>
18+
#include <memory>
19+
#include <string>
20+
#include <vector>
21+
#include "paddle/fluid/operators/slice_op.h"
22+
23+
namespace paddle {
24+
namespace operators {
25+
26+
using Tensor = framework::Tensor;
27+
28+
template <typename DeviceContext, typename T>
29+
class SliceXPUKernel : public framework::OpKernel<T> {
30+
public:
31+
void Compute(const framework::ExecutionContext& ctx) const override {
32+
auto in = ctx.Input<framework::Tensor>("Input");
33+
auto out = ctx.Output<framework::Tensor>("Out");
34+
auto axes = ctx.Attr<std::vector<int>>("axes");
35+
auto starts = ctx.Attr<std::vector<int>>("starts");
36+
auto ends = ctx.Attr<std::vector<int>>("ends");
37+
auto in_dims = in->dims();
38+
39+
// prepare starts, ends on XPU
40+
int dim_value = 0, start = 0, end = 0;
41+
// If a negative value is passed for any of the start or end indices,
42+
// it represents number of elements before the end of that dimension.
43+
// If the value passed to start or end is larger than the n
44+
// (the number of elements in this dimension), it represents n.
45+
for (size_t i = 0; i < axes.size(); ++i) {
46+
dim_value = in_dims[axes[i]];
47+
start = starts[i];
48+
end = ends[i];
49+
start = start < 0 ? (start + dim_value) : start;
50+
end = end < 0 ? (end + dim_value) : end;
51+
start = std::max(start, 0);
52+
end = std::max(end, 0);
53+
end = std::min(end, dim_value);
54+
PADDLE_ENFORCE_GT(end, start, platform::errors::InvalidArgument(
55+
"end should greater than start"));
56+
starts[i] = start;
57+
ends[i] = end;
58+
}
59+
size_t shape_size = in_dims.size();
60+
// the slice XPU kernel require that the length of `start`, `end` must be
61+
// equal
62+
// to the dims size of input tensor, therefore, if shape_size > axes.size(),
63+
// the `starts_extension` and `ends_extension` is necessary.
64+
std::vector<int> starts_extension(shape_size, 0);
65+
std::vector<int> ends_extension(shape_size, 0);
66+
if (shape_size > axes.size()) {
67+
for (size_t i = 0; i < shape_size; ++i) {
68+
ends_extension[i] = in_dims[i];
69+
}
70+
for (size_t i = 0; i < axes.size(); ++i) {
71+
starts_extension[axes[i]] = starts[i];
72+
ends_extension[axes[i]] = ends[i];
73+
}
74+
} else {
75+
starts_extension = std::move(starts);
76+
ends_extension = std::move(ends);
77+
}
78+
79+
// prepare shape on XPU
80+
std::vector<int> shape(shape_size, 0);
81+
for (size_t i = 0; i < shape_size; ++i) {
82+
shape[i] = in_dims[i];
83+
}
84+
85+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
86+
auto* in_data = in->data<T>();
87+
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
88+
89+
int r = xpu::slice_forward(dev_ctx.x_context(), shape.data(),
90+
starts_extension.data(), ends_extension.data(),
91+
shape_size, in_data, out_data);
92+
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
93+
platform::errors::External("XPU slice kernel error!"));
94+
}
95+
};
96+
97+
template <typename DeviceContext, typename T>
98+
class SliceGradXPUKernel : public framework::OpKernel<T> {
99+
public:
100+
void Compute(const framework::ExecutionContext& ctx) const override {
101+
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
102+
auto* d_in = ctx.Output<framework::Tensor>(framework::GradVarName("Input"));
103+
d_in->mutable_data<T>(ctx.GetPlace());
104+
105+
auto in_dims = d_in->dims();
106+
auto axes = ctx.Attr<std::vector<int>>("axes");
107+
auto starts = ctx.Attr<std::vector<int>>("starts");
108+
auto ends = ctx.Attr<std::vector<int>>("ends");
109+
110+
// prepare starts, ends on XPU
111+
int dim_value = 0, start = 0, end = 0;
112+
// If a negative value is passed for any of the start or end indices,
113+
// it represents number of elements before the end of that dimension.
114+
// If the value passed to start or end is larger than the n
115+
// (the number of elements in this dimension), it represents n.
116+
for (size_t i = 0; i < axes.size(); ++i) {
117+
dim_value = in_dims[axes[i]];
118+
start = starts[i];
119+
end = ends[i];
120+
start = start < 0 ? (start + dim_value) : start;
121+
end = end < 0 ? (end + dim_value) : end;
122+
start = std::max(start, 0);
123+
end = std::max(end, 0);
124+
end = std::min(end, dim_value);
125+
PADDLE_ENFORCE_GT(end, start, platform::errors::InvalidArgument(
126+
"end should greater than start"));
127+
starts[i] = start;
128+
ends[i] = end;
129+
}
130+
size_t shape_size = in_dims.size();
131+
// the slice XPU kernel require that the length of `start`, `end` must be
132+
// equal
133+
// to the dims size of input tensor, therefore, if shape_size > axes.size(),
134+
// the `starts_extension` and `ends_extension` is necessary.
135+
std::vector<int> starts_extension(shape_size, 0);
136+
std::vector<int> ends_extension(shape_size, 0);
137+
if (shape_size > axes.size()) {
138+
for (size_t i = 0; i < shape_size; ++i) {
139+
ends_extension[i] = in_dims[i];
140+
}
141+
for (size_t i = 0; i < axes.size(); ++i) {
142+
starts_extension[axes[i]] = starts[i];
143+
ends_extension[axes[i]] = ends[i];
144+
}
145+
}
146+
int* starts_device = nullptr;
147+
int* ends_device = nullptr;
148+
int* starts_host =
149+
shape_size > axes.size() ? starts_extension.data() : starts.data();
150+
int* ends_host =
151+
shape_size > axes.size() ? ends_extension.data() : ends.data();
152+
PADDLE_ENFORCE_EQ(
153+
xpu_malloc((void**)(&starts_device), shape_size * sizeof(int)),
154+
XPU_SUCCESS, platform::errors::External("XPU has no enough memory"));
155+
PADDLE_ENFORCE_EQ(
156+
xpu_malloc((void**)(&ends_device), shape_size * sizeof(int)),
157+
XPU_SUCCESS, platform::errors::External("XPU has no enough memory"));
158+
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()),
159+
starts_device, platform::CPUPlace(), starts_host,
160+
shape_size * sizeof(int));
161+
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()),
162+
ends_device, platform::CPUPlace(), ends_host,
163+
shape_size * sizeof(int));
164+
165+
// prepare shape on XPU
166+
std::vector<int> shape(shape_size, 0);
167+
for (size_t i = 0; i < shape_size; ++i) {
168+
shape[i] = in_dims[i];
169+
}
170+
int* shape_device = nullptr;
171+
PADDLE_ENFORCE_EQ(
172+
xpu_malloc((void**)(&shape_device), shape_size * sizeof(int)),
173+
XPU_SUCCESS, platform::errors::External("XPU has no enough memory"));
174+
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()),
175+
shape_device, platform::CPUPlace(), shape.data(),
176+
shape_size * sizeof(int));
177+
178+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
179+
int r =
180+
xpu::slice_backward(dev_ctx.x_context(), shape_device, starts_device,
181+
ends_device, shape_size, d_out->data<T>(),
182+
d_in->data<T>(), d_in->numel(), d_out->numel());
183+
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
184+
platform::errors::External("xpu slice kernel error"));
185+
dev_ctx.Wait();
186+
// free device data
187+
xpu_free(shape_device);
188+
xpu_free(starts_device);
189+
xpu_free(ends_device);
190+
}
191+
};
192+
193+
} // namespace operators
194+
} // namespace paddle
195+
196+
namespace ops = paddle::operators;
197+
198+
REGISTER_OP_XPU_KERNEL(
199+
slice, ops::SliceXPUKernel<paddle::platform::XPUDeviceContext, float>);
200+
REGISTER_OP_XPU_KERNEL(
201+
slice_grad,
202+
ops::SliceGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
203+
#endif
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import numpy as np
19+
import sys
20+
sys.path.append("..")
21+
import paddle
22+
import paddle.fluid.core as core
23+
from op_test import OpTest
24+
import paddle.fluid as fluid
25+
import paddle.fluid.layers as layers
26+
27+
28+
# Situation 1: starts(list, no tensor), ends(list, no tensor)
29+
# 1.1 without attr(decrease)
30+
class TestSliceOp(OpTest):
31+
def setUp(self):
32+
self.op_type = "slice"
33+
self.config()
34+
self.inputs = {'Input': self.input}
35+
self.outputs = {'Out': self.out}
36+
self.attrs = {
37+
'axes': self.axes,
38+
'starts': self.starts,
39+
'ends': self.ends,
40+
'infer_flags': self.infer_flags,
41+
"use_xpu": True
42+
}
43+
44+
def config(self):
45+
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
46+
self.starts = [1, 0, 2]
47+
self.ends = [3, 3, 4]
48+
self.axes = [0, 1, 2]
49+
self.infer_flags = [1, 1, 1]
50+
self.out = self.input[1:3, 0:3, 2:4, :]
51+
52+
def test_check_output(self):
53+
place = paddle.XPUPlace(0)
54+
self.check_output_with_place(place)
55+
56+
def test_check_grad_normal(self):
57+
place = paddle.XPUPlace(0)
58+
self.check_grad_with_place(place, ['Input'], 'Out')
59+
60+
61+
class TestCase1(TestSliceOp):
62+
def config(self):
63+
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
64+
self.starts = [-3, 0, 2]
65+
self.ends = [3, 100, -1]
66+
self.axes = [0, 1, 2]
67+
self.infer_flags = [1, 1, 1]
68+
self.out = self.input[-3:3, 0:100, 2:-1, :]
69+
70+
71+
class TestCase2(TestSliceOp):
72+
def config(self):
73+
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
74+
self.starts = [-3, 0, 2]
75+
self.ends = [3, 100, -1]
76+
self.axes = [0, 1, 3]
77+
self.infer_flags = [1, 1, 1]
78+
self.out = self.input[-3:3, 0:100, :, 2:-1]
79+
80+
81+
# 1.2 with attr(decrease)
82+
class TestSliceOp_decs_dim(OpTest):
83+
def setUp(self):
84+
self.op_type = "slice"
85+
self.config()
86+
self.inputs = {'Input': self.input}
87+
self.outputs = {'Out': self.out}
88+
self.attrs = {
89+
'axes': self.axes,
90+
'starts': self.starts,
91+
'ends': self.ends,
92+
'infer_flags': self.infer_flags,
93+
'decrease_axis': self.decrease_axis,
94+
"use_xpu": True
95+
}
96+
97+
def config(self):
98+
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
99+
self.starts = [1, 0, 2]
100+
self.ends = [2, 3, 4]
101+
self.axes = [0, 1, 2]
102+
self.decrease_axis = [0]
103+
self.infer_flags = [1, 1, 1]
104+
self.out = self.input[1, 0:3, 2:4, :]
105+
106+
def test_check_output(self):
107+
place = paddle.XPUPlace(0)
108+
self.check_output_with_place(place)
109+
110+
def test_check_grad_normal(self):
111+
place = paddle.XPUPlace(0)
112+
self.check_grad_with_place(place, ['Input'], 'Out')
113+
114+
115+
class TestSliceOp_decs_dim_2(TestSliceOp_decs_dim):
116+
def config(self):
117+
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
118+
self.starts = [1, 0, 2]
119+
self.ends = [2, 1, 4]
120+
self.axes = [0, 1, 2]
121+
self.decrease_axis = [0, 1]
122+
self.infer_flags = [1, 1, 1]
123+
self.out = self.input[1, 0, 2:4, :]
124+
125+
126+
class TestSliceOp_decs_dim_3(TestSliceOp_decs_dim):
127+
def config(self):
128+
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
129+
self.starts = [-1, 0, 2]
130+
self.ends = [1000000, 1, 4]
131+
self.axes = [0, 1, 2]
132+
self.decrease_axis = [0, 1]
133+
self.infer_flags = [1, 1, 1]
134+
self.out = self.input[-1, 0, 2:4, :]
135+
136+
137+
class TestSliceOp_decs_dim_4(TestSliceOp_decs_dim):
138+
def config(self):
139+
self.input = np.random.random([3, 4, 5, 7]).astype("float64")
140+
self.starts = [0, 1, 2, 3]
141+
self.ends = [1, 2, 3, 4]
142+
self.axes = [0, 1, 2, 3]
143+
self.decrease_axis = [0, 1, 2, 3]
144+
self.infer_flags = [1, 1, 1]
145+
self.out = self.input[0, 1, 2, 3:4]
146+
147+
148+
class TestSliceOp_decs_dim_5(TestSliceOp_decs_dim):
149+
def config(self):
150+
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
151+
self.starts = [-1]
152+
self.ends = [1000000]
153+
self.axes = [3]
154+
self.decrease_axis = [3]
155+
self.infer_flags = [1, 1, 1]
156+
self.out = self.input[:, :, :, -1]
157+
158+
159+
class TestSliceOp_decs_dim_6(TestSliceOp_decs_dim):
160+
def config(self):
161+
self.input = np.random.random([3, 4, 5, 6]).astype("float64")
162+
self.starts = [0, 1, 2, 3]
163+
self.ends = [1, 2, 3, 4]
164+
self.axes = [0, 1, 2, 3]
165+
self.decrease_axis = [0, 1, 2, 3]
166+
self.infer_flags = [1, 1, 1]
167+
self.out = self.input[0, 1, 2, 3:4]
168+
169+
170+
if __name__ == '__main__':
171+
unittest.main()

0 commit comments

Comments
 (0)