Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 130 additions & 34 deletions vllm_hpu_extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import math
import habana_frameworks.torch.core as htcore
from vllm_hpu_extension.flags import enabled_flags
from vllm.logger import init_logger

import habana_frameworks.torch.utils.experimental as htexp

is_hpu_gaudi2 = htexp._get_device_type(
Expand All @@ -22,6 +24,9 @@
FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max

import os

logger = init_logger(__name__)

# MAX_EXPERTS_PER_SLICE is needed for 1.20, up to 64 experts per slice
MAX_EXPERTS_PER_SLICE = int(os.environ.get("MAX_EXPERTS_PER_SLICE", -1))

Expand Down Expand Up @@ -392,6 +397,40 @@ def dispatch_bgmv_embedding(
out = x @ wb
y += out

def static_MoE(
hidden_states,
expert_routing_table,
router_weights,
experts_min,
w13_weight,
w2_weight,
tokens_num,
num_experts,
):
selected_experts = (expert_routing_table - experts_min).to(torch.int64)
moe_intermediate = w2_weight.shape[2]
padded_weights = torch.zeros((tokens_num, num_experts),
dtype=hidden_states.dtype,
device=hidden_states.device)
padded_weights.scatter_(-1, selected_experts, router_weights)
padded_weights = padded_weights.transpose(0, 1).unsqueeze(-1)

up_gate_states = torch.matmul(
hidden_states,
w13_weight.view(-1, w13_weight.size(-1)).transpose(
0, 1))
up_gate_states = up_gate_states.reshape(tokens_num, num_experts, 2,
moe_intermediate)
up_states = up_gate_states[:, :, 0, :]
gate_states = up_gate_states[:, :, 1, :]
current_state_static = F.silu(up_states) * gate_states
current_state_static = current_state_static.transpose(0, 1)

current_hidden_states_static = torch.matmul(
current_state_static, w2_weight.transpose(
1, 2)) * padded_weights
final_hidden_states = current_hidden_states_static.sum(dim=0)
return final_hidden_states

class MoeMatmul(torch.nn.Module):

Expand Down Expand Up @@ -424,6 +463,40 @@ def __init__(self, num_total_experts, experts_min: int = 0, experts_max: int = 8
self.moe_n_slice = 1 if self.num_experts <= max_expert_per_slice \
else self.num_experts // max_expert_per_slice
self.num_expert_per_group = self.num_experts // self.moe_n_slice

# if num_tokens exceed the VLLM_DYNAMIC_MOE_MIN_TOKENS,
# dynamic MoE is used since its performance is better than
# static MoE in this case.
self.dynamic_moe_min_tokens = int(
os.environ.get("VLLM_DYNAMIC_MOE_MIN_TOKENS", -1))
# if the number of expert on a single card is smaller than
# VLLM_DYNAMIC_MOE_MIN_EXPERTS_SINGLEHPU, dynamic MoE
# is used since its performance is better than
# static MoE in this case.
self.dynamic_moe_max_num_expert_singleHpu = int(
os.environ.get("VLLM_DYNAMIC_MOE_MIN_EXPERTS_SINGLEHPU", 32))

#self.w13_weight is a tensor of combined w13_list
self.w13_weight = None
#self.w2_weight is a tensor of combined w2_list
self.w2_weight = None

# if num_tokens exceed the VLLM_DYNAMIC_MOE_MIN_TOKENS,
# dynamic MoE is used since its performance is better than
# static MoE in this case.
self.dynamic_moe_min_tokens = int(
os.environ.get("VLLM_DYNAMIC_MOE_MIN_TOKENS", 256))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this duplicated with 466 - 482?

# if the number of expert on a single card is smaller than
# VLLM_DYNAMIC_MOE_MIN_EXPERTS_SINGLEHPU, dynamic MoE
# is used since its performance is better than
# static MoE in this case.
self.dynamic_moe_max_num_expert_singleHpu = int(
os.environ.get("VLLM_DYNAMIC_MOE_MIN_EXPERTS_SINGLEHPU", 32))

#self.w13_weight is a tensor of combined w13_list
self.w13_weight = None
#self.w2_weight is a tensor of combined w2_list
self.w2_weight = None

def forward(self,
hidden_states,
Expand All @@ -432,44 +505,67 @@ def forward(self,
permuted_weights=True,
activation="silu"):
# pre-processing for custom op inputs
bt, hidden_dim = hidden_states.shape
experts_range = range(self.num_experts)
w1_list = [self.w13_list[i].weight.squeeze() for i in experts_range]
w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range]

if self.moe_n_slice == 1:
return torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
permuted_weights=permuted_weights,
activation=activation,
experts_min=self.experts_min,
experts_max=self.experts_max)
for i in range(self.moe_n_slice):
w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
min_expert = self.experts_min + i * self.num_expert_per_group
max_expert = min_expert + self.num_expert_per_group - 1
slice_final_hidden_states = torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list_slice,
w3=w2_list_slice,
permuted_weights=permuted_weights,
activation=activation,
experts_min=min_expert,
experts_max=max_expert)
if i == 0:
final_hidden_states = slice_final_hidden_states
else:
final_hidden_states += slice_final_hidden_states
htorch.core.mark_step()
return final_hidden_states


# When the number of input tokens (batch_size*seqence_length) exceeds
# dynamic_moe_min_tokens (default 256) or the number of the experts
# on the single card is smaller than dynamic_moe_max_num_expert_singleHpu
# (default 32), dynamic MoE is used since it delivers better performance
# than static MoE. Otherwise static MoE is used.
if bt > self.dynamic_moe_min_tokens or \
(self.num_experts <= self.dynamic_moe_max_num_expert_singleHpu):
if self.moe_n_slice == 1:
return torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
permuted_weights=permuted_weights,
activation=activation,
experts_min=self.experts_min,
experts_max=self.experts_max)
for i in range(self.moe_n_slice):
w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
min_expert = self.experts_min + i * self.num_expert_per_group
max_expert = min_expert + self.num_expert_per_group - 1
slice_final_hidden_states = torch.ops.hpu.mixture_of_experts(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list_slice,
w3=w2_list_slice,
permuted_weights=permuted_weights,
activation=activation,
experts_min=min_expert,
experts_max=max_expert)
if i == 0:
final_hidden_states = slice_final_hidden_states
else:
final_hidden_states += slice_final_hidden_states
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't use else. wrap static moe into a function and do:
if ...:
return static_MOE()
#existing code
...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May you switch the sequence,

do:
if ...:
return static_MOE()

existing codes

dynamic_moe()


instead of:

if ...:
dynamic_moe()
else:
return statc

if self.w13_weight is None:
self.w13_weight = torch.stack(w1_list)

if self.w2_weight is None:
self.w2_weight = torch.stack(w2_list)

final_hidden_states = static_MoE(
hidden_states,
expert_routing_table,
router_weights,
self.experts_min,
self.w13_weight,
self.w2_weight,
bt,
self.num_experts
)
return final_hidden_states.view(-1, hidden_states.shape[1])

class DynamicFusedMOE(torch.nn.Module):

def __init__(self, num_total_experts):
Expand Down