Skip to content
Draft
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 72 additions & 33 deletions vllm_hpu_extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
import math
import habana_frameworks.torch.core as htcore
from vllm_hpu_extension.flags import enabled_flags
from vllm.logger import init_logger
import os

dynamic_moe_min_tokens = int(
os.environ.get("VLLM_DYNAMIC_MOE_MIN_TOKENS", 256))
logger = init_logger(__name__)

import os
# MAX_EXPERTS_PER_SLICE is needed for 1.20, up to 64 experts per slice
Expand Down Expand Up @@ -415,58 +421,91 @@ def __init__(self, num_total_experts, experts_min: int = 0, experts_max: int = 8
self.moe_n_slice = 1 if self.num_experts <= max_expert_per_slice \
else self.num_experts // max_expert_per_slice
self.num_expert_per_group = self.num_experts // self.moe_n_slice

def set_weights(self, w13,w2):
self.w13_weight = w13
self.w2_weight = w2

def forward(self,
hidden_states,
expert_routing_table,
router_weights,
router_weights,
permuted_weights=True,
activation="silu"):
# pre-processing for custom op inputs
bt, hidden_dim = hidden_states.shape
experts_range = range(self.num_experts)
w1_list = [self.w13_list[i].weight.squeeze() for i in experts_range]
w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range]

if self.moe_n_slice == 1:
return torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
permuted_weights=permuted_weights,
activation=activation,
experts_min=self.experts_min,
experts_max=self.experts_max)
for i in range(self.moe_n_slice):
w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
min_expert = self.experts_min + i * self.num_expert_per_group
max_expert = min_expert + self.num_expert_per_group - 1
slice_final_hidden_states = torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list_slice,
w3=w2_list_slice,
permuted_weights=permuted_weights,
activation=activation,
experts_min=min_expert,
experts_max=max_expert)
if i == 0:
final_hidden_states = slice_final_hidden_states
else:
final_hidden_states += slice_final_hidden_states
htorch.core.mark_step()
return final_hidden_states
if bt > dynamic_moe_min_tokens:
if self.moe_n_slice == 1:
return torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
permuted_weights=permuted_weights,
activation=activation,
experts_min=self.experts_min,
experts_max=self.experts_max)
for i in range(self.moe_n_slice):
w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
min_expert = self.experts_min + i * self.num_expert_per_group
max_expert = min_expert + self.num_expert_per_group - 1
slice_final_hidden_states = torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list_slice,
w3=w2_list_slice,
permuted_weights=permuted_weights,
activation=activation,
experts_min=min_expert,
experts_max=max_expert)
if i == 0:
final_hidden_states = slice_final_hidden_states
else:
final_hidden_states += slice_final_hidden_states

else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't use else. wrap static moe into a function and do:
if ...:
return static_MOE()
#existing code
...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May you switch the sequence,

do:
if ...:
return static_MOE()

existing codes

dynamic_moe()


instead of:

if ...:
dynamic_moe()
else:
return statc

selected_experts = (expert_routing_table - self.experts_min).to(torch.int64)
moe_intermediate = self.w2_weight.shape[2]
padded_weights = torch.zeros((bt, self.num_experts),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights.scatter_(-1, selected_experts, router_weights)
padded_weights = padded_weights.transpose(0, 1).unsqueeze(-1)

up_gate_states = torch.matmul(
hidden_states,
self.w13_weight.view(-1, self.w13_weight.size(-1)).transpose(
0, 1))
up_gate_states = up_gate_states.reshape(bt, self.num_experts, 2,
moe_intermediate)
up_states = up_gate_states[:, :, 0, :]
gate_states = up_gate_states[:, :, 1, :]
current_state_static = F.silu(up_states) * gate_states
current_state_static = current_state_static.transpose(0, 1)

current_hidden_states_static = torch.matmul(
current_state_static, self.w2_weight.transpose(
1, 2)) * padded_weights
final_hidden_states = current_hidden_states_static.sum(dim=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about wrap static_moe as a function and invoke here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The INC engineer suggested not to wrap it.


return final_hidden_states.view(-1, hidden_states.shape[1])

class DynamicFusedMOE(torch.nn.Module):

def __init__(self, num_total_experts):
super().__init__()
self.MoeOp = VllmMixtureOfExpertsOp(num_total_experts)

def set_MoeOp_weights(self, w13, w2):
self.MoeOp.set_weights(w13, w2)

def forward(self, hidden_states, score, topk):
htorch.core.mark_step()
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
Expand Down