Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions python/sglang/srt/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,14 +243,13 @@ def forward(self, hidden_states):
# NOTE: For some unknown reason, router_gemm seems degrade accept length.
if (
_is_cuda
and not self.is_nextn
and hidden_states.shape[0] < 4
and hidden_states.shape[0] <= 16
and hidden_states.shape[1] == 7168
and self.weight.shape[0] == 256
and _device_sm >= 90
):
logits = dsv3_router_gemm(hidden_states, self.weight).to(
hidden_states.dtype
logits = dsv3_router_gemm(
hidden_states, self.weight, out_dtype=torch.bfloat16
)
else:
logits = F.linear(hidden_states, self.weight, None)
Expand Down