Skip to content
Draft
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 90 additions & 32 deletions vllm_hpu_extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import math
import habana_frameworks.torch.core as htcore
from vllm_hpu_extension.flags import enabled_flags

import habana_frameworks.torch.utils.experimental as htexp

is_hpu_gaudi2 = htexp._get_device_type(
Expand All @@ -22,6 +23,9 @@
FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max

import os

logger = init_logger(__name__)

# MAX_EXPERTS_PER_SLICE is needed for 1.20, up to 64 experts per slice
MAX_EXPERTS_PER_SLICE = int(os.environ.get("MAX_EXPERTS_PER_SLICE", -1))

Expand Down Expand Up @@ -424,6 +428,23 @@ def __init__(self, num_total_experts, experts_min: int = 0, experts_max: int = 8
self.moe_n_slice = 1 if self.num_experts <= max_expert_per_slice \
else self.num_experts // max_expert_per_slice
self.num_expert_per_group = self.num_experts // self.moe_n_slice

# if num_tokens exceed the VLLM_DYNAMIC_MOE_MIN_TOKENS,
# dynamic MoE is used since its performance is better than
# static MoE in this case.
self.dynamic_moe_min_tokens = int(
os.environ.get("VLLM_DYNAMIC_MOE_MIN_TOKENS", -1))
# if the number of expert on a single card is smaller than
# VLLM_DYNAMIC_MOE_MIN_EXPERTS_SINGLEHPU, dynamic MoE
# is used since its performance is better than
# static MoE in this case.
self.dynamic_moe_max_num_expert_singleHpu = int(
os.environ.get("VLLM_DYNAMIC_MOE_MIN_EXPERTS_SINGLEHPU", 32))

#self.w13_weight is a tensor of combined w13_list
self.w13_weight = None
#self.w2_weight is a tensor of combined w2_list
self.w2_weight = None

def forward(self,
hidden_states,
Expand All @@ -432,44 +453,81 @@ def forward(self,
permuted_weights=True,
activation="silu"):
# pre-processing for custom op inputs
bt, hidden_dim = hidden_states.shape
experts_range = range(self.num_experts)
w1_list = [self.w13_list[i].weight.squeeze() for i in experts_range]
w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range]

if self.moe_n_slice == 1:
return torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
permuted_weights=permuted_weights,
activation=activation,
experts_min=self.experts_min,
experts_max=self.experts_max)
for i in range(self.moe_n_slice):
w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
min_expert = self.experts_min + i * self.num_expert_per_group
max_expert = min_expert + self.num_expert_per_group - 1
slice_final_hidden_states = torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list_slice,
w3=w2_list_slice,
permuted_weights=permuted_weights,
activation=activation,
experts_min=min_expert,
experts_max=max_expert)
if i == 0:
final_hidden_states = slice_final_hidden_states
else:
final_hidden_states += slice_final_hidden_states
htorch.core.mark_step()
return final_hidden_states
# When the number of input tokens (batch_size*seqence_length) exceeds
# dynamic_moe_min_tokens (default -1) or the number of the experts
# on the single card is smaller than dynamic_moe_max_num_expert_singleHpu
# (default 32), dynamic MoE is used since it delivers better performance
# than static MoE. Otherwise static MoE is used.
if bt > self.dynamic_moe_min_tokens or \
(self.num_experts <= self.dynamic_moe_max_num_expert_singleHpu):
if self.moe_n_slice == 1:
return torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
permuted_weights=permuted_weights,
activation=activation,
experts_min=self.experts_min,
experts_max=self.experts_max)
for i in range(self.moe_n_slice):
w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
min_expert = self.experts_min + i * self.num_expert_per_group
max_expert = min_expert + self.num_expert_per_group - 1
slice_final_hidden_states = torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list_slice,
w3=w2_list_slice,
permuted_weights=permuted_weights,
activation=activation,
experts_min=min_expert,
experts_max=max_expert)
if i == 0:
final_hidden_states = slice_final_hidden_states
else:
final_hidden_states += slice_final_hidden_states
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't use else. wrap static moe into a function and do:
if ...:
return static_MOE()
#existing code
...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May you switch the sequence,

do:
if ...:
return static_MOE()

existing codes

dynamic_moe()


instead of:

if ...:
dynamic_moe()
else:
return statc

if self.w13_weight is None:
self.w13_weight = torch.stack(w1_list)

if self.w2_weight is None:
self.w2_weight = torch.stack(w2_list)

selected_experts = (expert_routing_table - self.experts_min).to(torch.int64)
moe_intermediate = self.w2_weight.shape[2]
padded_weights = torch.zeros((bt, self.num_experts),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights.scatter_(-1, selected_experts, router_weights)
padded_weights = padded_weights.transpose(0, 1).unsqueeze(-1)

up_gate_states = torch.matmul(
hidden_states,
self.w13_weight.view(-1, self.w13_weight.size(-1)).transpose(
0, 1))
up_gate_states = up_gate_states.reshape(bt, self.num_experts, 2,
moe_intermediate)
up_states = up_gate_states[:, :, 0, :]
gate_states = up_gate_states[:, :, 1, :]
current_state_static = F.silu(up_states) * gate_states
current_state_static = current_state_static.transpose(0, 1)

current_hidden_states_static = torch.matmul(
current_state_static, self.w2_weight.transpose(
1, 2)) * padded_weights
final_hidden_states = current_hidden_states_static.sum(dim=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about wrap static_moe as a function and invoke here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The INC engineer suggested not to wrap it.


return final_hidden_states.view(-1, hidden_states.shape[1])

class DynamicFusedMOE(torch.nn.Module):

def __init__(self, num_total_experts):
Expand Down