@@ -721,6 +721,102 @@ def replace(x, input_ids, padding_offset, seq_len_encoder, cache_kv, mask, rotar
721721
722722 return pattern , replace
723723
724+ @ir .RegisterPass
725+ def llama65B_fuse_attention_dynamic_parallel_layer ():
726+ def pattern (x , residual , input_ids , padding_offset , seq_len_encoder , cache_kv , mask , rotary_emb , ln_scale , qkv_weight , out_proj_weight , ffn_in_scale , ffn1_weight , ffn2_weight ):
727+ rms_norm_0 = ir .PassDesc .OP .rms_norm (norm_weight = ln_scale , residual = residual , x = x )
728+ qkv = ir .PassDesc .OP .matmul_v2 (X = rms_norm_0 .Output ("out" ), Y = qkv_weight )
729+ qkv_split = ir .PassDesc .OP .qkv_transpose_split (input_ids = input_ids , padding_offset = padding_offset , qkv = qkv , seq_lens = seq_len_encoder )
730+ q = qkv_split .Output ("q_out" )[0 ]
731+ k = qkv_split .Output ("k_out" )[0 ]
732+ v = qkv_split .Output ("v_out" )[0 ]
733+ scale1 = ir .PassDesc .OP .scale (X = seq_len_encoder )
734+ write_cache_kv = ir .PassDesc .OP .write_cache_kv (cache_kv = cache_kv , input_k = k , input_v = v , sequence_lengths = scale1 )
735+ scale2 = ir .PassDesc .OP .scale (X = seq_len_encoder )
736+ attention = ir .PassDesc .OP .variable_length_memory_efficient_attention (key = k , kv_seq_lens = scale2 , mask = mask , query = q , seq_lens = seq_len_encoder , value = v )
737+
738+ transpose_remove_padding = ir .PassDesc .OP .transpose_remove_padding (input = attention , padding_offset = padding_offset , seq_lens = seq_len_encoder )
739+ matmul_0 = ir .PassDesc .OP .matmul_v2 (X = transpose_remove_padding , Y = out_proj_weight )
740+
741+ allreduce = ir .PassDesc .OP .c_allreduce_sum (X = matmul_0 )
742+
743+ rms_norm_1 = ir .PassDesc .OP .rms_norm (norm_weight = ffn_in_scale , residual = rms_norm_0 .Output ("residual_out" )[0 ], x = allreduce )
744+ matmul_1 = ir .PassDesc .OP .matmul_v2 (X = rms_norm_1 .Output ("out" ), Y = ffn1_weight )
745+ fused_bias_act = ir .PassDesc .OP .fused_bias_act (x = matmul_1 )
746+
747+ matmul_2 = ir .PassDesc .OP .matmul_v2 (X = fused_bias_act , Y = ffn2_weight )
748+ hidden = ir .PassDesc .OP .c_allreduce_sum (X = matmul_2 )
749+ residual_out = rms_norm_1 .Output ("residual_out" )[0 ]
750+
751+ encode_rotary_qk = ir .PassDesc .OP .encode_rotary_qk (kv = k , q = q , rotary_emb = rotary_emb , seq_lens = seq_len_encoder )
752+ rotary_kv_out = encode_rotary_qk .Output ("rotary_kv_out" )[0 ]
753+ rotary_q_out = encode_rotary_qk .Output ("rotary_q_out" )[0 ]
754+
755+ return write_cache_kv , q , k , v , hidden , residual_out , rotary_kv_out , rotary_q_out
756+
757+ def replace (x , residual , input_ids , padding_offset , seq_len_encoder , cache_kv , mask , rotary_emb , ln_scale , qkv_weight , out_proj_weight , ffn_in_scale , ffn1_weight , ffn2_weight ):
758+ llama_layer = llama_paralle_layer_adaptor (x , residual , input_ids , padding_offset , seq_len_encoder , cache_kv , mask , rotary_emb , ln_scale , qkv_weight , out_proj_weight , ffn_in_scale , ffn1_weight , ffn2_weight )
759+
760+ return (llama_layer [3 ],
761+ llama_layer [4 ],
762+ llama_layer [5 ],
763+ llama_layer [6 ],
764+ llama_layer [0 ],
765+ llama_layer [7 ],
766+ llama_layer [1 ],
767+ llama_layer [2 ])
768+
769+ return pattern , replace
770+
771+
772+ @ir .RegisterPass
773+ def llama65B_fuse_attention_dynamic_first_parallel_layer ():
774+ def pattern (x , input_ids , padding_offset , seq_len_encoder , cache_kv , mask , rotary_emb , ln_scale , qkv_weight , out_proj_weight , ffn_in_scale , ffn1_weight , ffn2_weight ):
775+ rms_norm_0 = ir .PassDesc .OP .rms_norm (norm_weight = ln_scale , x = x )
776+ qkv = ir .PassDesc .OP .matmul_v2 (X = rms_norm_0 .Output ("out" ), Y = qkv_weight )
777+ qkv_split = ir .PassDesc .OP .qkv_transpose_split (input_ids = input_ids , padding_offset = padding_offset , qkv = qkv , seq_lens = seq_len_encoder )
778+ q = qkv_split .Output ("q_out" )[0 ]
779+ k = qkv_split .Output ("k_out" )[0 ]
780+ v = qkv_split .Output ("v_out" )[0 ]
781+ scale1 = ir .PassDesc .OP .scale (X = seq_len_encoder )
782+ write_cache_kv = ir .PassDesc .OP .write_cache_kv (cache_kv = cache_kv , input_k = k , input_v = v , sequence_lengths = scale1 )
783+ scale2 = ir .PassDesc .OP .scale (X = seq_len_encoder )
784+ attention = ir .PassDesc .OP .variable_length_memory_efficient_attention (key = k , kv_seq_lens = scale2 , mask = mask , query = q , seq_lens = seq_len_encoder , value = v )
785+
786+ transpose_remove_padding = ir .PassDesc .OP .transpose_remove_padding (input = attention , padding_offset = padding_offset , seq_lens = seq_len_encoder )
787+ matmul_0 = ir .PassDesc .OP .matmul_v2 (X = transpose_remove_padding , Y = out_proj_weight )
788+
789+ allreduce = ir .PassDesc .OP .c_allreduce_sum (X = matmul_0 )
790+
791+ rms_norm_1 = ir .PassDesc .OP .rms_norm (norm_weight = ffn_in_scale , residual = x , x = allreduce )
792+ matmul_1 = ir .PassDesc .OP .matmul_v2 (X = rms_norm_1 .Output ("out" ), Y = ffn1_weight )
793+ fused_bias_act = ir .PassDesc .OP .fused_bias_act (x = matmul_1 )
794+
795+ matmul_2 = ir .PassDesc .OP .matmul_v2 (X = fused_bias_act , Y = ffn2_weight )
796+ hidden = ir .PassDesc .OP .c_allreduce_sum (X = matmul_2 )
797+ residual_out = rms_norm_1 .Output ("residual_out" )[0 ]
798+
799+ encode_rotary_qk = ir .PassDesc .OP .encode_rotary_qk (kv = k , q = q , rotary_emb = rotary_emb , seq_lens = seq_len_encoder )
800+ rotary_kv_out = encode_rotary_qk .Output ("rotary_kv_out" )[0 ]
801+ rotary_q_out = encode_rotary_qk .Output ("rotary_q_out" )[0 ]
802+
803+ return write_cache_kv , q , k , v , hidden , residual_out , rotary_kv_out , rotary_q_out
804+
805+ def replace (x , input_ids , padding_offset , seq_len_encoder , cache_kv , mask , rotary_emb , ln_scale , qkv_weight , out_proj_weight , ffn_in_scale , ffn1_weight , ffn2_weight ):
806+ llama_layer = llama_paralle_layer_adaptor (x , None , input_ids , padding_offset , seq_len_encoder , cache_kv , mask , rotary_emb , ln_scale , qkv_weight , out_proj_weight , ffn_in_scale , ffn1_weight , ffn2_weight )
807+
808+ return (llama_layer [3 ],
809+ llama_layer [4 ],
810+ llama_layer [5 ],
811+ llama_layer [6 ],
812+ llama_layer [0 ],
813+ llama_layer [7 ],
814+ llama_layer [1 ],
815+ llama_layer [2 ])
816+
817+ return pattern , replace
818+
819+
724820@ir .RegisterPass
725821def llama_layer_tail ():
726822 def pattern (x , norm_weight ):
@@ -729,5 +825,4 @@ def pattern(x, norm_weight):
729825
730826 def replace (x , norm_weight ):
731827 norm = ir .PassDesc .OP .llama_lmhead (Hidden = x , NormWeight = norm_weight )
732- return norm
733- return pattern , replace
828+ return norm
0 commit comments