diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py index 6086b1bb0..086516e9b 100644 --- a/vllm_hpu_extension/ops.py +++ b/vllm_hpu_extension/ops.py @@ -12,6 +12,8 @@ import math import habana_frameworks.torch.core as htcore from vllm_hpu_extension.flags import enabled_flags +from vllm.logger import init_logger + import habana_frameworks.torch.utils.experimental as htexp is_hpu_gaudi2 = htexp._get_device_type( @@ -22,6 +24,9 @@ FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max import os + +logger = init_logger(__name__) + # MAX_EXPERTS_PER_SLICE is needed for 1.20, up to 64 experts per slice MAX_EXPERTS_PER_SLICE = int(os.environ.get("MAX_EXPERTS_PER_SLICE", -1)) @@ -392,6 +397,40 @@ def dispatch_bgmv_embedding( out = x @ wb y += out +def static_MoE( + hidden_states, + expert_routing_table, + router_weights, + experts_min, + w13_weight, + w2_weight, + tokens_num, + num_experts, +): + selected_experts = (expert_routing_table - experts_min).to(torch.int64) + moe_intermediate = w2_weight.shape[2] + padded_weights = torch.zeros((tokens_num, num_experts), + dtype=hidden_states.dtype, + device=hidden_states.device) + padded_weights.scatter_(-1, selected_experts, router_weights) + padded_weights = padded_weights.transpose(0, 1).unsqueeze(-1) + + up_gate_states = torch.matmul( + hidden_states, + w13_weight.view(-1, w13_weight.size(-1)).transpose( + 0, 1)) + up_gate_states = up_gate_states.reshape(tokens_num, num_experts, 2, + moe_intermediate) + up_states = up_gate_states[:, :, 0, :] + gate_states = up_gate_states[:, :, 1, :] + current_state_static = F.silu(up_states) * gate_states + current_state_static = current_state_static.transpose(0, 1) + + current_hidden_states_static = torch.matmul( + current_state_static, w2_weight.transpose( + 1, 2)) * padded_weights + final_hidden_states = current_hidden_states_static.sum(dim=0) + return final_hidden_states class MoeMatmul(torch.nn.Module): @@ -424,6 +463,40 @@ def __init__(self, num_total_experts, experts_min: int = 0, experts_max: int = 8 self.moe_n_slice = 1 if self.num_experts <= max_expert_per_slice \ else self.num_experts // max_expert_per_slice self.num_expert_per_group = self.num_experts // self.moe_n_slice + + # if num_tokens exceed the VLLM_DYNAMIC_MOE_MIN_TOKENS, + # dynamic MoE is used since its performance is better than + # static MoE in this case. + self.dynamic_moe_min_tokens = int( + os.environ.get("VLLM_DYNAMIC_MOE_MIN_TOKENS", -1)) + # if the number of expert on a single card is smaller than + # VLLM_DYNAMIC_MOE_MIN_EXPERTS_SINGLEHPU, dynamic MoE + # is used since its performance is better than + # static MoE in this case. + self.dynamic_moe_max_num_expert_singleHpu = int( + os.environ.get("VLLM_DYNAMIC_MOE_MIN_EXPERTS_SINGLEHPU", 32)) + + #self.w13_weight is a tensor of combined w13_list + self.w13_weight = None + #self.w2_weight is a tensor of combined w2_list + self.w2_weight = None + + # if num_tokens exceed the VLLM_DYNAMIC_MOE_MIN_TOKENS, + # dynamic MoE is used since its performance is better than + # static MoE in this case. + self.dynamic_moe_min_tokens = int( + os.environ.get("VLLM_DYNAMIC_MOE_MIN_TOKENS", 256)) + # if the number of expert on a single card is smaller than + # VLLM_DYNAMIC_MOE_MIN_EXPERTS_SINGLEHPU, dynamic MoE + # is used since its performance is better than + # static MoE in this case. + self.dynamic_moe_max_num_expert_singleHpu = int( + os.environ.get("VLLM_DYNAMIC_MOE_MIN_EXPERTS_SINGLEHPU", 32)) + + #self.w13_weight is a tensor of combined w13_list + self.w13_weight = None + #self.w2_weight is a tensor of combined w2_list + self.w2_weight = None def forward(self, hidden_states, @@ -432,44 +505,67 @@ def forward(self, permuted_weights=True, activation="silu"): # pre-processing for custom op inputs + bt, hidden_dim = hidden_states.shape experts_range = range(self.num_experts) w1_list = [self.w13_list[i].weight.squeeze() for i in experts_range] w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range] - if self.moe_n_slice == 1: - return torch.ops.hpu.mixture_of_experts( - hidden_states=hidden_states, - expert_routing_table=expert_routing_table, - router_weights=router_weights, - w12=w1_list, - w3=w2_list, - permuted_weights=permuted_weights, - activation=activation, - experts_min=self.experts_min, - experts_max=self.experts_max) - for i in range(self.moe_n_slice): - w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] - w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] - min_expert = self.experts_min + i * self.num_expert_per_group - max_expert = min_expert + self.num_expert_per_group - 1 - slice_final_hidden_states = torch.ops.hpu.mixture_of_experts( - hidden_states=hidden_states, - expert_routing_table=expert_routing_table, - router_weights=router_weights, - w12=w1_list_slice, - w3=w2_list_slice, - permuted_weights=permuted_weights, - activation=activation, - experts_min=min_expert, - experts_max=max_expert) - if i == 0: - final_hidden_states = slice_final_hidden_states - else: - final_hidden_states += slice_final_hidden_states - htorch.core.mark_step() - return final_hidden_states - - + # When the number of input tokens (batch_size*seqence_length) exceeds + # dynamic_moe_min_tokens (default 256) or the number of the experts + # on the single card is smaller than dynamic_moe_max_num_expert_singleHpu + # (default 32), dynamic MoE is used since it delivers better performance + # than static MoE. Otherwise static MoE is used. + if bt > self.dynamic_moe_min_tokens or \ + (self.num_experts <= self.dynamic_moe_max_num_expert_singleHpu): + if self.moe_n_slice == 1: + return torch.ops.hpu.mixture_of_experts( + hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list, + w3=w2_list, + permuted_weights=permuted_weights, + activation=activation, + experts_min=self.experts_min, + experts_max=self.experts_max) + for i in range(self.moe_n_slice): + w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] + w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] + min_expert = self.experts_min + i * self.num_expert_per_group + max_expert = min_expert + self.num_expert_per_group - 1 + slice_final_hidden_states = torch.ops.hpu.mixture_of_experts( + hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list_slice, + w3=w2_list_slice, + permuted_weights=permuted_weights, + activation=activation, + experts_min=min_expert, + experts_max=max_expert) + if i == 0: + final_hidden_states = slice_final_hidden_states + else: + final_hidden_states += slice_final_hidden_states + else: + if self.w13_weight is None: + self.w13_weight = torch.stack(w1_list) + + if self.w2_weight is None: + self.w2_weight = torch.stack(w2_list) + + final_hidden_states = static_MoE( + hidden_states, + expert_routing_table, + router_weights, + self.experts_min, + self.w13_weight, + self.w2_weight, + bt, + self.num_experts + ) + return final_hidden_states.view(-1, hidden_states.shape[1]) + class DynamicFusedMOE(torch.nn.Module): def __init__(self, num_total_experts):