Skip to content

Commit 831909c

Browse files
authored
Merge pull request #11313 from sneaxiy/argmin_argmax
Add argmin and argmax op
2 parents 69b5a62 + 9b43ede commit 831909c

File tree

7 files changed

+434
-0
lines changed

7 files changed

+434
-0
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/arg_min_max_op_base.h"
16+
17+
REGISTER_OPERATOR(arg_max, paddle::operators::ArgMinMaxOp,
18+
paddle::operators::ArgMaxOpMaker,
19+
paddle::framework::EmptyGradOpMaker);
20+
21+
REGISTER_OP_CPU_KERNEL(
22+
arg_max,
23+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, float>,
24+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, double>,
25+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
26+
int64_t>,
27+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
28+
int32_t>,
29+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
30+
int16_t>,
31+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext, size_t>,
32+
paddle::operators::ArgMaxKernel<paddle::platform::CPUDeviceContext,
33+
uint8_t>);
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/arg_min_max_op_base.h"
16+
17+
REGISTER_OP_CUDA_KERNEL(
18+
arg_max,
19+
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext, float>,
20+
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
21+
double>,
22+
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
23+
int64_t>,
24+
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
25+
int32_t>,
26+
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
27+
int16_t>,
28+
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
29+
size_t>,
30+
paddle::operators::ArgMaxKernel<paddle::platform::CUDADeviceContext,
31+
uint8_t>);
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <string>
17+
#include <type_traits>
18+
#include <vector>
19+
#include "paddle/fluid/framework/ddim.h"
20+
#include "paddle/fluid/framework/eigen.h"
21+
#include "paddle/fluid/framework/lod_tensor.h"
22+
#include "paddle/fluid/framework/op_registry.h"
23+
#include "paddle/fluid/framework/operator.h"
24+
#include "paddle/fluid/platform/enforce.h"
25+
#include "paddle/fluid/string/printf.h"
26+
27+
namespace paddle {
28+
namespace operators {
29+
30+
enum ArgMinMaxType { kArgMin, kArgMax };
31+
32+
template <typename DeviceContext, typename T, typename Tout, int64_t Rank,
33+
ArgMinMaxType argMinMaxValue>
34+
struct ArgMinMaxFunctor {};
35+
36+
#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
37+
template <typename DeviceContext, typename T, typename Tout, int64_t Rank> \
38+
struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \
39+
enum_argminmax_value> { \
40+
void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \
41+
framework::LoDTensor* out, int64_t axis) { \
42+
auto in_eigen = framework::EigenTensor<T, Rank>::From(in); \
43+
auto out_eigen = framework::EigenTensor<Tout, Rank - 1>::From(*out); \
44+
out_eigen.device(*(ctx.eigen_device())) = \
45+
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
46+
} \
47+
}
48+
49+
DECLARE_ARG_MIN_MAX_FUNCTOR(argmin, ArgMinMaxType::kArgMin);
50+
DECLARE_ARG_MIN_MAX_FUNCTOR(argmax, ArgMinMaxType::kArgMax);
51+
52+
template <typename DeviceContext, typename T, typename Tout,
53+
ArgMinMaxType EnumArgMinMaxValue>
54+
class ArgMinMaxKernel : public framework::OpKernel<T> {
55+
public:
56+
void Compute(const framework::ExecutionContext& ctx) const override {
57+
auto& x = *(ctx.Input<framework::LoDTensor>("X"));
58+
auto& out = *(ctx.Output<framework::LoDTensor>("Out"));
59+
out.mutable_data<Tout>(ctx.GetPlace());
60+
auto axis = ctx.Attr<int64_t>("axis");
61+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
62+
63+
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
64+
ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \
65+
functor##rank; \
66+
functor##rank(dev_ctx, x, &out, axis)
67+
68+
switch (x.dims().size()) {
69+
case 1:
70+
CALL_ARG_MINMAX_FUNCTOR(1);
71+
break;
72+
case 2:
73+
CALL_ARG_MINMAX_FUNCTOR(2);
74+
break;
75+
case 3:
76+
CALL_ARG_MINMAX_FUNCTOR(3);
77+
break;
78+
case 4:
79+
CALL_ARG_MINMAX_FUNCTOR(4);
80+
break;
81+
case 5:
82+
CALL_ARG_MINMAX_FUNCTOR(5);
83+
break;
84+
case 6:
85+
CALL_ARG_MINMAX_FUNCTOR(6);
86+
break;
87+
default:
88+
PADDLE_THROW(
89+
"%s operator doesn't supports tensors whose ranks are greater "
90+
"than 6.",
91+
(EnumArgMinMaxValue == kArgMin ? "argmin" : "argmax"));
92+
break;
93+
#undef CALL_ARG_MINMAX_FUNCTOR
94+
}
95+
}
96+
};
97+
98+
template <typename DeviceContext, typename T>
99+
using ArgMinKernel =
100+
ArgMinMaxKernel<DeviceContext, T, int64_t, ArgMinMaxType::kArgMin>;
101+
102+
template <typename DeviceContext, typename T>
103+
using ArgMaxKernel =
104+
ArgMinMaxKernel<DeviceContext, T, int64_t, ArgMinMaxType::kArgMax>;
105+
106+
class ArgMinMaxOp : public framework::OperatorWithKernel {
107+
public:
108+
using framework::OperatorWithKernel::OperatorWithKernel;
109+
110+
void InferShape(framework::InferShapeContext* ctx) const override {
111+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
112+
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null");
113+
const auto& x_dims = ctx->GetInputDim("X");
114+
int64_t axis = ctx->Attrs().Get<int64_t>("axis");
115+
PADDLE_ENFORCE(axis >= -x_dims.size() && axis < x_dims.size(),
116+
"'axis' must be inside [-Rank(X), Rank(X))");
117+
118+
auto x_rank = x_dims.size();
119+
if (axis < 0) axis += x_rank;
120+
121+
std::vector<int64_t> vec;
122+
for (int64_t i = 0; i < axis; i++) vec.push_back(x_dims[i]);
123+
for (int64_t i = axis + 1; i < x_rank; i++) vec.push_back(x_dims[i]);
124+
ctx->SetOutputDim("Out", framework::make_ddim(vec));
125+
}
126+
};
127+
128+
class BaseArgMinMaxOpMaker : public framework::OpProtoAndCheckerMaker {
129+
protected:
130+
virtual const char* OpName() const = 0;
131+
virtual const char* Name() const = 0;
132+
133+
public:
134+
void Make() override {
135+
AddInput("X", "Input tensor.");
136+
AddOutput("Out", "Output tensor.");
137+
AddAttr<int64_t>("axis", "The axis in which to compute the arg indics.");
138+
AddComment(string::Sprintf(R"DOC(
139+
%s Operator.
140+
141+
Computes the indices of the %s elements of the input tensor's element
142+
along the provided axis.
143+
)DOC",
144+
OpName(), Name()));
145+
}
146+
};
147+
148+
class ArgMinOpMaker : public BaseArgMinMaxOpMaker {
149+
protected:
150+
const char* OpName() const override { return "ArgMin"; }
151+
const char* Name() const override { return "min"; }
152+
};
153+
154+
class ArgMaxOpMaker : public BaseArgMinMaxOpMaker {
155+
protected:
156+
const char* OpName() const override { return "ArgMax"; }
157+
const char* Name() const override { return "max"; }
158+
};
159+
} // namespace operators
160+
} // namespace paddle
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/arg_min_max_op_base.h"
16+
17+
REGISTER_OPERATOR(arg_min, paddle::operators::ArgMinMaxOp,
18+
paddle::operators::ArgMinOpMaker,
19+
paddle::framework::EmptyGradOpMaker);
20+
21+
REGISTER_OP_CPU_KERNEL(
22+
arg_min,
23+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, float>,
24+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, double>,
25+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
26+
int64_t>,
27+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
28+
int32_t>,
29+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
30+
int16_t>,
31+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext, size_t>,
32+
paddle::operators::ArgMinKernel<paddle::platform::CPUDeviceContext,
33+
uint8_t>);
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/arg_min_max_op_base.h"
16+
17+
REGISTER_OP_CUDA_KERNEL(
18+
arg_min,
19+
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext, float>,
20+
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
21+
double>,
22+
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
23+
int64_t>,
24+
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
25+
int32_t>,
26+
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
27+
int16_t>,
28+
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
29+
size_t>,
30+
paddle::operators::ArgMinKernel<paddle::platform::CUDADeviceContext,
31+
uint8_t>);

