diff --git a/paddle/fluid/pir/dialect/operator/ir/update_ops.yaml b/paddle/fluid/pir/dialect/operator/ir/update_ops.yaml index 8fb2e49185c92d..43cc29d376c06b 100644 --- a/paddle/fluid/pir/dialect/operator/ir/update_ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/update_ops.yaml @@ -17,7 +17,7 @@ support_tensor : [start, end, step] - op : sequence_mask - args: (Tensor x, Scalar(int) max_len, int out_dtype) + args: (Tensor x, Scalar(int) max_len, DataType out_dtype) output: Tensor(y) infer_meta: func: SequenceMaskScalarInferMeta diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 7daada9e82a3c3..a8ab9e00f83be8 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -991,7 +991,7 @@ backward : rrelu_grad - op : sequence_mask - args: (Tensor x, Scalar(int) max_len, int out_dtype) + args: (Tensor x, Scalar(int) max_len, DataType out_dtype) output: Tensor(y) infer_meta: func: SequenceMaskScalarInferMeta diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 2c739b2a48f3fb..4a6680ed1ef248 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -2857,7 +2857,7 @@ void ShuffleBatchInferMeta(const MetaTensor& x, void SequenceMaskInferMeta(const MetaTensor& x, const MetaTensor& max_len_tensor, int maxlen, - int out_dtype, + DataType out_dtype, MetaTensor* y) { auto dim = common::vectorize(x.dims()); @@ -2868,8 +2868,7 @@ void SequenceMaskInferMeta(const MetaTensor& x, } y->set_dims(common::make_ddim(dim)); - auto out_phi_dtype = phi::TransToPhiDataType(out_dtype); - y->set_dtype(out_phi_dtype); + y->set_dtype(out_dtype); } void SoftmaxMaskFuseInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 11aea8f63659a7..5f3cb48f3d3a8b 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/meta_tensor.h" @@ -457,7 +458,7 @@ void SearchsortedInferMeta(const MetaTensor& sorted_sequence, void SequenceMaskInferMeta(const MetaTensor& x, const MetaTensor& max_len_tensor, int maxlen, - int out_dtype, + DataType out_dtype, MetaTensor* y); void ShuffleBatchInferMeta(const MetaTensor& x, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 0fc2f8b3e541ca..287d70615fafd1 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -4097,14 +4097,13 @@ void SplitWithNumInferMeta(const MetaTensor& x, void SequenceMaskScalarInferMeta(const MetaTensor& x, const Scalar& max_len, - int out_dtype, + DataType out_dtype, MetaTensor* y) { auto dim = phi::vectorize(x.dims()); int maxlen = max_len.to(); dim.push_back(maxlen > 0 ? maxlen : -1); y->set_dims(phi::make_ddim(dim)); - auto out_phi_dtype = phi::TransToPhiDataType(out_dtype); - y->set_dtype(out_phi_dtype); + y->set_dtype(out_dtype); } void SquaredL2NormInferMeta(const MetaTensor& x, MetaTensor* out) { diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 53026c1ee3d1f5..21859b7e78401e 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -581,7 +581,7 @@ void RReluGradInferMeta(const MetaTensor& out_grad, void SequenceMaskScalarInferMeta(const MetaTensor& x, const Scalar& max_len, - int out_dtype, + DataType out_dtype, MetaTensor* y); void SetValueInferMeta(const MetaTensor& x, MetaTensor* out); diff --git a/paddle/phi/kernels/impl/sequence_mask_kernel_impl.h b/paddle/phi/kernels/impl/sequence_mask_kernel_impl.h index 20fc0bda1f9184..2c000fef986698 100644 --- a/paddle/phi/kernels/impl/sequence_mask_kernel_impl.h +++ b/paddle/phi/kernels/impl/sequence_mask_kernel_impl.h @@ -31,7 +31,7 @@ template void SequenceMaskScalarKernel(const Context& ctx, const DenseTensor& x, const Scalar& max_len, - int out_dtype, + DataType out_dtype, DenseTensor* y) { int maxlen = max_len.to(); auto* x_data = x.data(); @@ -58,7 +58,7 @@ void SequenceMaskScalarKernel(const Context& ctx, y->Resize(common::make_ddim(y_dim)); } - phi::VisitDataType(phi::TransToPhiDataType(out_dtype), + phi::VisitDataType(out_dtype, phi::funcs::SequenceMaskFunctor( ctx, x_data, y, x_numel * maxlen, maxlen)); } @@ -68,7 +68,7 @@ void SequenceMaskKernel(const Context& ctx, const DenseTensor& x, const paddle::optional& max_len_tensor, int maxlen, - int out_dtype, + DataType out_dtype, DenseTensor* y) { if (max_len_tensor) { bool is_gpu_place = ctx.GetPlace().GetType() == phi::AllocationType::GPU; diff --git a/paddle/phi/kernels/sequence_mask_kernel.h b/paddle/phi/kernels/sequence_mask_kernel.h index ff547ee2522ecf..438d9f44310bf0 100644 --- a/paddle/phi/kernels/sequence_mask_kernel.h +++ b/paddle/phi/kernels/sequence_mask_kernel.h @@ -19,7 +19,7 @@ void SequenceMaskKernel(const Context& ctx, const DenseTensor& x, const paddle::optional& max_len_tensor, int maxlen, - int out_dtype, + DataType out_dtype, DenseTensor* y); } // namespace phi diff --git a/python/paddle/base/framework.py b/python/paddle/base/framework.py index f8d689c1e578e0..e00a1827361aab 100644 --- a/python/paddle/base/framework.py +++ b/python/paddle/base/framework.py @@ -235,11 +235,6 @@ def __setattr__(self, name, val): } -# FIXME(dev): We haven't fully verified eager mode on XPU et.al but -# only GPU/CPU. Remove this after we improve this feature. -_is_first_import_ = True - - def in_dygraph_mode(): """ diff --git a/python/paddle/nn/functional/extension.py b/python/paddle/nn/functional/extension.py index dac7ba30d93fdf..2709a7b090e982 100644 --- a/python/paddle/nn/functional/extension.py +++ b/python/paddle/nn/functional/extension.py @@ -103,10 +103,11 @@ def sequence_mask(x, maxlen=None, dtype='int64', name=None): if in_dynamic_or_pir_mode(): if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)): dtype = convert_np_dtype_to_dtype_(dtype) - if maxlen is not None: - out = _C_ops.sequence_mask(x, maxlen, dtype) - out.stop_gradient = True - return out + if maxlen is None: + maxlen = -1 + out = _C_ops.sequence_mask(x, maxlen, dtype) + out.stop_gradient = True + return out helper = LayerHelper('sequence_mask', **locals()) out = helper.create_variable_for_type_inference(dtype=dtype) diff --git a/python/paddle/pir/core.py b/python/paddle/pir/core.py index c6807201bc019d..5d3a3354fe6439 100644 --- a/python/paddle/pir/core.py +++ b/python/paddle/pir/core.py @@ -45,21 +45,6 @@ VarDesc.VarType.COMPLEX128: DataType.COMPLEX128, } -datatype_to_vartype = { - DataType.FLOAT32: VarDesc.VarType.FP32, - DataType.FLOAT64: VarDesc.VarType.FP64, - DataType.FLOAT16: VarDesc.VarType.FP16, - DataType.BFLOAT16: VarDesc.VarType.BF16, - DataType.INT32: VarDesc.VarType.INT32, - DataType.INT16: VarDesc.VarType.INT16, - DataType.INT64: VarDesc.VarType.INT64, - DataType.BOOL: VarDesc.VarType.BOOL, - DataType.UINT8: VarDesc.VarType.UINT8, - DataType.INT8: VarDesc.VarType.INT8, - DataType.COMPLEX64: VarDesc.VarType.COMPLEX64, - DataType.COMPLEX128: VarDesc.VarType.COMPLEX128, -} - np_type_to_paddle_type = { np.dtype("float32"): DataType.FLOAT32, np.dtype("float64"): DataType.FLOAT64, diff --git a/test/legacy_test/prim_op_test.py b/test/legacy_test/prim_op_test.py index fa755f84a18ce4..4498c51b64de71 100644 --- a/test/legacy_test/prim_op_test.py +++ b/test/legacy_test/prim_op_test.py @@ -32,12 +32,13 @@ canonicalize_attrs, in_dygraph_mode, in_pir_mode, + paddle_type_to_proto_type, use_pir_api, ) from paddle.decomposition import decompose from paddle.incubate.autograd import primapi from paddle.jit.dy2static.utils import parse_arg_and_kwargs -from paddle.pir.core import datatype_to_vartype, vartype_to_datatype +from paddle.pir.core import vartype_to_datatype def flatten(nest_list): @@ -156,7 +157,7 @@ def convert_dtype(dtype, target_dtype): isinstance(dtype, paddle.pir.core.DataType) and target_dtype is core.VarDesc.VarType ): - return datatype_to_vartype[dtype] + return paddle_type_to_proto_type[dtype] return dtype # NOTE(xiongkun): the logic of constructing parameters: @@ -221,11 +222,10 @@ def convert_dtype(dtype, target_dtype): else: results.append(tmp) assert len(results) == len(api_params) - # TODO(SigureMo): remove this in #60423 - if hasattr(api, "__name__") and api.__name__ != "sequence_mask_wraper": - results = paddle.utils.map_structure( - partial(convert_dtype, target_dtype=target_dtype), results - ) + + results = paddle.utils.map_structure( + partial(convert_dtype, target_dtype=target_dtype), results + ) return results @classmethod diff --git a/test/sequence/test_sequence_mask.py b/test/sequence/test_sequence_mask.py index 81a856c826270d..c7580c6dcc5a47 100644 --- a/test/sequence/test_sequence_mask.py +++ b/test/sequence/test_sequence_mask.py @@ -31,7 +31,7 @@ from paddle.pir_utils import test_with_pir_api -def sequence_mask_wraper(x, maxlen_tensor=None, maxlen=-1, mask_dtype='int64'): +def sequence_mask_wrapper(x, maxlen_tensor=None, maxlen=-1, mask_dtype='int64'): if maxlen_tensor is not None: maxlen = maxlen_tensor return paddle.nn.functional.sequence_mask( @@ -42,7 +42,7 @@ def sequence_mask_wraper(x, maxlen_tensor=None, maxlen=-1, mask_dtype='int64'): class SequenceMaskTestBase(OpTest): def initDefaultParameters(self): self.op_type = 'sequence_mask' - self.python_api = sequence_mask_wraper + self.python_api = sequence_mask_wrapper self.maxlen = 10 self.mask_dtype = 'int64' self.x = [[0, 3, 4], [5, 7, 9]] @@ -112,7 +112,7 @@ def initParameters(self): class SequenceMaskTestBase_tensor_attr(OpTest): def initDefaultParameters(self): self.op_type = 'sequence_mask' - self.python_api = sequence_mask_wraper + self.python_api = sequence_mask_wrapper self.maxlen = 10 self.maxlen_tensor = np.ones((1), 'int32') * 10 self.mask_dtype = 'int64'