Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,23 @@ def prepare(i: int):
input_gating.copy_(gating_output[i])

def run():
fused_moe(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
override_config=config,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
from vllm.model_executor.layers.fused_moe import override_config
with override_config(config):
fused_moe(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)

# JIT compilation & warmup
run()
Expand Down
4 changes: 2 additions & 2 deletions tests/compile/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
@pytest.mark.parametrize(
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
[
("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASH_ATTN", "generate", True),
("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASHINFER", "generate", True),
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
["--quantization", "compressed-tensors"
], 1, 1, "FLASH_ATTN", "generate", True),
("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True),
("ibm/PowerMoE-3b", [], 1, 2, "FLASH_ATTN", "generate", True),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it does not work with FLASHINFER due to some head size config.

# TODO: add multi-modality test for llava
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
])
Expand Down
21 changes: 10 additions & 11 deletions tests/kernels/test_awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
import pytest
import torch

import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe,
torch_moe_single)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
awq_marlin_quantize)
Expand Down Expand Up @@ -81,7 +80,7 @@ def test_fused_marlin_moe_awq(
score = torch.randn((m, e), device="cuda", dtype=dtype)

topk_weights, topk_ids = fused_topk(a, score, topk, False)
marlin_output = fused_marlin_moe(
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,
qweight2,
Expand Down Expand Up @@ -150,14 +149,14 @@ def test_single_marlin_moe_multiply_awq(

score = torch.randn((m, e), device="cuda", dtype=dtype)

marlin_output = single_marlin_moe(a,
qweight,
scales,
score,
topk,
renormalize=False,
w_zeros=zp,
num_bits=num_bits)
marlin_output = torch.ops.vllm.single_marlin_moe(a,
qweight,
scales,
score,
topk,
renormalize=False,
w_zeros=zp,
num_bits=num_bits)

torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)

Expand Down
7 changes: 3 additions & 4 deletions tests/kernels/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev,
torch_moe, torch_moe_single)
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, moe_align_block_size)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
Expand Down Expand Up @@ -191,7 +190,7 @@ def test_fused_marlin_moe(
topk,
renormalize=False,
)
marlin_output = fused_marlin_moe(
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,
qweight2,
Expand Down Expand Up @@ -306,7 +305,7 @@ def test_single_marlin_moe_multiply(
sort_indices = stack_and_dev(sort_indices_l)

score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(
marlin_output = torch.ops.vllm.single_marlin_moe(
a,
qweight,
scales,
Expand Down
26 changes: 22 additions & 4 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,41 @@
from contextlib import contextmanager
from typing import Any, Dict, Optional

from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.triton_utils import HAS_TRITON

_config: Optional[Dict[str, Any]] = None


@contextmanager
def override_config(config):
global _config
old_config = _config
_config = config
yield
_config = old_config


def get_config() -> Optional[Dict[str, Any]]:
return _config


__all__ = [
"FusedMoE",
"FusedMoEMethodBase",
"FusedMoeWeightScaleSupported",
"override_config",
"get_config",
]

if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe, single_marlin_moe)
import vllm.model_executor.layers.fused_moe # noqa
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)

__all__ += [
"fused_marlin_moe",
"single_marlin_moe",
"fused_moe",
"fused_topk",
"fused_experts",
Expand Down
51 changes: 42 additions & 9 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Fused MoE utilities for GPTQ."""
import functools
from typing import Any, Dict, Optional
from typing import Optional

import torch

Expand All @@ -18,6 +18,7 @@ def get_scalar_type(num_bits: int, has_zp: bool):
return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128


@torch.library.custom_op("vllm::single_marlin_moe", mutates_args=[])
def single_marlin_moe(
hidden_states: torch.Tensor,
w: torch.Tensor,
Expand All @@ -28,7 +29,6 @@ def single_marlin_moe(
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
Expand All @@ -49,8 +49,6 @@ def single_marlin_moe(
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_bits (bool): The number of bits in expert weights quantization.

Returns:
Expand Down Expand Up @@ -79,7 +77,6 @@ def single_marlin_moe(
w.shape,
topk_ids.shape[1],
None,
override_config=override_config,
is_marlin=True)
config = get_config_func(M)

Expand Down Expand Up @@ -122,6 +119,24 @@ def single_marlin_moe(
return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1)


@single_marlin_moe.register_fake
def _(
hidden_states: torch.Tensor,
w: torch.Tensor,
scales: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
g_idx: Optional[torch.Tensor] = None,
sort_indices: Optional[torch.Tensor] = None,
w_zeros: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
return torch.empty_like(hidden_states)


@torch.library.custom_op("vllm::fused_marlin_moe", mutates_args=[])
def fused_marlin_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
Expand All @@ -137,7 +152,6 @@ def fused_marlin_moe(
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
override_config: Optional[Dict[str, Any]] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
Expand All @@ -161,8 +175,6 @@ def fused_marlin_moe(
permutation.
- topk_weights (torch.Tensor): Top-k weights.
- topk_ids (torch.Tensor): Indices of topk-k elements.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1.
- w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2.
- num_bits (bool): The number of bits in expert weights quantization.
Expand Down Expand Up @@ -209,7 +221,6 @@ def fused_marlin_moe(
w2.shape,
topk_ids.shape[1],
None,
override_config=override_config,
is_marlin=True,
)
config = get_config_func(M)
Expand Down Expand Up @@ -311,3 +322,25 @@ def fused_marlin_moe(

return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)


@fused_marlin_moe.register_fake
def _(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
g_idx1: Optional[torch.Tensor] = None,
g_idx2: Optional[torch.Tensor] = None,
sort_indices1: Optional[torch.Tensor] = None,
sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None,
num_bits: int = 8,
is_k_full: bool = True,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
Loading