Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 143 additions & 140 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,9 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
raise NotImplementedError

@abstractmethod
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:
def apply(self, layer: torch.nn.Module, x: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
**kwargs) -> torch.Tensor:
raise NotImplementedError


Expand Down Expand Up @@ -61,66 +55,45 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
) -> torch.Tensor:
return self.forward(x, layer.w13_weight, layer.w2_weight,
router_logits, top_k, renormalize,
use_grouped_topk, num_expert_group, topk_group)

def forward_cuda(
self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
return fused_moe(x,
w1,
w2,
router_logits,
top_k,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group)
def apply(self, layer: torch.nn.Module, x: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
**kwargs) -> torch.Tensor:

return self.forward(x=x,
layer=layer,
topk_weights=topk_weights,
topk_ids=topk_ids,
**kwargs)

def forward_cuda(self, layer: torch.nn.Module, x: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
**kwargs) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True)

def forward_cpu(self, *args, **kwargs):
raise NotImplementedError(
"The CPU backend currently does not support MoE.")

def forward_tpu(
self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
) -> torch.Tensor:
def forward_tpu(self, layer: torch.nn.Module, x: torch.Tensor,
use_grouped_topk: bool, topk: int,
gating_output: torch.Tensor, renormalize: bool,
**kwargs) -> torch.Tensor:

from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)
return fused_moe(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk=topk,
gating_output=gating_output,
renormalize=renormalize)


class FusedMoE(torch.nn.Module):
Expand Down Expand Up @@ -195,67 +168,90 @@ def __init__(

def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: int, expert_id: int):
param_data = param.data
shard_id: str, expert_id: int) -> None:
if shard_id not in ["w1", "w2", "w3"]:
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
f"got {shard_id}.")

# Special case for fp8 scales.
if getattr(param, "is_fp8_scale", False):
self._load_fp8_scale(param.data, loaded_weight, weight_name,
shard_id, expert_id)
return

expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank()

# If transposed, weight is saved as [input_dim, output_dim]
# Otherwise, weight is saved as [output_dim, input_dim]
# Default is not transposed/input dim is dim 1
input_dim = getattr(param, "input_dim", 1)
output_dim = getattr(param, "output_dim", 0)

# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
if (shard_id == "w2"):
shard_dim = input_dim
shard_size = expert_data.shape[shard_dim]
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
elif (shard_id == "w1" or shard_id == "w3"):
shard_dim = output_dim
shard_size = expert_data.shape[output_dim] // 2
offset = shard_size * tp_rank
loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size)

# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
expert_data.copy_(loaded_weight)
# w3, up_proj: Load into second logical weight of w13.
elif shard_id == "w3":
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
# w2, down_proj: Load into only logical weight of w2.
elif shard_id == "w2":
expert_data.copy_(loaded_weight)
else:
raise ValueError(
f"Expected shard_id w1,w2 or w3 but got {shard_id}")

# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight
# Weight scales
elif "weight_scale" in weight_name:
# If we are in merged column case (gate_up_proj)
# shard_id 0 == gate_proj / w1
# shard_id 2 == up_proj / w3
if shard_id == 0 or shard_id == 2:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == 0 else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
# shard_id 1 == down_proj / w2
else:
param_data[expert_id] = loaded_weight
# Weights
def _select_experts(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk)

# DeekSeekv2 uses grouped_top_k
if self.use_grouped_topk:
assert self.topk_group is not None
assert self.num_expert_group is not None
topk_weights, topk_ids = grouped_topk(hidden_states, router_logits,
self.top_k, self.renormalize,
self.num_expert_group,
self.topk_group)
else:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.intermediate_size_per_partition
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)

# w1, gate_proj case: Load into first shard of w13.
if shard_id == 0:
param_data[expert_id,
0:shard_size, :] = loaded_weight[shard, :]
# w3, up_proj case: Load into second shard of w13.
elif shard_id == 2:
param_data[expert_id, shard_size:2 *
shard_size, :] = loaded_weight[shard, :]
# w2, down_proj case: Load into only shard of w2.
elif shard_id == 1:
param_data[expert_id, :, :] = loaded_weight[:, shard]
else:
raise ValueError(
f"Shard id must be in [0,1,2] but got {shard_id}")
topk_weights, topk_ids = fused_topk(hidden_states, router_logits,
self.top_k, self.renormalize)

return topk_weights, topk_ids

def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
assert self.quant_method is not None

topk_weights, topk_ids = self._select_experts(
hidden_states=hidden_states, router_logits=router_logits)

# Matrix multiply.
final_hidden_states = self.quant_method.apply(
self,
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
topk_weights=topk_weights,
topk_ids=topk_ids,
topk=self.top_k,
gating_output=router_logits,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group)
use_grouped_topk=self.use_grouped_topk)

if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
Expand All @@ -267,35 +263,42 @@ def forward(self, hidden_states: torch.Tensor,
def make_expert_params_mapping(
cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str,
ckpt_up_proj_name: str,
num_experts: int) -> List[Tuple[str, str, int, int]]:

gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name]
gate_down_up = [
ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name
]
num_experts: int) -> List[Tuple[str, str, int, str]]:

return [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_scale"
if weight_name in gate_up else "experts.w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id,
shard_id) for expert_id in range(num_experts)
for shard_id, weight_name in enumerate(gate_down_up)
] + [
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_weight"
if weight_name in gate_up else "experts.w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
for expert_id in range(num_experts)
for shard_id, weight_name in enumerate(gate_down_up)
] + [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
("experts.a13_scale"
if weight_name in gate_up else "experts.a2_scale",
f"experts.{expert_id}.{weight_name}.input_scale", expert_id,
shard_id) for expert_id in range(num_experts)
for shard_id, weight_name in enumerate(gate_down_up)
("experts.w13_" if weight_name
in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_",
f"experts.{expert_id}.{weight_name}.", expert_id, shard_id)
for expert_id in range(num_experts) for shard_id, weight_name in [
("w1", ckpt_gate_proj_name),
("w2", ckpt_down_proj_name),
("w3", ckpt_up_proj_name),
]
]

def _load_fp8_scale(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
param_data = param.data

# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight
# Weight scales
elif "weight_scale" in weight_name:
# If we are in merged column case (gate_up_proj)
if shard_id == "w1" or shard_id == "w3":
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
else:
param_data[expert_id] = loaded_weight
Loading