Skip to content
Draft
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 49 additions & 13 deletions vllm_hpu_extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
import habana_frameworks.torch.core as htcore
from vllm_hpu_extension.flags import enabled_flags
from vllm.logger import init_logger
import os

dynamic_moe_min_tokens = int(
os.environ.get("VLLM_DYNAMIC_MOE_MIN_TOKENS", 256))
logger = init_logger(__name__)


Expand Down Expand Up @@ -355,30 +358,62 @@ def forward(self,
hidden_states,
expert_routing_table,
router_weights,
layer,
permuted_weights=True,
activation="silu"):
# pre-processing for custom op inputs
experts_range = range(self.num_experts)
w1_list = [self.w13_list[i].weight.squeeze() for i in experts_range]
w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range]
return torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
permuted_weights=permuted_weights,
activation=activation,
experts_min=0,
experts_max=self.num_experts - 1)
bt, hidden_dim = hidden_states.shape
num_experts = layer.w13_weight.shape[0]
moe_intermediate = layer.w2_weight.shape[2]
ep_shift = layer.ep_rank * num_experts
selected_experts = (expert_routing_table - ep_shift).to(torch.int64)
if bt > dynamic_moe_min_tokens:
experts_range = range(num_experts)
w1_list = [layer.w13_weight[i].squeeze() for i in experts_range]
w2_list = [layer.w2_weight[i].squeeze() for i in experts_range]
final_hidden_states = torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=selected_experts,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
permuted_weights=True,
activation="silu",
experts_min=0,
experts_max=(num_experts - 1),
)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't use else. wrap static moe into a function and do:
if ...:
return static_MOE()
#existing code
...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May you switch the sequence,

do:
if ...:
return static_MOE()

existing codes

dynamic_moe()


instead of:

if ...:
dynamic_moe()
else:
return statc

padded_weights = torch.zeros((bt, num_experts),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights.scatter_(-1, selected_experts, router_weights)
padded_weights = padded_weights.transpose(0, 1).unsqueeze(-1)

up_gate_states = torch.matmul(
hidden_states,
layer.w13_weight.view(-1, layer.w13_weight.size(-1)).transpose(
0, 1))
up_gate_states = up_gate_states.reshape(bt, num_experts, 2,
moe_intermediate)
up_states = up_gate_states[:, :, 0, :]
gate_states = up_gate_states[:, :, 1, :]
current_state_static = F.silu(up_states) * gate_states
current_state_static = current_state_static.transpose(0, 1)

current_hidden_states_static = torch.matmul(
current_state_static, layer.w2_weight.transpose(
1, 2)) * padded_weights
final_hidden_states = current_hidden_states_static.sum(dim=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about wrap static_moe as a function and invoke here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The INC engineer suggested not to wrap it.


return final_hidden_states.view(-1, hidden_states.shape[1])

class DynamicFusedMOE(torch.nn.Module):

def __init__(self, num_total_experts):
super().__init__()
self.MoeOp = VllmMixtureOfExpertsOp(num_total_experts)

def forward(self, hidden_states, score, topk):
def forward(self, hidden_states, score, topk, layer):
htorch.core.mark_step()
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
routing_weights, selected_experts = torch.topk(routing_weights,
Expand All @@ -391,6 +426,7 @@ def forward(self, hidden_states, score, topk):
hidden_states=hidden_states,
expert_routing_table=selected_experts,
router_weights=routing_weights,
layer=layer,
permuted_weights=True,
activation="silu",
)
Expand Down