python/paddle/fluid/layers/tensor.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
'assign',
3232
'fill_constant_batch_size_like',
3333
'fill_constant',
34+
'argmin',
35+
'argmax',
3436
'ones',
3537
'zeros',
3638
]
@@ -315,6 +317,68 @@ def fill_constant_batch_size_like(input,
315317
return out
316318

317319

320+
def argmin(x, axis=0):
321+
"""
322+
**argmin**
323+
324+
This function computes the indices of the min elements
325+
of the input tensor's element along the provided axis.
326+
327+
Args:
328+
x(Variable): The input to compute the indices of
329+
the min elements.
330+
axis(int): Axis to compute indices along.
331+
332+
Returns:
333+
Variable: The tensor variable storing the output
334+
335+
Examples:
336+
.. code-block:: python
337+
338+
out = fluid.layers.argmin(x=in, axis=0)
339+
out = fluid.layers.argmin(x=in, axis=-1)
340+
"""
341+
helper = LayerHelper("arg_min", **locals())
342+
out = helper.create_tmp_variable(VarDesc.VarType.INT64)
343+
helper.append_op(
344+
type='arg_min',
345+
inputs={'X': x},
346+
outputs={'Out': [out]},
347+
attrs={'axis': axis})
348+
return out
349+
350+
351+
def argmax(x, axis=0):
352+
"""
353+
**argmax**
354+
355+
This function computes the indices of the max elements
356+
of the input tensor's element along the provided axis.
357+
358+
Args:
359+
x(Variable): The input to compute the indices of
360+
the max elements.
361+
axis(int): Axis to compute indices along.
362+
363+
Returns:
364+
Variable: The tensor variable storing the output
365+
366+
Examples:
367+
.. code-block:: python
368+
369+
out = fluid.layers.argmax(x=in, axis=0)
370+
out = fluid.layers.argmax(x=in, axis=-1)
371+
"""
372+
helper = LayerHelper("arg_max", **locals())
373+
out = helper.create_tmp_variable(VarDesc.VarType.INT64)
374+
helper.append_op(
375+
type='arg_max',
376+
inputs={'X': x},
377+
outputs={'Out': [out]},
378+
attrs={'axis': axis})
379+
return out
380+
381+
318382
def ones(shape, dtype, force_cpu=False):
319383
"""
320384
**ones**

0 commit comments

Comments
 (0)