Skip to content

Commit 5b58365

Browse files
committed
not use gemm_dequant default and fix bug
1 parent 0b4b810 commit 5b58365

1 file changed

Lines changed: 6 additions & 7 deletions

File tree

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ class AvxConfig:
139139

140140
@dataclass
141141
class SpeculateConfig:
142-
speculate_max_draft_token_num: int = (1,)
142+
speculate_max_draft_token_num: int = 5
143143
speculate_method: str = None
144144

145145

@@ -1690,7 +1690,7 @@ def __init__(self, config: FusedMultiTransformerConfig):
16901690
self.quant_round_type = config.quant_round_type
16911691
self.quant_max_bound = config.quant_max_bound
16921692
self.quant_min_bound = config.quant_min_bound
1693-
# self.use_gemm_dequant = False
1693+
self.use_gemm_dequant = False
16941694

16951695
self.qkv_out_scales = []
16961696
self.linear_out_scales = []
@@ -1928,7 +1928,6 @@ def compute_qkv_linear(self, ln_out, i):
19281928
if paddle.is_compiled_with_rocm():
19291929
qkv_out = paddle.matmul(ln_out, self.qkv_weights[i])
19301930
else:
1931-
# TODO: add gemm_dequant after qkv_out
19321931
qkv_out = paddle.matmul(ln_out, self.qkv_weights[i], False, True)
19331932
return qkv_out
19341933

@@ -2033,13 +2032,13 @@ def compute_out_linear(self, fmha_out, i):
20332032
out_linear_out = paddle.matmul(fmha_out, self.linear_weights[i])
20342033
out_linear_out = dequant_int8(out_linear_out, self.linear_out_scales[i], self._dtype)
20352034
else:
2036-
try:
2035+
if self.use_gemm_dequant:
20372036
from paddlenlp_ops import gemm_dequant
20382037

20392038
out_linear_out = gemm_dequant(
20402039
fmha_out, self.linear_weights[i], self.linear_out_scales[i], self._dtype
20412040
)
2042-
except:
2041+
else:
20432042
out_linear_out = paddle.matmul(fmha_out, self.linear_weights[i], False, True)
20442043
out_linear_out = dequant_int8(out_linear_out, self.linear_out_scales[i], self._dtype)
20452044
return out_linear_out
@@ -2094,11 +2093,11 @@ def compute_ffn2(self, ffn1_out, i):
20942093
ffn2_out = paddle.matmul(ffn1_out, self.ffn2_weights[i])
20952094
ffn2_out = dequant_int8(ffn2_out, self.ffn2_out_scales[i], self._dtype)
20962095
else:
2097-
try:
2096+
if self.use_gemm_dequant:
20982097
from paddlenlp_ops import gemm_dequant
20992098

21002099
ffn2_out = gemm_dequant(ffn1_out, self.ffn2_weights[i], self.ffn2_out_scales[i], self._dtype)
2101-
except:
2100+
else:
21022101
ffn2_out = paddle.matmul(ffn1_out, self.ffn2_weights[i], False, True)
21032102
ffn2_out = dequant_int8(ffn2_out, self.ffn2_out_scales[i], self._dtype)
21042103
return ffn2_out

0 commit comments

Comments
 (0)