Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,11 @@ def FindParsingFunctionFromAttributeType(atype):
' auto& {} = {}("{}", "{}", args, {}, {});\n'
)
PARSE_PYTHON_C_TENSORS_FROM_ARGS_OR_KWARGS_TEMPLATE = ' auto {} = GetTensorFromArgsOrKWArgs("{}", "{}", args, {}, kwargs,{},nargs,&remaining_kwargs,{});\n'

PARSE_PYTHON_C_OPTIONAL_TENSORS_FROM_ARGS_OR_KWARGS_TEMPLATE = ' auto {} = GetOptionalTensorFromArgsOrKWArgs("{}", "{}", args, {}, kwargs,{},nargs,&remaining_kwargs,{});\n'
CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_TEMPLATE = (
' {} = {}("{}", "{}", args, {}, {}, mesh);\n'
)
CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_FROM_ARGS_OR_KWARGS_TEMPLATE = ' {} = {}("{}", "{}", args, {}, kwargs,{},nargs,&remaining_kwargs,{},mesh);\n'

CONVERT_INPUT_TENSORS_TO_DIST_TENSOR_WITH_SINGLE_TENSOR_TEMPLATE = """
const phi::distributed::ProcessMesh* mesh = nullptr;
Expand Down Expand Up @@ -458,16 +459,27 @@ def _get_keywords(name, alias_map):
)
else:
if is_optional:
get_eager_tensor_str += (
PARSE_PYTHON_C_TENSORS_TEMPLATE.format(
if need_parse_python_api_args:
keywords = _get_keywords(name, args_alias_map)
get_eager_tensor_str += PARSE_PYTHON_C_OPTIONAL_TENSORS_FROM_ARGS_OR_KWARGS_TEMPLATE.format(
name,
"GetOptionalTensorFromArgs",
forward_api_name,
name,
pos,
keywords,
"true",
)
)
else:
get_eager_tensor_str += (
PARSE_PYTHON_C_TENSORS_TEMPLATE.format(
name,
"GetOptionalTensorFromArgs",
forward_api_name,
name,
pos,
"true",
)
)
else:
input_single_tensor_names = (
input_single_tensor_names + ", " + name
Expand Down Expand Up @@ -621,14 +633,26 @@ def pre_process_add_ampersand(s):
)
else:
if is_optional:
optional_and_vector_convert_code += CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_TEMPLATE.format(
name,
"GetOptionalTensorFromArgs",
forward_api_name,
name,
pos,
"true",
)
if need_parse_python_api_args:
keywords = _get_keywords(name, args_alias_map)
optional_and_vector_convert_code += CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_FROM_ARGS_OR_KWARGS_TEMPLATE.format(
name,
"GetOptionalTensorFromArgsOrKWArgs",
forward_api_name,
name,
pos,
keywords,
"true",
)
else:
optional_and_vector_convert_code += CONVERT_TO_DISTTENSOR_AND_PARSE_PYTHON_C_TENSORS_TEMPLATE.format(
name,
"GetOptionalTensorFromArgs",
forward_api_name,
name,
pos,
"true",
)
if len(input_single_tensor_names) > 0:
convert_to_dist_str += CONVERT_INPUT_TENSORS_TO_DIST_TENSOR_WITH_SINGLE_TENSOR_TEMPLATE.format(
input_names=input_names,
Expand Down
45 changes: 45 additions & 0 deletions paddle/fluid/pybind/arg_pre_process.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,60 @@
// paddle/fluid/pybind/eager_op_function.cc. Mainly used to customize the
// processing of parameters originally done in the Python API
#include "paddle/fluid/pybind/arg_pre_process.h"
#include "paddle/common/ddim.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/pir/utils/general_functions.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/enforce.h"

namespace paddle {
namespace pybind {
constexpr char kStopGradientAttrName[] = "stop_gradient"; // NOLINT
void ExpandAsPreProcess(paddle::Tensor* x,
paddle::optional<paddle::Tensor>* y,
std::vector<int64_t>* target_shape) {
if (target_shape->empty() && y->get_ptr() == nullptr) {
PADDLE_THROW(common::errors::InvalidArgument(
"The y of expand_as api must be specified."));
}
if (y->get_ptr() == nullptr) return;
*target_shape = common::vectorize<int64_t>(y->get_ptr()->dims());
}
void ExpandAsPreProcess(pir::Value* x,
paddle::optional<pir::Value>* y,
std::vector<int64_t>* target_shape) {
if (target_shape->empty() && y->get_ptr() == nullptr) {
PADDLE_THROW(common::errors::InvalidArgument(
"The y of expand_as api must be specified."));
}
if (y->get_ptr() == nullptr) return;
*target_shape = pir::GetShapeFromValue(*(y->get_ptr()));

/**
* if convert_dtype(x.dtype) == 'bool' and not x.stop_gradient:
* raise ValueError(
* "When the data type of input 'x' for expand_as is bool, "
* "you must set its stop_gradient to be False by "
* "some_var.stop_gradient = True, supporting "
* "some_var as the input 'x'."
* )
*
*/
auto dtype = pir::GetValueDtype(*x);
auto stop_gradient_attr =
x->attribute<pir::BoolAttribute>(kStopGradientAttrName);
auto stop_gradient = !stop_gradient_attr || stop_gradient_attr.data();
if (dtype == phi::DataType::BOOL && !stop_gradient) {
PADDLE_THROW(common::errors::InvalidArgument(
"When the data type of input 'x' for expand_as is bool, "
"you must set its stop_gradient to be False by "
"some_var.stop_gradient = True, supporting "
"some_var as the input 'x'."));
}
}
void RollPreProcess(Tensor* x, IntArray* shifts, IntVector* axis) {
int64_t len_origin_shape = x->dims().size();
if (axis != NULL) {
Expand Down
8 changes: 7 additions & 1 deletion paddle/fluid/pybind/arg_pre_process.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/pir/include/core/value.h"

#include "paddle/utils/optional.h"
namespace paddle {

namespace pybind {
Expand All @@ -30,6 +30,12 @@ using Value = pir::Value;
using IntArray = paddle::experimental::IntArray;
using IntVector = std::vector<int64_t>;

void ExpandAsPreProcess(paddle::Tensor* x,
paddle::optional<paddle::Tensor>* y,
std::vector<int64_t>* target_shape);
void ExpandAsPreProcess(Value* x,
paddle::optional<pir::Value>* y,
std::vector<int64_t>* target_shape);
void RollPreProcess(Tensor* x, IntArray* shifts, IntVector* axis);
void RollPreProcess(Value* x, Value* shifts, IntVector* axis);

Expand Down
42 changes: 42 additions & 0 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1432,6 +1432,48 @@ paddle::optional<paddle::Tensor> GetOptionalTensorFromArgs(
}
}

paddle::optional<paddle::Tensor> GetOptionalTensorFromArgsOrKWArgs(
const std::string& op_type,
const std::string& arg_name,
PyObject* args,
ssize_t arg_idx,
PyObject* kwargs,
const std::vector<std::string>& keywords,
const int nargs,
int* remaining_kwargs,
bool dispensable,
const phi::distributed::ProcessMesh* mesh) {
PyObject* obj = GetItemFromArgsOrKWArgs(
args, arg_idx, kwargs, keywords, nargs, remaining_kwargs);

if (obj == nullptr || obj == Py_None) {
if (!dispensable) {
PADDLE_THROW(common::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got None",
op_type,
arg_name,
arg_idx));
}
return paddle::none;
}

if (PyObject_TypeCheck(obj, p_tensor_type)) {
if (mesh) {
ConvertToDistTensor(&(reinterpret_cast<TensorObject*>(obj)->tensor),
mesh);
}
return paddle::make_optional<paddle::Tensor>(
reinterpret_cast<TensorObject*>(obj)->tensor);
} else {
PADDLE_THROW(common::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got %s",
op_type,
arg_name,
arg_idx,
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
}
}

PyObject* ToPyObject(std::shared_ptr<egr::GradNodeBase> grad_node) {
py::object py_obj = py::cast(grad_node, py::return_value_policy::reference);
PyObject* py_grad_node = py_obj.release().ptr();
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/pybind/eager_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,18 @@ paddle::optional<paddle::Tensor> GetOptionalTensorFromArgs(
bool dispensable = false,
const phi::distributed::ProcessMesh* mesh = nullptr);

paddle::optional<paddle::Tensor> GetOptionalTensorFromArgsOrKWArgs(
const std::string& op_type,
const std::string& arg_name,
PyObject* args,
ssize_t arg_idx,
PyObject* kwargs,
const std::vector<std::string>& keywords,
const int nargs,
int* remaining_kwargs,
bool dispensable = false,
const phi::distributed::ProcessMesh* mesh = nullptr);

paddle::Tensor& GetTensorFromArgs(const std::string& op_type,
const std::string& arg_name,
PyObject* args,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/ops/yaml/python_api_info.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
args_alias :
use_default_mapping : True

- op : expand_as
name : [paddle.expand_as,paddle.Tensor.expand_as]
args_alias :
use_default_mapping : True
pre_process :
func : ExpandAsPreProcess(x,y,target_shape)
- op : logical_and
name : [paddle.logical_and, paddle.Tensor.logical_and]
args_alias:
Expand Down
40 changes: 40 additions & 0 deletions python/paddle/_paddle_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,46 @@ def any(
) -> Tensor
""",
)
add_doc_and_signature(
"expand_as",
"""

Expand the input tensor ``x`` to the same shape as the input tensor ``y``.

Both the number of dimensions of ``x`` and ``y`` must be less than or equal to 6, and the number of dimensions of ``y`` must be greater than or equal to that of ``x``. The dimension to expand must have a value of 0.

The following diagram illustrates how a one-dimensional tensor is transformed into a tensor with a shape of [2,3] through the expand_as operation. The target tensor has a shape of [2,3], and through expand_as, the one-dimensional tensor is expanded into a tensor with a shape of [2,3].

.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/expand_as.png
:width: 800
:alt: expand_as API
:align: center

Args:
x (Tensor): The input tensor, its data type is bool, float32, float64, int32 or int64.
y (Tensor): The input tensor that gives the shape to expand to.
name (str|None, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.

Returns:
N-D Tensor, A Tensor with the same shape as ``y``. The data type is the same as ``x``.

Examples:
.. code-block:: python

>>> import paddle

>>> data_x = paddle.to_tensor([1, 2, 3], 'int32')
>>> data_y = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], 'int32')
>>> out = paddle.expand_as(data_x, data_y)
>>> print(out)
Tensor(shape=[2, 3], dtype=int32, place=Place(cpu), stop_gradient=True,
[[1, 2, 3],
[1, 2, 3]])
""",
"""
def expand_as(x: Tensor, y: Tensor, name: str | None = None) -> Tensor
""",
)

