diff --git a/vllm/envs.py b/vllm/envs.py index 45dae28347e5..bd6abca2629e 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -40,6 +40,7 @@ VERBOSE: bool = False VLLM_SYNC_SERVER_ACCUM_REQUESTS: int = 1 VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS: int = 1 + VLLM_MOE_PADDING: bool = True # The begin-* and end* here are used by the documentation generator # to extract the used env vars. @@ -229,6 +230,10 @@ # Poll for new requests every this many steps "VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS": lambda: int(os.getenv("VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS", "1")), + + # Pad the weight for moe kernel or not + "VLLM_MOE_PADDING": + lambda: bool(int(os.getenv("VLLM_MOE_PADDING", "1"))), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7a3c6ec77335..e759d63b588b 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -10,9 +10,11 @@ import vllm._moe_C as moe_kernels from vllm import _custom_ops as ops +from vllm import envs from vllm.logger import init_logger logger = init_logger(__name__) +padding_size = 128 if envs.VLLM_MOE_PADDING else 0 @triton.jit @@ -262,7 +264,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, expert_ids, num_tokens_post_padded, B.shape[1], - B.shape[2], + B.shape[2] - padding_size, sorted_token_ids.shape[0], topk_ids.numel(), A.stride(0), @@ -365,7 +367,8 @@ def fused_experts(hidden_states: torch.Tensor, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None): # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[ + 1] == w1.shape[2] - padding_size, "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -381,7 +384,7 @@ def fused_experts(hidden_states: torch.Tensor, config = override_config else: # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2], + configs = get_moe_configs(E, w2.shape[2] - padding_size, "float8" if use_fp8 else None) if configs: diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 2f4237339486..c34077fa2bfa 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -24,10 +24,12 @@ from typing import Iterable, List, Optional, Tuple import torch +import torch.nn.functional as F from torch import nn from transformers import MixtralConfig from vllm import _custom_ops as ops +from vllm import envs from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, @@ -181,6 +183,13 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, def process_weights_after_loading(self): # Fp8 is the only case where we need to process after loading. if not self.use_fp8: + if envs.VLLM_MOE_PADDING: + self.w13_weight = nn.Parameter(F.pad(self.w13_weight.data, + (0, 128), "constant", 0), + requires_grad=False) + self.w2_weight = nn.Parameter(F.pad(self.w2_weight.data, + (0, 128), "constant", 0), + requires_grad=False) return # If checkpoint is fp16, quantize here.