Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
d7eb6a7
fix compatible in index_select, logical_*
cangtianhuang Aug 19, 2025
c85eaa1
fix format
cangtianhuang Aug 19, 2025
b4e13ca
refine
cangtianhuang Aug 20, 2025
be756dd
fix tests
cangtianhuang Aug 20, 2025
bd62fc6
Merge remote-tracking branch 'upstream/develop' into comp_index
cangtianhuang Aug 20, 2025
af42b4b
add out for index_select
cangtianhuang Aug 20, 2025
61acc8f
add out in signature for index_select
cangtianhuang Aug 20, 2025
9919331
fix
cangtianhuang Aug 21, 2025
39fbd83
fix tests
cangtianhuang Aug 21, 2025
0576301
fix docs
cangtianhuang Aug 21, 2025
fde0a19
Merge remote-tracking branch 'upstream/develop' into comp_index
cangtianhuang Aug 21, 2025
a66afcc
Merge remote-tracking branch 'upstream/develop' into comp_index
cangtianhuang Aug 22, 2025
0701cd3
add numpy.dtype and str_dtype to Paddle DataType
zhengshengning Aug 22, 2025
118757b
paddle.roll、paddle.flatten and paddle.Tensor.flatten sink into C++
zhengshengning Aug 22, 2025
b5b1883
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhengshengning Aug 22, 2025
e680d08
fix tests
cangtianhuang Aug 23, 2025
93fef56
PyObject can be a mixed type in static image mode
zhengshengning Aug 24, 2025
933a8a3
merge develop
zhengshengning Aug 24, 2025
c958dcd
Delete invalid code
zhengshengning Aug 24, 2025
3625362
remove sum test
zhengshengning Aug 24, 2025
e93d1ce
merge develop
zhengshengning Aug 25, 2025
7750c5f
revert index_select
cangtianhuang Aug 25, 2025
8c73a6e
fix pad [int value int]
zhengshengning Aug 25, 2025
213f5d9
Merge remote-tracking branch 'upstream/develop' into comp_index
cangtianhuang Aug 25, 2025
f5be1f9
move ops.yaml
cangtianhuang Aug 25, 2025
c052535
revert docs
cangtianhuang Aug 25, 2025
f000806
fix
zhengshengning Aug 25, 2025
238651a
fix
zhengshengning Aug 25, 2025
0ad0f4f
Merge remote-tracking branch 'upstream/develop' into comp_index
cangtianhuang Aug 25, 2025
0779bb5
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhengshengning Aug 25, 2025
2ef01bf
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhengshengning Aug 25, 2025
4ef6baf
fix
zhengshengning Aug 25, 2025
a20de6c
fix
zhengshengning Aug 25, 2025
c1996b5
remove sum
zhengshengning Aug 26, 2025
30630e3
fix
zhengshengning Aug 26, 2025
d3792c9
merge develop
zhengshengning Aug 26, 2025
1a11f22
add blank line
zhengshengning Aug 26, 2025
9264cd3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhengshengning Aug 26, 2025
03df5c2
Merge remote-tracking branch 'upstream/develop' into comp_index
cangtianhuang Aug 26, 2025
59c7cf5
fix docs
cangtianhuang Aug 26, 2025
441dcb2
Merge branch 'axis_mixed_type' into merge_3_branch
zhengshengning Aug 26, 2025
1276646
merge zhengshengning:c_alias_roll_new_2
zhengshengning Aug 26, 2025
535a79e
merge cangtianhuang:comp_index
zhengshengning Aug 26, 2025
b1bdb45
merge develop
zhengshengning Aug 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/python_c_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,8 @@
{mutable_cast_attrs}
}}else if (PyObject_CheckIRVectorOfValue({name}_obj)){{
{mutable_vector_cast_attrs}
}}else if (PyObject_CheckIRVectorOfValueOrLong({name}_obj)){{
{mix_vector_cast_attrs}
}}else{{
{no_mutable_cast_attrs}
}}"""
Expand Down Expand Up @@ -525,6 +527,18 @@ def _gen_cast_attrs(self, op_info, op_name):
name=name
)

mix_vector_cast_str = MUTABLE_ATTR_CAST_TEMPLATE.format(
type='std::vector<pir::Value>',
name_=name + '_tmp',
name=name,
cast_func='CastPyArg2VectorOfValueOrLong',
api_name=op_name,
index=input_size + i,
)
mix_vector_cast_str += BUILTIN_STACK_OP_TEMPLATE.format(
name=name
)

else:
mutable_cast_str = MUTABLE_ATTR_CAST_TEMPLATE.format(
type='',
Expand Down Expand Up @@ -570,6 +584,7 @@ def _gen_cast_attrs(self, op_info, op_name):
name=name,
mutable_cast_attrs=mutable_cast_str,
mutable_vector_cast_attrs=mutable_vector_cast_str,
mix_vector_cast_attrs=mix_vector_cast_str,
no_mutable_cast_attrs=no_mutable_cast_str,
)
else:
Expand Down
47 changes: 43 additions & 4 deletions paddle/fluid/pybind/arg_pre_process.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,53 @@
// processing of parameters originally done in the Python API
#include "paddle/fluid/pybind/arg_pre_process.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/pir/utils/general_functions.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace pybind {
void LogsumexpPreProcess(Tensor *x, std::vector<int> *axis, bool *reduce_all) {
void RollPreProcess(Tensor* x, IntArray* shifts, IntVector* axis) {
int64_t len_origin_shape = x->dims().size();
if (axis != NULL) {
int64_t axis_len = axis->size();
for (int64_t i = 0; i < axis_len; i++) {
PADDLE_ENFORCE_EQ(
((*axis)[i] < len_origin_shape && (*axis)[i] >= -len_origin_shape),
true,
common::errors::InvalidArgument("axis is out of range, it should be "
"in range [%d, %d), but received %ld",
-len_origin_shape,
len_origin_shape,
(*axis)[i]));
}
} else {
axis = new IntVector();
}
}
void RollPreProcess(Value* x, Value* shifts, IntVector* axis) {
std::vector<int64_t> x_shape = pir::GetShapeFromValue(*x);
int64_t len_origin_shape = x_shape.size();
if (axis != NULL) {
int64_t axis_len = axis->size();
for (int64_t i = 0; i < axis_len; i++) {
PADDLE_ENFORCE_EQ(
((*axis)[i] < len_origin_shape && (*axis)[i] >= -len_origin_shape),
true,
common::errors::InvalidArgument("axis is out of range, it should be "
"in range [%d, %d), but received %ld",
-len_origin_shape,
len_origin_shape,
(*axis)[i]));
}
} else {
axis = new IntVector();
}
}

void LogsumexpPreProcess(Tensor* x, std::vector<int>* axis, bool* reduce_all) {
/**
if axis == [] or len(axis) == len(x.shape):
reduce_all = True
Expand All @@ -41,9 +80,9 @@ void LogsumexpPreProcess(Tensor *x, std::vector<int> *axis, bool *reduce_all) {
return;
}

void LogsumexpPreProcess(pir::Value *x,
std::vector<int> *axis,
bool *reduce_all) {
void LogsumexpPreProcess(pir::Value* x,
std::vector<int>* axis,
bool* reduce_all) {
std::vector<int64_t> x_shape = pir::GetShapeFromValue(*x);
if (axis->empty() || axis->size() == x_shape.size()) {
*reduce_all = true;
Expand Down
11 changes: 9 additions & 2 deletions paddle/fluid/pybind/arg_pre_process.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <Python.h>
#include <vector>
#include "paddle/fluid/ir_adaptor/translator/program_translator.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/scalar.h"
Expand All @@ -24,10 +25,16 @@
namespace paddle {

namespace pybind {
using Tensor = paddle::Tensor;
using Value = pir::Value;
using IntArray = paddle::experimental::IntArray;
using IntVector = std::vector<int64_t>;

void LogsumexpPreProcess(Tensor *x, std::vector<int> *axis, bool *reduce_all);
void LogsumexpPreProcess(Value *x, std::vector<int> *axis, bool *reduce_all);
void RollPreProcess(Tensor* x, IntArray* shifts, IntVector* axis);
void RollPreProcess(Value* x, Value* shifts, IntVector* axis);

void LogsumexpPreProcess(Tensor* x, std::vector<int>* axis, bool* reduce_all);
void LogsumexpPreProcess(Value* x, std::vector<int>* axis, bool* reduce_all);
} // namespace pybind

} // namespace paddle
115 changes: 115 additions & 0 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ limitations under the License. */
#include "paddle/fluid/jit/function.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/pir/utils/general_functions.h"
#include "paddle/fluid/pir/utils/name_analysis.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/pybind/eager.h"
Expand Down Expand Up @@ -232,6 +234,39 @@ bool PyObject_CheckIRVectorOfValue(PyObject* obj) {
}
}

bool PyObject_CheckIRVectorOfValueOrLong(PyObject* obj) {
if (!PyList_Check(obj) && !PyTuple_Check(obj)) {
return false;
}

Py_ssize_t len = PySequence_Size(obj);
if (len == 0) {
return false;
}

bool is_ir_value = false, is_long = false;

for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* item = PySequence_GetItem(obj, i); // Returns new reference
if (!item) {
return false;
}

if (PyObject_CheckIRValue(item)) {
is_ir_value = true;
} else if (PyObject_CheckLong(item)) {
is_long = true;
} else {
Py_DECREF(item);
return false;
}

Py_DECREF(item); // Because PySequence_GetItem returns new reference
}

return is_ir_value && is_long;
}

bool CastPyArg2AttrBoolean(PyObject* obj, ssize_t arg_pos) {
if (obj == Py_None || obj == Py_False) {
return false; // To be compatible with QA integration testing. Some
Expand Down Expand Up @@ -2276,6 +2311,86 @@ std::vector<pir::Value> CastPyArg2VectorOfValue(PyObject* obj,
return value_list;
}

std::vector<pir::Value> CastPyArg2VectorOfValueOrLong(
PyObject* obj,
const std::string& op_type,
size_t arg_pos,
bool dispensable) {
std::vector<pir::Value> value_list;

if (!PyList_Check(obj) && !PyTuple_Check(obj)) {
PADDLE_THROW(common::errors::InvalidType(
"%s(): argument (position %d) must be "
"Vector<>, but got %s",
op_type,
arg_pos + 1,
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
}

Py_ssize_t len = PySequence_Size(obj);
if (len == 0 && !dispensable) {
PADDLE_THROW(
common::errors::InvalidArgument("%s(): argument (position %d) must be "
"list of Value, but got empty list",
op_type,
arg_pos + 1));
}

phi::DataType dtype = phi::DataType::INT64;
std::vector<int64_t> shape;
for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* item = PySequence_GetItem(obj, i);
if (!item) {
continue;
}

item = CastPyArg2ValuePreHook(item);

if (PyObject_TypeCheck(item, g_ir_value_pytype)) {
pir::Value val = ::pybind11::handle(item).cast<pir::Value>();
dtype = paddle::dialect::GetValueDataType(val);
shape = pir::GetShapeFromValue(val);
Py_DECREF(item);
break;
}

Py_DECREF(item);
}

for (Py_ssize_t i = 0; i < len; ++i) {
PyObject* item = PySequence_GetItem(obj, i);
if (!item) {
PADDLE_THROW(common::errors::Fatal(
"%s(): failed to get item from sequence at position %d",
op_type,
static_cast<int>(i)));
}

item = CastPyArg2ValuePreHook(item);

if (PyObject_CheckIRValue(item)) {
value_list.emplace_back(::pybind11::handle(item).cast<pir::Value>());
} else if (PyObject_CheckLong(item)) {
int64_t k_tmp = CastPyArg2Long(item, op_type, arg_pos);
value_list.emplace_back(
paddle::dialect::full(shape, k_tmp, dtype, phi::CPUPlace()));
} else if (item == Py_None) {
continue; // skip
} else {
PADDLE_THROW(common::errors::InvalidType(
"%s(): argument (position %d) must be vector<Value>, "
"but got vector<%s>",
op_type,
arg_pos + 1,
reinterpret_cast<PyTypeObject*>(item->ob_type)->tp_name));
}

Py_DECREF(item);
}

return value_list;
}

paddle::optional<std::vector<pir::Value>> CastPyArg2OptionalVectorOfValue(
PyObject* obj,
const std::string& op_type,
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/pybind/eager_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ int TensorDtype2NumpyDtype(phi::DataType dtype);
bool PyObject_CheckStr(PyObject* obj);
bool PyObject_CheckIRValue(PyObject* obj);
bool PyObject_CheckIRVectorOfValue(PyObject* obj);
bool PyObject_CheckIRVectorOfValueOrLong(PyObject* obj);
bool CastPyArg2AttrBoolean(PyObject* obj, ssize_t arg_pos);
int CastPyArg2AttrInt(PyObject* obj, ssize_t arg_pos);
int64_t CastPyArg2AttrLong(PyObject* obj, ssize_t arg_pos);
Expand Down Expand Up @@ -100,6 +101,11 @@ std::vector<pir::Value> CastPyArg2VectorOfValue(PyObject* obj,
const std::string& op_type,
size_t arg_pos,
bool dispensable = false);
std::vector<pir::Value> CastPyArg2VectorOfValueOrLong(
PyObject* obj,
const std::string& op_type,
size_t arg_pos,
bool dispensable = false);
paddle::optional<std::vector<pir::Value>> CastPyArg2OptionalVectorOfValue(
PyObject* obj,
const std::string& op_type,
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4681,6 +4681,13 @@

- op : roll
args : (Tensor x, IntArray shifts={}, int64_t[] axis={})
python_api:
name : [paddle.roll, paddle.Tensor.roll]
args_alias:
axis : [dims]
use_default_mapping : True
pre_process:
func : RollPreProcess(x, shifts, axis)
output : Tensor(out)
infer_meta :
func : RollInferMeta
Expand Down
20 changes: 20 additions & 0 deletions paddle/phi/ops/yaml/python_api_info.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,23 @@
name : [paddle.amax,paddle.Tensor.amax]
args_alias :
use_default_mapping : True

- op : logical_and
name : [paddle.logical_and, paddle.Tensor.logical_and]
args_alias:
use_default_mapping : True

- op : logical_or
name : [paddle.logical_or, paddle.Tensor.logical_or]
args_alias:
use_default_mapping : True

- op : logical_xor
name : [paddle.logical_xor, paddle.Tensor.logical_xor]
args_alias:
use_default_mapping : True

- op : logical_not
name : [paddle.logical_not, paddle.Tensor.logical_not]
args_alias:
use_default_mapping : True
Loading