diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 76a5745a4f51..23e2059e6c89 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -103,6 +103,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) + num_tokens = hidden_states.size(0) if self.is_sequence_parallel: hidden_states = sequence_parallel_chunk(hidden_states) @@ -114,7 +115,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states, 0) - num_tokens = orig_shape[0] final_hidden_states = final_hidden_states[:num_tokens] return final_hidden_states.view(orig_shape)