Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions paddle/cinn/frontend/op_mappers/paddle/arg_min_max.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"
#include "paddle/cinn/frontend/var_type_utils.h"
#include "paddle/common/enforce.h"

namespace cinn {
namespace frontend {
Expand Down Expand Up @@ -47,24 +48,39 @@ Variable ArgImpl<ArgType::ArgMin>(NetBuilder* builder,
template <ArgType type>
void ArgOpMapperHelper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Argmax/Argmin op must be 1."));
auto x_name = op_desc.Input("X").front();

CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument(
"The output of Argmax/Argmin op must be 1."));
auto out_name = op_desc.Output("Out").front();

auto x = ctx.GetVar(x_name);
auto axis = utils::GetAttrOrDefault<int64_t>(op_desc, "axis", -1);
CHECK(op_desc.HasAttr("axis"))
<< "Argmax/Argmin op should has attribute \"axis\"! Please check.";
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("axis"),
true,
phi::errors::InvalidArgument("Argmax/Argmin op should has attribute "
"\"axis\"! Please check."));

auto keepdims = utils::GetAttrOrDefault<bool>(op_desc, "keepdims", false);
CHECK(op_desc.HasAttr("keepdims"))
<< "Argmax/Argmin op should has attribute \"keepdims\"! Please check.";
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("keepdims"),
true,
phi::errors::InvalidArgument("Argmax/Argmin op should has attribute"
" \"keepdims\"! Please check."));

auto flatten = utils::GetAttrOrDefault<bool>(op_desc, "flatten", false);
CHECK(op_desc.HasAttr("flatten"))
<< "Argmax/Argmin op should has attribute \"flatten\"! Please check.";
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("flatten"),
true,
phi::errors::InvalidArgument("Argmax/Argmin op should has attribute"
" \"flatten\"! Please check."));

auto dtype = utils::GetPaddleDtype(
op_desc, "dtype", paddle::cpp::VarDescAPI::Type::INT64);
Expand Down
17 changes: 13 additions & 4 deletions paddle/cinn/frontend/op_mappers/paddle/argsort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,29 @@
#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"
#include "paddle/cinn/utils/string.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

void ArgsortOpMapper(const paddle::cpp::OpDesc& op_desc,
const cinn::frontend::OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of Argmax/Argmin op must be 1."));
auto x_name = op_desc.Input("X").front();

CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument(
"The output of Argmax/Argmin op must be 1."));
auto out_name = op_desc.Output("Out").front();

CHECK_EQ(op_desc.Output("Indices").size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Output("Indices").size(),
1UL,
phi::errors::InvalidArgument(
"The output of Argmax/Argmin op must be 1."));
auto indices_name = op_desc.Output("Indices").front();

auto is_ascend =
Expand Down
17 changes: 13 additions & 4 deletions paddle/cinn/frontend/op_mappers/paddle/atan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,27 @@
#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"
#include "paddle/cinn/utils/string.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

void Atan2OpMapper(const paddle::cpp::OpDesc& op_desc,
const cinn::frontend::OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X1").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X1").size(),
1UL,
phi::errors::InvalidArgument("The input of Atan2 op must be 1."));
auto x1_name = op_desc.Input("X1").front();
CHECK_EQ(op_desc.Input("X2").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X2").size(),
1UL,
phi::errors::InvalidArgument("The input of Atan2 op must be 1."));
auto x2_name = op_desc.Input("X2").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of Atan2 op must be 1."));
auto out_name = op_desc.Output("Out").front();

