From c33e37a2153b9957d65337028ea7b272dfbac3cf Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 21 Jan 2024 23:26:12 -0800 Subject: [PATCH 01/17] Fused MOE for Mixtral --- vllm/model_executor/layers/moe.py | 452 ++++++++++++++++++++++++++ vllm/model_executor/models/mixtral.py | 161 ++------- 2 files changed, 487 insertions(+), 126 deletions(-) create mode 100644 vllm/model_executor/layers/moe.py diff --git a/vllm/model_executor/layers/moe.py b/vllm/model_executor/layers/moe.py new file mode 100644 index 000000000000..39c307a18ff4 --- /dev/null +++ b/vllm/model_executor/layers/moe.py @@ -0,0 +1,452 @@ +from typing import Tuple + +import torch +from torch import nn +import torch.nn.functional as F +import triton +import triton.language as tl + +from vllm._C import ops +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.utils import set_weight_attrs + + +class MoE(nn.Module): + """a tensor-parallel MOE implementation that shards each expert across + all ranks. + + Each expert's weights are sharded across all ranks. The forward pass + will first expand and group the hidden states by experts, then compute + the per-rank MLP output of each expert using grouped gemm, and finally + reduce the output across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + ): + super().__init__() + tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = num_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size // tp_size + + self.gate = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + linear_method=None) + + self.w1s = nn.Parameter( + torch.empty(self.num_total_experts, + self.intermediate_size, + self.hidden_size, + device="cuda")) + self.w2s = nn.Parameter( + torch.empty(self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device="cuda")) + self.w3s = nn.Parameter( + torch.empty(self.num_total_experts, + self.intermediate_size, + self.hidden_size, + device="cuda")) + + # TODO: Currently this is fake data but should be + # [self.w1s, self.w3s] concatenated along the intermediate + # size dimension. + self.ws = nn.Parameter( + torch.randn(self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + device="cuda")) + + set_weight_attrs(self.w1s, { + "weight_loader": self.weight_loader, + "tp_type": "column" + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + "tp_type": "row" + }) + set_weight_attrs(self.w3s, { + "weight_loader": self.weight_loader, + "tp_type": "column" + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + if getattr(param, "tp_type", None) == "row": + shard_size = param_data.shape[2] + w_shard = loaded_weight[:,(tp_rank * shard_size): (tp_rank+1) * shard_size] + else: + shard_size = param_data.shape[1] + w_shard = loaded_weight[(tp_rank * shard_size): (tp_rank+1) * shard_size,:] + assert param_data[expert_id].shape == w_shard.shape, \ + f"{param_data[expert_id].shape}, {w_shard.shape}" + param_data[expert_id].copy_(w_shard) + + def fused_moe_infer(self, hidden_states: torch.Tensor, + selected_experts: torch.Tensor, + routing_weights: torch.Tensor) -> torch.Tensor: + return fused_moe(hidden_states, + # self.w1s, + self.ws, + self.w2s, + # self.w3s, + routing_weights, + selected_experts, + inplace=True) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (batch * sequence_length, n_experts) + router_logits, _ = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, + self.top_k, + dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = self.fused_moe_infer(hidden_states, + selected_experts, + routing_weights) + + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(batch_size, sequence_length, + hidden_size) + + + +@triton.jit +def fused_moe_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + M, + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_weight, + stride_token_id, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated, + and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to. + - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. + This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` + by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + + # + off_experts = tl.load(expert_ids_ptr + pid_m) * stride_be + b_ptrs = b_ptr + off_experts + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token * stride_weight, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +# TODO: rewrite in CPP +def alig_block_size( + topk_ids: torch.Tensor, block_size: int, + num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor): + """ + Aligns the token distribution across experts to be compatible with block size for matrix multiplication. + Parameters: + - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token. + - block_size: The block size used in block matrix multiplication. + - num_experts: The total number of experts. + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert. + - expert_ids: A tensor indicating the assigned expert index for each block. + - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size. + This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. + Padding ensures that during block matrix multiplication, the dimensions align correctly. + Example: + Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4: + - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens. + - As block_size is 4, we pad 1 token for each expert. + - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. + - Then append padding tokens [1, 2, 3, 4] at the end, resulting in [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3 | 1, 2, 3, 4]. + - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 13, 1, 7, 11, 14, 2, 5, 8, 15]. + Tokens 12-15 are non-existent (padding) and are ignored in the subsequent matrix multiplication. + - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. + """ + cnts = torch.zeros(topk_ids.shape[0], + num_experts, + dtype=topk_ids.dtype, + device=topk_ids.device) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + tokens_per_expert_post_alig = torch.floor_divide( + tokens_per_expert + block_size - 1, block_size) * block_size + + cumsum = tokens_per_expert_post_alig.cumsum(0) + num_tokens_post_padded = cumsum[-1].clone() + max_tokens_post_padded = ( + topk_ids.numel() + num_experts * + (block_size - 1)) if topk_ids.numel() > num_experts else ( + topk_ids.numel() + 1) * block_size + + # we just store the expert id of each single block but each token, + # as each token in the same block will be process by the same expert. + expert_ids = torch.zeros(max( + (max_tokens_post_padded + block_size - 1) // block_size + 1, + num_experts), + dtype=topk_ids.dtype, + device=topk_ids.device) + + cumsum.div_(block_size, rounding_mode="floor") + ones = torch.ones_like(expert_ids) + expert_ids.scatter_add_(0, cumsum, ones) + expert_ids = expert_ids.cumsum(0) + + cumsum = (tokens_per_expert_post_alig - tokens_per_expert).cumsum(0) + + padded_tokens = torch.zeros(max_tokens_post_padded - topk_ids.numel(), + dtype=topk_ids.dtype, + device=topk_ids.device) + ones = torch.ones_like(padded_tokens) + padded_tokens.scatter_add_(0, cumsum[:-1], ones) + padded_tokens = padded_tokens.cumsum(0) + + sorted_token_ids = torch.cat([topk_ids.view(-1), padded_tokens]).argsort() + + return sorted_token_ids, expert_ids, num_tokens_post_padded + + +def fused_moe(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace=False): + """ + This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. + We used three shared cache variables across all layers to save gpu memory, which is more effective in a static graph context. + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - topk_weights (torch.Tensor): The weights for the top-k selected experts. + - topk_ids (torch.Tensor): The indices of the top-k selected experts. + - inplace (bool): If True, perform the operation in-place. Defaults to False. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Incompatible dimensions" + assert hidden_states.is_contiguous(), "Matrix A must be contiguous" + assert w1.is_contiguous(), "Matrix B must be contiguous" + assert hidden_states.dtype in [torch.float16, torch.bfloat16] + M, K = hidden_states.shape + E, N, K = w1.shape + + config = { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + } + + if topk_ids.numel() <= w1.shape[0]: + config = { + 'BLOCK_SIZE_M': 16, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 1 + } + + intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + + sorted_token_ids, expert_ids, num_tokens_post_padded = alig_block_size( + topk_ids, config['BLOCK_SIZE_M'], E) + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ + 'BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + + fused_moe_kernel[grid]( + hidden_states, + w1, + intermediate_cache1, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + M, + N, + K, + sorted_token_ids.shape[0], + topk_ids.numel(), + hidden_states.stride(0), + hidden_states.stride(1), + w1.stride(0), + w1.stride(2), + w1.stride(1), + intermediate_cache1.stride(1), + intermediate_cache1.stride(2), + topk_weights.stride(1), + sorted_token_ids.stride(0), + MUL_ROUTED_WEIGHT=False, + top_k=topk_ids.shape[1], + compute_type=tl.bfloat16 + if hidden_states.dtype == torch.bfloat16 else tl.float16, + **config, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ + 'BLOCK_SIZE_M']) * triton.cdiv(w2.shape[1], META['BLOCK_SIZE_N']), ) + fused_moe_kernel[grid]( + intermediate_cache2, + w2, + intermediate_cache3, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + M, + w2.shape[1], + w2.shape[2], + sorted_token_ids.shape[0], + topk_ids.numel(), + intermediate_cache2.stride(0), + intermediate_cache2.stride(1), + w2.stride(0), + w2.stride(2), + w2.stride(1), + intermediate_cache3.stride(1), + intermediate_cache3.stride(2), + topk_weights.stride(1), + sorted_token_ids.stride(0), + MUL_ROUTED_WEIGHT=True, + top_k=1, # + compute_type=tl.bfloat16 + if hidden_states.dtype == torch.bfloat16 else tl.float16, + **config, + ) + if inplace: + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=hidden_states) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1) \ No newline at end of file diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index a8dadce24aa1..91196cb2e0a2 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -1,3 +1,4 @@ + # coding=utf-8 # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py @@ -23,10 +24,7 @@ """Inference-only Mixtral model.""" from typing import List, Optional, Tuple -import numpy as np - import torch -import torch.nn.functional as F from torch import nn from transformers import MixtralConfig @@ -35,17 +33,15 @@ from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, - ReplicatedLinear, QKVParallelLinear, RowParallelLinear) +from vllm.model_executor.layers.moe import MoE from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) @@ -54,112 +50,6 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] -class MixtralMLP(nn.Module): - - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - linear_method: Optional[LinearMethodBase] = None, - ) -> None: - super().__init__() - self.num_experts = num_experts - self.ffn_dim = intermediate_size - self.hidden_dim = hidden_size - - self.w1 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - linear_method=linear_method) - self.w2 = ReplicatedLinear(self.ffn_dim, - self.hidden_dim, - bias=False, - linear_method=linear_method) - self.w3 = ReplicatedLinear(self.hidden_dim, - self.ffn_dim, - bias=False, - linear_method=linear_method) - - # TODO: Use vllm's SiluAndMul - self.act_fn = nn.SiLU() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - w1_out, _ = self.w1(hidden_states) - w1_out = self.act_fn(w1_out) - w3_out, _ = self.w3(hidden_states) - current_hidden_states = w1_out * w3_out - current_hidden_states, _ = self.w2(current_hidden_states) - return current_hidden_states - - -class MixtralMoE(nn.Module): - - def __init__( - self, - config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, - ): - super().__init__() - self.config = config - self.rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - self.num_total_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - if self.tp_size > self.num_total_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.num_total_experts}.") - # Split experts equally between ranks - self.expert_indicies = np.array_split(range( - self.num_total_experts), self.tp_size)[self.rank].tolist() - if not self.expert_indicies: - raise ValueError( - f"Rank {self.rank} has no experts assigned to it.") - - self.experts = nn.ModuleList([ - MixtralMLP(self.num_total_experts, - config.hidden_size, - config.intermediate_size, - linear_method=linear_method) - if idx in self.expert_indicies else None - for idx in range(self.num_total_experts) - ]) - self.gate = ReplicatedLinear(config.hidden_size, - self.num_total_experts, - bias=False, - linear_method=None) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) - router_logits, _ = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, - self.top_k, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - - final_hidden_states = None - for expert_idx in self.expert_indicies: - expert_layer = self.experts[expert_idx] - expert_mask = (selected_experts == expert_idx) - expert_weights = (routing_weights * expert_mask).sum(dim=-1, - keepdim=True) - - current_hidden_states = expert_layer(hidden_states).mul_( - expert_weights) - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states.add_(current_hidden_states) - - return tensor_model_parallel_all_reduce(final_hidden_states).view( - batch_size, sequence_length, hidden_dim) - - class MixtralAttention(nn.Module): def __init__(self, @@ -257,8 +147,10 @@ def __init__( rope_theta=rope_theta, sliding_window=config.sliding_window, linear_method=linear_method) - self.block_sparse_moe = MixtralMoE(config=config, - linear_method=linear_method) + self.block_sparse_moe = MoE(num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -378,6 +270,13 @@ def load_weights(self, ("qkv_proj", "v_proj", "v"), ] + expert_params_mapping = [ + # (param_name, weight_name, expert_id) + (f"{weight_name}s", f"experts.{expert_id}.{weight_name}.weight", + expert_id) for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, @@ -387,6 +286,7 @@ def load_weights(self, fall_back_to_pt=False): if "rotary_emb.inv_freq" in name: continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue @@ -399,14 +299,23 @@ def load_weights(self, weight_loader(param, loaded_weight, shard_id) break else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Skip experts that are not assigned to this worker. - if ("block_sparse_moe.experts." in name - and name not in params_dict): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if ("block_sparse_moe.experts." in name + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From 03f5dc8c25fd020408bf18d31b8fff315bc9e779 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 22 Jan 2024 02:54:19 -0800 Subject: [PATCH 02/17] fix weight loading --- vllm/model_executor/layers/moe.py | 46 ++++++++------------------- vllm/model_executor/models/mixtral.py | 7 ++-- 2 files changed, 17 insertions(+), 36 deletions(-) diff --git a/vllm/model_executor/layers/moe.py b/vllm/model_executor/layers/moe.py index 39c307a18ff4..bc9da4adbced 100644 --- a/vllm/model_executor/layers/moe.py +++ b/vllm/model_executor/layers/moe.py @@ -44,9 +44,9 @@ def __init__( bias=False, linear_method=None) - self.w1s = nn.Parameter( + self.ws = nn.Parameter( torch.empty(self.num_total_experts, - self.intermediate_size, + 2 * self.intermediate_size, self.hidden_size, device="cuda")) self.w2s = nn.Parameter( @@ -54,47 +54,27 @@ def __init__( self.hidden_size, self.intermediate_size, device="cuda")) - self.w3s = nn.Parameter( - torch.empty(self.num_total_experts, - self.intermediate_size, - self.hidden_size, - device="cuda")) - - # TODO: Currently this is fake data but should be - # [self.w1s, self.w3s] concatenated along the intermediate - # size dimension. - self.ws = nn.Parameter( - torch.randn(self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - device="cuda")) - set_weight_attrs(self.w1s, { + set_weight_attrs(self.ws, { "weight_loader": self.weight_loader, - "tp_type": "column" }) set_weight_attrs(self.w2s, { "weight_loader": self.weight_loader, - "tp_type": "row" - }) - set_weight_attrs(self.w3s, { - "weight_loader": self.weight_loader, - "tp_type": "column" }) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - expert_id: int): + weight_name: str, expert_id: int): tp_rank = get_tensor_model_parallel_rank() param_data = param.data - if getattr(param, "tp_type", None) == "row": - shard_size = param_data.shape[2] - w_shard = loaded_weight[:,(tp_rank * shard_size): (tp_rank+1) * shard_size] - else: - shard_size = param_data.shape[1] - w_shard = loaded_weight[(tp_rank * shard_size): (tp_rank+1) * shard_size,:] - assert param_data[expert_id].shape == w_shard.shape, \ - f"{param_data[expert_id].shape}, {w_shard.shape}" - param_data[expert_id].copy_(w_shard) + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank+1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id,0:shard_size,:] = loaded_weight[shard,:] + if weight_name.endswith("w3.weight"): + param_data[expert_id,shard_size:2*shard_size,:] = loaded_weight[shard,:] + if weight_name.endswith("w2.weight"): + param_data[expert_id,:,:] = loaded_weight[:,shard] + def fused_moe_infer(self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 91196cb2e0a2..1d9d39445a57 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -272,8 +272,9 @@ def load_weights(self, expert_params_mapping = [ # (param_name, weight_name, expert_id) - (f"{weight_name}s", f"experts.{expert_id}.{weight_name}.weight", - expert_id) for expert_id in range(self.config.num_local_experts) + (f"ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.config.num_local_experts) for weight_name in ["w1", "w2", "w3"] ] @@ -305,7 +306,7 @@ def load_weights(self, name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, expert_id=expert_id) + weight_loader(param, loaded_weight, weight_name, expert_id=expert_id) break else: # Skip loading extra bias for GPTQ models. From 2867f346274c1f850597a42f447b2ef7e4266592 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 22 Jan 2024 03:02:57 -0800 Subject: [PATCH 03/17] lint and cleanup --- vllm/model_executor/layers/moe.py | 4 +--- vllm/model_executor/models/mixtral.py | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/moe.py b/vllm/model_executor/layers/moe.py index bc9da4adbced..57e027442f47 100644 --- a/vllm/model_executor/layers/moe.py +++ b/vllm/model_executor/layers/moe.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch from torch import nn import torch.nn.functional as F @@ -429,4 +427,4 @@ def fused_moe(hidden_states: torch.Tensor, dim=1, out=hidden_states) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) \ No newline at end of file + dim=1) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 1d9d39445a57..9613a5d727b9 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -1,4 +1,3 @@ - # coding=utf-8 # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py @@ -272,7 +271,7 @@ def load_weights(self, expert_params_mapping = [ # (param_name, weight_name, expert_id) - (f"ws" if weight_name in ["w1", "w3"] else "w2s", + ("ws" if weight_name in ["w1", "w3"] else "w2s", f"experts.{expert_id}.{weight_name}.weight", expert_id) for expert_id in range(self.config.num_local_experts) for weight_name in ["w1", "w2", "w3"] From a82f73f072ad2024b86508c35b6d4fb178f9c95f Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 22 Jan 2024 03:14:02 -0800 Subject: [PATCH 04/17] style --- vllm/model_executor/layers/moe.py | 28 +++++++++++++-------------- vllm/model_executor/models/mixtral.py | 5 ++++- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/moe.py b/vllm/model_executor/layers/moe.py index 57e027442f47..30a88f28d344 100644 --- a/vllm/model_executor/layers/moe.py +++ b/vllm/model_executor/layers/moe.py @@ -65,26 +65,25 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, tp_rank = get_tensor_model_parallel_rank() param_data = param.data shard_size = self.intermediate_size - shard = slice(tp_rank * shard_size, (tp_rank+1) * shard_size) + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) if weight_name.endswith("w1.weight"): - param_data[expert_id,0:shard_size,:] = loaded_weight[shard,:] + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w3.weight"): - param_data[expert_id,shard_size:2*shard_size,:] = loaded_weight[shard,:] + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w2.weight"): - param_data[expert_id,:,:] = loaded_weight[:,shard] - + param_data[expert_id, :, :] = loaded_weight[:, shard] def fused_moe_infer(self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: - return fused_moe(hidden_states, - # self.w1s, - self.ws, - self.w2s, - # self.w3s, - routing_weights, - selected_experts, - inplace=True) + return fused_moe( + hidden_states, + self.ws, + self.w2s, + routing_weights, + selected_experts, + inplace=True) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_size = hidden_states.shape @@ -101,7 +100,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = self.fused_moe_infer(hidden_states, selected_experts, routing_weights) - + final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) @@ -109,7 +108,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_size) - @triton.jit def fused_moe_kernel( # Pointers to matrices diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 9613a5d727b9..ed81fa07c308 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -305,7 +305,10 @@ def load_weights(self, name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, weight_name, expert_id=expert_id) + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) break else: # Skip loading extra bias for GPTQ models. From 406a188f3c86f886a22d475e753a48e8c8fd37bc Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 22 Jan 2024 03:16:52 -0800 Subject: [PATCH 05/17] yapf --- vllm/model_executor/layers/moe.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/moe.py b/vllm/model_executor/layers/moe.py index 30a88f28d344..13cb4bd9c93c 100644 --- a/vllm/model_executor/layers/moe.py +++ b/vllm/model_executor/layers/moe.py @@ -77,13 +77,12 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, def fused_moe_infer(self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: - return fused_moe( - hidden_states, - self.ws, - self.w2s, - routing_weights, - selected_experts, - inplace=True) + return fused_moe(hidden_states, + self.ws, + self.w2s, + routing_weights, + selected_experts, + inplace=True) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_size = hidden_states.shape From 707479d68a2068c5bde805558b01abaf6fc1ba34 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 22 Jan 2024 03:29:45 -0800 Subject: [PATCH 06/17] update comment --- vllm/model_executor/layers/moe.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/moe.py b/vllm/model_executor/layers/moe.py index 13cb4bd9c93c..e4364e9f893a 100644 --- a/vllm/model_executor/layers/moe.py +++ b/vllm/model_executor/layers/moe.py @@ -14,13 +14,12 @@ class MoE(nn.Module): - """a tensor-parallel MOE implementation that shards each expert across + """A tensor-parallel MOE implementation that shards each expert across all ranks. - Each expert's weights are sharded across all ranks. The forward pass - will first expand and group the hidden states by experts, then compute - the per-rank MLP output of each expert using grouped gemm, and finally - reduce the output across ranks. + Each expert's weights are sharded across all ranks and a fused MOE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. """ def __init__( From f283c98fde108a4f946479122a5b2a96f93e7235 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 22 Jan 2024 16:00:52 -0800 Subject: [PATCH 07/17] update fused_moe implementation --- csrc/dispatch_utils.h | 11 + csrc/moe_alig_block_size_kernels.cu | 85 +++++++ csrc/ops.h | 9 + csrc/pybind.cpp | 4 + setup.py | 1 + vllm/model_executor/layers/fused_moe.py | 297 ++++++++++++++++++++++ vllm/model_executor/layers/moe.py | 324 +----------------------- 7 files changed, 408 insertions(+), 323 deletions(-) create mode 100644 csrc/moe_alig_block_size_kernels.cu create mode 100644 vllm/model_executor/layers/fused_moe.py diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 0ae9cd641598..bed640b48b8c 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -14,3 +14,14 @@ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH( \ TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) diff --git a/csrc/moe_alig_block_size_kernels.cu b/csrc/moe_alig_block_size_kernels.cu new file mode 100644 index 000000000000..57b03f2016b0 --- /dev/null +++ b/csrc/moe_alig_block_size_kernels.cu @@ -0,0 +1,85 @@ +#include +#include + +#include +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" + +const static size_t NUM_MAX_EXPERTS = 64; + +namespace vllm { +template +__global__ void moe_alig_block_size_kernel(scalar_t *__restrict__ topk_ids, + int32_t *sorted_token_ids, + int32_t *expert_ids, + int32_t *total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel) { + const size_t tokens_per_thread = ((numel + blockDim.x - 1) / blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + __shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS]; + __shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1]; + for(int i = 0;i < num_experts;i++){ + tokens_cnts[threadIdx.x + 1][i] = 0; + } + + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[threadIdx.x + 1][topk_ids[i]]; + } + + __syncthreads(); + + tokens_cnts[0][threadIdx.x] = 0; + for(int i=1;i<=blockDim.x;++i){ + tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x]; + } + + __syncthreads(); + + if(threadIdx.x ==0){ + cumsum[0] = 0; + for(int i=1;i<=num_experts;++i){ + cumsum[i] = cumsum[i-1] + (tokens_cnts[blockDim.x][i - 1] + block_size - 1) / block_size * block_size; + } + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + for(int i= cumsum[threadIdx.x];i<<<1, num_experts, 0, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, + block_size, + topk_ids.numel()); + }); +} diff --git a/csrc/ops.h b/csrc/ops.h index 9340a60da141..872201171675 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -89,3 +89,12 @@ torch::Tensor gptq_gemm( void gptq_shuffle( torch::Tensor q_weight, torch::Tensor q_perm); + +void moe_alig_block_size( + torch::Tensor topk_ids, + int num_experts, + int block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad + ); diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 95f557686f33..659d32a842fa 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -55,6 +55,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); + ops.def( + "moe_alig_block_size", + &moe_alig_block_size, + "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); // Cache ops pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); diff --git a/setup.py b/setup.py index fb37a8d95231..181ec7591e12 100644 --- a/setup.py +++ b/setup.py @@ -251,6 +251,7 @@ def get_torch_arch_list() -> Set[str]: "csrc/quantization/squeezellm/quant_cuda_kernel.cu", "csrc/quantization/gptq/q_gemm.cu", "csrc/cuda_utils_kernels.cu", + "csrc/moe_alig_block_size_kernels.cu", "csrc/pybind.cpp", ] diff --git a/vllm/model_executor/layers/fused_moe.py b/vllm/model_executor/layers/fused_moe.py new file mode 100644 index 000000000000..004134e964eb --- /dev/null +++ b/vllm/model_executor/layers/fused_moe.py @@ -0,0 +1,297 @@ +"""Fused MoE kernel.""" +import torch +import triton +import triton.language as tl +from vllm._C import ops + + +@triton.jit +def fused_moe_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + M, + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_weight, + stride_token_id, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated, + and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to. + - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. + This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` + by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + + # + off_experts = tl.load(expert_ids_ptr + pid_m) * stride_be + b_ptrs = b_ptr + off_experts + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token * stride_weight, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def alig_block_size( + topk_ids: torch.Tensor, block_size: int, + num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor): + """ + Aligns the token distribution across experts to be compatible with block size for matrix multiplication. + Parameters: + - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token. + - block_size: The block size used in block matrix multiplication. + - num_experts: The total number of experts. + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert. + - expert_ids: A tensor indicating the assigned expert index for each block. + - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size. + This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. + Padding ensures that during block matrix multiplication, the dimensions align correctly. + Example: + Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4: + - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens. + - As block_size is 4, we pad 1 token for each expert. + - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. + - Then append padding tokens [12, 12, 12, 12] for each block. + - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. + Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication. + - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. + """ + sorted_ids = torch.empty( + (topk_ids.numel() + num_experts * (block_size - 1), ), + dtype=torch.int32, + device=topk_ids.device) + expert_ids = torch.empty((topk_ids.numel() + num_experts, ), + dtype=torch.int32, + device=topk_ids.device) + sorted_ids.fill_(topk_ids.numel()) + num_tokens_post_pad = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) + ops.moe_alig_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad) + return sorted_ids, expert_ids, num_tokens_post_pad + + +def fused_moe(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace=False): + """ + This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - topk_weights (torch.Tensor): The weights for the top-k selected experts. + - topk_ids (torch.Tensor): The indices of the top-k selected experts. + - inplace (bool): If True, perform the operation in-place. Defaults to False. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Incompatible dimensions" + assert hidden_states.is_contiguous(), "Matrix A must be contiguous" + assert w1.is_contiguous(), "Matrix B must be contiguous" + assert hidden_states.dtype in [torch.float16, torch.bfloat16] + M, K = hidden_states.shape + E, N, K = w1.shape + + config = { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + } + + if topk_ids.numel() <= w1.shape[0]: + config = { + 'BLOCK_SIZE_M': 16, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 1 + } + + intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + + sorted_token_ids, expert_ids, num_tokens_post_padded = alig_block_size( + topk_ids, config['BLOCK_SIZE_M'], E) + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ + 'BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + + fused_moe_kernel[grid]( + hidden_states, + w1, + intermediate_cache1, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + M, + N, + K, + sorted_token_ids.shape[0], + topk_ids.numel(), + hidden_states.stride(0), + hidden_states.stride(1), + w1.stride(0), + w1.stride(2), + w1.stride(1), + intermediate_cache1.stride(1), + intermediate_cache1.stride(2), + topk_weights.stride(1), + sorted_token_ids.stride(0), + MUL_ROUTED_WEIGHT=False, + top_k=topk_ids.shape[1], + compute_type=tl.bfloat16 + if hidden_states.dtype == torch.bfloat16 else tl.float16, + **config, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ + 'BLOCK_SIZE_M']) * triton.cdiv(w2.shape[1], META['BLOCK_SIZE_N']), ) + fused_moe_kernel[grid]( + intermediate_cache2, + w2, + intermediate_cache3, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + M, + w2.shape[1], + w2.shape[2], + sorted_token_ids.shape[0], + topk_ids.numel(), + intermediate_cache2.stride(0), + intermediate_cache2.stride(1), + w2.stride(0), + w2.stride(2), + w2.stride(1), + intermediate_cache3.stride(1), + intermediate_cache3.stride(2), + topk_weights.stride(1), + sorted_token_ids.stride(0), + MUL_ROUTED_WEIGHT=True, + top_k=1, # + compute_type=tl.bfloat16 + if hidden_states.dtype == torch.bfloat16 else tl.float16, + **config, + ) + if inplace: + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=hidden_states) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1) diff --git a/vllm/model_executor/layers/moe.py b/vllm/model_executor/layers/moe.py index e4364e9f893a..e109d3df0427 100644 --- a/vllm/model_executor/layers/moe.py +++ b/vllm/model_executor/layers/moe.py @@ -1,11 +1,9 @@ import torch from torch import nn import torch.nn.functional as F -import triton -import triton.language as tl -from vllm._C import ops from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce) from vllm.model_executor.parallel_utils.parallel_state import ( @@ -104,323 +102,3 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states.view(batch_size, sequence_length, hidden_size) - - -@triton.jit -def fused_moe_kernel( - # Pointers to matrices - a_ptr, - b_ptr, - c_ptr, - topk_weights_ptr, - sorted_token_ids_ptr, - expert_ids_ptr, - num_tokens_post_padded_ptr, - # Matrix dimensions - M, - N, - K, - EM, - num_valid_tokens, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` - # by to get the element one row down (A has M rows). - stride_am, - stride_ak, - stride_be, - stride_bk, - stride_bn, - stride_cm, - stride_cn, - stride_weight, - stride_token_id, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - MUL_ROUTED_WEIGHT: tl.constexpr, - top_k: tl.constexpr, - compute_type: tl.constexpr, -): - """ - Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. - Key Parameters: - - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token. - - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension. - - C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated, - and N is the output feature dimension. - - sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to. - - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. - This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` - by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert. - """ - # ----------------------------------------------------------- - # Map program ids `pid` to the block of C it should compute. - # This is done in a grouped ordering to promote L2 data reuse. - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - # ---------------------------------------------------------- - # Create pointers for the first blocks of A and B. - # We will advance this pointer as we move in the K direction - # and accumulate - # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers - # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers - num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) - if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: - return - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) - token_mask = offs_token < num_valid_tokens - - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + - offs_k[None, :] * stride_ak) - - # - off_experts = tl.load(expert_ids_ptr + pid_m) * stride_be - b_ptrs = b_ptr + off_experts + (offs_k[:, None] * stride_bk + - offs_bn[None, :] * stride_bn) - - # ----------------------------------------------------------- - # Iterate to compute a block of the C matrix. - # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block - # of fp32 values for higher accuracy. - # `accumulator` will be converted back to fp16 after the loop. - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - # Load the next block of A and B, generate a mask by checking the K dimension. - a = tl.load(a_ptrs, - mask=token_mask[:, None] & - (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0) - b = tl.load(b_ptrs, - mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, - other=0.0) - # We accumulate along the K dimension. - accumulator += tl.dot(a, b) - # Advance the ptrs to the next K block. - a_ptrs += BLOCK_SIZE_K * stride_ak - b_ptrs += BLOCK_SIZE_K * stride_bk - - if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token * stride_weight, - mask=token_mask, - other=0) - accumulator = accumulator * moe_weight[:, None] - - accumulator = accumulator.to(compute_type) - # ----------------------------------------------------------- - # Write back the block of the output - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ - None, :] - c_mask = token_mask[:, None] & (offs_cn[None, :] < N) - tl.store(c_ptrs, accumulator, mask=c_mask) - - -# TODO: rewrite in CPP -def alig_block_size( - topk_ids: torch.Tensor, block_size: int, - num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor): - """ - Aligns the token distribution across experts to be compatible with block size for matrix multiplication. - Parameters: - - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token. - - block_size: The block size used in block matrix multiplication. - - num_experts: The total number of experts. - Returns: - - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert. - - expert_ids: A tensor indicating the assigned expert index for each block. - - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size. - This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. - Padding ensures that during block matrix multiplication, the dimensions align correctly. - Example: - Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4: - - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens. - - As block_size is 4, we pad 1 token for each expert. - - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. - - Then append padding tokens [1, 2, 3, 4] at the end, resulting in [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3 | 1, 2, 3, 4]. - - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 13, 1, 7, 11, 14, 2, 5, 8, 15]. - Tokens 12-15 are non-existent (padding) and are ignored in the subsequent matrix multiplication. - - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. - """ - cnts = torch.zeros(topk_ids.shape[0], - num_experts, - dtype=topk_ids.dtype, - device=topk_ids.device) - cnts.scatter_(1, topk_ids, 1) - tokens_per_expert = cnts.sum(dim=0) - tokens_per_expert_post_alig = torch.floor_divide( - tokens_per_expert + block_size - 1, block_size) * block_size - - cumsum = tokens_per_expert_post_alig.cumsum(0) - num_tokens_post_padded = cumsum[-1].clone() - max_tokens_post_padded = ( - topk_ids.numel() + num_experts * - (block_size - 1)) if topk_ids.numel() > num_experts else ( - topk_ids.numel() + 1) * block_size - - # we just store the expert id of each single block but each token, - # as each token in the same block will be process by the same expert. - expert_ids = torch.zeros(max( - (max_tokens_post_padded + block_size - 1) // block_size + 1, - num_experts), - dtype=topk_ids.dtype, - device=topk_ids.device) - - cumsum.div_(block_size, rounding_mode="floor") - ones = torch.ones_like(expert_ids) - expert_ids.scatter_add_(0, cumsum, ones) - expert_ids = expert_ids.cumsum(0) - - cumsum = (tokens_per_expert_post_alig - tokens_per_expert).cumsum(0) - - padded_tokens = torch.zeros(max_tokens_post_padded - topk_ids.numel(), - dtype=topk_ids.dtype, - device=topk_ids.device) - ones = torch.ones_like(padded_tokens) - padded_tokens.scatter_add_(0, cumsum[:-1], ones) - padded_tokens = padded_tokens.cumsum(0) - - sorted_token_ids = torch.cat([topk_ids.view(-1), padded_tokens]).argsort() - - return sorted_token_ids, expert_ids, num_tokens_post_padded - - -def fused_moe(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - inplace=False): - """ - This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. - We used three shared cache variables across all layers to save gpu memory, which is more effective in a static graph context. - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - topk_weights (torch.Tensor): The weights for the top-k selected experts. - - topk_ids (torch.Tensor): The indices of the top-k selected experts. - - inplace (bool): If True, perform the operation in-place. Defaults to False. - - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Incompatible dimensions" - assert hidden_states.is_contiguous(), "Matrix A must be contiguous" - assert w1.is_contiguous(), "Matrix B must be contiguous" - assert hidden_states.dtype in [torch.float16, torch.bfloat16] - M, K = hidden_states.shape - E, N, K = w1.shape - - config = { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'GROUP_SIZE_M': 8 - } - - if topk_ids.numel() <= w1.shape[0]: - config = { - 'BLOCK_SIZE_M': 16, - 'BLOCK_SIZE_N': 32, - 'BLOCK_SIZE_K': 64, - 'GROUP_SIZE_M': 1 - } - - intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype) - intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) - - sorted_token_ids, expert_ids, num_tokens_post_padded = alig_block_size( - topk_ids, config['BLOCK_SIZE_M'], E) - # 1D launch kernel where each block gets its own program. - grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ - 'BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) - - fused_moe_kernel[grid]( - hidden_states, - w1, - intermediate_cache1, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - M, - N, - K, - sorted_token_ids.shape[0], - topk_ids.numel(), - hidden_states.stride(0), - hidden_states.stride(1), - w1.stride(0), - w1.stride(2), - w1.stride(1), - intermediate_cache1.stride(1), - intermediate_cache1.stride(2), - topk_weights.stride(1), - sorted_token_ids.stride(0), - MUL_ROUTED_WEIGHT=False, - top_k=topk_ids.shape[1], - compute_type=tl.bfloat16 - if hidden_states.dtype == torch.bfloat16 else tl.float16, - **config, - ) - - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - - grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ - 'BLOCK_SIZE_M']) * triton.cdiv(w2.shape[1], META['BLOCK_SIZE_N']), ) - fused_moe_kernel[grid]( - intermediate_cache2, - w2, - intermediate_cache3, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - M, - w2.shape[1], - w2.shape[2], - sorted_token_ids.shape[0], - topk_ids.numel(), - intermediate_cache2.stride(0), - intermediate_cache2.stride(1), - w2.stride(0), - w2.stride(2), - w2.stride(1), - intermediate_cache3.stride(1), - intermediate_cache3.stride(2), - topk_weights.stride(1), - sorted_token_ids.stride(0), - MUL_ROUTED_WEIGHT=True, - top_k=1, # - compute_type=tl.bfloat16 - if hidden_states.dtype == torch.bfloat16 else tl.float16, - **config, - ) - if inplace: - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=hidden_states) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) From adeb80f2ae0b82cbd3bc9b3476d1868ffbc8f4fc Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Wed, 24 Jan 2024 11:17:01 -0800 Subject: [PATCH 08/17] optimize block sizes --- vllm/model_executor/layers/fused_moe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe.py b/vllm/model_executor/layers/fused_moe.py index 004134e964eb..15fce1c9cb40 100644 --- a/vllm/model_executor/layers/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe.py @@ -198,9 +198,9 @@ def fused_moe(hidden_states: torch.Tensor, E, N, K = w1.shape config = { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8 } From 23284e35d40aba0ce89a3a4a3c8855c7b21f6b4c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Sun, 28 Jan 2024 16:59:32 -0800 Subject: [PATCH 09/17] add dtype --- vllm/model_executor/layers/moe.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/moe.py b/vllm/model_executor/layers/moe.py index e109d3df0427..eed760112b1b 100644 --- a/vllm/model_executor/layers/moe.py +++ b/vllm/model_executor/layers/moe.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from torch import nn import torch.nn.functional as F @@ -26,6 +28,7 @@ def __init__( top_k: int, hidden_size: int, intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, ): super().__init__() tp_size = get_tensor_model_parallel_world_size() @@ -34,21 +37,28 @@ def __init__( self.hidden_size = hidden_size self.intermediate_size = intermediate_size // tp_size + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + self.gate = ReplicatedLinear(self.hidden_size, self.num_total_experts, bias=False, + params_dtype=self.params_dtype, linear_method=None) self.ws = nn.Parameter( torch.empty(self.num_total_experts, 2 * self.intermediate_size, self.hidden_size, - device="cuda")) + device="cuda", + dtype=self.params_dtype)) self.w2s = nn.Parameter( torch.empty(self.num_total_experts, self.hidden_size, self.intermediate_size, - device="cuda")) + device="cuda", + dtype=self.params_dtype)) set_weight_attrs(self.ws, { "weight_loader": self.weight_loader, From a1f01073f26b3ab1e9b9675625ba02b83b69fcab Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 29 Jan 2024 02:34:40 -0800 Subject: [PATCH 10/17] update --- ...els.cu => moe_align_block_size_kernels.cu} | 49 ++++-- csrc/ops.h | 2 +- csrc/pybind.cpp | 4 +- setup.py | 2 +- vllm/model_executor/layers/fused_moe.py | 149 ++++++++---------- 5 files changed, 109 insertions(+), 97 deletions(-) rename csrc/{moe_alig_block_size_kernels.cu => moe_align_block_size_kernels.cu} (53%) diff --git a/csrc/moe_alig_block_size_kernels.cu b/csrc/moe_align_block_size_kernels.cu similarity index 53% rename from csrc/moe_alig_block_size_kernels.cu rename to csrc/moe_align_block_size_kernels.cu index 57b03f2016b0..de6a0ec0a972 100644 --- a/csrc/moe_alig_block_size_kernels.cu +++ b/csrc/moe_align_block_size_kernels.cu @@ -8,53 +8,76 @@ #include "dispatch_utils.h" const static size_t NUM_MAX_EXPERTS = 64; +#define CEILDIV(x,y) (((x) + (y) - 1) / (y)) namespace vllm { template -__global__ void moe_alig_block_size_kernel(scalar_t *__restrict__ topk_ids, +__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, int32_t *sorted_token_ids, int32_t *expert_ids, int32_t *total_tokens_post_pad, int32_t num_experts, int32_t block_size, size_t numel) { - const size_t tokens_per_thread = ((numel + blockDim.x - 1) / blockDim.x); + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread; __shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS]; __shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1]; - for(int i = 0;i < num_experts;i++){ + for (int i = 0; i < num_experts; ++i) { tokens_cnts[threadIdx.x + 1][i] = 0; } + /** + * In the first step we compute token_cnts[thread_index + 1][expert_index], + * which counts how many tokens in the token shard of thread_index are assigned + * to expert expert_index. + */ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { ++tokens_cnts[threadIdx.x + 1][topk_ids[i]]; } __syncthreads(); + // For each expert we accumulate the token counts from the different threads. tokens_cnts[0][threadIdx.x] = 0; - for(int i=1;i<=blockDim.x;++i){ + for (int i = 1; i <= blockDim.x; ++i) { tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x]; } __syncthreads(); - - if(threadIdx.x ==0){ + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { cumsum[0] = 0; - for(int i=1;i<=num_experts;++i){ - cumsum[i] = cumsum[i-1] + (tokens_cnts[blockDim.x][i - 1] + block_size - 1) / block_size * block_size; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size; } *total_tokens_post_pad = cumsum[num_experts]; } __syncthreads(); - for(int i= cumsum[threadIdx.x];i<<<1, num_experts, 0, stream>>>( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + vllm::moe_align_block_size_kernel<<<1, num_experts, 0, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), diff --git a/csrc/ops.h b/csrc/ops.h index ad6bb12c7fbf..5baa971048bb 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -98,7 +98,7 @@ void gptq_shuffle( torch::Tensor q_weight, torch::Tensor q_perm); -void moe_alig_block_size( +void moe_align_block_size( torch::Tensor topk_ids, int num_experts, int block_size, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 8af416679802..b9d030eeaf01 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -57,8 +57,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); ops.def( - "moe_alig_block_size", - &moe_alig_block_size, + "moe_align_block_size", + &moe_align_block_size, "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); // Cache ops diff --git a/setup.py b/setup.py index fe8cec49b490..9a65296907bb 100644 --- a/setup.py +++ b/setup.py @@ -302,7 +302,7 @@ def get_torch_arch_list() -> Set[str]: "csrc/quantization/squeezellm/quant_cuda_kernel.cu", "csrc/quantization/gptq/q_gemm.cu", "csrc/cuda_utils_kernels.cu", - "csrc/moe_alig_block_size_kernels.cu", + "csrc/moe_align_block_size_kernels.cu", "csrc/pybind.cpp", ] diff --git a/vllm/model_executor/layers/fused_moe.py b/vllm/model_executor/layers/fused_moe.py index 15fce1c9cb40..0b1143bb037e 100644 --- a/vllm/model_executor/layers/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe.py @@ -16,7 +16,6 @@ def fused_moe_kernel( expert_ids_ptr, num_tokens_post_padded_ptr, # Matrix dimensions - M, N, K, EM, @@ -31,8 +30,6 @@ def fused_moe_kernel( stride_bn, stride_cm, stride_cn, - stride_weight, - stride_token_id, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -44,6 +41,7 @@ def fused_moe_kernel( ): """ Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. + Key Parameters: - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token. - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension. @@ -85,10 +83,9 @@ def fused_moe_kernel( a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) - # - off_experts = tl.load(expert_ids_ptr + pid_m) * stride_be - b_ptrs = b_ptr + off_experts + (offs_k[:, None] * stride_bk + - offs_bn[None, :] * stride_bn) + off_experts = tl.load(expert_ids_ptr + pid_m) + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -113,7 +110,7 @@ def fused_moe_kernel( b_ptrs += BLOCK_SIZE_K * stride_bk if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token * stride_weight, + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) accumulator = accumulator * moe_weight[:, None] @@ -128,21 +125,25 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) -def alig_block_size( +def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor): """ Aligns the token distribution across experts to be compatible with block size for matrix multiplication. + Parameters: - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token. - block_size: The block size used in block matrix multiplication. - num_experts: The total number of experts. + Returns: - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert. - expert_ids: A tensor indicating the assigned expert index for each block. - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size. + This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. Padding ensures that during block matrix multiplication, the dimensions align correctly. + Example: Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4: - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens. @@ -164,11 +165,50 @@ def alig_block_size( num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_alig_block_size(topk_ids, num_experts, block_size, sorted_ids, - expert_ids, num_tokens_post_pad) + ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad) return sorted_ids, expert_ids, num_tokens_post_pad +def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, top_k: int, config: dict): + + assert topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ + 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) + + fused_moe_kernel[grid]( + A, + B, + C, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2], + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16, + **config, + ) + + def fused_moe(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -191,16 +231,17 @@ def fused_moe(hidden_states: torch.Tensor, """ # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Incompatible dimensions" - assert hidden_states.is_contiguous(), "Matrix A must be contiguous" - assert w1.is_contiguous(), "Matrix B must be contiguous" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [torch.float16, torch.bfloat16] - M, K = hidden_states.shape - E, N, K = w1.shape + M, _ = hidden_states.shape + E, N, _ = w1.shape config = { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 64, + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 } @@ -222,73 +263,21 @@ def fused_moe(hidden_states: torch.Tensor, device=hidden_states.device, dtype=hidden_states.dtype) - sorted_token_ids, expert_ids, num_tokens_post_padded = alig_block_size( + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, config['BLOCK_SIZE_M'], E) - # 1D launch kernel where each block gets its own program. - grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ - 'BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) - fused_moe_kernel[grid]( - hidden_states, - w1, - intermediate_cache1, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - M, - N, - K, - sorted_token_ids.shape[0], - topk_ids.numel(), - hidden_states.stride(0), - hidden_states.stride(1), - w1.stride(0), - w1.stride(2), - w1.stride(1), - intermediate_cache1.stride(1), - intermediate_cache1.stride(2), - topk_weights.stride(1), - sorted_token_ids.stride(0), - MUL_ROUTED_WEIGHT=False, - top_k=topk_ids.shape[1], - compute_type=tl.bfloat16 - if hidden_states.dtype == torch.bfloat16 else tl.float16, - **config, - ) + invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, + topk_weights, topk_ids, sorted_token_ids, + expert_ids, num_tokens_post_padded, False, + topk_ids.shape[1], config) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ - 'BLOCK_SIZE_M']) * triton.cdiv(w2.shape[1], META['BLOCK_SIZE_N']), ) - fused_moe_kernel[grid]( - intermediate_cache2, - w2, - intermediate_cache3, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - M, - w2.shape[1], - w2.shape[2], - sorted_token_ids.shape[0], - topk_ids.numel(), - intermediate_cache2.stride(0), - intermediate_cache2.stride(1), - w2.stride(0), - w2.stride(2), - w2.stride(1), - intermediate_cache3.stride(1), - intermediate_cache3.stride(2), - topk_weights.stride(1), - sorted_token_ids.stride(0), - MUL_ROUTED_WEIGHT=True, - top_k=1, # - compute_type=tl.bfloat16 - if hidden_states.dtype == torch.bfloat16 else tl.float16, - **config, - ) + invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, + topk_weights, topk_ids, sorted_token_ids, + expert_ids, num_tokens_post_padded, True, 1, + config) + if inplace: return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1, From a77a2c1f02d762b5cd102398b43c45b1cbd50700 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 30 Jan 2024 05:25:04 +0000 Subject: [PATCH 11/17] Resolve merge error --- csrc/dispatch_utils.h | 11 ----------- csrc/ops.h | 3 +-- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index b0d5035538ae..91abd9e85b4b 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -15,17 +15,6 @@ AT_DISPATCH_SWITCH( \ TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) -#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) - -#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH( \ - TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) - #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ diff --git a/csrc/ops.h b/csrc/ops.h index fad176da420f..2bcd0c2efc5c 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -106,8 +106,7 @@ void moe_align_block_size( int block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad - ); + torch::Tensor num_tokens_post_pad); #ifndef USE_ROCM using fptr_t = uint64_t; From 2233c2abea5df0bc391ad3e93c5f54f2c1e65cfc Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 29 Jan 2024 21:45:29 -0800 Subject: [PATCH 12/17] remove extra function --- vllm/model_executor/layers/moe.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/layers/moe.py b/vllm/model_executor/layers/moe.py index eed760112b1b..30bfa8ed8c1b 100644 --- a/vllm/model_executor/layers/moe.py +++ b/vllm/model_executor/layers/moe.py @@ -81,16 +81,6 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] - def fused_moe_infer(self, hidden_states: torch.Tensor, - selected_experts: torch.Tensor, - routing_weights: torch.Tensor) -> torch.Tensor: - return fused_moe(hidden_states, - self.ws, - self.w2s, - routing_weights, - selected_experts, - inplace=True) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) @@ -103,9 +93,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - final_hidden_states = self.fused_moe_infer(hidden_states, - selected_experts, - routing_weights) + final_hidden_states = fused_moe( + hidden_states, self.ws, self.w2s, + routing_weights, selected_experts, + routing_weights, inplace=True) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) From 44cc7d049d5e244745703ed409f56b63174030ac Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 29 Jan 2024 21:48:14 -0800 Subject: [PATCH 13/17] update --- vllm/model_executor/models/mixtral.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index ed81fa07c308..9eac002570e8 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -314,10 +314,6 @@ def load_weights(self, # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip experts that are not assigned to this worker. - if ("block_sparse_moe.experts." in name - and name not in params_dict): - continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) From 7b076aaca9d1590bc9898c9fcc2e5acaed491fc9 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 29 Jan 2024 21:53:50 -0800 Subject: [PATCH 14/17] move mixtral moe layer back --- vllm/model_executor/layers/moe.py | 105 ------------------------- vllm/model_executor/models/mixtral.py | 109 ++++++++++++++++++++++++-- 2 files changed, 103 insertions(+), 111 deletions(-) delete mode 100644 vllm/model_executor/layers/moe.py diff --git a/vllm/model_executor/layers/moe.py b/vllm/model_executor/layers/moe.py deleted file mode 100644 index 30bfa8ed8c1b..000000000000 --- a/vllm/model_executor/layers/moe.py +++ /dev/null @@ -1,105 +0,0 @@ -from typing import Optional - -import torch -from torch import nn -import torch.nn.functional as F - -from vllm.model_executor.layers.linear import ReplicatedLinear -from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.utils import set_weight_attrs - - -class MoE(nn.Module): - """A tensor-parallel MOE implementation that shards each expert across - all ranks. - - Each expert's weights are sharded across all ranks and a fused MOE - kernel is used for the forward pass, and finally we reduce the outputs - across ranks. - """ - - def __init__( - self, - num_experts: int, - top_k: int, - hidden_size: int, - intermediate_size: int, - params_dtype: Optional[torch.dtype] = None, - ): - super().__init__() - tp_size = get_tensor_model_parallel_world_size() - self.num_total_experts = num_experts - self.top_k = top_k - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size // tp_size - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype - - self.gate = ReplicatedLinear(self.hidden_size, - self.num_total_experts, - bias=False, - params_dtype=self.params_dtype, - linear_method=None) - - self.ws = nn.Parameter( - torch.empty(self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - device="cuda", - dtype=self.params_dtype)) - self.w2s = nn.Parameter( - torch.empty(self.num_total_experts, - self.hidden_size, - self.intermediate_size, - device="cuda", - dtype=self.params_dtype)) - - set_weight_attrs(self.ws, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2s, { - "weight_loader": self.weight_loader, - }) - - def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, - weight_name: str, expert_id: int): - tp_rank = get_tensor_model_parallel_rank() - param_data = param.data - shard_size = self.intermediate_size - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - if weight_name.endswith("w1.weight"): - param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("w3.weight"): - param_data[expert_id, - shard_size:2 * shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("w2.weight"): - param_data[expert_id, :, :] = loaded_weight[:, shard] - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, sequence_length, hidden_size = hidden_states.shape - hidden_states = hidden_states.view(-1, self.hidden_size) - # router_logits: (batch * sequence_length, n_experts) - router_logits, _ = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, - self.top_k, - dim=-1) - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - - final_hidden_states = fused_moe( - hidden_states, self.ws, self.w2s, - routing_weights, selected_experts, - routing_weights, inplace=True) - - final_hidden_states = tensor_model_parallel_all_reduce( - final_hidden_states) - - return final_hidden_states.view(batch_size, sequence_length, - hidden_size) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 9eac002570e8..c55a835cd5ff 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -30,18 +30,22 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) -from vllm.model_executor.layers.moe import MoE from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -49,6 +53,98 @@ KVCache = Tuple[torch.Tensor, torch.Tensor] +class MixtralMoE(nn.Module): + """A tensor-parallel MoE implementation for Mixtral that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = num_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size // tp_size + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.gate = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype, + linear_method=None) + + self.ws = nn.Parameter( + torch.empty(self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + device="cuda", + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device="cuda", + dtype=self.params_dtype)) + + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (batch * sequence_length, n_experts) + router_logits, _ = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, + self.top_k, + dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = fused_moe( + hidden_states, self.ws, self.w2s, + routing_weights, selected_experts, + routing_weights, inplace=True) + + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(batch_size, sequence_length, + hidden_size) + + class MixtralAttention(nn.Module): def __init__(self, @@ -146,10 +242,11 @@ def __init__( rope_theta=rope_theta, sliding_window=config.sliding_window, linear_method=linear_method) - self.block_sparse_moe = MoE(num_experts=config.num_local_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size) + self.block_sparse_moe = MixtralMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, From b4cb78c3df644e03571a56afd819940cf608c4c6 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 29 Jan 2024 22:00:04 -0800 Subject: [PATCH 15/17] put import back --- vllm/model_executor/models/mixtral.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index c55a835cd5ff..591f9c5dac49 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -24,6 +24,7 @@ from typing import List, Optional, Tuple import torch +import torch.nn.functional as F from torch import nn from transformers import MixtralConfig From 111c1b5014e0f64f9b09f2f0d1e5d8c84c48bce0 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 29 Jan 2024 22:04:56 -0800 Subject: [PATCH 16/17] yapf --- vllm/model_executor/models/mixtral.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 591f9c5dac49..d6bb17662ae5 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -134,10 +134,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: dim=-1) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - final_hidden_states = fused_moe( - hidden_states, self.ws, self.w2s, - routing_weights, selected_experts, - routing_weights, inplace=True) + final_hidden_states = fused_moe(hidden_states, + self.ws, + self.w2s, + routing_weights, + selected_experts, + routing_weights, + inplace=True) final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) From d30e844d46ea60805b1e8cbf0b096e256db2c7d1 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 29 Jan 2024 22:13:54 -0800 Subject: [PATCH 17/17] fix typo --- vllm/model_executor/models/mixtral.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index d6bb17662ae5..f36c35fd27ad 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -139,7 +139,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: self.w2s, routing_weights, selected_experts, - routing_weights, inplace=True) final_hidden_states = tensor_model_parallel_all_reduce(