From 2761ab85a85e0a8a68be51476001c25553874e31 Mon Sep 17 00:00:00 2001 From: gyou2021 Date: Fri, 18 Apr 2025 10:53:40 +0000 Subject: [PATCH 1/7] 1. Optimized MoE on Gaudi. 2. Enabled EP on Gaudi. Signed-off-by: gyou2021 --- vllm_hpu_extension/ops.py | 62 +++++++++++++++++++++++++++++++-------- 1 file changed, 49 insertions(+), 13 deletions(-) diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index d28dbda8f..340ed42eb 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -13,7 +13,10 @@ import habana_frameworks.torch.core as htcore from vllm_hpu_extension.flags import enabled_flags from vllm.logger import init_logger +import os +dynamic_moe_min_tokens = int( +os.environ.get("VLLM_DYNAMIC_MOE_MIN_TOKENS", 256)) logger = init_logger(__name__) @@ -355,22 +358,54 @@ def forward(self, hidden_states, expert_routing_table, router_weights, + layer, permuted_weights=True, activation="silu"): # pre-processing for custom op inputs - experts_range = range(self.num_experts) - w1_list = [self.w13_list[i].weight.squeeze() for i in experts_range] - w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range] - return torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states, - expert_routing_table=expert_routing_table, - router_weights=router_weights, - w12=w1_list, - w3=w2_list, - permuted_weights=permuted_weights, - activation=activation, - experts_min=0, - experts_max=self.num_experts - 1) + bt, hidden_dim = hidden_states.shape + num_experts = layer.w13_weight.shape[0] + moe_intermediate = layer.w2_weight.shape[2] + ep_shift = layer.ep_rank * num_experts + selected_experts = (expert_routing_table - ep_shift).to(torch.int64) + if bt > dynamic_moe_min_tokens: + experts_range = range(num_experts) + w1_list = [layer.w13_weight[i].squeeze() for i in experts_range] + w2_list = [layer.w2_weight[i].squeeze() for i in experts_range] + final_hidden_states = torch.ops.hpu.mixture_of_experts( + hidden_states=hidden_states, + expert_routing_table=selected_experts, + router_weights=router_weights, + w12=w1_list, + w3=w2_list, + permuted_weights=True, + activation="silu", + experts_min=0, + experts_max=(num_experts - 1), + ) + else: + padded_weights = torch.zeros((bt, num_experts), + dtype=hidden_states.dtype, + device=hidden_states.device) + padded_weights.scatter_(-1, selected_experts, router_weights) + padded_weights = padded_weights.transpose(0, 1).unsqueeze(-1) + up_gate_states = torch.matmul( + hidden_states, + layer.w13_weight.view(-1, layer.w13_weight.size(-1)).transpose( + 0, 1)) + up_gate_states = up_gate_states.reshape(bt, num_experts, 2, + moe_intermediate) + up_states = up_gate_states[:, :, 0, :] + gate_states = up_gate_states[:, :, 1, :] + current_state_static = F.silu(up_states) * gate_states + current_state_static = current_state_static.transpose(0, 1) + + current_hidden_states_static = torch.matmul( + current_state_static, layer.w2_weight.transpose( + 1, 2)) * padded_weights + final_hidden_states = current_hidden_states_static.sum(dim=0) + + return final_hidden_states.view(-1, hidden_states.shape[1]) class DynamicFusedMOE(torch.nn.Module): @@ -378,7 +413,7 @@ def __init__(self, num_total_experts): super().__init__() self.MoeOp = VllmMixtureOfExpertsOp(num_total_experts) - def forward(self, hidden_states, score, topk): + def forward(self, hidden_states, score, topk, layer): htorch.core.mark_step() routing_weights = F.softmax(score, dim=1, dtype=torch.float32) routing_weights, selected_experts = torch.topk(routing_weights, @@ -391,6 +426,7 @@ def forward(self, hidden_states, score, topk): hidden_states=hidden_states, expert_routing_table=selected_experts, router_weights=routing_weights, + layer=layer, permuted_weights=True, activation="silu", ) From 321a69ad263660464a8fa3d4f06b1987fb902ff6 Mon Sep 17 00:00:00 2001 From: gyou2021 Date: Wed, 7 May 2025 11:04:58 +0000 Subject: [PATCH 2/7] Fixed format. Signed-off-by: gyou2021 --- vllm_hpu_extension/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index 94a10b7b6..45571dc8b 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -438,7 +438,7 @@ def forward(self, hidden_states, score, topk): final_hidden_states = self.MoeOp( hidden_states=hidden_states, expert_routing_table=selected_experts, - router_weights=routing_weights, + router_weights=routing_weights, permuted_weights=True, activation="silu", ) From 71ff585b4dd52681ca8629aeffe69b72957a8d72 Mon Sep 17 00:00:00 2001 From: gyou2021 Date: Fri, 23 May 2025 09:05:25 +0000 Subject: [PATCH 3/7] Optimized MoE on Gaudi Signed-off-by: gyou2021 --- vllm_hpu_extension/ops.py | 123 ++++++++++++++++++++++++++++---------- 1 file changed, 91 insertions(+), 32 deletions(-) diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index 6e19ce11e..028e217ab 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -12,6 +12,10 @@ import math import habana_frameworks.torch.core as htcore from vllm_hpu_extension.flags import enabled_flags +from vllm.logger import init_logger +import os + +logger = init_logger(__name__) import os # MAX_EXPERTS_PER_SLICE is needed for 1.20, up to 64 experts per slice @@ -415,6 +419,23 @@ def __init__(self, num_total_experts, experts_min: int = 0, experts_max: int = 8 else self.num_experts // max_expert_per_slice self.num_expert_per_group = self.num_experts // self.moe_n_slice + # if num_tokens exceed the VLLM_DYNAMIC_MOE_MIN_TOKENS, + # dynamic MoE is used since its performance is better than + # static MoE in this case. + self.dynamic_moe_min_tokens = int( + os.environ.get("VLLM_DYNAMIC_MOE_MIN_TOKENS", -1)) + # if the number of expert on a single card is smaller than + # VLLM_DYNAMIC_MOE_MIN_EXPERTS_SINGLEHPU, dynamic MoE + # is used since its performance is better than + # static MoE in this case. + self.dynamic_moe_max_num_expert_singleHpu = int( + os.environ.get("VLLM_DYNAMIC_MOE_MIN_EXPERTS_SINGLEHPU", 32)) + + #self.w13_weight is a tensor of combined w13_list + self.w13_weight = None + #self.w2_weight is a tensor of combined w2_list + self.w2_weight = None + def forward(self, hidden_states, expert_routing_table, @@ -422,42 +443,80 @@ def forward(self, permuted_weights=True, activation="silu"): # pre-processing for custom op inputs + bt, hidden_dim = hidden_states.shape experts_range = range(self.num_experts) w1_list = [self.w13_list[i].weight.squeeze() for i in experts_range] w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range] - if self.moe_n_slice == 1: - return torch.ops.hpu.mixture_of_experts( - hidden_states=hidden_states, - expert_routing_table=expert_routing_table, - router_weights=router_weights, - w12=w1_list, - w3=w2_list, - permuted_weights=permuted_weights, - activation=activation, - experts_min=self.experts_min, - experts_max=self.experts_max) - for i in range(self.moe_n_slice): - w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] - w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] - min_expert = self.experts_min + i * self.num_expert_per_group - max_expert = min_expert + self.num_expert_per_group - 1 - slice_final_hidden_states = torch.ops.hpu.mixture_of_experts( - hidden_states=hidden_states, - expert_routing_table=expert_routing_table, - router_weights=router_weights, - w12=w1_list_slice, - w3=w2_list_slice, - permuted_weights=permuted_weights, - activation=activation, - experts_min=min_expert, - experts_max=max_expert) - if i == 0: - final_hidden_states = slice_final_hidden_states - else: - final_hidden_states += slice_final_hidden_states - htorch.core.mark_step() - return final_hidden_states + # When the number of input tokens (batch_size*seqence_length) exceeds + # dynamic_moe_min_tokens (default 256) or the number of the experts + # on the single card is smaller than dynamic_moe_max_num_expert_singleHpu + # (default 32), dynamic MoE is used since it delivers better performance + # than static MoE. Otherwise static MoE is used. + if bt > self.dynamic_moe_min_tokens or \ + (self.num_experts <= self.dynamic_moe_max_num_expert_singleHpu): + if self.moe_n_slice == 1: + return torch.ops.hpu.mixture_of_experts( + hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list, + w3=w2_list, + permuted_weights=permuted_weights, + activation=activation, + experts_min=self.experts_min, + experts_max=self.experts_max) + for i in range(self.moe_n_slice): + w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] + w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] + min_expert = self.experts_min + i * self.num_expert_per_group + max_expert = min_expert + self.num_expert_per_group - 1 + slice_final_hidden_states = torch.ops.hpu.mixture_of_experts( + hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list_slice, + w3=w2_list_slice, + permuted_weights=permuted_weights, + activation=activation, + experts_min=min_expert, + experts_max=max_expert) + if i == 0: + final_hidden_states = slice_final_hidden_states + else: + final_hidden_states += slice_final_hidden_states + else: + if self.w13_weight is None: + self.w13_weight = torch.stack(w1_list) + + if self.w2_weight is None: + self.w2_weight = torch.stack(w2_list) + + selected_experts = (expert_routing_table - self.experts_min).to(torch.int64) + moe_intermediate = self.w2_weight.shape[2] + padded_weights = torch.zeros((bt, self.num_experts), + dtype=hidden_states.dtype, + device=hidden_states.device) + padded_weights.scatter_(-1, selected_experts, router_weights) + padded_weights = padded_weights.transpose(0, 1).unsqueeze(-1) + + up_gate_states = torch.matmul( + hidden_states, + self.w13_weight.view(-1, self.w13_weight.size(-1)).transpose( + 0, 1)) + up_gate_states = up_gate_states.reshape(bt, self.num_experts, 2, + moe_intermediate) + up_states = up_gate_states[:, :, 0, :] + gate_states = up_gate_states[:, :, 1, :] + current_state_static = F.silu(up_states) * gate_states + current_state_static = current_state_static.transpose(0, 1) + + current_hidden_states_static = torch.matmul( + current_state_static, self.w2_weight.transpose( + 1, 2)) * padded_weights + final_hidden_states = current_hidden_states_static.sum(dim=0) + + return final_hidden_states.view(-1, hidden_states.shape[1]) class DynamicFusedMOE(torch.nn.Module): From 6e9468009d10a4d2310b766ae65628de8ebfdbf6 Mon Sep 17 00:00:00 2001 From: gyou2021 Date: Mon, 26 May 2025 06:28:07 +0000 Subject: [PATCH 4/7] Removed functions for INC based FP8 inference. Signed-off-by: gyou2021 --- vllm_hpu_extension/ops.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index c60696e57..36ff8931b 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -418,10 +418,6 @@ def __init__(self, num_total_experts, experts_min: int = 0, experts_max: int = 8 else self.num_experts // max_expert_per_slice self.num_expert_per_group = self.num_experts // self.moe_n_slice - def set_weights(self, w13,w2): - self.w13_weight = w13 - self.w2_weight = w2 - # if num_tokens exceed the VLLM_DYNAMIC_MOE_MIN_TOKENS, # dynamic MoE is used since its performance is better than # static MoE in this case. @@ -442,7 +438,7 @@ def set_weights(self, w13,w2): def forward(self, hidden_states, expert_routing_table, - router_weights, + router_weights, permuted_weights=True, activation="silu"): # pre-processing for custom op inputs @@ -527,9 +523,6 @@ def __init__(self, num_total_experts): super().__init__() self.MoeOp = VllmMixtureOfExpertsOp(num_total_experts) - def set_MoeOp_weights(self, w13, w2): - self.MoeOp.set_weights(w13, w2) - def forward(self, hidden_states, score, topk): htorch.core.mark_step() routing_weights = F.softmax(score, dim=1, dtype=torch.float32) From 316f0e0fe541429d648fde0bcd3495c8f9daed57 Mon Sep 17 00:00:00 2001 From: gyou2021 Date: Mon, 26 May 2025 06:32:11 +0000 Subject: [PATCH 5/7] Modified the comment. Signed-off-by: gyou2021 --- vllm_hpu_extension/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index 30f05b0f5..822543e33 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -459,7 +459,7 @@ def forward(self, w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range] # When the number of input tokens (batch_size*seqence_length) exceeds - # dynamic_moe_min_tokens (default 256) or the number of the experts + # dynamic_moe_min_tokens (default -1) or the number of the experts # on the single card is smaller than dynamic_moe_max_num_expert_singleHpu # (default 32), dynamic MoE is used since it delivers better performance # than static MoE. Otherwise static MoE is used. From 9b60bebec9228d28315a71a0fe6b89f2cd235ef6 Mon Sep 17 00:00:00 2001 From: gyou2021 Date: Mon, 26 May 2025 08:41:36 +0000 Subject: [PATCH 6/7] rebased to b8a0e5 to be compatible with vllm-for dev/qwen3-habanamain; added static moe. Signed-off-by: gyou2021 --- vllm_hpu_extension/ops.py | 124 ++++++++++++++++++++++++++++---------- 1 file changed, 91 insertions(+), 33 deletions(-) diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index fdb71600c..f74c87e77 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -12,8 +12,11 @@ import math import habana_frameworks.torch.core as htcore from vllm_hpu_extension.flags import enabled_flags - +from vllm.logger import init_logger import os + +logger = init_logger(__name__) + # MAX_EXPERTS_PER_SLICE is needed for 1.20, up to 64 experts per slice MAX_EXPERTS_PER_SLICE = os.environ.get("MAX_EXPERTS_PER_SLICE", -1) @@ -416,6 +419,23 @@ def __init__(self, num_total_experts, experts_min: int = 0, experts_max: int = 8 else self.num_experts // max_expert_per_slice self.num_expert_per_group = self.num_experts // self.moe_n_slice + # if num_tokens exceed the VLLM_DYNAMIC_MOE_MIN_TOKENS, + # dynamic MoE is used since its performance is better than + # static MoE in this case. + self.dynamic_moe_min_tokens = int( + os.environ.get("VLLM_DYNAMIC_MOE_MIN_TOKENS", 256)) + # if the number of expert on a single card is smaller than + # VLLM_DYNAMIC_MOE_MIN_EXPERTS_SINGLEHPU, dynamic MoE + # is used since its performance is better than + # static MoE in this case. + self.dynamic_moe_max_num_expert_singleHpu = int( + os.environ.get("VLLM_DYNAMIC_MOE_MIN_EXPERTS_SINGLEHPU", 32)) + + #self.w13_weight is a tensor of combined w13_list + self.w13_weight = None + #self.w2_weight is a tensor of combined w2_list + self.w2_weight = None + def forward(self, hidden_states, expert_routing_table, @@ -423,42 +443,80 @@ def forward(self, permuted_weights=True, activation="silu"): # pre-processing for custom op inputs + bt, hidden_dim = hidden_states.shape experts_range = range(self.num_experts) w1_list = [self.w13_list[i].weight.squeeze() for i in experts_range] w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range] - if self.moe_n_slice == 1: - return torch.ops.hpu.mixture_of_experts( - hidden_states=hidden_states, - expert_routing_table=expert_routing_table, - router_weights=router_weights, - w12=w1_list, - w3=w2_list, - permuted_weights=permuted_weights, - activation=activation, - experts_min=self.experts_min, - experts_max=self.experts_max) - for i in range(self.moe_n_slice): - w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] - w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] - min_expert = self.experts_min + i * self.num_expert_per_group - max_expert = min_expert + self.num_expert_per_group - 1 - slice_final_hidden_states = torch.ops.hpu.mixture_of_experts( - hidden_states=hidden_states, - expert_routing_table=expert_routing_table, - router_weights=router_weights, - w12=w1_list_slice, - w3=w2_list_slice, - permuted_weights=permuted_weights, - activation=activation, - experts_min=min_expert, - experts_max=max_expert) - if i == 0: - final_hidden_states = slice_final_hidden_states - else: - final_hidden_states += slice_final_hidden_states - htorch.core.mark_step() - return final_hidden_states + # When the number of input tokens (batch_size*seqence_length) exceeds + # dynamic_moe_min_tokens (default 256) or the number of the experts + # on the single card is smaller than dynamic_moe_max_num_expert_singleHpu + # (default 32), dynamic MoE is used since it delivers better performance + # than static MoE. Otherwise static MoE is used. + if bt > self.dynamic_moe_min_tokens or \ + (self.num_experts <= self.dynamic_moe_max_num_expert_singleHpu): + if self.moe_n_slice == 1: + return torch.ops.hpu.mixture_of_experts( + hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list, + w3=w2_list, + permuted_weights=permuted_weights, + activation=activation, + experts_min=self.experts_min, + experts_max=self.experts_max) + for i in range(self.moe_n_slice): + w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] + w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] + min_expert = self.experts_min + i * self.num_expert_per_group + max_expert = min_expert + self.num_expert_per_group - 1 + slice_final_hidden_states = torch.ops.hpu.mixture_of_experts( + hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list_slice, + w3=w2_list_slice, + permuted_weights=permuted_weights, + activation=activation, + experts_min=min_expert, + experts_max=max_expert) + if i == 0: + final_hidden_states = slice_final_hidden_states + else: + final_hidden_states += slice_final_hidden_states + else: + if self.w13_weight is None: + self.w13_weight = torch.stack(w1_list) + + if self.w2_weight is None: + self.w2_weight = torch.stack(w2_list) + + selected_experts = (expert_routing_table - self.experts_min).to(torch.int64) + moe_intermediate = self.w2_weight.shape[2] + padded_weights = torch.zeros((bt, self.num_experts), + dtype=hidden_states.dtype, + device=hidden_states.device) + padded_weights.scatter_(-1, selected_experts, router_weights) + padded_weights = padded_weights.transpose(0, 1).unsqueeze(-1) + + up_gate_states = torch.matmul( + hidden_states, + self.w13_weight.view(-1, self.w13_weight.size(-1)).transpose( + 0, 1)) + up_gate_states = up_gate_states.reshape(bt, self.num_experts, 2, + moe_intermediate) + up_states = up_gate_states[:, :, 0, :] + gate_states = up_gate_states[:, :, 1, :] + current_state_static = F.silu(up_states) * gate_states + current_state_static = current_state_static.transpose(0, 1) + + current_hidden_states_static = torch.matmul( + current_state_static, self.w2_weight.transpose( + 1, 2)) * padded_weights + final_hidden_states = current_hidden_states_static.sum(dim=0) + + return final_hidden_states.view(-1, hidden_states.shape[1]) class DynamicFusedMOE(torch.nn.Module): From f06441afc705f36e1c1fb535aa29f9cefa00bc32 Mon Sep 17 00:00:00 2001 From: gyou2021 Date: Wed, 4 Jun 2025 09:37:39 +0000 Subject: [PATCH 7/7] Refactor code. Signed-off-by: gyou2021 --- vllm_hpu_extension/ops.py | 68 +++++++++++++++++++++++++-------------- 1 file changed, 44 insertions(+), 24 deletions(-) diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index f74c87e77..c4b7f5c67 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -386,6 +386,40 @@ def dispatch_bgmv_embedding( out = x @ wb y += out +def static_MoE( + hidden_states, + expert_routing_table, + router_weights, + experts_min, + w13_weight, + w2_weight, + tokens_num, + num_experts, +): + selected_experts = (expert_routing_table - experts_min).to(torch.int64) + moe_intermediate = w2_weight.shape[2] + padded_weights = torch.zeros((tokens_num, num_experts), + dtype=hidden_states.dtype, + device=hidden_states.device) + padded_weights.scatter_(-1, selected_experts, router_weights) + padded_weights = padded_weights.transpose(0, 1).unsqueeze(-1) + + up_gate_states = torch.matmul( + hidden_states, + w13_weight.view(-1, w13_weight.size(-1)).transpose( + 0, 1)) + up_gate_states = up_gate_states.reshape(tokens_num, num_experts, 2, + moe_intermediate) + up_states = up_gate_states[:, :, 0, :] + gate_states = up_gate_states[:, :, 1, :] + current_state_static = F.silu(up_states) * gate_states + current_state_static = current_state_static.transpose(0, 1) + + current_hidden_states_static = torch.matmul( + current_state_static, w2_weight.transpose( + 1, 2)) * padded_weights + final_hidden_states = current_hidden_states_static.sum(dim=0) + return final_hidden_states class MoeMatmul(torch.nn.Module): @@ -492,30 +526,16 @@ def forward(self, if self.w2_weight is None: self.w2_weight = torch.stack(w2_list) - selected_experts = (expert_routing_table - self.experts_min).to(torch.int64) - moe_intermediate = self.w2_weight.shape[2] - padded_weights = torch.zeros((bt, self.num_experts), - dtype=hidden_states.dtype, - device=hidden_states.device) - padded_weights.scatter_(-1, selected_experts, router_weights) - padded_weights = padded_weights.transpose(0, 1).unsqueeze(-1) - - up_gate_states = torch.matmul( - hidden_states, - self.w13_weight.view(-1, self.w13_weight.size(-1)).transpose( - 0, 1)) - up_gate_states = up_gate_states.reshape(bt, self.num_experts, 2, - moe_intermediate) - up_states = up_gate_states[:, :, 0, :] - gate_states = up_gate_states[:, :, 1, :] - current_state_static = F.silu(up_states) * gate_states - current_state_static = current_state_static.transpose(0, 1) - - current_hidden_states_static = torch.matmul( - current_state_static, self.w2_weight.transpose( - 1, 2)) * padded_weights - final_hidden_states = current_hidden_states_static.sum(dim=0) - + final_hidden_states = static_MoE( + hidden_states, + expert_routing_table, + router_weights, + self.experts_min, + self.w13_weight, + self.w2_weight, + bt, + self.num_experts + ) return final_hidden_states.view(-1, hidden_states.shape[1])