Skip to content

Conversation

@0x45f
Copy link
Contributor

@0x45f 0x45f commented Jan 31, 2024

PR types

Others

PR changes

Others

Description

Pcard-67164

本PR是PIR下支持AMP功能的第二个PR,本PR中主要完成了PIR API内部AMP代码逻辑的自动生成。

  • 以matmul为例,生成的AMP逻辑代码如下所示:
pir::Value matmul(const pir::Value& x, const pir::Value& y, bool transpose_x,
                  bool transpose_y) {
  // AMP Logic
  if (egr::Controller::Instance().GetCurrentAMPState()->GetAmpLevel() !=
      paddle::imperative::AmpLevel::O0) {
    VLOG(5) << "Check and Prepare For AMP";
    auto op_name = phi::TransToFluidOpName("matmul");
    std::vector<std::vector<pir::Value>> amp_values_vector = {{x}, {y}};
    auto amp_dst_dtype =
        paddle::dialect::GetAmpDestDtype("matmul", amp_values_vector);
    auto new_x =
        paddle::dialect::PirAmpAutoCast("x", x, amp_dst_dtype, op_name);
    auto new_y =
        paddle::dialect::PirAmpAutoCast("y", y, amp_dst_dtype, op_name);

    {
      paddle::imperative::AutoCastGuard guard(
          egr::Controller::Instance().GetCurrentAMPState(),
          paddle::imperative::AmpLevel::O0);
      return paddle::dialect::matmul(new_x, new_y, transpose_x, transpose_y);
    }
  }

  CheckValueDataType(y, "y", "matmul");
  paddle::dialect::MatmulOp matmul_op =
      ApiBuilder::Instance().GetBuilder()->Build<paddle::dialect::MatmulOp>(
          x, y, transpose_x, transpose_y);
  return matmul_op.result(0);
}
  • 单测中在O1模式下测试了linear+mean的组网,program如下所示。matmul是fp16类型,add和mean是fp32类型,行为上和动态图对齐。
{
 (%0) = "builtin.parameter" () {is_persisable:[true],parameter_name:"linear_0.b_0",stop_gradient:[false]} : () -> pd_op.tensor<5xf32>
 (%1) = "builtin.parameter" () {is_persisable:[true],parameter_name:"linear_0.w_0",stop_gradient:[false]} : () -> pd_op.tensor<4x5xf32>
 (%2) = "pd_op.data" () {dtype:(pd_op.DataType)float32,name:"x",place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[3,4],stop_gradient:[true]} : () -> pd_op.tensor<3x4xf32>
 (%3) = "pd_op.cast" (%2) {dtype:(pd_op.DataType)float16,stop_gradient:[true]} : (pd_op.tensor<3x4xf32>) -> pd_op.tensor<3x4xf16>
 (%4) = "pd_op.cast" (%1) {dtype:(pd_op.DataType)float16,stop_gradient:[false]} : (pd_op.tensor<4x5xf32>) -> pd_op.tensor<4x5xf16>
 (%5) = "pd_op.matmul" (%3, %4) {stop_gradient:[false],transpose_x:false,transpose_y:false} : (pd_op.tensor<3x4xf16>, pd_op.tensor<4x5xf16>) -> pd_op.tensor<3x5xf16>
 (%6) = "pd_op.cast" (%5) {dtype:(pd_op.DataType)float32,stop_gradient:[false]} : (pd_op.tensor<3x5xf16>) -> pd_op.tensor<3x5xf32>
 (%7) = "pd_op.add" (%6, %0) {stop_gradient:[false]} : (pd_op.tensor<3x5xf32>, pd_op.tensor<5xf32>) -> pd_op.tensor<3x5xf32>
 (%8) = "pd_op.mean" (%7) {axis:(pd_op.IntArray)[],keepdim:false,stop_gradient:[false]} : (pd_op.tensor<3x5xf32>) -> pd_op.tensor<f32>
}
  • 新增的PIR下的GetPromoteType、Cast、NeedCast、PirAmpAutoCast、GetAmpDestDtype等函数和动态图下的对应函数的逻辑是相同的(除去动态图下place相关的判断逻辑)

@0x45f 0x45f changed the title Gen pir amp code [PIR AMP]Gen AMP logic code in PIR APIs Jan 31, 2024
namespace paddle {
namespace dialect {

phi::DataType GetPromoteType(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件中很多逻辑和eager下的amp_utils.h处理逻辑是一样的,而且看起来是对每个op都有对应的逻辑,是否可以复用一下代码,做到一处修改静态图和动态图都生效?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以考虑复用下,可能需要把eager下的公共逻辑抽离出来,放到单独的某个目录,eager和Pir分别调用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为动态图和PIR下相关函数的输入类型不太一样,比如tensor和value,所以这里没有复用代码。不太确定使用模板是否能够达到复用一份代码的目的,后面我单独提个PR来尝试一下。包括杰哥提的修改的意见,后面我一并修改下~

const phi::DataType& amp_dtype) {
auto dst_type = amp_dtype;
// only consider the dtype of input(X).
if (op_name == "batch_norm" || op_name == "layer_norm" ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以在文件的加一个匿名namespace,里面有一个const 的 unordered_set,代替这里的逻辑。这样此处的两个 if就可以用 & join起来了

line 28行也可以一起优化:

const auto& HandleSpecicalOp = [&](){....};
HandleSpecicalOp();

if (egr::Controller::Instance().GetCurrentAMPState()->GetAmpDtype() ==
"float16") {
if (op_name == "fused_attention") {
for (size_t i = 0; i < amp_values_vector.size(); i++) {
Copy link
Contributor

@Aurelius84 Aurelius84 Feb 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的for代码可以抽离出来一个lambda函数,放到 38行:

const auto& HandleFuseAttention = [&](){...};

"float16") {
if (op_name == "fused_attention") {
for (size_t i = 0; i < amp_values_vector.size(); i++) {
if (i != 3 || i != 4 || i != 9 || i != 10) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码中应该尽量避免magic mumber,这里的3,4,9,10需要在lambda函数里定义为变量,比如

const size_t xxx_index = 3;

或者定义一个:

 const unorder_set<size_t>  skip_value_indexs = {/*xxx_index=*/ 3, ...}

}
}
} else if (op_name == "fused_feedforward") {
for (size_t i = 0; i < amp_values_vector.size(); i++) {
Copy link
Contributor

@Aurelius84 Aurelius84 Feb 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上。另外应该将size操作放到循环外面,避免O(N)的调用,比如

const size_t value_length = amp_values_vector.size();

<< " input(" << input_name << " to dst_dtype("
<< phi::DataTypeToString(dst_dtype) << ").";
if ((op_name == "batch_norm" || op_name == "layer_norm" ||
op_name == "sync_batch_norm" || op_name == "weight_only_linear") &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

}
if ((op_name == "fused_attention" || op_name == "fused_feedforward")) {
if (input_name == "LnScale" || input_name == "LnBias" ||
input_name == "Ln2Scale" || input_name == "Ln2Bias" ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以定义一个const set 在匿名空间

}

if (use_promote) {
if (paddle::imperative::AmpOperators::Instance()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最外层的 if-else 两个函数可以抽离一个lambda函数,减少嵌套

Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment的优化思路,需要单独PR fix

Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@0x45f 0x45f merged commit 915bf2f into PaddlePaddle:develop Feb 1, 2024
@0x45f 0x45f deleted the gen-pir-amp-code branch February 1, 2024 11:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants