diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc index 2daa5941ea38bd..2a8be85d8f52f0 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc @@ -225,6 +225,26 @@ pir::Value assign(const pir::Value& x) { } } +std::tuple fused_gemm_epilogue(pir::Value x, + pir::Value y, + pir::Value bias, + bool trans_x, + bool trans_y, + std::string activation) { + pir::IrContext* ctx = pir::IrContext::Instance(); + pir::AttributeMap attribute_map = { + {"trans_x", pir::BoolAttribute::get(ctx, trans_x)}, + {"trans_y", pir::BoolAttribute::get(ctx, trans_y)}, + {"activation", pir::StrAttribute::get(ctx, activation)}}; + auto fused_gemm_epilogue_op = + ApiBuilder::Instance() + .GetBuilder() + ->Build( + x, y, bias, attribute_map); + return std::make_tuple(fused_gemm_epilogue_op.result(0), + fused_gemm_epilogue_op.result(1)); +} + pir::Value array_pop(pir::Value input, int index) { if (input.type().isa()) { paddle::dialect::ArrayPopOp array_pop_op = diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.h b/paddle/fluid/pir/dialect/operator/ir/manual_api.h index eda5df20a60b13..d58e8aaeb90770 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.h @@ -87,6 +87,12 @@ pir::Value slice_array_dense(pir::Value input, pir::Value starts); pir::Value assign(const pir::Value& x); +std::tuple fused_gemm_epilogue(pir::Value x, + pir::Value y, + pir::Value bias, + bool trans_x, + bool trans_y, + std::string activation); pir::Value array_pop(pir::Value input, int index); } // namespace dialect diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 4593a7e7d7c427..cd9d6a38fd5f48 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -661,6 +661,16 @@ view : (mean -> mean_out), (variance -> variance_out) backward : fused_bn_add_activation_grad +- op : fused_softmax_mask + args : (Tensor x, Tensor mask) + output : Tensor(out) + infer_meta : + func : SoftmaxMaskFuseInferMeta + kernel : + func : fused_softmax_mask + data_type : x + backward: fused_softmax_mask_grad + - op : fused_softmax_mask_upper_triangle args : (Tensor X) output : Tensor(Out) diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml index 61d0d179c55474..19dc33f145a0b7 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml @@ -324,6 +324,17 @@ func: fused_feedforward_grad optional: linear1_bias, linear2_bias, ln1_scale, ln1_bias, ln1_out, ln1_mean, ln1_variance, ln2_scale, ln2_bias, ln2_mean, ln2_variance, dropout2_out, ln1_scale_grad, ln1_bias_grad, ln2_scale_grad, ln2_bias_grad, linear2_bias_grad +- backward_op : fused_softmax_mask_grad + forward : fused_softmax_mask (Tensor x, Tensor mask) -> Tensor(out) + args : (Tensor out, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : GeneralUnaryGradInferMeta + param: [out] + kernel : + func : fused_softmax_mask_grad + data_type : out + - backward_op : fused_softmax_mask_upper_triangle_grad forward : fused_softmax_mask_upper_triangle(Tensor X) -> Tensor(Out) args: (Tensor Out, Tensor Out_grad) diff --git a/paddle/fluid/pybind/manual_static_op_function.h b/paddle/fluid/pybind/manual_static_op_function.h index f5eafc5f384e11..659cc66bbf47a1 100644 --- a/paddle/fluid/pybind/manual_static_op_function.h +++ b/paddle/fluid/pybind/manual_static_op_function.h @@ -754,6 +754,39 @@ static PyObject *run_custom_op(PyObject *self, } } +static PyObject *static_api_fused_gemm_epilogue(PyObject *self, + PyObject *args, + PyObject *kwargs) { + try { + VLOG(6) << "Running Static API: fused_gemm_epilogue"; + + VLOG(8) << "args count: " << (PyTuple_Size(args) / 2); + // Get OpResult from args + PyObject *x_obj = PyTuple_GET_ITEM(args, 0); + auto x = CastPyArg2Value(x_obj, "fused_gemm_epilogue", 0); + PyObject *y_obj = PyTuple_GET_ITEM(args, 1); + auto y = CastPyArg2Value(y_obj, "fused_gemm_epilogue", 1); + PyObject *bias_obj = PyTuple_GET_ITEM(args, 2); + auto bias = CastPyArg2Value(bias_obj, "fused_gemm_epilogue", 2); + + // Parse Attributes if needed + PyObject *trans_x_obj = PyTuple_GET_ITEM(args, 3); + bool trans_x = CastPyArg2Boolean(trans_x_obj, "fused_gemm_epilogue", 3); + PyObject *trans_y_obj = PyTuple_GET_ITEM(args, 4); + bool trans_y = CastPyArg2Boolean(trans_y_obj, "fused_gemm_epilogue", 4); + PyObject *activation_obj = PyTuple_GET_ITEM(args, 5); + std::string activation = + CastPyArg2String(activation_obj, "fused_gemm_epilogue", 5); + + // Call ir static api + auto out = paddle::dialect::fused_gemm_epilogue( + x, y, bias, trans_x, trans_y, activation); + return ToPyObject(out); + } catch (...) { + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} static PyObject *static_api_array_pop(PyObject *self, PyObject *args, PyObject *kwargs) { @@ -778,6 +811,22 @@ static PyObject *static_api_array_pop(PyObject *self, } } +extern PyObject *eager_api_fused_gemm_epilogue(PyObject *self, + PyObject *args, + PyObject *kwargs); + +static PyObject *fused_gemm_epilogue(PyObject *self, + PyObject *args, + PyObject *kwargs) { + if (egr::Controller::Instance().GetCurrentTracer() == nullptr) { + VLOG(6) << "Call static_api_fused_gemm_epilogue"; + return static_api_fused_gemm_epilogue(self, args, kwargs); + } else { + VLOG(6) << "Call eager_api_fused_gemm_epilogue"; + return eager_api_fused_gemm_epilogue(self, args, kwargs); + } +} + static PyMethodDef ManualOpsAPI[] = { {"set_parameter", (PyCFunction)(void (*)(void))static_api_set_parameter, @@ -823,6 +872,10 @@ static PyMethodDef ManualOpsAPI[] = { (PyCFunction)(void (*)(void))static_api_slice_array_dense, METH_VARARGS | METH_KEYWORDS, "C++ interface function for slice_array_dense."}, + {"fused_gemm_epilogue", + (PyCFunction)(void (*)(void))fused_gemm_epilogue, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for fused_gemm_epilogue."}, {"_run_custom_op", (PyCFunction)(void (*)(void))run_custom_op, METH_VARARGS | METH_KEYWORDS, diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index d9cbf034999115..3577576938886d 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -25,13 +25,18 @@ limitations under the License. */ #include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/binary.h" +#include "paddle/phi/infermeta/fusion.h" #include "paddle/phi/infermeta/multiary.h" #include "paddle/phi/infermeta/nullary.h" #include "paddle/phi/infermeta/unary.h" +#include "paddle/utils/flags.h" #ifdef PADDLE_WITH_DISTRIBUTE #include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" #include "paddle/phi/infermeta/spmd_rules/rules.h" #endif + +PD_DECLARE_int32(low_precision_op_list); + namespace paddle { namespace experimental { @@ -221,6 +226,114 @@ Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) { return out; } +std::tuple fused_gemm_epilogue_impl( + const Tensor& x, + const Tensor& y, + const Tensor& bias, + bool trans_x, + bool trans_y, + const std::string& activation) { + Backend kernel_backend = Backend::UNDEFINED; + DataLayout kernel_layout = DataLayout::UNDEFINED; + DataType kernel_data_type = DataType::UNDEFINED; + + if (kernel_backend == Backend::UNDEFINED || + kernel_layout == DataLayout::UNDEFINED || + kernel_data_type == DataType::UNDEFINED) { + auto kernel_key_set = ParseKernelKeyByInputArgs(x, y, bias); + auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); + if (kernel_backend == Backend::UNDEFINED) { + kernel_backend = kernel_key.backend(); + } + if (kernel_layout == DataLayout::UNDEFINED) { + kernel_layout = kernel_key.layout(); + } + if (kernel_data_type == DataType::UNDEFINED) { + kernel_data_type = kernel_key.dtype(); + } + } + + VLOG(6) << "fused_gemm_epilogue API kernel key: [" << kernel_backend << ", " + << kernel_layout << ", " << kernel_data_type << "]"; + auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "fused_gemm_epilogue", + {kernel_backend, kernel_layout, kernel_data_type}, + true); + const auto& kernel = kernel_result.kernel; + if (FLAGS_low_precision_op_list) { + phi::KernelFactory::Instance().AddToLowPrecisionKernelList( + "fused_gemm_epilogue", kernel_data_type); + } + VLOG(6) << "fused_gemm_epilogue kernel: " << kernel; + // add actual_kernel_backend to select actual kernel backend after a potential + // falling-back to CPU + Backend actual_kernel_backend = + kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend; + auto* dev_ctx = GetDeviceContextByBackend(actual_kernel_backend); + + auto input_x = PrepareData( + x, + GetKernelInputArgDef(kernel.InputAt(0), actual_kernel_backend), + {}, + kernel_result.is_stride_kernel); + auto input_y = PrepareData( + y, + GetKernelInputArgDef(kernel.InputAt(1), actual_kernel_backend), + {}, + kernel_result.is_stride_kernel); + auto input_bias = PrepareData( + bias, + GetKernelInputArgDef(kernel.InputAt(2), actual_kernel_backend), + {}, + kernel_result.is_stride_kernel); + + std::tuple api_output; + auto kernel_out_0 = SetKernelOutput(&std::get<0>(api_output)); + phi::DenseTensor* kernel_out_1 = nullptr; + if (activation != "none") { + kernel_out_1 = SetKernelOutput(&std::get<1>(api_output)); + } + + phi::MetaTensor meta_out_0(kernel_out_0, kernel_result.is_stride_kernel); + phi::MetaTensor meta_out_1(kernel_out_1, kernel_result.is_stride_kernel); + + phi::FusedGemmEpilogueInferMeta(MakeMetaTensor(*input_x), + MakeMetaTensor(*input_y), + MakeMetaTensor(*input_bias), + trans_x, + trans_y, + activation, + kernel_out_0 ? &meta_out_0 : nullptr, + kernel_out_1 ? &meta_out_1 : nullptr); + + using kernel_signature = void (*)(const phi::DeviceContext&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + bool, + bool, + const std::string&, + phi::DenseTensor*, + phi::DenseTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + + (*kernel_fn)(*dev_ctx, + *input_x, + *input_y, + *input_bias, + trans_x, + trans_y, + activation, + kernel_out_0, + kernel_out_1); + + if (kernel_result.has_fallback_cpu) { + TransDataBackend(kernel_out_0, kernel_backend, kernel_out_0); + TransDataBackend(kernel_out_1, kernel_backend, kernel_out_1); + } + return api_output; +} + ////////////////// Backward(grad) api impls ////////////////////// void embedding_grad_impl(const Tensor& x, diff --git a/paddle/phi/api/lib/api_custom_impl.h b/paddle/phi/api/lib/api_custom_impl.h index 474f4f981185f2..9335dac7c2575d 100644 --- a/paddle/phi/api/lib/api_custom_impl.h +++ b/paddle/phi/api/lib/api_custom_impl.h @@ -35,6 +35,14 @@ Tensor add_n_impl(const std::vector& x); Tensor copy_to_impl(const Tensor& x, Place place, bool blocking); +std::tuple fused_gemm_epilogue_impl( + const Tensor& x, + const Tensor& y, + const Tensor& bias, + bool trans_x, + bool trans_y, + const std::string& activation); + ////////////////// Backward(grad) api impls ////////////////////// void imag_grad_impl(const Tensor& out_grad, Tensor* x_grad); diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 214b4d3d48c341..cbaa1d56a0ae74 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -281,6 +281,17 @@ data_type : out_grad optional : reserve_space +- backward_op : fused_softmax_mask_grad + forward : fused_softmax_mask (Tensor x, Tensor mask) -> Tensor(out) + args : (Tensor out, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : GeneralUnaryGradInferMeta + param: [out] + kernel : + func : fused_softmax_mask_grad + data_type : out + - backward_op : fused_softmax_mask_upper_triangle_grad forward : fused_softmax_mask_upper_triangle(Tensor X) -> Tensor(Out) args: (Tensor Out, Tensor Out_grad) @@ -855,6 +866,16 @@ func: check_model_nan_inf data_type: out_grad +- backward_op: fused_gemm_epilogue_grad + forward : fused_gemm_epilogue(Tensor x, Tensor y, Tensor bias, bool trans_x, bool trans_y, str activation) -> Tensor(out), Tensor(reserve_space) + args : (Tensor x, Tensor y, Tensor reserve_space, Tensor out_grad, bool trans_x, bool trans_y, str activation) + output : Tensor(x_grad), Tensor(y_grad), Tensor(bias_grad) + infer_meta : + func : FusedGemmEpilogueGradInferMeta + kernel: + func : fused_gemm_epilogue_grad + optional : reserve_space + - backward_op: unpool_grad forward: unpool (Tensor x, Tensor indices, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format) -> Tensor(out) args: (Tensor x, Tensor indices, Tensor out, Tensor out_grad, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 4a1e37caa9ab83..cf33905084ae4e 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -584,6 +584,23 @@ view : (mean -> mean_out), (variance -> variance_out) backward : fused_bn_add_activation_grad +- op : fused_gemm_epilogue + args : (Tensor x, Tensor y, Tensor bias, bool trans_x, bool trans_y, str activation) + output : Tensor(out), Tensor(reserve_space) + invoke : fused_gemm_epilogue_impl(x, y, bias, trans_x, trans_y, activation) + backward: fused_gemm_epilogue_grad + optional: reserve_space + +- op : fused_softmax_mask + args : (Tensor x, Tensor mask) + output : Tensor(out) + infer_meta : + func : SoftmaxMaskFuseInferMeta + kernel : + func : fused_softmax_mask + data_type : x + backward: fused_softmax_mask_grad + - op : fused_softmax_mask_upper_triangle args : (Tensor X) output : Tensor(Out) diff --git a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py index 05f3e2d241ab84..176c261aae25d3 100644 --- a/python/paddle/distributed/fleet/layers/mpu/mp_layers.py +++ b/python/paddle/distributed/fleet/layers/mpu/mp_layers.py @@ -16,7 +16,6 @@ import paddle from paddle.autograd import PyLayer -from paddle.base import core from paddle.distributed import fleet from paddle.nn import functional as F @@ -34,7 +33,7 @@ def is_fused_matmul_bias_supported(): - return hasattr(core.eager.ops.legacy, 'fused_gemm_epilogue') + return hasattr(paddle._C_ops, 'fused_gemm_epilogue') def is_fused_linear_param_grad_add_supported(): @@ -214,7 +213,10 @@ def forward( if not fuse_matmul_bias: return paddle._C_ops.linear(x, weight, bias) else: - return paddle._legacy_C_ops.fused_gemm_epilogue(x, weight, bias) + result, _ = paddle._C_ops.fused_gemm_epilogue( + x, weight, bias, False, False, "none" + ) + return result @staticmethod def backward(ctx, dy): diff --git a/python/paddle/distributed/fleet/utils/sequence_parallel_utils.py b/python/paddle/distributed/fleet/utils/sequence_parallel_utils.py index 940d7408ff5be7..f499054bc8496e 100644 --- a/python/paddle/distributed/fleet/utils/sequence_parallel_utils.py +++ b/python/paddle/distributed/fleet/utils/sequence_parallel_utils.py @@ -17,7 +17,6 @@ import paddle from paddle import distributed as dist from paddle.autograd import PyLayer -from paddle.base import core from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.utils.hybrid_parallel_util import ( @@ -222,7 +221,7 @@ def is_fused_matmul_bias_supported(): and not paddle.is_compiled_with_rocm() or paddle.is_compiled_with_xpu() ): - return hasattr(core.eager.ops.legacy, "fused_gemm_epilogue") + return hasattr(paddle._C_ops, "fused_gemm_epilogue") else: return False diff --git a/python/paddle/incubate/nn/functional/fused_matmul_bias.py b/python/paddle/incubate/nn/functional/fused_matmul_bias.py index 83d3b5a91d4ba6..1b894ce297a1c0 100644 --- a/python/paddle/incubate/nn/functional/fused_matmul_bias.py +++ b/python/paddle/incubate/nn/functional/fused_matmul_bias.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle import _legacy_C_ops +from paddle import _C_ops from paddle.base.layer_helper import LayerHelper -from paddle.framework import in_dynamic_mode +from paddle.framework import in_dynamic_or_pir_mode from paddle.tensor.linalg import matmul @@ -56,10 +56,11 @@ def fused_matmul_bias( """ if bias is None: return matmul(x, y, transpose_x, transpose_y, name) - if in_dynamic_mode(): - return _legacy_C_ops.fused_gemm_epilogue( - x, y, bias, 'trans_x', transpose_x, 'trans_y', transpose_y + if in_dynamic_or_pir_mode(): + out, _ = _C_ops.fused_gemm_epilogue( + x, y, bias, transpose_x, transpose_y, "none" ) + return out helper = LayerHelper('fused_matmul_bias', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -145,18 +146,16 @@ def fused_linear_activation( if activation is None: activation = "none" - if in_dynamic_mode(): - return _legacy_C_ops.fused_gemm_epilogue( + if in_dynamic_or_pir_mode(): + out, _ = _C_ops.fused_gemm_epilogue( x, y, bias, - 'trans_x', trans_x, - 'trans_y', trans_y, - 'activation', activation, ) + return out helper = LayerHelper('fused_matmul_bias', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) diff --git a/python/paddle/incubate/operators/softmax_mask_fuse.py b/python/paddle/incubate/operators/softmax_mask_fuse.py index 4a4e01d816272e..ba92e96d408e15 100644 --- a/python/paddle/incubate/operators/softmax_mask_fuse.py +++ b/python/paddle/incubate/operators/softmax_mask_fuse.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle import _legacy_C_ops +from paddle import _C_ops from paddle.base.layer_helper import LayerHelper -from paddle.framework import in_dynamic_mode +from paddle.framework import in_dynamic_or_pir_mode def softmax_mask_fuse(x, mask, name=None): @@ -56,8 +56,8 @@ def softmax_mask_fuse(x, mask, name=None): >>> rst.shape [2, 8, 8, 32] """ - if in_dynamic_mode(): - out = _legacy_C_ops.fused_softmax_mask(x, mask) + if in_dynamic_or_pir_mode(): + out = _C_ops.fused_softmax_mask(x, mask) return out helper = LayerHelper('fused_softmax_mask', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) diff --git a/test/legacy_test/test_fused_gemm_epilogue_op.py b/test/legacy_test/test_fused_gemm_epilogue_op.py index 7a3301a3981d5b..de065a813329b7 100644 --- a/test/legacy_test/test_fused_gemm_epilogue_op.py +++ b/test/legacy_test/test_fused_gemm_epilogue_op.py @@ -25,7 +25,7 @@ def is_fused_gemm_epilogue_supported(): if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm(): - return hasattr(core.eager.ops, 'fused_gemm_epilogue') + return hasattr(paddle._C_ops, 'fused_gemm_epilogue') else: return False @@ -588,17 +588,11 @@ def test_case_act(self): x.stop_gradient = False y.stop_gradient = False - out1 = fused_linear_activation( - x, y, bias, 'trans_x', False, 'trans_y', False, 'activation', 'none' - ) + out1 = fused_linear_activation(x, y, bias, False, False, 'none') - out2 = fused_linear_activation( - x, y, bias, 'trans_x', False, 'trans_y', False, 'activation', 'relu' - ) + out2 = fused_linear_activation(x, y, bias, False, False, 'relu') - out3 = fused_linear_activation( - x, y, bias, 'trans_x', False, 'trans_y', False, 'activation', 'gelu' - ) + out3 = fused_linear_activation(x, y, bias, False, False, 'gelu') out_np1 = get_output(x_np, y_np, bias_np, 'none') out_np2 = get_output(x_np, y_np, bias_np, 'relu') diff --git a/test/legacy_test/test_fused_matmul_bias.py b/test/legacy_test/test_fused_matmul_bias.py index 85666710b0e450..2ba4ea4c761857 100644 --- a/test/legacy_test/test_fused_matmul_bias.py +++ b/test/legacy_test/test_fused_matmul_bias.py @@ -20,6 +20,7 @@ from paddle.base import core from paddle.incubate.nn import FusedLinear from paddle.incubate.nn.functional import fused_linear, fused_matmul_bias +from paddle.pir_utils import test_with_pir_api def is_fused_matmul_bias_supported(): @@ -153,6 +154,7 @@ def test_transpose(self): "fused_gemm_epilogue is only supported when CUDA version >= 11.6", ) class TestStaticGraph(unittest.TestCase): + @test_with_pir_api def test_static_graph(self): paddle.enable_static() x = paddle.static.data(name='x', dtype='float32', shape=[-1, 100]) diff --git a/test/legacy_test/test_softmax_mask_fuse_op.py b/test/legacy_test/test_softmax_mask_fuse_op.py index 495876a850588e..4f907602062fd7 100644 --- a/test/legacy_test/test_softmax_mask_fuse_op.py +++ b/test/legacy_test/test_softmax_mask_fuse_op.py @@ -20,6 +20,7 @@ import paddle from paddle import base, incubate from paddle.base import core +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -51,10 +52,12 @@ def setUp(self): self.outputs = {'Out': rst} def test_check_output(self): - self.check_output_with_place(core.CPUPlace()) + self.check_output_with_place(core.CPUPlace(), check_pir=True) def test_check_grad(self): - self.check_grad_with_place(core.CPUPlace(), ["X"], "Out") + self.check_grad_with_place( + core.CPUPlace(), ["X"], "Out", check_pir=True + ) @unittest.skipIf( @@ -72,10 +75,12 @@ def setUp(self): self.outputs = {'Out': rst} def test_check_output(self): - self.check_output_with_place(core.CUDAPlace(0)) + self.check_output_with_place(core.CUDAPlace(0), check_pir=True) def test_check_grad(self): - self.check_grad_with_place(core.CUDAPlace(0), ["X"], "Out") + self.check_grad_with_place( + core.CUDAPlace(0), ["X"], "Out", check_pir=True + ) @unittest.skipIf( @@ -93,18 +98,23 @@ def setUp(self): self.outputs = {'Out': rst} def test_check_output(self): - self.check_output_with_place(core.CUDAPlace(0)) + self.check_output_with_place(core.CUDAPlace(0), check_pir=True) def test_check_grad(self): - self.check_grad_with_place(core.CUDAPlace(0), ["X"], "Out") + self.check_grad_with_place( + core.CUDAPlace(0), ["X"], "Out", check_pir=True + ) @unittest.skipIf( not core.is_compiled_with_cuda(), "core is not compiled with CUDA" ) class TestDropoutBiasFuseOp3(unittest.TestCase): + @test_with_pir_api def test_static_result(self): - with base.program_guard(base.Program(), base.Program()): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): input_x = paddle.static.data( name="x", shape=[1, 1, 8, 32], dtype="float32" ) @@ -120,7 +130,7 @@ def test_static_result(self): exe = base.Executor(base.CUDAPlace(0)) fetches = exe.run( - base.default_main_program(), + paddle.static.default_main_program(), feed={"x": x_in_np, "mask": mask_in_np}, fetch_list=[rst], ) diff --git a/test/xpu/test_fused_gemm_epilogue_op_xpu.py b/test/xpu/test_fused_gemm_epilogue_op_xpu.py index fe1d08e36bc39e..3367da6e84abc0 100644 --- a/test/xpu/test_fused_gemm_epilogue_op_xpu.py +++ b/test/xpu/test_fused_gemm_epilogue_op_xpu.py @@ -24,7 +24,7 @@ from op_test_xpu import XPUOpTest import paddle -from paddle import _legacy_C_ops +from paddle import _C_ops from paddle.base import core @@ -251,15 +251,9 @@ def test_case_act(self): x.stop_gradient = False y.stop_gradient = False - out1 = _legacy_C_ops.fused_gemm_epilogue( - x, y, bias, 'trans_x', False, 'trans_y', False, 'activation', 'none' - ) - out2 = _legacy_C_ops.fused_gemm_epilogue( - x, y, bias, 'trans_x', False, 'trans_y', False, 'activation', 'relu' - ) - out3 = _legacy_C_ops.fused_gemm_epilogue( - x, y, bias, 'trans_x', False, 'trans_y', False, 'activation', 'gelu' - ) + out1, _ = _C_ops.fused_gemm_epilogue(x, y, bias, False, False, 'none') + out2, _ = _C_ops.fused_gemm_epilogue(x, y, bias, False, False, 'relu') + out3, _ = _C_ops.fused_gemm_epilogue(x, y, bias, False, False, 'gelu') out_np1 = get_output(x_np, y_np, bias_np, 'none') out_np2 = get_output(x_np, y_np, bias_np, 'relu')