auto x1 = ctx.GetVar(x1_name);
Expand Down
61 changes: 46 additions & 15 deletions paddle/cinn/frontend/op_mappers/paddle/batchnorm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {
Expand All @@ -29,7 +29,10 @@ void BatchNormOpMapper(const paddle::cpp::OpDesc& op_desc,
<< op_desc.Type();
return;
}
CHECK_EQ(op_desc.Output(pd_param_name).size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output(pd_param_name).size(),
1UL,
phi::errors::InvalidArgument("The output of batch_norm op must be 1."));
auto output_name = op_desc.Output(pd_param_name).front();

VLOG(4) << "The " << op_desc.Type() << "'s output " << pd_param_name
Expand All @@ -39,15 +42,30 @@ void BatchNormOpMapper(const paddle::cpp::OpDesc& op_desc,
ctx.AddVarModelToProgram(output_name, out->id, can_inplace);
};

CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of batch_norm op must be 1."));
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Input("Scale").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("Scale").size(),
1UL,
phi::errors::InvalidArgument("The input of batch_norm op must be 1."));
auto scale_name = op_desc.Input("Scale").front();
CHECK_EQ(op_desc.Input("Bias").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("Bias").size(),
1UL,
phi::errors::InvalidArgument("The input of batch_norm op must be 1."));
auto bias_name = op_desc.Input("Bias").front();
CHECK_EQ(op_desc.Input("Mean").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("Mean").size(),
1UL,
phi::errors::InvalidArgument("The input of batch_norm op must be 1."));
auto mean_name = op_desc.Input("Mean").front();
CHECK_EQ(op_desc.Input("Variance").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("Variance").size(),
1UL,
phi::errors::InvalidArgument("The input of batch_norm op must be 1."));
auto variance_name = op_desc.Input("Variance").front();

