Skip to content

Commit 0943b57

Browse files
Merge remote-tracking branch 'origin/develop' into feat/group2
2 parents 2c7f32a + 53f2a48 commit 0943b57

File tree

91 files changed

+4023
-10246
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+4023
-10246
lines changed

paddle/fluid/pir/dialect/op_generator/python_c_gen.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@
225225
{mutable_cast_attrs}
226226
}}else if (PyObject_CheckIRVectorOfValue({name}_obj)){{
227227
{mutable_vector_cast_attrs}
228+
}}else if (PyObject_CheckIRVectorOfValueOrLong({name}_obj)){{
229+
{mix_vector_cast_attrs}
228230
}}else{{
229231
{no_mutable_cast_attrs}
230232
}}"""
@@ -525,6 +527,18 @@ def _gen_cast_attrs(self, op_info, op_name):
525527
name=name
526528
)
527529

530+
mix_vector_cast_str = MUTABLE_ATTR_CAST_TEMPLATE.format(
531+
type='std::vector<pir::Value>',
532+
name_=name + '_tmp',
533+
name=name,
534+
cast_func='CastPyArg2VectorOfValueOrLong',
535+
api_name=op_name,
536+
index=input_size + i,
537+
)
538+
mix_vector_cast_str += BUILTIN_STACK_OP_TEMPLATE.format(
539+
name=name
540+
)
541+
528542
else:
529543
mutable_cast_str = MUTABLE_ATTR_CAST_TEMPLATE.format(
530544
type='',
@@ -570,6 +584,7 @@ def _gen_cast_attrs(self, op_info, op_name):
570584
name=name,
571585
mutable_cast_attrs=mutable_cast_str,
572586
mutable_vector_cast_attrs=mutable_vector_cast_str,
587+
mix_vector_cast_attrs=mix_vector_cast_str,
573588
no_mutable_cast_attrs=no_mutable_cast_str,
574589
)
575590
else:

paddle/fluid/pybind/arg_pre_process.cc

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,78 @@
1919
// processing of parameters originally done in the Python API
2020
#include "paddle/fluid/pybind/arg_pre_process.h"
2121
#include "paddle/fluid/eager/utils.h"
22+
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
23+
#include "paddle/fluid/pir/utils/general_functions.h"
2224
#include "paddle/fluid/pybind/eager_utils.h"
2325
#include "paddle/fluid/pybind/op_function_common.h"
2426
#include "paddle/phi/common/data_type.h"
2527
#include "paddle/phi/core/enforce.h"
2628
namespace paddle {
27-
namespace pybind {} // namespace pybind
29+
namespace pybind {
30+
void RollPreProcess(Tensor* x, IntArray* shifts, IntVector* axis) {
31+
int64_t len_origin_shape = x->dims().size();
32+
if (axis != NULL) {
33+
int64_t axis_len = axis->size();
34+
for (int64_t i = 0; i < axis_len; i++) {
35+
PADDLE_ENFORCE_EQ(
36+
((*axis)[i] < len_origin_shape && (*axis)[i] >= -len_origin_shape),
37+
true,
38+
common::errors::InvalidArgument("axis is out of range, it should be "
39+
"in range [%d, %d), but received %ld",
40+
-len_origin_shape,
41+
len_origin_shape,
42+
(*axis)[i]));
43+
}
44+
} else {
45+
axis = new IntVector();
46+
}
47+
}
48+
void RollPreProcess(Value* x, Value* shifts, IntVector* axis) {
49+
std::vector<int64_t> x_shape = pir::GetShapeFromValue(*x);
50+
int64_t len_origin_shape = x_shape.size();
51+
if (axis != NULL) {
52+
int64_t axis_len = axis->size();
53+
for (int64_t i = 0; i < axis_len; i++) {
54+
PADDLE_ENFORCE_EQ(
55+
((*axis)[i] < len_origin_shape && (*axis)[i] >= -len_origin_shape),
56+
true,
57+
common::errors::InvalidArgument("axis is out of range, it should be "
58+
"in range [%d, %d), but received %ld",
59+
-len_origin_shape,
60+
len_origin_shape,
61+
(*axis)[i]));
62+
}
63+
} else {
64+
axis = new IntVector();
65+
}
66+
}
67+
68+
void LogsumexpPreProcess(Tensor* x, std::vector<int>* axis, bool* reduce_all) {
69+
/**
70+
if axis == [] or len(axis) == len(x.shape):
71+
reduce_all = True
72+
else:
73+
reduce_all = False
74+
*/
75+
if (axis->empty() || axis->size() == x->dims().size()) {
76+
*reduce_all = true;
77+
} else {
78+
*reduce_all = false;
79+
}
80+
return;
81+
}
82+
83+
void LogsumexpPreProcess(pir::Value* x,
84+
std::vector<int>* axis,
85+
bool* reduce_all) {
86+
std::vector<int64_t> x_shape = pir::GetShapeFromValue(*x);
87+
if (axis->empty() || axis->size() == x_shape.size()) {
88+
*reduce_all = true;
89+
} else {
90+
*reduce_all = false;
91+
}
92+
return;
93+
}
94+
} // namespace pybind
2895

2996
} // namespace paddle

paddle/fluid/pybind/arg_pre_process.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,26 @@
1515
#pragma once
1616

1717
#include <Python.h>
18+
#include <vector>
19+
#include "paddle/fluid/ir_adaptor/translator/program_translator.h"
20+
#include "paddle/phi/api/include/tensor.h"
21+
#include "paddle/phi/common/data_type.h"
22+
#include "paddle/phi/common/scalar.h"
23+
#include "paddle/pir/include/core/value.h"
1824

1925
namespace paddle {
2026

21-
namespace pybind {} // namespace pybind
27+
namespace pybind {
28+
using Tensor = paddle::Tensor;
29+
using Value = pir::Value;
30+
using IntArray = paddle::experimental::IntArray;
31+
using IntVector = std::vector<int64_t>;
32+
33+
void RollPreProcess(Tensor* x, IntArray* shifts, IntVector* axis);
34+
void RollPreProcess(Value* x, Value* shifts, IntVector* axis);
35+
36+
void LogsumexpPreProcess(Tensor* x, std::vector<int>* axis, bool* reduce_all);
37+
void LogsumexpPreProcess(Value* x, std::vector<int>* axis, bool* reduce_all);
38+
} // namespace pybind
2239

2340
} // namespace paddle

paddle/fluid/pybind/args_mapper.cc

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,141 @@
2020

2121
#include "paddle/fluid/pybind/args_mapper.h"
2222
#include "paddle/fluid/eager/utils.h"
23+
#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h"
2324
#include "paddle/fluid/pybind/eager_utils.h"
2425
#include "paddle/fluid/pybind/op_function_common.h"
2526
#include "paddle/phi/common/data_type.h"
2627
#include "paddle/phi/core/enforce.h"
2728
namespace paddle {
28-
namespace pybind {} // namespace pybind
29+
namespace pybind {
30+
void ArgMaxMinMapper(PyObject* args,
31+
PyObject* kwargs,
32+
Tensor* x,
33+
paddle::experimental::Scalar* axis,
34+
bool* keepdims,
35+
bool* flatten,
36+
phi::DataType* dtype) {
37+
// The python params are (x, axis,keepdim,dtype,name) which haven't flatten
38+
// The _C_ops params are (x, axis,keepdim,flatten,dtype) which have flatten
39+
// but haven't name We should parse the python params and convert them to the
40+
// _C_ops params
41+
int nargs = args ? static_cast<int>(PyTuple_Size(args)) : 0;
42+
int remaining_kwargs = kwargs ? static_cast<int>(PyDict_Size(kwargs)) : 0;
43+
// python params count only consider the python params(x, axis, keepdim,
44+
// dtype), not include the name
45+
const int max_args = 4;
46+
CheckParamsCount(nargs, remaining_kwargs, max_args);
2947

48+
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
49+
// Get EagerTensors from args
50+
*x = GetTensorFromArgsOrKWArgs("argmax",
51+
"x",
52+
args,
53+
0,
54+
kwargs,
55+
{"x", "input"},
56+
nargs,
57+
&remaining_kwargs,
58+
false);
59+
60+
// Parse Attributes if needed
61+
62+
PyObject* axis_obj = GetItemFromArgsOrKWArgs(
63+
args, 1, kwargs, {"axis", "dim"}, nargs, &remaining_kwargs);
64+
/**
65+
flatten = False
66+
if axis is None:
67+
flatten = True
68+
axis = 0
69+
*/
70+
*flatten = false;
71+
if (axis_obj == Py_None || axis_obj == nullptr) {
72+
*flatten = true;
73+
*axis = 0;
74+
} else {
75+
*axis = CastPyArg2Scalar(axis_obj, "argmax", 1);
76+
}
77+
PyObject* keepdims_obj = GetItemFromArgsOrKWArgs(
78+
args, 2, kwargs, {"keepdim", "keepdims"}, nargs, &remaining_kwargs);
79+
*keepdims = CastPyArg2Boolean(keepdims_obj, "argmax", 2, false);
80+
81+
PyObject* dtype_obj = GetItemFromArgsOrKWArgs(
82+
args, 3, kwargs, {"dtype"}, nargs, &remaining_kwargs);
83+
/**
84+
if dtype is None:
85+
raise ValueError(
86+
"the value of 'dtype' in argmax could not be None, but received None")
87+
*/
88+
PADDLE_ENFORCE_NE(
89+
dtype_obj,
90+
Py_None,
91+
phi::errors::InvalidArgument("the value of 'dtype' in argmax and argmin "
92+
"could not be None, but received None"));
93+
*dtype = CastPyArg2DataType(dtype_obj, "argmax", 3, phi::DataType::INT64);
94+
// Check Reminding Params validity if needed
95+
CheckRemainingParamsValidity(args, kwargs, remaining_kwargs, nargs);
96+
97+
return;
98+
}
99+
void ArgMaxMinMapper(PyObject* args,
100+
PyObject* kwargs,
101+
pir::Value* x,
102+
pir::Value* axis,
103+
bool* keepdims,
104+
bool* flatten,
105+
phi::DataType* dtype) {
106+
// Get Total Params count and check validity if needed
107+
int nargs = args ? static_cast<int>(PyTuple_Size(args)) : 0;
108+
int remaining_kwargs = kwargs ? static_cast<int>(PyDict_Size(kwargs)) : 0;
109+
const int max_args = 4;
110+
CheckParamsCount(nargs, remaining_kwargs, max_args);
111+
112+
// Get Value from args
113+
PyObject* x_obj = GetItemFromArgsOrKWArgs(
114+
args, 0, kwargs, {"x", "input"}, nargs, &remaining_kwargs);
115+
*x = CastPyArg2Value(x_obj, "argmax", 0, false);
116+
117+
// Parse Attributes
118+
PyObject* axis_obj = GetItemFromArgsOrKWArgs(
119+
args, 1, kwargs, {"axis", "dim"}, nargs, &remaining_kwargs);
120+
PyObject* keepdims_obj = GetItemFromArgsOrKWArgs(
121+
args, 2, kwargs, {"keepdim", "keepdims"}, nargs, &remaining_kwargs);
122+
PyObject* dtype_obj = GetItemFromArgsOrKWArgs(
123+
args, 3, kwargs, {"dtype"}, nargs, &remaining_kwargs);
124+
125+
/**
126+
flatten = False
127+
if axis is None:
128+
flatten = True
129+
axis = 0
130+
*/
131+
*flatten = false;
132+
if (axis_obj == Py_None || axis_obj == nullptr) {
133+
*flatten = true;
134+
*axis = paddle::dialect::full(
135+
std::vector<int64_t>{1}, 0, phi::DataType::INT64, phi::CPUPlace());
136+
} else if (PyObject_CheckIRValue(axis_obj)) {
137+
*axis = CastPyArg2Value(axis_obj, "argmax", 1);
138+
} else {
139+
int64_t axis_tmp = CastPyArg2Long(axis_obj, "argmax", 1);
140+
*axis = paddle::dialect::full(std::vector<int64_t>{1},
141+
axis_tmp,
142+
phi::DataType::INT64,
143+
phi::CPUPlace());
144+
}
145+
*keepdims = CastPyArg2Boolean(keepdims_obj, "argmax", 2, false);
146+
147+
PADDLE_ENFORCE_NE(
148+
dtype_obj,
149+
Py_None,
150+
phi::errors::InvalidArgument("the value of 'dtype' in argmax and argmin "
151+
"could not be None, but received None"));
152+
*dtype = CastPyArg2DataType(dtype_obj, "argmax", 3, phi::DataType::INT64);
153+
154+
// Check Reminding Params validity if needed
155+
CheckRemainingParamsValidity(args, kwargs, remaining_kwargs, nargs);
156+
return;
157+
}
158+
159+
} // namespace pybind
30160
} // namespace paddle

paddle/fluid/pybind/args_mapper.h

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,28 @@
1616

1717
#include <Python.h>
1818
#include <vector>
19+
#include "paddle/phi/api/include/tensor.h"
20+
#include "paddle/phi/common/data_type.h"
21+
#include "paddle/phi/common/scalar.h"
22+
#include "paddle/pir/include/core/value.h"
1923
namespace paddle {
2024

21-
namespace pybind {} // namespace pybind
25+
namespace pybind {
26+
void ArgMaxMinMapper(PyObject* args,
27+
PyObject* kwargs,
28+
Tensor* x,
29+
paddle::experimental::Scalar* axis,
30+
bool* keepdims,
31+
bool* flatten,
32+
phi::DataType* dtype);
33+
void ArgMaxMinMapper(PyObject* args,
34+
PyObject* kwargs,
35+
pir::Value* x,
36+
pir::Value* axis,
37+
bool* keepdims,
38+
bool* flatten,
39+
phi::DataType* dtype);
40+
41+
} // namespace pybind
2242

2343
} // namespace paddle

0 commit comments

Comments
 (0)