-
Notifications
You must be signed in to change notification settings - Fork 48
Optimized MoE on Gaudi #159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
2761ab8
c7f8752
c6d128d
321a69a
a549bfb
71ff585
3ae83f3
fd1e0eb
6e94680
4a7b82a
316f0e0
9b60beb
f06441a
7bc3e28
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,10 @@ | |
| import habana_frameworks.torch.core as htcore | ||
| from vllm_hpu_extension.flags import enabled_flags | ||
| from vllm.logger import init_logger | ||
| import os | ||
|
|
||
| dynamic_moe_min_tokens = int( | ||
| os.environ.get("VLLM_DYNAMIC_MOE_MIN_TOKENS", 256)) | ||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
|
|
@@ -355,30 +358,62 @@ def forward(self, | |
| hidden_states, | ||
| expert_routing_table, | ||
| router_weights, | ||
| layer, | ||
gyou2021 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| permuted_weights=True, | ||
| activation="silu"): | ||
| # pre-processing for custom op inputs | ||
| experts_range = range(self.num_experts) | ||
| w1_list = [self.w13_list[i].weight.squeeze() for i in experts_range] | ||
| w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range] | ||
| return torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states, | ||
| expert_routing_table=expert_routing_table, | ||
| router_weights=router_weights, | ||
| w12=w1_list, | ||
| w3=w2_list, | ||
| permuted_weights=permuted_weights, | ||
| activation=activation, | ||
| experts_min=0, | ||
| experts_max=self.num_experts - 1) | ||
| bt, hidden_dim = hidden_states.shape | ||
gyou2021 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| num_experts = layer.w13_weight.shape[0] | ||
gyou2021 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| moe_intermediate = layer.w2_weight.shape[2] | ||
| ep_shift = layer.ep_rank * num_experts | ||
| selected_experts = (expert_routing_table - ep_shift).to(torch.int64) | ||
| if bt > dynamic_moe_min_tokens: | ||
| experts_range = range(num_experts) | ||
| w1_list = [layer.w13_weight[i].squeeze() for i in experts_range] | ||
| w2_list = [layer.w2_weight[i].squeeze() for i in experts_range] | ||
| final_hidden_states = torch.ops.hpu.mixture_of_experts( | ||
| hidden_states=hidden_states, | ||
| expert_routing_table=selected_experts, | ||
| router_weights=router_weights, | ||
| w12=w1_list, | ||
| w3=w2_list, | ||
| permuted_weights=True, | ||
| activation="silu", | ||
| experts_min=0, | ||
| experts_max=(num_experts - 1), | ||
| ) | ||
| else: | ||
gyou2021 marked this conversation as resolved.
Show resolved
Hide resolved
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't use else. wrap static moe into a function and do:
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modified.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. May you switch the sequence, do: existing codesdynamic_moe() instead of: if ...: |
||
| padded_weights = torch.zeros((bt, num_experts), | ||
| dtype=hidden_states.dtype, | ||
| device=hidden_states.device) | ||
| padded_weights.scatter_(-1, selected_experts, router_weights) | ||
| padded_weights = padded_weights.transpose(0, 1).unsqueeze(-1) | ||
|
|
||
| up_gate_states = torch.matmul( | ||
| hidden_states, | ||
| layer.w13_weight.view(-1, layer.w13_weight.size(-1)).transpose( | ||
| 0, 1)) | ||
| up_gate_states = up_gate_states.reshape(bt, num_experts, 2, | ||
| moe_intermediate) | ||
| up_states = up_gate_states[:, :, 0, :] | ||
| gate_states = up_gate_states[:, :, 1, :] | ||
| current_state_static = F.silu(up_states) * gate_states | ||
| current_state_static = current_state_static.transpose(0, 1) | ||
|
|
||
| current_hidden_states_static = torch.matmul( | ||
| current_state_static, layer.w2_weight.transpose( | ||
| 1, 2)) * padded_weights | ||
| final_hidden_states = current_hidden_states_static.sum(dim=0) | ||
|
||
|
|
||
| return final_hidden_states.view(-1, hidden_states.shape[1]) | ||
|
|
||
| class DynamicFusedMOE(torch.nn.Module): | ||
|
|
||
| def __init__(self, num_total_experts): | ||
| super().__init__() | ||
| self.MoeOp = VllmMixtureOfExpertsOp(num_total_experts) | ||
|
|
||
| def forward(self, hidden_states, score, topk): | ||
| def forward(self, hidden_states, score, topk, layer): | ||
| htorch.core.mark_step() | ||
| routing_weights = F.softmax(score, dim=1, dtype=torch.float32) | ||
| routing_weights, selected_experts = torch.topk(routing_weights, | ||
|
|
@@ -391,6 +426,7 @@ def forward(self, hidden_states, score, topk): | |
| hidden_states=hidden_states, | ||
| expert_routing_table=selected_experts, | ||
| router_weights=routing_weights, | ||
| layer=layer, | ||
| permuted_weights=True, | ||
| activation="silu", | ||
| ) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.