Skip to content

Commit 8886423

Browse files
committed
Move float16 typecast hack to gptq marlin moe method
1 parent 565cc43 commit 8886423

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

vllm/model_executor/layers/quantization/gptq_marlin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,9 @@ def apply(
584584
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
585585
fused_marlin_moe)
586586

587+
# The input must currently be float16
588+
x = x.half()
589+
587590
topk_weights, topk_ids = FusedMoE.select_experts(
588591
hidden_states=x,
589592
router_logits=router_logits,

vllm/model_executor/models/mixtral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
9999
hidden_states = hidden_states.view(-1, self.hidden_size)
100100
# router_logits: (num_tokens, n_experts)
101101
router_logits, _ = self.gate(hidden_states)
102-
final_hidden_states = self.experts(hidden_states.half(), router_logits)
102+
final_hidden_states = self.experts(hidden_states, router_logits)
103103
return final_hidden_states.view(orig_shape).to(orig_dtype)
104104

105105

0 commit comments

Comments
 (0)