Skip to content

Commit 1004205

Browse files
authored
[MTP] Refactor mtp predictor to avoid d2h operation (#27643)
Signed-off-by: MengqingCao <[email protected]>
1 parent ba33e88 commit 1004205

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

vllm/model_executor/models/deepseek_mtp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def forward(
9797
) -> torch.Tensor:
9898
assert inputs_embeds is not None
9999
# masking inputs at position 0, as not needed by MTP
100-
inputs_embeds[positions == 0] = 0
100+
inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
101101
inputs_embeds = self.enorm(inputs_embeds)
102102
previous_hidden_states = self.hnorm(previous_hidden_states)
103103

0 commit comments

Comments
 (0)