Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,26 @@ pir::Value assign(const pir::Value& x) {
}
}

std::tuple<pir::Value, pir::Value> fused_gemm_epilogue(pir::Value x,
pir::Value y,
pir::Value bias,
bool trans_x,
bool trans_y,
std::string activation) {
pir::IrContext* ctx = pir::IrContext::Instance();
pir::AttributeMap attribute_map = {
{"trans_x", pir::BoolAttribute::get(ctx, trans_x)},
{"trans_y", pir::BoolAttribute::get(ctx, trans_y)},
{"activation", pir::StrAttribute::get(ctx, activation)}};
auto fused_gemm_epilogue_op =
ApiBuilder::Instance()
.GetBuilder()
->Build<paddle::dialect::FusedGemmEpilogueOp>(
x, y, bias, attribute_map);
return std::make_tuple(fused_gemm_epilogue_op.result(0),
fused_gemm_epilogue_op.result(1));
}

pir::Value array_pop(pir::Value input, int index) {
if (input.type().isa<paddle::dialect::DenseTensorArrayType>()) {
paddle::dialect::ArrayPopOp array_pop_op =
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ pir::Value slice_array_dense(pir::Value input, pir::Value starts);

pir::Value assign(const pir::Value& x);

std::tuple<pir::Value, pir::Value> fused_gemm_epilogue(pir::Value x,
pir::Value y,
pir::Value bias,
bool trans_x,
bool trans_y,
std::string activation);
pir::Value array_pop(pir::Value input, int index);

} // namespace dialect
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,16 @@
view : (mean -> mean_out), (variance -> variance_out)
backward : fused_bn_add_activation_grad

- op : fused_softmax_mask
args : (Tensor x, Tensor mask)
output : Tensor(out)
infer_meta :
func : SoftmaxMaskFuseInferMeta
kernel :
func : fused_softmax_mask
data_type : x
backward: fused_softmax_mask_grad

- op : fused_softmax_mask_upper_triangle
args : (Tensor X)
output : Tensor(Out)
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,17 @@
func: fused_feedforward_grad
optional: linear1_bias, linear2_bias, ln1_scale, ln1_bias, ln1_out, ln1_mean, ln1_variance, ln2_scale, ln2_bias, ln2_mean, ln2_variance, dropout2_out, ln1_scale_grad, ln1_bias_grad, ln2_scale_grad, ln2_bias_grad, linear2_bias_grad

- backward_op : fused_softmax_mask_grad
forward : fused_softmax_mask (Tensor x, Tensor mask) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : GeneralUnaryGradInferMeta
param: [out]
kernel :
func : fused_softmax_mask_grad
data_type : out

- backward_op : fused_softmax_mask_upper_triangle_grad
forward : fused_softmax_mask_upper_triangle(Tensor X) -> Tensor(Out)
args: (Tensor Out, Tensor Out_grad)
Expand Down
53 changes: 53 additions & 0 deletions paddle/fluid/pybind/manual_static_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,39 @@ static PyObject *run_custom_op(PyObject *self,
}
}

