Skip to content
48 changes: 36 additions & 12 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def fused_moe_kernel(
top_k: tl.constexpr,
compute_type: tl.constexpr,
use_fp8: tl.constexpr,
use_int8: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
Expand Down Expand Up @@ -82,7 +83,7 @@ def fused_moe_kernel(
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid = tl.program_id(axis=0)
pid = tl.program_id(axis=0).to(tl.int64)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
Expand Down Expand Up @@ -118,13 +119,16 @@ def fused_moe_kernel(
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)

if use_int8:
a_scale = tl.load(a_scale_ptr + off_experts)
b_scale = tl.load(b_scale_ptr + off_experts * stride_cm + offs_bn)

# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32 if use_int8 else tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
Expand All @@ -138,6 +142,9 @@ def fused_moe_kernel(
# We accumulate along the K dimension.
if use_fp8:
accumulator = tl.dot(a, b, acc=accumulator)
elif use_int8:
a = tl.math.llrint((a / a_scale)).to(tl.int8)
accumulator = tl.dot(a, b, acc=accumulator, out_dtype=accumulator.dtype)
else:
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
Expand All @@ -150,7 +157,8 @@ def fused_moe_kernel(
other=0)
accumulator = accumulator * moe_weight[:, None]

if use_fp8:
if use_fp8 or use_int8:
accumulator = accumulator.to(tl.float32)
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
Expand Down Expand Up @@ -229,16 +237,19 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int,
config: Dict[str, Any], compute_type: tl.dtype,
use_fp8: bool) -> None:
use_fp8: bool,
use_int8: bool) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1

if not use_fp8:
assert A_scale is None
assert B_scale is None
else:
if use_fp8:
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None
elif use_int8:
assert B_scale is not None
else:
assert A_scale is None
assert B_scale is None

grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
Expand Down Expand Up @@ -268,6 +279,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
top_k=top_k,
compute_type=compute_type,
use_fp8=use_fp8,
use_int8=use_int8,
**config,
)

Expand Down Expand Up @@ -434,6 +446,7 @@ def fused_experts(hidden_states: torch.Tensor,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
use_int8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
Expand All @@ -455,12 +468,19 @@ def fused_experts(hidden_states: torch.Tensor,
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)

if use_fp8:
dtype = "float8"
elif use_int8:
dtype = "int8"
else:
dtype = None

get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
"float8" if use_fp8 else None,
dtype,
override_config=override_config,
)

Expand Down Expand Up @@ -524,7 +544,8 @@ def fused_experts(hidden_states: torch.Tensor,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8=use_fp8)
use_fp8=use_fp8,
use_int8=use_int8)

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))

Expand All @@ -542,7 +563,8 @@ def fused_experts(hidden_states: torch.Tensor,
1,
config,
compute_type=compute_type,
use_fp8=use_fp8)
use_fp8=use_fp8,
use_int8=use_int8)

torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
Expand All @@ -563,6 +585,7 @@ def fused_moe(
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
use_fp8: bool = False,
use_int8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -618,6 +641,7 @@ def fused_moe(
inplace=inplace,
override_config=override_config,
use_fp8=use_fp8,
use_int8=use_int8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
Expand Down
154 changes: 149 additions & 5 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import CompressionFormat, QuantizationStrategy
from vllm.model_executor.layers.quantization.utils.w8a8_utils import create_per_channel_scale_param
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)
Expand Down Expand Up @@ -123,6 +125,138 @@ def forward_tpu(
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)


class W8A8QuantizedFusedMoEMethod(FusedMoEMethodBase):
"""MoE method without quantization."""


def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
self.strategy = extra_weight_attrs['quant_config'].target_scheme_map['Linear']['weights'].strategy
self.is_static_input_scheme = not extra_weight_attrs['quant_config'].target_scheme_map['Linear']['input_activations'].dynamic

self.quant_config = extra_weight_attrs["quant_config"]
self.weight_loader = extra_weight_attrs["weight_loader"]

self.logical_widths_13 = [intermediate_size * 2]
self.logical_widths_2 = [intermediate_size * 2]
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=torch.int8),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
# set_weight_attrs(w13_weight, extra_weight_attrs)

set_weight_attrs(w13_weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": self.weight_loader,
})

