Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,10 @@ cc_library(

#Note(risemeup1):compile some *.cc files which depend on primitive_vjp_experimental into op_dialect_vjp.a/lib
set(op_decomp_source_file ${PD_DIALECT_SOURCE_DIR}/op_decomp.cc)
# set(op_decomp_vjp_source_file ${PD_DIALECT_SOURCE_DIR}/op_decomp_vjp.cc)
set(op_dialect_vjp_srcs
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_decomp.cc
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_decomp_vjp.cc
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_op_vjp.cc
${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/op_dialect.cc
${op_decomp_source_file}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@
"unsqueeze",
]


# xshape output will no longer used after decomp, but return none to keep output num the same as origin op
decomp_ops_contain_unused_output = ["squeeze", "unsqueeze"]

decomp_vjp_interface_declare_gen_op_list = [
"add_grad",
"matmul_grad",
"relu_grad",
]
25 changes: 24 additions & 1 deletion paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from distutils.util import strtobool

import yaml
from decomp_interface_gen_op_list import decomp_interface_declare_gen_op_list
from decomp_interface_gen_op_list import (
decomp_interface_declare_gen_op_list,
decomp_vjp_interface_declare_gen_op_list,
)
from gen_utils import to_pascal_case
from infer_symbolic_shape_gen import gen_infer_symbolic_shape_str
from op_all_func_gen import gen_op_all_func
Expand Down Expand Up @@ -85,6 +88,7 @@
#include <vector>

#include "paddle/fluid/pir/dialect/operator/interface/decomp.h"
#include "paddle/fluid/pir/dialect/operator/interface/decomp_vjp.h"
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h"
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
#include "paddle/fluid/pir/dialect/operator/interface/layout_transformation.h"
Expand Down Expand Up @@ -1348,6 +1352,8 @@ def AutoCodeGen(
exclusive_interface_str_tmp = exclusive_interface_str
decomp_interface_str = "paddle::dialect::DecompInterface"
decomp_interface_declare_str = "\n static std::vector<std::vector<pir::Value>> Decomp(pir::Operation* op);"
decomp_vjp_interface_str = "paddle::dialect::DecompVjpInterface"
decomp_vjp_interface_declare_str = "\n static std::vector<std::vector<pir::Value>> DecompVjp(pir::Operation* op);"

# If op has inplace info, we will generate inplace op and non-inplace op.
for op_name in op_info.op_phi_name:
Expand Down Expand Up @@ -1392,6 +1398,23 @@ def AutoCodeGen(
not in exclusive_interface_str
):
exclusive_interface_str += decomp_interface_declare_str
elif (
op_name in decomp_vjp_interface_declare_gen_op_list
and kernel_func_name
in decomp_vjp_interface_declare_gen_op_list
and dialect_name != "onednn_op"
):
if decomp_vjp_interface_str not in op_interfaces:
op_interfaces = op_interfaces + [
decomp_vjp_interface_str
]
if (
decomp_vjp_interface_declare_str
not in exclusive_interface_str
):
exclusive_interface_str += (
decomp_vjp_interface_declare_str
)
else:
op_interfaces = op_interfaces_tmp
exclusive_interface_str = exclusive_interface_str_tmp
Expand Down
52 changes: 52 additions & 0 deletions paddle/fluid/pir/dialect/operator/interface/decomp_vjp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include "paddle/pir/include/core/op_base.h"

namespace paddle {
namespace dialect {
class DecompVjpInterface : public pir::OpInterfaceBase<DecompVjpInterface> {
public:
struct Concept {
explicit Concept(
std::vector<std::vector<pir::Value>> (*decomp)(pir::Operation* op))
: decomp_(decomp) {}
std::vector<std::vector<pir::Value>> (*decomp_)(pir::Operation* op);
};

template <class ConcreteOp>
struct Model : public Concept {
static std::vector<std::vector<pir::Value>> DecompVjp(pir::Operation* op) {
return ConcreteOp::DecompVjp(op);
}
Model() : Concept(DecompVjp) {}
};

/// Constructor
DecompVjpInterface(pir::Operation* op, Concept* impl)
: pir::OpInterfaceBase<DecompVjpInterface>(op), impl_(impl) {}

std::vector<std::vector<pir::Value>> DecompVjp(pir::Operation* op) {
return impl_->decomp_(op);
}

private:
Concept* impl_;
};

} // namespace dialect
} // namespace paddle

IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DecompVjpInterface)
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/interface/interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
// limitations under the License.

#include "paddle/fluid/pir/dialect/operator/interface/decomp.h"
#include "paddle/fluid/pir/dialect/operator/interface/decomp_vjp.h"
#include "paddle/fluid/pir/dialect/operator/interface/get_kernel_type_for_var.h"
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"

IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InferMetaInterface)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OpYamlInfoInterface)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DecompInterface)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DecompVjpInterface)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::GetKernelTypeForVarInterface)
213 changes: 213 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_op_decomp_vjp.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/primitive/composite/composite.h"
#include "paddle/fluid/primitive/rule/vjp/details.h"
#include "paddle/fluid/primitive/rule/vjp/generated/generated_vjp.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/pir/include/core/builtin_op.h"
#include "paddle/pir/include/core/op_base.h"