auto epsilon = utils::GetAttrOrDefault<float>(op_desc, "epsilon", 1e-5f);
Expand Down Expand Up @@ -105,8 +123,11 @@ void BatchNormOpMapper(const paddle::cpp::OpDesc& op_desc,
add_output("VarianceOut", variance_out, true);
} else {
VLOG(4) << "Invoke batch_norm OpMapper with train mode";
CHECK_EQ(outs.size(), 5U)
<< "batch_norm in train mode should only has 5 output! Please check.";
PADDLE_ENFORCE_EQ(outs.size(),
5U,
phi::errors::InvalidArgument(
"batch_norm in train mode should only has 5 output! "
"Please check."));

add_output("Y", outs[0]);
add_output("SavedMean", outs[1]);
Expand All @@ -122,7 +143,10 @@ void BatchNormGradOpMapper(const paddle::cpp::OpDesc& op_desc,
std::unordered_map<std::string, std::string> input_names_map;
auto get_input_var =
[&op_desc, &ctx, &input_names_map](const std::string& op_name) {
CHECK_EQ(op_desc.Input(op_name).size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Input(op_name).size(),
1UL,
phi::errors::InvalidArgument(
"The input of batch_norm_grad op must be 1."));
auto var_name = op_desc.Input(op_name).front();
input_names_map.emplace(op_name, var_name);
return ctx.GetVar(var_name);
Expand All @@ -132,12 +156,17 @@ void BatchNormGradOpMapper(const paddle::cpp::OpDesc& op_desc,
auto get_output_name =
[&op_desc, &output_names_map](const std::string& op_name) -> std::string {
if (op_desc.Output(op_name).empty()) {
CHECK_NE(op_name, paddle::GradVarName("X"))
<< "The input X should not empty.";
PADDLE_ENFORCE_NE(
op_name,
paddle::GradVarName("X"),
phi::errors::InvalidArgument("The input X should not empty."));
return "";
}

CHECK_EQ(op_desc.Output(op_name).size(), 1UL);
PADDLE_ENFORCE_EQ(op_desc.Output(op_name).size(),
1UL,
phi::errors::InvalidArgument(
"The output of batch_norm_grad op must be 1."));
auto var_name = op_desc.Output(op_name).front();
output_names_map.emplace(op_name, var_name);
return var_name;
Expand Down Expand Up @@ -174,8 +203,10 @@ void BatchNormGradOpMapper(const paddle::cpp::OpDesc& op_desc,
// batch norm grad, output(grad_x, grad_scale, grad_bias)
auto outs = ctx.Builder()->BatchNormGrad(
dy, x, scale, saved_mean, saved_variance, epsilon, data_layout);
CHECK_EQ(outs.size(), 3ul)
<< "batch_norm_grad APIs should return 3 Variable!";
PADDLE_ENFORCE_EQ(outs.size(),
3ul,
phi::errors::InvalidArgument(
"batch_norm_grad APIs should return 3 Variable!"));

for (int i = 0; i < outs.size(); i++) {
if (output_names[i].empty()) {
Expand Down
39 changes: 24 additions & 15 deletions paddle/cinn/frontend/op_mappers/paddle/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,34 @@

#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

#define BINARY_OPMAPPER_FUNCTION(OP_NAME) \
void OP_NAME##OpMapper(const paddle::cpp::OpDesc& op_desc, \
const OpMapperContext& ctx) { \
CHECK_EQ(op_desc.Input("X").size(), 1UL); \
auto x_name = op_desc.Input("X").front(); \
CHECK_EQ(op_desc.Input("Y").size(), 1UL); \
auto y_name = op_desc.Input("Y").front(); \
CHECK_EQ(op_desc.Output("Out").size(), 1UL); \
auto out_name = op_desc.Output("Out").front(); \
auto x = ctx.GetVar(x_name); \
auto y = ctx.GetVar(y_name); \
auto out = ctx.Builder()->OP_NAME(x, y); \
ctx.AddVar(out_name, out); \
ctx.AddVarModelToProgram(out_name, out->id); \
#define BINARY_OPMAPPER_FUNCTION(OP_NAME) \
void OP_NAME##OpMapper(const paddle::cpp::OpDesc& op_desc, \
const OpMapperContext& ctx) { \
PADDLE_ENFORCE_EQ( \
op_desc.Input("X").size(), \
1UL, \
phi::errors::InvalidArgument("The input of op must be 1.")); \
auto x_name = op_desc.Input("X").front(); \
PADDLE_ENFORCE_EQ( \
op_desc.Input("Y").size(), \
1UL, \
phi::errors::InvalidArgument("The input of op must be 1.")); \
auto y_name = op_desc.Input("Y").front(); \
PADDLE_ENFORCE_EQ( \
op_desc.Output("Out").size(), \
1UL, \
phi::errors::InvalidArgument("The output of op must be 1.")); \
auto out_name = op_desc.Output("Out").front(); \
auto x = ctx.GetVar(x_name); \
auto y = ctx.GetVar(y_name); \
auto out = ctx.Builder()->OP_NAME(x, y); \
ctx.AddVar(out_name, out); \
ctx.AddVarModelToProgram(out_name, out->id); \
}

BINARY_OPMAPPER_FUNCTION(LogicalAnd)
Expand Down
12 changes: 9 additions & 3 deletions paddle/cinn/frontend/op_mappers/paddle/cholesky.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,22 @@

#include "paddle/cinn/frontend/op_mapper_registry.h"
#include "paddle/cinn/frontend/op_mappers/common_utils.h"

#include "paddle/common/enforce.h"
namespace cinn {
namespace frontend {
namespace paddle_mappers {

void CholeskyOpMapper(const paddle::cpp::OpDesc& op_desc,
const OpMapperContext& ctx) {
CHECK_EQ(op_desc.Input("X").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Input("X").size(),
1UL,
phi::errors::InvalidArgument("The input of cholesky op must be 1."));
auto x_name = op_desc.Input("X").front();
CHECK_EQ(op_desc.Output("Out").size(), 1UL);
PADDLE_ENFORCE_EQ(
op_desc.Output("Out").size(),
1UL,
phi::errors::InvalidArgument("The output of cholesky op must be 1."));
auto out_name = op_desc.Output("Out").front();

auto upper = utils::GetAttrOrDefault<bool>(op_desc, "upper", false);
Expand Down
Loading