# WEIGHT SCALE
layer_kwargs = {"weight_loader": self.weight_loader, "num_experts": num_experts}
if self.strategy == QuantizationStrategy.CHANNEL:
scale = create_per_channel_scale_param([intermediate_size * 2],
**layer_kwargs)
else:
assert self.strategy == QuantizationStrategy.TENSOR
scale = torch.nn.Parameter(torch.empty((num_experts, 2), dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {
"needs_scalar_to_array": True,
**layer_kwargs
})
layer.register_parameter("w13_scale", scale)


# INPUT SCALE
if self.is_static_input_scheme:
scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
set_weight_attrs(scale, {
"needs_scalar_to_array": True,
**layer_kwargs
})
layer.register_parameter("a13_scale", scale)


# down_proj (row parallel)
# Fused gate_up_proj (column parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=torch.int8),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)

set_weight_attrs(w2_weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": self.weight_loader,
})

# WEIGHT SCALE
if self.strategy == QuantizationStrategy.CHANNEL:
scale = create_per_channel_scale_param([hidden_size],
**layer_kwargs)

else:
assert self.strategy == QuantizationStrategy.TENSOR
scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
set_weight_attrs(scale, {
"needs_scalar_to_array": True,
**layer_kwargs
})
layer.register_parameter("w2_scale", scale)

# INPUT SCALE
if self.is_static_input_scheme:
scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
set_weight_attrs(scale, {
"needs_scalar_to_array": True,
**layer_kwargs
})
layer.register_parameter("a2_scale", scale)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
return fused_moe(x,
layer.w13_weight,
layer.w2_weight,
router_logits,
top_k,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
use_int8=True,
w1_scale=layer.w13_scale,
w2_scale=layer.w2_scale,
a1_scale=layer.a13_scale,
a2_scale=layer.a2_scale)



class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models.

Expand Down Expand Up @@ -177,10 +311,12 @@ def __init__(
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
UnquantizedFusedMoEMethod())
elif quant_config.quant_format == CompressionFormat.int_quantized.value:
self.quant_method: Optional[QuantizeMethodBase] = (
W8A8QuantizedFusedMoEMethod())
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
Expand All @@ -191,24 +327,29 @@ def __init__(
hidden_size=hidden_size,
intermediate_size=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=self.weight_loader)
weight_loader=self.weight_loader,
quant_config=quant_config)

def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: int, expert_id: int):
param_data = param.data
if isinstance(self.quant_method, W8A8QuantizedFusedMoEMethod):
weight_quant_strategy = self.quant_method.quant_config.target_scheme_map['Linear']['weights'].strategy
else:
weight_quant_strategy = None

# Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name:
if param_data[expert_id] != 1 and (param_data[expert_id] -
loaded_weight).abs() > 1e-5:
loaded_weight.to(param_data.device)).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param_data[expert_id]} "
f"vs. {loaded_weight}")
param_data[expert_id] = loaded_weight
# Weight scales
elif "weight_scale" in weight_name:
elif "weight_scale" in weight_name and weight_quant_strategy == QuantizationStrategy.TENSOR:
# If we are in merged column case (gate_up_proj)
# shard_id 0 == gate_proj / w1
# shard_id 2 == up_proj / w3
Expand Down Expand Up @@ -237,7 +378,10 @@ def weight_loader(self, param: torch.nn.Parameter,
shard_size, :] = loaded_weight[shard, :]
# w2, down_proj case: Load into only shard of w2.
elif shard_id == 1:
param_data[expert_id, :, :] = loaded_weight[:, shard]
if "weight_scale" in weight_name and weight_quant_strategy == QuantizationStrategy.CHANNEL:
param_data[expert_id, :, :] = loaded_weight
else:
param_data[expert_id, :, :] = loaded_weight[:, shard]
else:
raise ValueError(
f"Shard id must be in [0,1,2] but got {shard_id}")
Expand Down
12 changes: 9 additions & 3 deletions vllm/model_executor/layers/quantization/utils/w8a8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,15 @@ def create_per_tensor_scale_param(

def create_per_channel_scale_param(output_partition_sizes: List[int],
**extra_weight_attrs) -> Parameter:
scale = Parameter(torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
requires_grad=False)
num_expert = extra_weight_attrs.get("num_experts", 1)
if num_expert == 1:
scale = Parameter(torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
requires_grad=False)
else:
scale = Parameter(torch.empty((num_expert, sum(output_partition_sizes), 1),
dtype=torch.float32),
requires_grad=False)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"output_dim": 0, **extra_weight_attrs})
return scale
Expand Down
Loading