Skip to content

Commit 970e06a

Browse files
ElizaWszoladsikka
authored andcommitted
Move output type conversion to gptq method as well
1 parent 1faab90 commit 970e06a

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

vllm/model_executor/layers/quantization/gptq_marlin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,7 @@ def apply(
585585
fused_marlin_moe)
586586

587587
# The input must currently be float16
588+
orig_dtype = x.dtype
588589
x = x.half()
589590

590591
topk_weights, topk_ids = FusedMoE.select_experts(
@@ -610,4 +611,4 @@ def apply(
610611
topk_ids,
611612
w1_scale=layer.w13_scales,
612613
w2_scale=layer.w2_scales,
613-
)
614+
).to(orig_dtype)

vllm/model_executor/models/mixtral.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,11 @@ def __init__(self,
9595
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
9696
# NOTE: hidden_states can have either 1D or 2D shape.
9797
orig_shape = hidden_states.shape
98-
orig_dtype = hidden_states.dtype
9998
hidden_states = hidden_states.view(-1, self.hidden_size)
10099
# router_logits: (num_tokens, n_experts)
101100
router_logits, _ = self.gate(hidden_states)
102101
final_hidden_states = self.experts(hidden_states, router_logits)
103-
return final_hidden_states.view(orig_shape).to(orig_dtype)
102+
return final_hidden_states.view(orig_shape)
104103

105104

106105
class MixtralAttention(nn.Module):

0 commit comments

Comments
 (0)