Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 29 additions & 29 deletions paddle/fluid/framework/ir/fused_attention_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,23 +123,23 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
fuse_qkv_split_out_v_node});

// core attention pattern
auto* qk_scale_node =
pattern->NewNode(qk_scale_op_repr())->assert_is_op("scale");
auto* qk_scale_out_node =
pattern->NewNode(qk_scale_out_repr())->assert_is_op_output("scale");
fuse_qkv_split_out_q_node->assert_is_op_input("scale", "X");
qk_scale_node->LinksFrom({fuse_qkv_split_out_q_node})
.LinksTo({qk_scale_out_node});

auto* qk_matmul_node =
pattern->NewNode(qk_matmul_op_repr())->assert_is_op("matmul_v2");
auto* qk_matmul_out_node =
pattern->NewNode(qk_matmul_out_repr())->assert_is_op_output("matmul_v2");
fuse_qkv_split_out_q_node->assert_is_op_input("matmul_v2", "X");
qk_scale_out_node->assert_is_op_input("matmul_v2", "X");
fuse_qkv_split_out_k_node->assert_is_op_input("matmul_v2", "Y");
qk_matmul_node
->LinksFrom({fuse_qkv_split_out_q_node, fuse_qkv_split_out_k_node})
qk_matmul_node->LinksFrom({qk_scale_out_node, fuse_qkv_split_out_k_node})
.LinksTo({qk_matmul_out_node});

auto* qk_scale_node =
pattern->NewNode(qk_scale_op_repr())->assert_is_op("scale");
auto* qk_scale_out_node =
pattern->NewNode(qk_scale_out_repr())->assert_is_op_output("scale");
qk_matmul_out_node->assert_is_op_input("scale", "X");
qk_scale_node->LinksFrom({qk_matmul_out_node}).LinksTo({qk_scale_out_node});

PDNode* add_mask_ele_add_out_node{nullptr};
if (has_attn_mask) {
auto* add_mask_ele_add_node = pattern->NewNode(add_mask_ele_add_op_repr())
Expand All @@ -149,9 +149,9 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
->assert_is_op_input("elementwise_add", "Y");
add_mask_ele_add_out_node = pattern->NewNode(add_mask_ele_add_out_repr())
->assert_is_op_output("elementwise_add");
qk_scale_out_node->assert_is_op_input("elementwise_add", "X");
qk_matmul_out_node->assert_is_op_input("elementwise_add", "X");
add_mask_ele_add_node
->LinksFrom({qk_scale_out_node, add_mask_ele_add_mask_node})
->LinksFrom({qk_matmul_out_node, add_mask_ele_add_mask_node})
.LinksTo({add_mask_ele_add_out_node});
}

Expand All @@ -164,8 +164,8 @@ PDNode* FusedAttentionPattern::operator()(PDNode* x,
qk_softmax_node->LinksFrom({add_mask_ele_add_out_node})
.LinksTo({qk_softmax_out_node});
} else {
qk_scale_out_node->assert_is_op_input("softmax", "X");
qk_softmax_node->LinksFrom({qk_scale_out_node})
qk_matmul_out_node->assert_is_op_input("softmax", "X");
qk_softmax_node->LinksFrom({qk_matmul_out_node})
.LinksTo({qk_softmax_out_node});
}

Expand Down Expand Up @@ -575,16 +575,8 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
.LinksTo({add_mask_ele_add_grad_x_grad_node});
}

PDNode* qk_scale_grad_input_node =
PDNode* qk_matmul_grad_input_node =
has_attn_mask ? add_mask_ele_add_grad_x_grad_node : qk_softmax_grad_out;
auto* qk_scale_grad_node =
pattern->NewNode(qk_scale_grad_op_repr())->assert_is_op("scale");
auto* qk_scale_grad_out_node =
pattern->NewNode(qk_scale_grad_out_repr())->assert_is_op_output("scale");
qk_scale_grad_input_node->assert_is_op_input("scale", "X");
qk_scale_grad_node->LinksFrom({qk_scale_grad_input_node})
.LinksTo({qk_scale_grad_out_node});

