From 8b52115c4982c49620a5dcb09f39e1a45fda8a9e Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Wed, 29 May 2024 08:24:30 +0000 Subject: [PATCH] fix flash attn pass --- paddle/fluid/pir/transforms/gpu/fused_flash_attn_pass.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/transforms/gpu/fused_flash_attn_pass.cc b/paddle/fluid/pir/transforms/gpu/fused_flash_attn_pass.cc index 0cbc5e0bd93bce..be8202356036ff 100644 --- a/paddle/fluid/pir/transforms/gpu/fused_flash_attn_pass.cc +++ b/paddle/fluid/pir/transforms/gpu/fused_flash_attn_pass.cc @@ -139,7 +139,7 @@ class FlashAttnPatternQscaleWithMask : public paddle::drr::DrrPatternBase { } // mask's shape [bs, 1, seq_len, seq_len] auto mask_add = pir::GetShapeFromValue(match_ctx.Tensor("mask")); - if (mask_add.size() != 4 || mask_add.at(1) != 1) { + if (mask_add.size() != 4 || mask_add.at(1) != 1 || mask_add.at(0) != -1) { return false; } @@ -285,7 +285,7 @@ class FlashAttnPatternOutscaleWithMask : public paddle::drr::DrrPatternBase { } // mask's shape [bs, 1, seq_len, seq_len] auto mask_add = pir::GetShapeFromValue(match_ctx.Tensor("mask")); - if (mask_add.size() != 4 || mask_add.at(1) != 1) { + if (mask_add.size() != 4 || mask_add.at(1) != 1 || mask_add.at(0) != -1) { return false; } @@ -556,7 +556,7 @@ class TransposeSliceFlashAttnPattern : public paddle::drr::DrrPatternBase { } // mask's shape [bs, 1, seq_len, seq_len] auto mask_add = pir::GetShapeFromValue(match_ctx.Tensor("mask")); - if (mask_add.size() != 4 || mask_add.at(1) != 1) { + if (mask_add.size() != 4 || mask_add.at(1) != 1 || mask_add.at(0) != -1) { return false; }