static PyObject *static_api_fused_gemm_epilogue(PyObject *self,
PyObject *args,
PyObject *kwargs) {
try {
VLOG(6) << "Running Static API: fused_gemm_epilogue";

VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
// Get OpResult from args
PyObject *x_obj = PyTuple_GET_ITEM(args, 0);
auto x = CastPyArg2Value(x_obj, "fused_gemm_epilogue", 0);
PyObject *y_obj = PyTuple_GET_ITEM(args, 1);
auto y = CastPyArg2Value(y_obj, "fused_gemm_epilogue", 1);
PyObject *bias_obj = PyTuple_GET_ITEM(args, 2);
auto bias = CastPyArg2Value(bias_obj, "fused_gemm_epilogue", 2);

// Parse Attributes if needed
PyObject *trans_x_obj = PyTuple_GET_ITEM(args, 3);
bool trans_x = CastPyArg2Boolean(trans_x_obj, "fused_gemm_epilogue", 3);
PyObject *trans_y_obj = PyTuple_GET_ITEM(args, 4);
bool trans_y = CastPyArg2Boolean(trans_y_obj, "fused_gemm_epilogue", 4);
PyObject *activation_obj = PyTuple_GET_ITEM(args, 5);
std::string activation =
CastPyArg2String(activation_obj, "fused_gemm_epilogue", 5);

// Call ir static api
auto out = paddle::dialect::fused_gemm_epilogue(
x, y, bias, trans_x, trans_y, activation);
return ToPyObject(out);
} catch (...) {
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
static PyObject *static_api_array_pop(PyObject *self,
PyObject *args,
PyObject *kwargs) {
Expand All @@ -778,6 +811,22 @@ static PyObject *static_api_array_pop(PyObject *self,
}
}

extern PyObject *eager_api_fused_gemm_epilogue(PyObject *self,
PyObject *args,
PyObject *kwargs);

static PyObject *fused_gemm_epilogue(PyObject *self,
PyObject *args,
PyObject *kwargs) {
if (egr::Controller::Instance().GetCurrentTracer() == nullptr) {
VLOG(6) << "Call static_api_fused_gemm_epilogue";
return static_api_fused_gemm_epilogue(self, args, kwargs);
} else {
VLOG(6) << "Call eager_api_fused_gemm_epilogue";
return eager_api_fused_gemm_epilogue(self, args, kwargs);
}
}

static PyMethodDef ManualOpsAPI[] = {
{"set_parameter",
(PyCFunction)(void (*)(void))static_api_set_parameter,
Expand Down Expand Up @@ -823,6 +872,10 @@ static PyMethodDef ManualOpsAPI[] = {
(PyCFunction)(void (*)(void))static_api_slice_array_dense,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for slice_array_dense."},
{"fused_gemm_epilogue",
(PyCFunction)(void (*)(void))fused_gemm_epilogue,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for fused_gemm_epilogue."},
{"_run_custom_op",
(PyCFunction)(void (*)(void))run_custom_op,
METH_VARARGS | METH_KEYWORDS,
Expand Down
113 changes: 113 additions & 0 deletions paddle/phi/api/lib/api_custom_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,18 @@ limitations under the License. */
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/fusion.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/utils/flags.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h"
#include "paddle/phi/infermeta/spmd_rules/rules.h"
#endif

PD_DECLARE_int32(low_precision_op_list);

namespace paddle {
namespace experimental {

Expand Down Expand Up @@ -221,6 +226,114 @@ Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) {
return out;
}

std::tuple<Tensor, Tensor> fused_gemm_epilogue_impl(
const Tensor& x,
const Tensor& y,
const Tensor& bias,
bool trans_x,
bool trans_y,
const std::string& activation) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;

if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x, y, bias);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}

VLOG(6) << "fused_gemm_epilogue API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
auto kernel_result = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"fused_gemm_epilogue",
{kernel_backend, kernel_layout, kernel_data_type},
true);
const auto& kernel = kernel_result.kernel;
if (FLAGS_low_precision_op_list) {
phi::KernelFactory::Instance().AddToLowPrecisionKernelList(
"fused_gemm_epilogue", kernel_data_type);
}
VLOG(6) << "fused_gemm_epilogue kernel: " << kernel;
// add actual_kernel_backend to select actual kernel backend after a potential
// falling-back to CPU
Backend actual_kernel_backend =
kernel_result.has_fallback_cpu ? Backend::CPU : kernel_backend;
auto* dev_ctx = GetDeviceContextByBackend(actual_kernel_backend);

auto input_x = PrepareData(
x,
GetKernelInputArgDef(kernel.InputAt(0), actual_kernel_backend),
{},
kernel_result.is_stride_kernel);
auto input_y = PrepareData(
y,
GetKernelInputArgDef(kernel.InputAt(1), actual_kernel_backend),
{},
kernel_result.is_stride_kernel);
auto input_bias = PrepareData(
bias,
GetKernelInputArgDef(kernel.InputAt(2), actual_kernel_backend),
{},
kernel_result.is_stride_kernel);

std::tuple<Tensor, Tensor> api_output;
auto kernel_out_0 = SetKernelOutput(&std::get<0>(api_output));
phi::DenseTensor* kernel_out_1 = nullptr;
if (activation != "none") {
kernel_out_1 = SetKernelOutput(&std::get<1>(api_output));
}

phi::MetaTensor meta_out_0(kernel_out_0, kernel_result.is_stride_kernel);
phi::MetaTensor meta_out_1(kernel_out_1, kernel_result.is_stride_kernel);

phi::FusedGemmEpilogueInferMeta(MakeMetaTensor(*input_x),
MakeMetaTensor(*input_y),
MakeMetaTensor(*input_bias),
trans_x,
trans_y,
activation,
kernel_out_0 ? &meta_out_0 : nullptr,
kernel_out_1 ? &meta_out_1 : nullptr);

using kernel_signature = void (*)(const phi::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
bool,
bool,
const std::string&,
phi::DenseTensor*,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();

(*kernel_fn)(*dev_ctx,
*input_x,
*input_y,
*input_bias,
trans_x,
trans_y,
activation,
kernel_out_0,
kernel_out_1);

if (kernel_result.has_fallback_cpu) {
TransDataBackend(kernel_out_0, kernel_backend, kernel_out_0);
TransDataBackend(kernel_out_1, kernel_backend, kernel_out_1);
}
return api_output;
}

////////////////// Backward(grad) api impls //////////////////////

void embedding_grad_impl(const Tensor& x,
Expand Down
8 changes: 8 additions & 0 deletions paddle/phi/api/lib/api_custom_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ Tensor add_n_impl(const std::vector<Tensor>& x);

Tensor copy_to_impl(const Tensor& x, Place place, bool blocking);

std::tuple<Tensor, Tensor> fused_gemm_epilogue_impl(
const Tensor& x,
const Tensor& y,
const Tensor& bias,
bool trans_x,
bool trans_y,
const std::string& activation);

////////////////// Backward(grad) api impls //////////////////////

void imag_grad_impl(const Tensor& out_grad, Tensor* x_grad);
Expand Down
21 changes: 21 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,17 @@
data_type : out_grad
optional : reserve_space

- backward_op : fused_softmax_mask_grad
forward : fused_softmax_mask (Tensor x, Tensor mask) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : GeneralUnaryGradInferMeta
param: [out]
kernel :
func : fused_softmax_mask_grad
data_type : out

- backward_op : fused_softmax_mask_upper_triangle_grad
forward : fused_softmax_mask_upper_triangle(Tensor X) -> Tensor(Out)
args: (Tensor Out, Tensor Out_grad)
Expand Down Expand Up @@ -855,6 +866,16 @@
func: check_model_nan_inf
data_type: out_grad

- backward_op: fused_gemm_epilogue_grad
forward : fused_gemm_epilogue(Tensor x, Tensor y, Tensor bias, bool trans_x, bool trans_y, str activation) -> Tensor(out), Tensor(reserve_space)
args : (Tensor x, Tensor y, Tensor reserve_space, Tensor out_grad, bool trans_x, bool trans_y, str activation)
output : Tensor(x_grad), Tensor(y_grad), Tensor(bias_grad)
infer_meta :
func : FusedGemmEpilogueGradInferMeta
kernel:
func : fused_gemm_epilogue_grad
optional : reserve_space

- backward_op: unpool_grad
forward: unpool (Tensor x, Tensor indices, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format) -> Tensor(out)
args: (Tensor x, Tensor indices, Tensor out, Tensor out_grad, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format)
Expand Down
17 changes: 17 additions & 0 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,23 @@
view : (mean -> mean_out), (variance -> variance_out)
backward : fused_bn_add_activation_grad

- op : fused_gemm_epilogue
args : (Tensor x, Tensor y, Tensor bias, bool trans_x, bool trans_y, str activation)
output : Tensor(out), Tensor(reserve_space)
invoke : fused_gemm_epilogue_impl(x, y, bias, trans_x, trans_y, activation)
backward: fused_gemm_epilogue_grad
optional: reserve_space

- op : fused_softmax_mask
args : (Tensor x, Tensor mask)
output : Tensor(out)
infer_meta :
func : SoftmaxMaskFuseInferMeta
kernel :
func : fused_softmax_mask
data_type : x
backward: fused_softmax_mask_grad

- op : fused_softmax_mask_upper_triangle
args : (Tensor X)
output : Tensor(Out)
Expand Down
8 changes: 5 additions & 3 deletions python/paddle/distributed/fleet/layers/mpu/mp_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import paddle
from paddle.autograd import PyLayer
from paddle.base import core
from paddle.distributed import fleet
from paddle.nn import functional as F

Expand All @@ -34,7 +33,7 @@


def is_fused_matmul_bias_supported():
return hasattr(core.eager.ops.legacy, 'fused_gemm_epilogue')
return hasattr(paddle._C_ops, 'fused_gemm_epilogue')


def is_fused_linear_param_grad_add_supported():
Expand Down Expand Up @@ -214,7 +213,10 @@ def forward(
if not fuse_matmul_bias:
return paddle._C_ops.linear(x, weight, bias)
else:
return paddle._legacy_C_ops.fused_gemm_epilogue(x, weight, bias)
result, _ = paddle._C_ops.fused_gemm_epilogue(
x, weight, bias, False, False, "none"
)
return result

@staticmethod
def backward(ctx, dy):
Expand Down
Loading