auto* qk_matmul_grad_node = pattern->NewNode(qk_matmul_grad_op_repr())
->assert_is_op("matmul_v2_grad");
auto* qk_matmul_grad_x_node = pattern->NewNode(qk_matmul_grad_x_repr())
Expand All @@ -597,24 +589,32 @@ PDNode* FusedAttentionGradPattern::operator()(PDNode* x,
auto* qk_matmul_grad_w_grad_node =
pattern->NewNode(qk_matmul_grad_w_grad_repr())
->assert_is_op_output("matmul_v2_grad", "Y@GRAD");
qk_scale_grad_out_node->assert_is_op_input("matmul_v2_grad", "Out@GRAD");
qk_matmul_grad_input_node->assert_is_op_input("matmul_v2_grad", "Out@GRAD");
qk_matmul_grad_node
->LinksFrom({qk_scale_grad_out_node,
->LinksFrom({qk_matmul_grad_input_node,
qk_matmul_grad_x_node,
qk_matmul_grad_w_node})
.LinksTo({qk_matmul_grad_x_grad_node, qk_matmul_grad_w_grad_node});

auto* qk_scale_grad_node =
pattern->NewNode(qk_scale_grad_op_repr())->assert_is_op("scale");
auto* qk_scale_grad_out_node =
pattern->NewNode(qk_scale_grad_out_repr())->assert_is_op_output("scale");
qk_matmul_grad_x_grad_node->assert_is_op_input("scale", "X");
qk_scale_grad_node->LinksFrom({qk_matmul_grad_x_grad_node})
.LinksTo({qk_scale_grad_out_node});

// fuse qkv projection
auto* fuse_qkv_split_grad_node =
pattern->NewNode(fuse_qkv_split_grad_op_repr())->assert_is_op("concat");
auto* fuse_qkv_split_grad_out_node =
pattern->NewNode(fuse_qkv_split_grad_out_repr())
->assert_is_op_output("concat");
qk_matmul_grad_x_grad_node->assert_is_op_input("concat"); // q grad
qk_scale_grad_out_node->assert_is_op_input("concat"); // q grad
qk_matmul_grad_w_grad_node->assert_is_op_input("concat"); // k grad
qkv_matmul_grad_w_grad_node->assert_is_op_input("concat"); // v grad
fuse_qkv_split_grad_node
->LinksFrom({qk_matmul_grad_x_grad_node,
->LinksFrom({qk_scale_grad_out_node,
qk_matmul_grad_w_grad_node,
qkv_matmul_grad_w_grad_node})
.LinksTo({fuse_qkv_split_grad_out_node});
Expand Down Expand Up @@ -894,7 +894,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResFwd(
fused_attention_op_desc.SetAttr("transpose_qkv_wb", true);
std::vector<int> shape = PADDLE_GET_CONST(
std::vector<int>, fuse_qkv_reshape_op_node->Op()->GetAttr("shape"));
fused_attention_op_desc.SetAttr("num_heads", shape[2]);
fused_attention_op_desc.SetAttr("num_heads", shape[2] / 3);
GET_IR_NODE_FROM_SUBGRAPH(
fuse_qkv_matmul_out_node, fuse_qkv_matmul_out, fused_attention_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fuse_qkv_ele_add_bias_node,
Expand Down Expand Up @@ -1337,7 +1337,7 @@ ir::Graph* FusedAttentionsPass::PreMaskDropResBwd(
std::vector<int> shape =
PADDLE_GET_CONST(std::vector<int>,
fuse_qkv_reshape_grad_op_node->Op()->GetAttr("shape"));
fused_attention_grad_op_desc.SetAttr("num_heads", shape[2]);
fused_attention_grad_op_desc.SetAttr("num_heads", shape[2] / 3);
fused_attention_grad_op_desc.SetAttr("pre_layer_norm", true);
fused_attention_grad_op_desc.SetAttr("transpose_qkv_wb", true);

Expand Down
70 changes: 42 additions & 28 deletions python/paddle/fluid/tests/unittests/test_fused_attention_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(

self.qkv_proj = paddle.nn.Linear(embed_dim, 3 * embed_dim)
self.out_proj = paddle.nn.Linear(embed_dim, embed_dim)
self.dropout = paddle.nn.Dropout(0.1, mode="upscale_in_train")
self.dropout = paddle.nn.Dropout(1e-10, mode="upscale_in_train")

def forward(self, x, attn_mask=None):
residual = x
Expand All @@ -64,13 +64,13 @@ def forward(self, x, attn_mask=None):

# compute qkv
qkv = self.qkv_proj(x)
qkv = paddle.reshape(qkv, [0, 0, self.num_heads, 3 * self.head_dim])
qkv = paddle.reshape(qkv, [0, 0, 3 * self.num_heads, self.head_dim])
qkv = paddle.transpose(qkv, [0, 2, 1, 3])
q, k, v = paddle.split(qkv, num_or_sections=3, axis=-1)
q, k, v = paddle.split(qkv, num_or_sections=3, axis=1)

# compute core attention
q = paddle.scale(q, scale=self.head_dim**-0.5)
product = paddle.matmul(x=q, y=k, transpose_y=True)
product = paddle.scale(product, scale=self.head_dim**-0.5)
if attn_mask is not None:
product = product + attn_mask
weights = F.softmax(product)
Expand Down Expand Up @@ -104,21 +104,28 @@ def setUp(self):
self.pre_ln = True
self.attn_dropout = True
self.add_mask = True
self.x_data = None
self.mask_data = None

def test_pass(self):
def get_rst(self, use_pass=False):
batch_size = 2
seq_len = 1024
hidden_size = 768
num_heads = 12

x_data = np.random.rand(batch_size, seq_len, seq_len).astype('float32')
mask_data = np.random.rand(
batch_size, num_heads, seq_len, seq_len
).astype('float32')
np.random.seed(1234)
if self.x_data is None:
self.x_data = np.random.rand(batch_size, seq_len, seq_len).astype(
'float32'
)
self.mask_data = np.random.rand(
batch_size, num_heads, seq_len, seq_len
).astype('float32')

main_prog = paddle.static.Program()
main_prog.random_seed = 1234
startup_prog = paddle.static.Program()
startup_prog.random_seed = 1234

with paddle.static.program_guard(main_prog, startup_prog):
data = paddle.static.data(
Expand Down Expand Up @@ -150,29 +157,36 @@ def test_pass(self):
sgd_optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001)
sgd_optimizer.minimize(loss)

pass_manager = PassManager([new_pass("fused_attention")])
pass_manager.apply([main_prog], [startup_prog])

ops = main_prog.global_block().ops
assert ops[2].type == 'fused_attention'
assert ops[3].type == 'reduce_mean'
assert ops[5].type == 'reduce_mean_grad'
assert ops[6].type == 'fused_attention_grad'
# two ops for linear, one op for reduce mean
# one fill constant
# one op for reduce mean grad, two ops for linear bwd
# the eighth op should be the optimizer
assert ops[9].type == 'sgd'
if use_pass:
pass_manager = PassManager([new_pass("fused_attention")])
pass_manager.apply([main_prog], [startup_prog])

ops = main_prog.global_block().ops
assert ops[2].type == 'fused_attention'
assert ops[3].type == 'reduce_mean'
assert ops[5].type == 'reduce_mean_grad'
assert ops[6].type == 'fused_attention_grad'
# two ops for linear, one op for reduce mean
# one fill constant
# one op for reduce mean grad, two ops for linear bwd
# the eighth op should be the optimizer
assert ops[9].type == 'sgd'

exe = paddle.static.Executor()
exe.run(startup_prog)
rst = exe.run(
main_prog,
feed={'x': x_data, 'attn_mask': mask_data},
fetch_list=[loss],
)
for i in range(2):
rst = exe.run(
main_prog,
feed={'x': self.x_data, 'attn_mask': self.mask_data},
fetch_list=[loss],
)
return rst

def test_pass(self):
fused_rst = self.get_rst(use_pass=True)
non_fused_rst = self.get_rst()
assert np.allclose(fused_rst, non_fused_rst)


if __name__ == "__main__":
np.random.seed(0)
unittest.main()