Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
VERBOSE: bool = False
VLLM_SYNC_SERVER_ACCUM_REQUESTS: int = 1
VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS: int = 1
VLLM_MOE_PADDING: bool = True

# The begin-* and end* here are used by the documentation generator
# to extract the used env vars.
Expand Down Expand Up @@ -229,6 +230,10 @@
# Poll for new requests every this many steps
"VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS":
lambda: int(os.getenv("VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS", "1")),

# Pad the weight for moe kernel or not
"VLLM_MOE_PADDING":
lambda: bool(int(os.getenv("VLLM_MOE_PADDING", "1"))),
}

# end-env-vars-definition
Expand Down
9 changes: 6 additions & 3 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

import vllm._moe_C as moe_kernels
from vllm import _custom_ops as ops
from vllm import envs
from vllm.logger import init_logger

logger = init_logger(__name__)
padding_size = 128 if envs.VLLM_MOE_PADDING else 0


@triton.jit
Expand Down Expand Up @@ -262,7 +264,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2],
B.shape[2] - padding_size,
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
Expand Down Expand Up @@ -365,7 +367,8 @@ def fused_experts(hidden_states: torch.Tensor,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None):
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert hidden_states.shape[
1] == w1.shape[2] - padding_size, "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
Expand All @@ -381,7 +384,7 @@ def fused_experts(hidden_states: torch.Tensor,
config = override_config
else:
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2],
configs = get_moe_configs(E, w2.shape[2] - padding_size,
"float8" if use_fp8 else None)

if configs:
Expand Down
9 changes: 9 additions & 0 deletions vllm/model_executor/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
from typing import Iterable, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn
from transformers import MixtralConfig

from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -181,6 +183,13 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
def process_weights_after_loading(self):
# Fp8 is the only case where we need to process after loading.
if not self.use_fp8:
if envs.VLLM_MOE_PADDING:
self.w13_weight = nn.Parameter(F.pad(self.w13_weight.data,
(0, 128), "constant", 0),
requires_grad=False)
self.w2_weight = nn.Parameter(F.pad(self.w2_weight.data,
(0, 128), "constant", 0),
requires_grad=False)
return

# If checkpoint is fp16, quantize here.
Expand Down