Skip to content

Commit 2ddc11c

Browse files
committed
simp
1 parent 4be77a1 commit 2ddc11c

1 file changed

Lines changed: 1 addition & 5 deletions

File tree

python/sglang/srt/models/deepseek_v2.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2198,9 +2198,6 @@ def post_load_weights(self, is_nextn=False, weight_names=None):
21982198
# This may affect the accuracy of fp8 model.
21992199
# Fix deepseek v3 blockwise bmm by using deep_gemm
22002200
use_deep_gemm_bmm = False
2201-
# model_dtype = torch.get_default_dtype()
2202-
print(f"HACK!!! {torch.get_default_dtype()=} but force model_dtype=bf16")
2203-
model_dtype = torch.bfloat16
22042201

22052202
if w.dtype in (
22062203
torch.float8_e4m3fn,
@@ -2226,7 +2223,6 @@ def post_load_weights(self, is_nextn=False, weight_names=None):
22262223
_is_cuda
22272224
and weight_block_size[0] == 128
22282225
and weight_block_size[1] == 128
2229-
and model_dtype == torch.bfloat16
22302226
):
22312227
if (
22322228
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
@@ -2240,7 +2236,7 @@ def post_load_weights(self, is_nextn=False, weight_names=None):
22402236
weight,
22412237
weight_scale,
22422238
weight_block_size,
2243-
model_dtype,
2239+
torch.bfloat16,
22442240
)
22452241
else:
22462242
w, scale = block_quant_to_tensor_quant(

0 commit comments

Comments
 (0)