-
Notifications
You must be signed in to change notification settings - Fork 48
Optimized MoE on Gaudi #159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 11 commits
2761ab8
c7f8752
c6d128d
321a69a
a549bfb
71ff585
3ae83f3
fd1e0eb
6e94680
4a7b82a
316f0e0
9b60beb
f06441a
7bc3e28
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| import math | ||
| import habana_frameworks.torch.core as htcore | ||
| from vllm_hpu_extension.flags import enabled_flags | ||
|
|
||
| import habana_frameworks.torch.utils.experimental as htexp | ||
|
|
||
| is_hpu_gaudi2 = htexp._get_device_type( | ||
|
|
@@ -22,6 +23,9 @@ | |
| FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max | ||
|
|
||
| import os | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
| # MAX_EXPERTS_PER_SLICE is needed for 1.20, up to 64 experts per slice | ||
| MAX_EXPERTS_PER_SLICE = int(os.environ.get("MAX_EXPERTS_PER_SLICE", -1)) | ||
|
|
||
|
|
@@ -424,6 +428,23 @@ def __init__(self, num_total_experts, experts_min: int = 0, experts_max: int = 8 | |
| self.moe_n_slice = 1 if self.num_experts <= max_expert_per_slice \ | ||
| else self.num_experts // max_expert_per_slice | ||
| self.num_expert_per_group = self.num_experts // self.moe_n_slice | ||
|
|
||
| # if num_tokens exceed the VLLM_DYNAMIC_MOE_MIN_TOKENS, | ||
| # dynamic MoE is used since its performance is better than | ||
| # static MoE in this case. | ||
| self.dynamic_moe_min_tokens = int( | ||
| os.environ.get("VLLM_DYNAMIC_MOE_MIN_TOKENS", -1)) | ||
| # if the number of expert on a single card is smaller than | ||
| # VLLM_DYNAMIC_MOE_MIN_EXPERTS_SINGLEHPU, dynamic MoE | ||
| # is used since its performance is better than | ||
| # static MoE in this case. | ||
| self.dynamic_moe_max_num_expert_singleHpu = int( | ||
| os.environ.get("VLLM_DYNAMIC_MOE_MIN_EXPERTS_SINGLEHPU", 32)) | ||
|
|
||
| #self.w13_weight is a tensor of combined w13_list | ||
| self.w13_weight = None | ||
| #self.w2_weight is a tensor of combined w2_list | ||
| self.w2_weight = None | ||
|
|
||
| def forward(self, | ||
| hidden_states, | ||
|
|
@@ -432,44 +453,81 @@ def forward(self, | |
| permuted_weights=True, | ||
| activation="silu"): | ||
| # pre-processing for custom op inputs | ||
| bt, hidden_dim = hidden_states.shape | ||
gyou2021 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| experts_range = range(self.num_experts) | ||
| w1_list = [self.w13_list[i].weight.squeeze() for i in experts_range] | ||
| w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range] | ||
|
|
||
xuechendi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if self.moe_n_slice == 1: | ||
| return torch.ops.hpu.mixture_of_experts( | ||
| hidden_states=hidden_states, | ||
| expert_routing_table=expert_routing_table, | ||
| router_weights=router_weights, | ||
| w12=w1_list, | ||
| w3=w2_list, | ||
| permuted_weights=permuted_weights, | ||
| activation=activation, | ||
| experts_min=self.experts_min, | ||
| experts_max=self.experts_max) | ||
| for i in range(self.moe_n_slice): | ||
| w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] | ||
| w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] | ||
| min_expert = self.experts_min + i * self.num_expert_per_group | ||
| max_expert = min_expert + self.num_expert_per_group - 1 | ||
| slice_final_hidden_states = torch.ops.hpu.mixture_of_experts( | ||
| hidden_states=hidden_states, | ||
| expert_routing_table=expert_routing_table, | ||
| router_weights=router_weights, | ||
| w12=w1_list_slice, | ||
| w3=w2_list_slice, | ||
| permuted_weights=permuted_weights, | ||
| activation=activation, | ||
| experts_min=min_expert, | ||
| experts_max=max_expert) | ||
| if i == 0: | ||
| final_hidden_states = slice_final_hidden_states | ||
| else: | ||
| final_hidden_states += slice_final_hidden_states | ||
| htorch.core.mark_step() | ||
| return final_hidden_states | ||
| # When the number of input tokens (batch_size*seqence_length) exceeds | ||
| # dynamic_moe_min_tokens (default -1) or the number of the experts | ||
| # on the single card is smaller than dynamic_moe_max_num_expert_singleHpu | ||
| # (default 32), dynamic MoE is used since it delivers better performance | ||
| # than static MoE. Otherwise static MoE is used. | ||
| if bt > self.dynamic_moe_min_tokens or \ | ||
| (self.num_experts <= self.dynamic_moe_max_num_expert_singleHpu): | ||
| if self.moe_n_slice == 1: | ||
| return torch.ops.hpu.mixture_of_experts( | ||
| hidden_states=hidden_states, | ||
| expert_routing_table=expert_routing_table, | ||
| router_weights=router_weights, | ||
| w12=w1_list, | ||
| w3=w2_list, | ||
| permuted_weights=permuted_weights, | ||
| activation=activation, | ||
| experts_min=self.experts_min, | ||
| experts_max=self.experts_max) | ||
| for i in range(self.moe_n_slice): | ||
| w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] | ||
| w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] | ||
| min_expert = self.experts_min + i * self.num_expert_per_group | ||
| max_expert = min_expert + self.num_expert_per_group - 1 | ||
| slice_final_hidden_states = torch.ops.hpu.mixture_of_experts( | ||
| hidden_states=hidden_states, | ||
| expert_routing_table=expert_routing_table, | ||
| router_weights=router_weights, | ||
| w12=w1_list_slice, | ||
| w3=w2_list_slice, | ||
| permuted_weights=permuted_weights, | ||
| activation=activation, | ||
| experts_min=min_expert, | ||
| experts_max=max_expert) | ||
| if i == 0: | ||
| final_hidden_states = slice_final_hidden_states | ||
| else: | ||
| final_hidden_states += slice_final_hidden_states | ||
| else: | ||
gyou2021 marked this conversation as resolved.
Show resolved
Hide resolved
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't use else. wrap static moe into a function and do:
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modified.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May you switch the sequence, do: existing codesdynamic_moe() instead of: if ...: |
||
| if self.w13_weight is None: | ||
| self.w13_weight = torch.stack(w1_list) | ||
|
|
||
| if self.w2_weight is None: | ||
| self.w2_weight = torch.stack(w2_list) | ||
|
|
||
| selected_experts = (expert_routing_table - self.experts_min).to(torch.int64) | ||
| moe_intermediate = self.w2_weight.shape[2] | ||
| padded_weights = torch.zeros((bt, self.num_experts), | ||
| dtype=hidden_states.dtype, | ||
| device=hidden_states.device) | ||
| padded_weights.scatter_(-1, selected_experts, router_weights) | ||
| padded_weights = padded_weights.transpose(0, 1).unsqueeze(-1) | ||
|
|
||
| up_gate_states = torch.matmul( | ||
| hidden_states, | ||
| self.w13_weight.view(-1, self.w13_weight.size(-1)).transpose( | ||
| 0, 1)) | ||
| up_gate_states = up_gate_states.reshape(bt, self.num_experts, 2, | ||
| moe_intermediate) | ||
| up_states = up_gate_states[:, :, 0, :] | ||
| gate_states = up_gate_states[:, :, 1, :] | ||
| current_state_static = F.silu(up_states) * gate_states | ||
| current_state_static = current_state_static.transpose(0, 1) | ||
|
|
||
| current_hidden_states_static = torch.matmul( | ||
| current_state_static, self.w2_weight.transpose( | ||
| 1, 2)) * padded_weights | ||
| final_hidden_states = current_hidden_states_static.sum(dim=0) | ||
|
||
|
|
||
| return final_hidden_states.view(-1, hidden_states.shape[1]) | ||
|
|
||
| class DynamicFusedMOE(torch.nn.Module): | ||
|
|
||
| def __init__(self, num_total_experts): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.