@@ -139,7 +139,7 @@ class AvxConfig:
139139
140140@dataclass
141141class SpeculateConfig :
142- speculate_max_draft_token_num : int = ( 1 ,)
142+ speculate_max_draft_token_num : int = 5
143143 speculate_method : str = None
144144
145145
@@ -1690,7 +1690,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
16901690 self .quant_round_type = config .quant_round_type
16911691 self .quant_max_bound = config .quant_max_bound
16921692 self .quant_min_bound = config .quant_min_bound
1693- # self.use_gemm_dequant = False
1693+ self .use_gemm_dequant = False
16941694
16951695 self .qkv_out_scales = []
16961696 self .linear_out_scales = []
@@ -1928,7 +1928,6 @@ def compute_qkv_linear(self, ln_out, i):
19281928 if paddle .is_compiled_with_rocm ():
19291929 qkv_out = paddle .matmul (ln_out , self .qkv_weights [i ])
19301930 else :
1931- # TODO: add gemm_dequant after qkv_out
19321931 qkv_out = paddle .matmul (ln_out , self .qkv_weights [i ], False , True )
19331932 return qkv_out
19341933
@@ -2033,13 +2032,13 @@ def compute_out_linear(self, fmha_out, i):
20332032 out_linear_out = paddle .matmul (fmha_out , self .linear_weights [i ])
20342033 out_linear_out = dequant_int8 (out_linear_out , self .linear_out_scales [i ], self ._dtype )
20352034 else :
2036- try :
2035+ if self . use_gemm_dequant :
20372036 from paddlenlp_ops import gemm_dequant
20382037
20392038 out_linear_out = gemm_dequant (
20402039 fmha_out , self .linear_weights [i ], self .linear_out_scales [i ], self ._dtype
20412040 )
2042- except :
2041+ else :
20432042 out_linear_out = paddle .matmul (fmha_out , self .linear_weights [i ], False , True )
20442043 out_linear_out = dequant_int8 (out_linear_out , self .linear_out_scales [i ], self ._dtype )
20452044 return out_linear_out
@@ -2094,11 +2093,11 @@ def compute_ffn2(self, ffn1_out, i):
20942093 ffn2_out = paddle .matmul (ffn1_out , self .ffn2_weights [i ])
20952094 ffn2_out = dequant_int8 (ffn2_out , self .ffn2_out_scales [i ], self ._dtype )
20962095 else :
2097- try :
2096+ if self . use_gemm_dequant :
20982097 from paddlenlp_ops import gemm_dequant
20992098
21002099 ffn2_out = gemm_dequant (ffn1_out , self .ffn2_weights [i ], self .ffn2_out_scales [i ], self ._dtype )
2101- except :
2100+ else :
21022101 ffn2_out = paddle .matmul (ffn1_out , self .ffn2_weights [i ], False , True )
21032102 ffn2_out = dequant_int8 (ffn2_out , self .ffn2_out_scales [i ], self ._dtype )
21042103 return ffn2_out
0 commit comments