// TODO(chenzhuo)
// this file will be generated in pd_op_decomp_vjp.cc

namespace paddle {
namespace dialect {
using IntArray = paddle::experimental::IntArray;

std::vector<std::vector<pir::Value>> AddGradOp::DecompVjp(pir::Operation* op) {
VLOG(4) << "Decomp call add_grad's decomp interface begin";

AddGradOp op_obj = op->dyn_cast<AddGradOp>();
(void)op_obj;

FLAGS_tensor_operants_mode = "static";

VLOG(6) << "Decomp Prepare inputs of add_grad";

Tensor x(std::make_shared<primitive::LazyTensor>(op_obj.x()));
Tensor y(std::make_shared<primitive::LazyTensor>(op_obj.y()));
Tensor out_grad(std::make_shared<primitive::LazyTensor>(op_obj.out_grad()));

VLOG(6) << "Decomp prepare attributes of add_grad";
int axis = op->attribute("axis").dyn_cast<pir::Int32Attribute>().data();

VLOG(6) << "Decomp call add_grad's composite rule prepare";

std::vector<std::vector<bool>> stop_gradients(op->results().size());
if (op->HasAttribute(kAttrStopGradients)) {
auto stop_gradients_attr = op->attribute(kAttrStopGradients)
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
stop_gradients[0].push_back(
stop_gradients_attr[0].dyn_cast<pir::BoolAttribute>().data());
stop_gradients[1].push_back(
stop_gradients_attr[1].dyn_cast<pir::BoolAttribute>().data());
VLOG(0) << " stop_gradients is set ";
} else {
stop_gradients[0].push_back(false);
stop_gradients[1].push_back(false);
VLOG(0) << " stop_gradients is not set ";
}

std::vector<std::vector<paddle::Tensor>> tensor_res;
for (auto arg : stop_gradients) {
tensor_res.push_back(std::vector<paddle::Tensor>(arg.size()));
}
std::string op_name = "add_grad";
FLAGS_tensor_operants_mode = "static";
VLOG(4) << "Call Pir Decomposed backward op add_grad";
paddle::Tensor* x_grad = !stop_gradients[0][0] ? &tensor_res[0][0] : nullptr;
paddle::Tensor* y_grad = !stop_gradients[1][0] ? &tensor_res[1][0] : nullptr;
paddle::primitive::details::add_grad<primitive::LazyTensor>(
x, y, out_grad, axis, x_grad, y_grad);
std::vector<std::vector<pir::Value>> res(tensor_res.size());
for (size_t i = 0; i < tensor_res.size(); ++i) {
res[i].resize(tensor_res[i].size());
for (size_t j = 0; j < tensor_res[i].size(); ++j) {
if (tensor_res[i][j].defined()) {
res[i][j] = std::static_pointer_cast<primitive::LazyTensor>(
tensor_res[i][j].impl())
->value();
}
}
}
return res;
}

std::vector<std::vector<pir::Value>> ReluGradOp::DecompVjp(pir::Operation* op) {
VLOG(4) << "Decomp call relu_grad's decomp interface begin";

ReluGradOp op_obj = op->dyn_cast<ReluGradOp>();
(void)op_obj;

FLAGS_tensor_operants_mode = "static";

VLOG(6) << "Decomp Prepare inputs of relu_grad";

Tensor out(std::make_shared<primitive::LazyTensor>(op_obj.out()));
Tensor out_grad(std::make_shared<primitive::LazyTensor>(op_obj.out_grad()));

VLOG(6) << "Decomp prepare attributes of relu_grad";

VLOG(6) << "Decomp call relu_grad's composite rule prepare";

std::vector<std::vector<bool>> stop_gradients(op->results().size());
if (op->HasAttribute(kAttrStopGradients)) {
auto stop_gradients_attr = op->attribute(kAttrStopGradients)
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
stop_gradients[0].push_back(
stop_gradients_attr[0].dyn_cast<pir::BoolAttribute>().data());
VLOG(0) << " stop_gradients is set ";
} else {
stop_gradients[0].push_back(false);
VLOG(0) << " stop_gradients is not set ";
}

std::vector<std::vector<paddle::Tensor>> tensor_res;
for (auto arg : stop_gradients) {
tensor_res.push_back(std::vector<paddle::Tensor>(arg.size()));
}
std::string op_name = "relu_grad";
FLAGS_tensor_operants_mode = "static";
VLOG(4) << "Call Pir Decomposed backward op relu_grad";
paddle::Tensor* x_grad = !stop_gradients[0][0] ? &tensor_res[0][0] : nullptr;
paddle::primitive::details::relu_grad<primitive::LazyTensor>(
out, out_grad, x_grad);
std::vector<std::vector<pir::Value>> res(tensor_res.size());
for (size_t i = 0; i < tensor_res.size(); ++i) {
res[i].resize(tensor_res[i].size());
for (size_t j = 0; j < tensor_res[i].size(); ++j) {
if (tensor_res[i][j].defined()) {
res[i][j] = std::static_pointer_cast<primitive::LazyTensor>(
tensor_res[i][j].impl())
->value();
}
}
}
return res;
}

std::vector<std::vector<pir::Value>> MatmulGradOp::DecompVjp(
pir::Operation* op) {
VLOG(4) << "Decomp call matmul_grad's decomp interface begin";

MatmulGradOp op_obj = op->dyn_cast<MatmulGradOp>();
(void)op_obj;

FLAGS_tensor_operants_mode = "static";

VLOG(6) << "Decomp Prepare inputs of matmul_grad";

Tensor x(std::make_shared<primitive::LazyTensor>(op_obj.x()));
Tensor y(std::make_shared<primitive::LazyTensor>(op_obj.y()));
Tensor out_grad(std::make_shared<primitive::LazyTensor>(op_obj.out_grad()));

VLOG(6) << "Decomp prepare attributes of matmul_grad";
bool transpose_x =
op->attribute("transpose_x").dyn_cast<pir::BoolAttribute>().data();
bool transpose_y =
op->attribute("transpose_y").dyn_cast<pir::BoolAttribute>().data();

VLOG(6) << "Decomp call matmul_grad's composite rule prepare";

std::vector<std::vector<bool>> stop_gradients(op->results().size());
if (op->HasAttribute(kAttrStopGradients)) {
auto stop_gradients_attr = op->attribute(kAttrStopGradients)
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
stop_gradients[0].push_back(
stop_gradients_attr[0].dyn_cast<pir::BoolAttribute>().data());
stop_gradients[1].push_back(
stop_gradients_attr[1].dyn_cast<pir::BoolAttribute>().data());
VLOG(0) << " stop_gradients is set ";
} else {
stop_gradients[0].push_back(false);
stop_gradients[1].push_back(false);
VLOG(0) << " stop_gradients is not set ";
}

std::vector<std::vector<paddle::Tensor>> tensor_res;
for (auto arg : stop_gradients) {
tensor_res.push_back(std::vector<paddle::Tensor>(arg.size()));
}
std::string op_name = "matmul_grad";
FLAGS_tensor_operants_mode = "static";
VLOG(4) << "Call Pir Decomposed backward op matmul_grad";
paddle::Tensor* x_grad = !stop_gradients[0][0] ? &tensor_res[0][0] : nullptr;
paddle::Tensor* y_grad = !stop_gradients[1][0] ? &tensor_res[1][0] : nullptr;
paddle::primitive::details::matmul_grad<primitive::LazyTensor>(
x, y, out_grad, transpose_x, transpose_y, x_grad, y_grad);
std::vector<std::vector<pir::Value>> res(tensor_res.size());
for (size_t i = 0; i < tensor_res.size(); ++i) {
res[i].resize(tensor_res[i].size());
for (size_t j = 0; j < tensor_res[i].size(); ++j) {
if (tensor_res[i][j].defined()) {
res[i][j] = std::static_pointer_cast<primitive::LazyTensor>(
tensor_res[i][j].impl())
->value();
}
}
}
return res;
}

} // namespace dialect
} // namespace paddle
Loading