@@ -43,7 +43,7 @@ def find_final_nodes(program):
4343 return final_nodes
4444
4545
46- def _is_mha (pattern_ops , pattern_ops_type ):
46+ def _is_mha (pattern_ops , pattern_ops_type , skip_quant_tensor_list = [] ):
4747 """ judge whether this pattern is multihead attention """
4848 if pattern_ops_type .count ('softmax' ) != 1 or pattern_ops_type .count (
4949 'fetch' ) > 0 :
@@ -53,6 +53,7 @@ def _is_mha(pattern_ops, pattern_ops_type):
5353 for op in pattern_ops :
5454 if op .type () in ['matmul' , 'matmul_v2' ]:
5555 if not is_dynamic_weight_op (op ):
56+ skip_quant_tensor_list .extend (op ._op .input ('X' ))
5657 matmul_num += 1
5758 if matmul_num == 2 :
5859 return True
@@ -81,6 +82,7 @@ def _is_ffn(pattern_ops, pattern_ops_type):
8182def get_patterns (program , only_final_node = True ):
8283 """ distinguish the pattern in the program and get distillation node """
8384 distill_node = []
85+ skip_quant_tensor_list = []
8486 patterns = {}
8587 graph = GraphWrapper (program )
8688 block_num = 0
@@ -110,7 +112,8 @@ def get_patterns(program, only_final_node=True):
110112 pattern_name = shortcut_start_op .type () + '$' + str (op .idx (
111113 ))
112114
113- if _is_mha (pattern_ops , pattern_ops_type ):
115+ if _is_mha (pattern_ops , pattern_ops_type ,
116+ skip_quant_tensor_list ):
114117 model_type = 'transformer'
115118 pattern_name = 'MHA$' + str (block_num )
116119
@@ -145,4 +148,12 @@ def get_patterns(program, only_final_node=True):
145148 distill_node .append ('teacher_' + out_var .name ())
146149 distill_node .append (out_var .name ())
147150
151+ #### skip quant matmul in attention
152+ if model_type == 'transformer' :
153+ for block_id in range (len (program .blocks )):
154+ for op in program .blocks [block_id ].ops :
155+ for inp_name in op .input_arg_names :
156+ if inp_name in skip_quant_tensor_list :
157+ op ._set_attr ("op_namescope" , "skip_quant" )
158+
148159 return patterns , distill_node , model_type
0 commit comments