Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,11 @@ def apply(
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
):
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens

import deep_gemm as dg
assert hidden_states.ndim == 3
assert self.block_shape is not None
Expand All @@ -287,7 +290,6 @@ def apply(
masked_m=expert_num_tokens,
expected_m=expected_m)

assert expert_num_tokens is not None
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
expert_num_tokens)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,12 @@ def apply(
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
):
experts = (self.batched_deep_gemm_experts
if self.allow_deep_gemm else self.batched_triton_experts)
assert experts is not None
experts.apply(output, hidden_states, w1, w2, topk_ids, activation,
global_num_experts, expert_map, w1_scale, w2_scale,
w1_zp, w2_zp, a1q_scale, a2_scale, workspace13,
workspace2, expert_num_tokens)
workspace2, expert_tokens_meta)
8 changes: 7 additions & 1 deletion vllm/model_executor/layers/fused_moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,17 @@ def apply(
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
):
assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE"
assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE"

expert_num_tokens = None
if expert_tokens_meta is not None:
expert_num_tokens = expert_tokens_meta.expert_num_tokens

activation_callable = lambda o, i: self.activation(activation, o, i)

in_dtype = hidden_states.dtype
run_cutlass_moe_fp8(
output, hidden_states, w1, w2, topk_ids, activation_callable,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/deep_gemm_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def apply(
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
):
import deep_gemm as dg
assert self.block_shape is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def _do_dispatch(self, tokens: torch.Tensor,

has_scales = token_scales is not None

(num_tokens_per_rank, num_tokens_per_rdma_rank, expert_num_tokens,
is_token_in_rank, event) = self.buffer.get_dispatch_layout(
(num_tokens_per_rank, num_tokens_per_rdma_rank,
dispatch_expert_num_tokens, is_token_in_rank,
event) = self.buffer.get_dispatch_layout(
topk_idx=rank_topk_ids,
num_experts=num_experts,
previous_event=None,
Expand All @@ -83,7 +84,7 @@ def _do_dispatch(self, tokens: torch.Tensor,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=expert_num_tokens,
num_tokens_per_expert=dispatch_expert_num_tokens,
topk_idx=rank_topk_ids,
topk_weights=rank_topk_weights,
# expert_alignment rounds the number of tokens per expert
Expand Down Expand Up @@ -115,7 +116,13 @@ def _do_dispatch(self, tokens: torch.Tensor,
num_experts - 1 if self.rank_expert_offset == 0 else 0,
expert_topk_ids + self.rank_expert_offset)

return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
# Makes a GPU-CPU copy.
# TODO (varun): Maybe it is better to re-compute the expert_num_tokens
# on GPU.
expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list(
expert_num_tokens_per_expert_list, device=expert_x.device)

return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights)

def prepare(
Expand All @@ -129,8 +136,9 @@ def prepare(
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:

if apply_router_weight_on_input:
topk = topk_ids.size(1)
Expand All @@ -149,7 +157,7 @@ def prepare(
)
if a1q_scale is not None and a1q_scale.numel() == 1:
a1q_scale = a1q_scale.view(1, 1)
(expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
(expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
tokens=a1q,
token_scales=a1q_scale,
Expand All @@ -159,7 +167,7 @@ def prepare(
else:
# DeepEP kernels only support dispatching per-token-quant
# quantization. dispatch in bfloat16.
(expert_x, _, expert_num_tokens, expert_topk_ids,
(expert_x, _, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
tokens=a1,
token_scales=None,
Expand All @@ -176,7 +184,7 @@ def prepare(
per_act_token_quant=False,
block_shape=quant_config.block_shape)

return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids,
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights)

def _apply_weights_and_reduce(self, num_tokens: int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ def prepare(
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:

hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
Expand Down Expand Up @@ -158,7 +159,10 @@ def prepare(
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)

return (expert_x, expert_x_scale, expert_num_tokens, None, None)
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)

return (expert_x, expert_x_scale, expert_tokens_meta, None, None)

def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
Expand Down
76 changes: 32 additions & 44 deletions vllm/model_executor/layers/fused_moe/fused_batched_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,9 @@ def prepare(
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
assert a1.dim() == 2
assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0)
Expand Down Expand Up @@ -587,7 +588,10 @@ def prepare(

assert b_a1_scale is None or b_a1_scale.ndim == 3

return b_a1, b_a1_scale, tokens_per_expert, None, None
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=tokens_per_expert, expert_num_tokens_cpu=None)

return b_a1, b_a1_scale, expert_tokens_meta, None, None

def finalize(
self,
Expand Down Expand Up @@ -691,28 +695,19 @@ def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
else:
return t.to(f32) * group_broadcast(scale, t.shape)

def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
):
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
assert hidden_states.dim() == 3
assert expert_num_tokens is not None
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens

num_local_experts = w1.size(0)
assert num_local_experts == w1.size(0), (
Expand Down Expand Up @@ -895,26 +890,16 @@ def workspace_shapes(
output = (num_experts, max_num_tokens * num_dp, K)
return (workspace13, workspace2, output, a.dtype)

def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
):
def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w2: torch.Tensor, topk_ids: torch.Tensor,
activation: str, global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor], w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor], a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor], workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata]):
# Check constraints.
if self.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), (
Expand All @@ -931,6 +916,9 @@ def apply(
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
]
assert expert_tokens_meta is not None

expert_num_tokens = expert_tokens_meta.expert_num_tokens

E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
hidden_states, w1, w2, topk_ids)
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1598,7 +1598,7 @@ def apply(
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
):
# Check constraints.
if self.use_int4_w4a16:
Expand Down
43 changes: 34 additions & 9 deletions vllm/model_executor/layers/fused_moe/modular_kernel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from math import prod
from typing import Optional, final
Expand Down Expand Up @@ -95,6 +96,26 @@ class FusedMoEActivationFormat(Enum):
BatchedExperts = "batched_experts",


@dataclass
class ExpertTokensMetadata:
"""
Metadata regarding expert-token routing.
"""
expert_num_tokens: torch.Tensor
expert_num_tokens_cpu: Optional[torch.Tensor]

@staticmethod
def make_from_list(expert_num_tokens_list: list[int],
device: str) -> "ExpertTokensMetadata":
expert_num_tokens_cpu = torch.tensor(expert_num_tokens_list,
device="cpu",
dtype=torch.int32)
return ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens_cpu.to(device,
non_blocking=True),
expert_num_tokens_cpu=expert_num_tokens_cpu)


# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC):
"""
Expand All @@ -114,8 +135,9 @@ def prepare(
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
Optional[torch.Tensor], Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
Perform any quantization (and/or) dispatching needed
for this kernel.
Expand All @@ -134,7 +156,8 @@ def prepare(
Returns a tuple of:
- quantized + dispatched a.
- quantized + dispatched a1_scales.
- Optional tensor as big as number of local experts that contains the
- Optional ExpertTokensMetadata containing gpu/cpu tensors
as big as the number of local experts with the information about the
number of tokens assigned to each local expert.
- Optional dispatched expert topk IDs
- Optional dispatched expert topk weight
Expand Down Expand Up @@ -318,7 +341,7 @@ def apply(
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_num_tokens: Optional[torch.Tensor],
expert_tokens_meta: Optional[ExpertTokensMetadata],
):
"""
This function computes the intermediate result of a Mixture of Experts
Expand Down Expand Up @@ -351,8 +374,10 @@ def apply(
must be large enough to hold output of either MoE gemm.
- workspace2 (torch.Tensor): A scratch tensor used for the activation
function.
- expert_num_tokens: An optional tensor containing the number of tokens
assigned to each expert when using batched experts format input.
- expert_tokens_meta (Optional[ExpertTokensMetadata]) - An optional
ExpertTokensMetadata object containing gpu/cpu tensors
as big as the number of local experts with the information about the
number of tokens assigned to each local expert.
"""
raise NotImplementedError

Expand Down Expand Up @@ -458,7 +483,7 @@ def forward(
if global_num_experts == -1:
global_num_experts = local_num_experts

(a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1,
a1_scale,
Expand Down Expand Up @@ -542,7 +567,7 @@ def forward(
a2_scale=a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
expert_tokens_meta=expert_tokens_meta,
)
else:
# The leading output dimension may not be equal to M, so
Expand Down Expand Up @@ -589,7 +614,7 @@ def forward(
a2_scale=curr_a2_scale,
workspace13=workspace13,
workspace2=workspace2,
expert_num_tokens=expert_num_tokens,
expert_tokens_meta=expert_tokens_meta,
)

self.prepare_finalize.finalize(output, fused_out, topk_weights,
Expand Down
Loading