Skip to content

Commit ad90acb

Browse files
authored
Merge branch 'develop' into 9_19_llama
2 parents 6519b33 + fcddea6 commit ad90acb

File tree

3 files changed

+113
-7
lines changed

3 files changed

+113
-7
lines changed

backends/npu/custom_op/llama_process_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ void fused_get_rotary_embedding(const int64_t* position_ids,
607607
}
608608
}
609609

610-
std::vector<std::vector<int64_t>> GetRoPEInferShape(const std::vector<int64_t>& head_dim_shape_tensor_shape,
610+
std::vector<std::vector<int64_t>> GetRoPEInferShape(const std::vector<int64_t>& head_dim_shape_tensor_shape,
611611
const std::vector<int64_t>& input_ids_shape,
612612
const std::vector<int64_t>& position_ids_shape) {
613613
const int64_t batch_size = position_ids_shape[0];
@@ -617,8 +617,8 @@ std::vector<std::vector<int64_t>> GetRoPEInferShape(const std::vector<int64_t>&
617617
return {out_shape};
618618
}
619619

620-
std::vector<paddle::DataType> GetRoPEInferDtype(const paddle::DataType& head_dim_shape_tensor_dtype,
621-
const paddle::DataType& input_ids_dtype,
620+
std::vector<paddle::DataType> GetRoPEInferDtype(const paddle::DataType& head_dim_shape_tensor_dtype,
621+
const paddle::DataType& input_ids_dtype,
622622
const paddle::DataType& position_ids_dtype) {
623623
// RoPE output dtype is Float.
624624
return {paddle::DataType::FLOAT32};

backends/npu/passes/common.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def register_pass(pass_builder, pass_name):
2020
paddle.base.core.register_subgraph_pass(pass_name)
2121

2222
def addPasses(pass_builder, model_type):
23-
if model_type == "llama_mp8_dynamic_batch":
23+
if model_type == "llama7B_mp8_dynamic_batch":
2424
register_pass(pass_builder, "llama_fuse_attention_dynamic_parallel_layer1")
2525
register_pass(pass_builder, "llama_fuse_attention_dynamic_parallel_layer2")
2626
register_pass(pass_builder, "llama_fuse_attention_dynamic_first_parallel_layer")
@@ -30,6 +30,17 @@ def addPasses(pass_builder, model_type):
3030
register_pass(pass_builder, "remove_get_padding_offset")
3131
register_pass(pass_builder, "remove_get_token_penalty_multi_scores")
3232
register_pass(pass_builder, "llama_layer_tail")
33-
33+
34+
elif model_type == "llama65B_mp8_dynamic_batch":
35+
register_pass(pass_builder, "llama_fuse_attention_dynamic_parallel_layer1")
36+
register_pass(pass_builder, "llama_fuse_attention_dynamic_parallel_layer2")
37+
register_pass(pass_builder, "llama65B_fuse_attention_dynamic_first_parallel_layer")
38+
register_pass(pass_builder, "llama65B_fuse_attention_dynamic_parallel_layer")
39+
register_pass(pass_builder, "remove_fused_bias_residual_layernorm")
40+
register_pass(pass_builder, "remove_rebuild_padding")
41+
register_pass(pass_builder, "remove_get_padding_offset")
42+
register_pass(pass_builder, "remove_get_token_penalty_multi_scores")
43+
register_pass(pass_builder, "llama_layer_tail")
44+
3445
else:
3546
print("NPU pass not support")

backends/npu/passes/llama_pass.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,102 @@ def replace(x, input_ids, padding_offset, seq_len_encoder, cache_kv, mask, rotar
721721

722722
return pattern, replace
723723

724+
@ir.RegisterPass
725+
def llama65B_fuse_attention_dynamic_parallel_layer():
726+
def pattern(x, residual, input_ids, padding_offset, seq_len_encoder, cache_kv, mask, rotary_emb, ln_scale, qkv_weight, out_proj_weight, ffn_in_scale, ffn1_weight, ffn2_weight):
727+
rms_norm_0 = ir.PassDesc.OP.rms_norm(norm_weight=ln_scale, residual=residual, x=x)
728+
qkv = ir.PassDesc.OP.matmul_v2(X=rms_norm_0.Output("out"), Y=qkv_weight)
729+
qkv_split = ir.PassDesc.OP.qkv_transpose_split(input_ids=input_ids, padding_offset=padding_offset, qkv=qkv, seq_lens=seq_len_encoder)
730+
q = qkv_split.Output("q_out")[0]
731+
k = qkv_split.Output("k_out")[0]
732+
v = qkv_split.Output("v_out")[0]
733+
scale1 = ir.PassDesc.OP.scale(X=seq_len_encoder)
734+
write_cache_kv = ir.PassDesc.OP.write_cache_kv(cache_kv=cache_kv, input_k=k, input_v=v, sequence_lengths=scale1)
735+
scale2 = ir.PassDesc.OP.scale(X=seq_len_encoder)
736+
attention = ir.PassDesc.OP.variable_length_memory_efficient_attention(key=k, kv_seq_lens=scale2, mask=mask, query=q, seq_lens=seq_len_encoder, value=v)
737+
738+
transpose_remove_padding = ir.PassDesc.OP.transpose_remove_padding(input=attention, padding_offset=padding_offset, seq_lens=seq_len_encoder)
739+
matmul_0 = ir.PassDesc.OP.matmul_v2(X=transpose_remove_padding, Y=out_proj_weight)
740+
741+
allreduce = ir.PassDesc.OP.c_allreduce_sum(X=matmul_0)
742+
743+
rms_norm_1 = ir.PassDesc.OP.rms_norm(norm_weight=ffn_in_scale, residual=rms_norm_0.Output("residual_out")[0], x=allreduce)
744+
matmul_1 = ir.PassDesc.OP.matmul_v2(X=rms_norm_1.Output("out"), Y=ffn1_weight)
745+
fused_bias_act = ir.PassDesc.OP.fused_bias_act(x=matmul_1)
746+
747+
matmul_2 = ir.PassDesc.OP.matmul_v2(X=fused_bias_act, Y=ffn2_weight)
748+
hidden = ir.PassDesc.OP.c_allreduce_sum(X=matmul_2)
749+
residual_out = rms_norm_1.Output("residual_out")[0]
750+
751+
encode_rotary_qk = ir.PassDesc.OP.encode_rotary_qk(kv=k, q=q, rotary_emb=rotary_emb, seq_lens=seq_len_encoder)
752+
rotary_kv_out = encode_rotary_qk.Output("rotary_kv_out")[0]
753+
rotary_q_out = encode_rotary_qk.Output("rotary_q_out")[0]
754+
755+
return write_cache_kv, q, k, v, hidden, residual_out, rotary_kv_out, rotary_q_out
756+
757+
def replace(x, residual, input_ids, padding_offset, seq_len_encoder, cache_kv, mask, rotary_emb, ln_scale, qkv_weight, out_proj_weight, ffn_in_scale, ffn1_weight, ffn2_weight):
758+
llama_layer = llama_paralle_layer_adaptor(x, residual, input_ids, padding_offset, seq_len_encoder, cache_kv, mask, rotary_emb, ln_scale, qkv_weight, out_proj_weight, ffn_in_scale, ffn1_weight, ffn2_weight)
759+
760+
return (llama_layer[3],
761+
llama_layer[4],
762+
llama_layer[5],
763+
llama_layer[6],
764+
llama_layer[0],
765+
llama_layer[7],
766+
llama_layer[1],
767+
llama_layer[2])
768+
769+
return pattern, replace
770+
771+
772+
@ir.RegisterPass
773+
def llama65B_fuse_attention_dynamic_first_parallel_layer():
774+
def pattern(x, input_ids, padding_offset, seq_len_encoder, cache_kv, mask, rotary_emb, ln_scale, qkv_weight, out_proj_weight, ffn_in_scale, ffn1_weight, ffn2_weight):
775+
rms_norm_0 = ir.PassDesc.OP.rms_norm(norm_weight=ln_scale, x=x)
776+
qkv = ir.PassDesc.OP.matmul_v2(X=rms_norm_0.Output("out"), Y=qkv_weight)
777+
qkv_split = ir.PassDesc.OP.qkv_transpose_split(input_ids=input_ids, padding_offset=padding_offset, qkv=qkv, seq_lens=seq_len_encoder)
778+
q = qkv_split.Output("q_out")[0]
779+
k = qkv_split.Output("k_out")[0]
780+
v = qkv_split.Output("v_out")[0]
781+
scale1 = ir.PassDesc.OP.scale(X=seq_len_encoder)
782+
write_cache_kv = ir.PassDesc.OP.write_cache_kv(cache_kv=cache_kv, input_k=k, input_v=v, sequence_lengths=scale1)
783+
scale2 = ir.PassDesc.OP.scale(X=seq_len_encoder)
784+
attention = ir.PassDesc.OP.variable_length_memory_efficient_attention(key=k, kv_seq_lens=scale2, mask=mask, query=q, seq_lens=seq_len_encoder, value=v)
785+
786+
transpose_remove_padding = ir.PassDesc.OP.transpose_remove_padding(input=attention, padding_offset=padding_offset, seq_lens=seq_len_encoder)
787+
matmul_0 = ir.PassDesc.OP.matmul_v2(X=transpose_remove_padding, Y=out_proj_weight)
788+
789+
allreduce = ir.PassDesc.OP.c_allreduce_sum(X=matmul_0)
790+
791+
rms_norm_1 = ir.PassDesc.OP.rms_norm(norm_weight=ffn_in_scale, residual=x, x=allreduce)
792+
matmul_1 = ir.PassDesc.OP.matmul_v2(X=rms_norm_1.Output("out"), Y=ffn1_weight)
793+
fused_bias_act = ir.PassDesc.OP.fused_bias_act(x=matmul_1)
794+
795+
matmul_2 = ir.PassDesc.OP.matmul_v2(X=fused_bias_act, Y=ffn2_weight)
796+
hidden = ir.PassDesc.OP.c_allreduce_sum(X=matmul_2)
797+
residual_out = rms_norm_1.Output("residual_out")[0]
798+
799+
encode_rotary_qk = ir.PassDesc.OP.encode_rotary_qk(kv=k, q=q, rotary_emb=rotary_emb, seq_lens=seq_len_encoder)
800+
rotary_kv_out = encode_rotary_qk.Output("rotary_kv_out")[0]
801+
rotary_q_out = encode_rotary_qk.Output("rotary_q_out")[0]
802+
803+
return write_cache_kv, q, k, v, hidden, residual_out, rotary_kv_out, rotary_q_out
804+
805+
def replace(x, input_ids, padding_offset, seq_len_encoder, cache_kv, mask, rotary_emb, ln_scale, qkv_weight, out_proj_weight, ffn_in_scale, ffn1_weight, ffn2_weight):
806+
llama_layer = llama_paralle_layer_adaptor(x, None, input_ids, padding_offset, seq_len_encoder, cache_kv, mask, rotary_emb, ln_scale, qkv_weight, out_proj_weight, ffn_in_scale, ffn1_weight, ffn2_weight)
807+
808+
return (llama_layer[3],
809+
llama_layer[4],
810+
llama_layer[5],
811+
llama_layer[6],
812+
llama_layer[0],
813+
llama_layer[7],
814+
llama_layer[1],
815+
llama_layer[2])
816+
817+
return pattern, replace
818+
819+
724820
@ir.RegisterPass
725821
def llama_layer_tail():
726822
def pattern(x, norm_weight):
@@ -729,5 +825,4 @@ def pattern(x, norm_weight):
729825

730826
def replace(x, norm_weight):
731827
norm = ir.PassDesc.OP.llama_lmhead(Hidden=x, NormWeight=norm_weight)
732-
return norm
733-
return pattern, replace
828+
return norm

0 commit comments

Comments
 (0)