# shenwei

Expand Down
86 changes: 1 addition & 85 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
ShapeLike,
TensorOrTensors,
)

from paddle._C_ops import expand_as # noqa: F401
from paddle.utils.decorator_utils import ForbidKeywordsDecorator

__all__ = []
Expand Down Expand Up @@ -4832,90 +4832,6 @@ def repeat(
return tile(input, repeat_times=repeats)


def expand_as(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
"""

Expand the input tensor ``x`` to the same shape as the input tensor ``y``.

Both the number of dimensions of ``x`` and ``y`` must be less than or equal to 6, and the number of dimensions of ``y`` must be greater than or equal to that of ``x``. The dimension to expand must have a value of 0.

The following diagram illustrates how a one-dimensional tensor is transformed into a tensor with a shape of [2,3] through the expand_as operation. The target tensor has a shape of [2,3], and through expand_as, the one-dimensional tensor is expanded into a tensor with a shape of [2,3].

.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/expand_as.png
:width: 800
:alt: expand_as API
:align: center

Args:
x (Tensor): The input tensor, its data type is bool, float32, float64, int32 or int64.
y (Tensor): The input tensor that gives the shape to expand to.
name (str|None, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.

Returns:
N-D Tensor, A Tensor with the same shape as ``y``. The data type is the same as ``x``.

Examples:
.. code-block:: python

>>> import paddle

>>> data_x = paddle.to_tensor([1, 2, 3], 'int32')
>>> data_y = paddle.to_tensor([[1, 2, 3], [4, 5, 6]], 'int32')
>>> out = paddle.expand_as(data_x, data_y)
>>> print(out)
Tensor(shape=[2, 3], dtype=int32, place=Place(cpu), stop_gradient=True,
[[1, 2, 3],
[1, 2, 3]])
"""
if in_dynamic_mode():
return _C_ops.expand_as(x, None, y.shape)
elif in_pir_mode():
if convert_dtype(x.dtype) == 'bool' and not x.stop_gradient:
raise ValueError(
"When the data type of input 'x' for expand_as is bool, "
"you must set its stop_gradient to be False by "
"some_var.stop_gradient = True, supporting "
"some_var as the input 'x'."
)
return _C_ops.expand_as(x, y, y.shape)
else:
check_variable_and_dtype(
x,
'x',
[
'bool',
'float32',
'float64',
'int32',
'int64',
'float16',
'uint16',
],
'expand_as',
)
check_type(y, 'y', Variable, 'expand_as')

if convert_dtype(x.dtype) == 'bool' and not x.stop_gradient:
raise ValueError(
"When the data type of input 'x' for expand_as is bool, "
"you must set its stop_gradient to be False by "
"some_var.stop_gradient = True, supporting "
"some_var as the input 'x'."
)
inputs = {"X": [x], "Y": [y]}

helper = LayerHelper('expand_as', **locals())
dtype = helper.input_dtype(input_param_name='x')
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='expand_as_v2',
inputs=inputs,
attrs={'target_shape': y.shape},
outputs={'Out': out},
)
return out


@ParamAliasDecorator({"x": ["input"], "shape": ["size"]})
def broadcast_to(
x: Tensor,
Expand Down
Loading
Loading