From 566d57ff1e27adf46b81a8fee63e056165a0b2fb Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 30 Aug 2024 14:38:21 -0700 Subject: [PATCH 01/75] add llamav tokeninizer and redirect loader to it --- examples/offline_inference_vision_language.py | 8 + vllm/transformers_utils/tokenizer.py | 9 +- .../transformers_utils/tokenizers/__init__.py | 1 + vllm/transformers_utils/tokenizers/llamavl.py | 199 ++++++++++++++++++ 4 files changed, 214 insertions(+), 3 deletions(-) create mode 100644 vllm/transformers_utils/tokenizers/llamavl.py diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 9a0e9d4bc536..899ad420586b 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -11,6 +11,8 @@ from vllm.assets.image import ImageAsset from vllm.utils import FlexibleArgumentParser +from functools import partial + # Input image and question image = ImageAsset("cherry_blossom").pil_image.convert("RGB") question = "What is the content of this image?" @@ -158,6 +160,10 @@ def run_blip2(question): stop_token_ids = None return llm, prompt, stop_token_ids +def run_llama(question, size: str): + checkpoint_dir = "/data/zhang-chen/llama/checkpoints" # update checkpoint path here + llm = LLM(model=f"{checkpoint_dir}/Meta-Llama-3.2-{size}-Vision-Early/") # update checkpoint path here + raise NotImplementedError model_example_map = { "llava": run_llava, @@ -169,6 +175,8 @@ def run_blip2(question): "minicpmv": run_minicpmv, "blip-2": run_blip2, "internvl_chat": run_internvl, + "llama-3.2-11b": partial(run_llama, size="11B"), + "llama-3.2-90b": partial(run_llama, size="90B"), } diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 2866975850db..ef0064c6f320 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -11,13 +11,14 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizers import (BaichuanTokenizer, - MistralTokenizer) + MistralTokenizer, + LlamaVLTokenizer) from vllm.utils import make_async logger = init_logger(__name__) AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, - MistralTokenizer] + MistralTokenizer, LlamaVLTokenizer] def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: @@ -111,10 +112,12 @@ def get_tokenizer( 'encoding and decoding.', FutureWarning, stacklevel=2) - + print("get tokenizer, tokenizer_name:", tokenizer_name, "Meta-Llama-3.2-11B-Vision-Early" in tokenizer_name,) if tokenizer_mode == "mistral": tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name), revision=revision) + elif "Meta-Llama-3.2-11B-Vision-Early" in str(tokenizer_name) or "Meta-Llama-3.2-90B-Vision-Early" in str(tokenizer_name): + tokenizer = LlamaVLTokenizer.from_pretrained(str(tokenizer_name)) else: try: tokenizer = AutoTokenizer.from_pretrained( diff --git a/vllm/transformers_utils/tokenizers/__init__.py b/vllm/transformers_utils/tokenizers/__init__.py index 9433f2d48f6f..2fa055890da8 100644 --- a/vllm/transformers_utils/tokenizers/__init__.py +++ b/vllm/transformers_utils/tokenizers/__init__.py @@ -1,4 +1,5 @@ from vllm.transformers_utils.tokenizers.baichuan import BaichuanTokenizer from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer +from vllm.transformers_utils.tokenizers.llamavl import LlamaVLTokenizer __all__ = ["BaichuanTokenizer", "MistralTokenizer"] diff --git a/vllm/transformers_utils/tokenizers/llamavl.py b/vllm/transformers_utils/tokenizers/llamavl.py new file mode 100644 index 000000000000..9ae273da1db7 --- /dev/null +++ b/vllm/transformers_utils/tokenizers/llamavl.py @@ -0,0 +1,199 @@ +import os +from logging import getLogger +from pathlib import Path +from typing import ( + AbstractSet, + cast, + Collection, + Dict, + Iterator, + List, + Literal, + Optional, + Sequence, + Union, +) +from transformers.tokenization_utils import PreTrainedTokenizer + + +import tiktoken + +from tiktoken.load import load_tiktoken_bpe + +logger = getLogger(__name__) + + +# The tiktoken tokenizer can handle <=400k chars without +# pyo3_runtime.PanicException. +TIKTOKEN_MAX_ENCODE_CHARS = 400_000 + +# https://github.com/openai/tiktoken/issues/195 +# Here we iterate over subsequences and split if we exceed the limit +# of max consecutive non-whitespace or whitespace characters. +MAX_NO_WHITESPACES_CHARS = 25_000 + + +class LlamaVLTokenizer(PreTrainedTokenizer): + """ + Tokenizing and encoding/decoding text using the Tiktoken tokenizer. + """ + + special_tokens: Dict[str, int] + + num_reserved_special_tokens = 256 + + pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 + + def __init__(self, model_path: str): + """ + Initializes the Tokenizer with a Tiktoken model. + + Args: + model_path (str): The path to the Tiktoken model file. + """ + assert os.path.isfile(model_path), model_path + + mergeable_ranks = load_tiktoken_bpe(model_path) + num_base_tokens = len(mergeable_ranks) + special_tokens = [ + "<|begin_of_text|>", + "<|end_of_text|>", + "<|reserved_special_token_0|>", + "<|reserved_special_token_1|>", + "<|finetune_right_pad_id|>", + "<|step_id|>", + "<|start_header_id|>", + "<|end_header_id|>", + "<|eom_id|>", # end of message + "<|eot_id|>", # end of turn + "<|python_tag|>", + "<|image|>", + ] + reserved_tokens = [ + f"<|reserved_special_token_{2 + i}|>" + for i in range(self.num_reserved_special_tokens - len(special_tokens)) + ] + special_tokens = special_tokens + reserved_tokens + + self.special_tokens = { + token: num_base_tokens + i for i, token in enumerate(special_tokens) + } + self.model = tiktoken.Encoding( + name=Path(model_path).name, + pat_str=self.pat_str, + mergeable_ranks=mergeable_ranks, + special_tokens=self.special_tokens, + ) + + self.n_words: int = num_base_tokens + len(special_tokens) + # BOS / EOS token IDs + self.bos_id: int = self.special_tokens["<|begin_of_text|>"] + self.eos_id: int = self.special_tokens["<|end_of_text|>"] + self.eot_id: int = self.special_tokens["<|eot_id|>"] + self.eom_id: int = self.special_tokens["<|eom_id|>"] + self.python_tag_id = self.special_tokens["<|python_tag|>"] + self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"] + self.stop_tokens = [ + self.eos_id, + self.special_tokens["<|eom_id|>"], + self.special_tokens["<|eot_id|>"], + ] + + def encode( + self, + s: str, + *, + bos: bool, + eos: bool, + allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, + disallowed_special: Union[Literal["all"], Collection[str]] = (), + ) -> List[int]: + """ + Encodes a string into a list of token IDs. + + Args: + s (str): The input string to be encoded. + bos (bool): Whether to prepend the beginning-of-sequence token. + eos (bool): Whether to append the end-of-sequence token. + allowed_special ("all"|set[str]): allowed special tokens in string + disallowed_special ("all"|set[str]): special tokens that raise an error when in string + + Returns: + list[int]: A list of token IDs. + + By default, setting disallowed_special=() encodes a string by ignoring + special tokens. Specifically: + - Setting `disallowed_special` to () will cause all text corresponding + to special tokens to be encoded as natural text (insteading of raising + an error). + - Setting `allowed_special` to "all" will treat all text corresponding + to special tokens to be encoded as special tokens. + """ + if allowed_special is None: + allowed_special = set() + assert type(s) is str + + substrs = ( + substr + for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) + for substr in self._split_whitespaces_or_nonwhitespaces( + s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS + ) + ) + t: List[int] = [] + for substr in substrs: + t.extend( + self.model.encode( + substr, + allowed_special=allowed_special, + disallowed_special=disallowed_special, + ) + ) + if bos: + t.insert(0, self.bos_id) + if eos: + t.append(self.eos_id) + return t + + def decode(self, t: Sequence[int]) -> str: + """ + Decodes a list of token IDs into a string. + + Args: + t (List[int]): The list of token IDs to be decoded. + + Returns: + str: The decoded string. + """ + # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. + return self.model.decode(cast(List[int], t)) + + @staticmethod + def _split_whitespaces_or_nonwhitespaces( + s: str, max_consecutive_slice_len: int + ) -> Iterator[str]: + """ + Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` + consecutive whitespaces or consecutive non-whitespaces. + """ + current_slice_len = 0 + current_slice_is_space = s[0].isspace() if len(s) > 0 else False + slice_start = 0 + + for i in range(len(s)): + is_now_space = s[i].isspace() + + if current_slice_is_space ^ is_now_space: + current_slice_len = 1 + current_slice_is_space = is_now_space + else: + current_slice_len += 1 + if current_slice_len > max_consecutive_slice_len: + yield s[slice_start:i] + slice_start = i + current_slice_len = 1 + yield s[slice_start:] + + @classmethod + def from_pretrained(cls, model_path: str) -> "LlamaVLTokenizer": + return cls(os.path.join(model_path, "tokenizer.model")) \ No newline at end of file From 218145ab7196b40de0f2e25b85943ed12b7266f6 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 1 Sep 2024 17:14:34 -0700 Subject: [PATCH 02/75] start to load shape --- examples/offline_inference_vision_language.py | 2 +- vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/llamavl.py | 49 +++++++++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 vllm/model_executor/models/llamavl.py diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 899ad420586b..76b1bed9a421 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -162,7 +162,7 @@ def run_blip2(question): def run_llama(question, size: str): checkpoint_dir = "/data/zhang-chen/llama/checkpoints" # update checkpoint path here - llm = LLM(model=f"{checkpoint_dir}/Meta-Llama-3.2-{size}-Vision-Early/") # update checkpoint path here + llm = LLM(model=f"{checkpoint_dir}/Meta-Llama-3.2-{size}-Vision-Early/") raise NotImplementedError model_example_map = { diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 8591c276b001..7580d16c38ea 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -85,6 +85,7 @@ "PaliGemmaForConditionalGeneration"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "UltravoxModel": ("ultravox", "UltravoxModel"), + "LlamaVLForCausalLM": ("llamavl", "LlamaVLForCausalLM"), } _CONDITIONAL_GENERATION_MODELS = { "BartModel": ("bart", "BartForConditionalGeneration"), diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py new file mode 100644 index 000000000000..7d3e391d6ff0 --- /dev/null +++ b/vllm/model_executor/models/llamavl.py @@ -0,0 +1,49 @@ +import itertools +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) + +import torch +import torch.nn as nn +# from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors, SamplerOutput +from .interfaces import SupportsMultiModal + +logger = init_logger(__name__) + +def get_max_llama_image_tokens(ctx: InputContext) -> int: + logger.warning("need further check on max llama image tokens") + print("ctx", type(ctx)) + print(ctx) + return 1025 * 2 + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llama_image_tokens) +class LlamaVLForCausalLM(nn.Module, SupportsMultiModal): + def __init__(self, config, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + print("config", type(config)) + print(config) + print("multimodal_config", type(multimodal_config)) + print(multimodal_config) + print("cache_config", type(cache_config)) + print(cache_config) + print("quant_config", type(quant_config)) + print(quant_config) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + for name, weight in weights: + print(name, weight.shape) + From 1c57f26a9d536a7405e92a88582f0fbce9627510 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 1 Sep 2024 20:32:17 -0700 Subject: [PATCH 03/75] copy original model --- vllm/model_executor/models/llamavl.py | 1574 ++++++++++++++++++++++++- 1 file changed, 1573 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py index 7d3e391d6ff0..25af74715f3f 100644 --- a/vllm/model_executor/models/llamavl.py +++ b/vllm/model_executor/models/llamavl.py @@ -26,6 +26,1578 @@ def get_max_llama_image_tokens(ctx: InputContext) -> int: print(ctx) return 1025 * 2 +# Image encoder for inference +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x + + +class ColumnParallelConv2dPatch(torch.nn.Module): + """Conv2D Patching layer with model parallelism. + Column parallel over unfolded input. + Arguments: + in_channels: Input channels. + out_channels: Output channels. + kernel_size: Size of convolution kernel. + stride (default 1): Stride for convolution. + bias (default False): Use bias in Conv2d. + Input: (bsz, in_channels, width, height) + Output: (bsz, num_tokens, out_channels) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + bias: Optional[bool] = False, + ) -> None: + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride) + self._linear = ColumnParallelLinear( + in_channels * kernel_size[0] * kernel_size[1], + out_channels, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._unfold(x) + x = x.permute(0, 2, 1) + x = F.linear(x, self._linear.weight) + x = gather_from_tensor_model_parallel_region(x) + return x + + +class ImageFeedForward(torch.nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + dropout: float, + act_layer: Callable = nn.GELU, + ): + super().__init__() + # layers + self.c_fc = ColumnParallelLinear( + dim, + hidden_dim, + bias=True, + gather_output=False, + init_method=lambda x: x, + ) + self.c_proj = RowParallelLinear( + hidden_dim, + dim, + bias=True, + input_is_parallel=True, + init_method=lambda x: x, + ) + self.non_linearity = act_layer() + self.dropout = dropout + + def forward(self, x): + hidden = F.linear(x, self.c_fc.weight, self.c_fc.bias) + hidden = self.non_linearity(hidden) + hidden = F.linear(hidden, self.c_proj.weight) + hidden = reduce_from_tensor_model_parallel_region(hidden) + hidden += self.c_proj.bias + return hidden + + +class ImageAttention(nn.Module): + def __init__( + self, + dim, + head_dim, + n_heads, + ): + super().__init__() + model_parallel_size = fs_init.get_model_parallel_world_size() + qkvo_replication = 1 + if model_parallel_size > 16: + qkvo_replication = model_parallel_size // 8 + + self.n_kv_heads = n_heads + self.n_local_heads = n_heads * qkvo_replication // model_parallel_size + self.n_local_kv_heads = ( + self.n_kv_heads * qkvo_replication // model_parallel_size + ) + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = dim // n_heads + + self.wq = ColumnParallelLinear( + dim, + qkvo_replication * n_heads * self.head_dim, + bias=True, + gather_output=False, + init_method=lambda x: x, + ) + self.wk = ColumnParallelLinear( + dim, + qkvo_replication * self.n_kv_heads * self.head_dim, + bias=True, + gather_output=False, + init_method=lambda x: x, + ) + self.wv = ColumnParallelLinear( + dim, + qkvo_replication * self.n_kv_heads * self.head_dim, + bias=True, + gather_output=False, + init_method=lambda x: x, + ) + self.wo = RowParallelLinear( + qkvo_replication * n_heads * self.head_dim, + dim, + bias=True, + input_is_parallel=True, + init_method=lambda x: x, + ) + self.qkvo_replication = qkvo_replication + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor = None, + ): + + xq, xk, xv = [ + F.linear(x, w, b) + for (w, b) in [ + (self.wq.weight, self.wq.bias), + (self.wk.weight, self.wk.bias), + (self.wv.weight, self.wv.bias), + ] + ] + + bs, slen, _ = xq.shape + + xq = xq.view(bs, slen, self.n_local_heads, self.head_dim) + xk = xk.view(bs, xk.shape[1], self.n_local_kv_heads, self.head_dim) + xv = xv.view(bs, xv.shape[1], self.n_local_kv_heads, self.head_dim) + + xq, xk, xv = [tensor.transpose(1, 2) for tensor in (xq, xk, xv)] + + xk = xk.repeat_interleave(self.n_rep, dim=1) + xv = xv.repeat_interleave(self.n_rep, dim=1) + + attn_output = F.scaled_dot_product_attention( + xq, xk, xv, attn_mask=mask, dropout_p=0.0 + ) + + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bs, slen, -1) + + out = F.linear(attn_output, self.wo.weight) + out = reduce_from_tensor_model_parallel_region(out) + out = out / self.qkvo_replication + out += self.wo.bias + return out + + +class ImageTransformerBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + act_layer: Callable = nn.GELU, + gated: bool = False, + ): + super().__init__() + assert d_model % n_head == 0 + self.n_heads = n_head + self.head_dim = d_model // self.n_heads + self.attn = ImageAttention( + dim=d_model, + head_dim=self.head_dim, + n_heads=self.n_heads, + ) + self.ln_1 = LayerNorm(d_model) + self.mlp = ImageFeedForward( + dim=d_model, + hidden_dim=int(mlp_ratio * d_model), + dropout=0.0, + act_layer=act_layer, + ) + self.ln_2 = LayerNorm(d_model) + self.gated = gated + if gated: + self.gate_attn = nn.Parameter(torch.zeros(1)) + self.gate_ffn = nn.Parameter(torch.zeros(1)) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor = None, + ): + _gate_attn = 1 if not self.gated else self.gate_attn.tanh() + _gate_ffn = 1 if not self.gated else self.gate_ffn.tanh() + x = x + _gate_attn * self.attn(self.ln_1(x), mask=mask) + x = x + _gate_ffn * self.mlp(self.ln_2(x)) + return x + + +class ImageTransformer(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + act_layer: Callable = nn.GELU, + gated: bool = False, + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ImageTransformerBlock( + d_model=width, + n_head=heads, + mlp_ratio=mlp_ratio, + act_layer=act_layer, + gated=gated, + ) + for _ in range(self.layers) + ] + ) + + def forward(self, x: torch.Tensor, return_intermediate=None, mask=None): + out = [] + for idx, r in enumerate(self.resblocks): + if return_intermediate is not None and idx in return_intermediate: + out.append(x) + x = r(x, mask=mask) + if return_intermediate is not None: + return x, torch.stack(out, dim=-1) + return x + + +class VisionEncoder(nn.Module): + def __init__( + self, + max_num_tiles: int, + ckpt_path: str = None, + image_size: int = 224, + patch_size: int = 14, + width: int = 1280, + layers: int = 32, + heads: int = 16, + mlp_ratio: float = 4.0, + act_layer: Callable = nn.GELU, + in_channels: int = 3, + load_ckpt: bool = False, + n_global_layers: int = 2, + global_model: bool = False, + return_intermediate=None, + ): + super().__init__() + self.global_model = global_model + self.return_intermediate = return_intermediate + self.max_num_tiles = max_num_tiles + self.image_size = to_2tuple(image_size) + self.patch_size = to_2tuple(patch_size) + self.grid_size = ( + self.image_size[0] // self.patch_size[0], + self.image_size[1] // self.patch_size[1], + ) + self.conv1 = ColumnParallelConv2dPatch( + in_channels=in_channels, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter( + scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width) + ) + self.ln_post = LayerNorm(width) + self.ln_pre = LayerNorm(width) + self.transformer = ImageTransformer( + width, layers, heads, mlp_ratio, act_layer=act_layer + ) + # pre and post tile position embedding + self.global_transformer = ImageTransformer( + width, n_global_layers, heads, mlp_ratio, act_layer=act_layer, gated=True + ) + # pre and post tile position embedding + self.pre_tile_pos_embed = TilePositionEmbedding( + num_tiles=max_num_tiles, + width=width, + gated=True, + ) + self.post_tile_pos_embed = TilePositionEmbedding( + num_tiles=max_num_tiles, + width=width, + gated=True, + ) + self.gated_positional_embedding = nn.Parameter( + scale + * torch.randn( + max_num_tiles, + max_num_tiles, + self.grid_size[0] * self.grid_size[1] + 1, + width, + ) + ) + self.gated_positional_embedding_gate = nn.Parameter(torch.zeros(1)) + + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, Any], + strict: bool = True, + missing_keys: List[str] = None, + unexpected_keys: List[str] = None, + error_msgs: List[str] = None, + return_state_dict: bool = False, + ) -> None: + orig_pos_embed = state_dict.get(prefix + "positional_embedding") + if orig_pos_embed is not None: + new_pos_embed = resize_local_position_embedding( + orig_pos_embed, self.grid_size + ) + state_dict[prefix + "positional_embedding"] = new_pos_embed + if hasattr(self, "gated_positional_embedding"): + if prefix + "gated_positional_embedding" not in state_dict: + # resize positional_embedding to fit the new grid size + global_pos_embed = initialize_global_position_embedding_from_local( + new_pos_embed, + self.grid_size, + self.max_num_tiles, + self.max_num_tiles, + ) + state_dict[prefix + "gated_positional_embedding"] = global_pos_embed + state_dict[prefix + "gated_positional_embedding_gate"] = torch.zeros( + 1, dtype=global_pos_embed.dtype + ) + logger.info( + f"Initialized global positional embedding with size {global_pos_embed.size()}" + ) + else: + global_pos_embed = resize_global_position_embedding( + state_dict[prefix + "gated_positional_embedding"], + self.grid_size, + self.max_num_tiles, + self.max_num_tiles, + ) + logger.info( + f"Resized global positional embedding from {state_dict[prefix + 'gated_positional_embedding'].size()} to {global_pos_embed.size()}" + ) + state_dict[prefix + "gated_positional_embedding"] = global_pos_embed + if return_state_dict: + return state_dict + + def apply_positional_embedding(self, x, ar): + out = [] + # apply regular position embedding + bsz, num_chunks, num_tokens, dim = x.shape + x = x.view(bsz * num_chunks, num_tokens, dim) + x = x + self.positional_embedding * ( + 1 - self.gated_positional_embedding_gate.tanh() + ) + x = x.view(bsz, num_chunks, num_tokens, dim) + for idx, arx in enumerate(ar): + _pos_embed = self.gated_positional_embedding[: arx[0], : arx[1]] + _pos_embed = _pos_embed.reshape(arx[0] * arx[1], *_pos_embed.shape[2:]) + x[idx, : arx[0] * arx[1]] += ( + _pos_embed * self.gated_positional_embedding_gate.tanh() + ) + return x + + def apply_class_embedding(self, x): + x = torch.cat( + [ + self.class_embedding.to(x.dtype) + + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device + ), + x, + ], + dim=1, + ) + return x + + def forward(self, images: torch.Tensor, ar: torch.Tensor) -> torch.Tensor: + if images.ndim == 5: + num_concurrent_media = 1 + bsz, num_chunks, nch, w, h = images.shape + else: + bsz, num_concurrent_media, num_chunks, nch, w, h = images.shape + + images = images.reshape(bsz * num_concurrent_media * num_chunks, nch, w, h) + ar = ar.reshape(bsz * num_concurrent_media, 2) + + # patch embedding + x = images.reshape(bsz * num_concurrent_media * num_chunks, nch, w, h) + x = self.conv1(x) + _, ntok, dim = x.shape + x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim) + + # tile embeddings + x = self.pre_tile_pos_embed(x, ar) + x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim) + + # apply cls token + x = self.apply_class_embedding(x) + ntok += 1 + + # apply position embeddings + x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim) + x = self.apply_positional_embedding(x, ar) + + x = self.ln_pre(x) + npad, attn_mask = 0, None + x, npad = expand_num_tokens_to_mult8(x) + attn_mask = build_encoder_attention_mask(x, ar, ntok, num_chunks, 1) + x = x.view(bsz * num_concurrent_media, -1, dim) + x, int_x = self.transformer( + x, return_intermediate=self.return_intermediate, mask=attn_mask + ) + + x = self.ln_post(x) + x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, dim) + x = self.post_tile_pos_embed(x, ar) + x = x.reshape(bsz * num_concurrent_media, num_chunks * (ntok + npad), dim) + x = self.global_transformer(x, mask=attn_mask) + x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, dim) + x = contract_num_tokens_from_mult8(x, npad) + + # adding back intermediate layer outputs + x = x.reshape(bsz, num_concurrent_media, num_chunks, ntok, dim) + int_x = int_x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, -1) + int_x = contract_num_tokens_from_mult8(int_x, npad) + int_x = int_x.reshape(bsz, num_concurrent_media, num_chunks, ntok, -1) + x = torch.cat([x, int_x], dim=-1) + return x + + +class Attention(nn.Module): + """Multi-head attention module.""" + + def __init__(self, args: ModelArgs): + """ + Initialize the Attention module. + Args: + args (ModelArgs): Model configuration parameters. + Attributes: + n_kv_heads (int): Number of key and value heads. + n_local_heads (int): Number of local query heads. + n_local_kv_heads (int): Number of local key and value heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (ColumnParallelLinear): Linear transformation for queries. + wk (ColumnParallelLinear): Linear transformation for keys. + wv (ColumnParallelLinear): Linear transformation for values. + wo (RowParallelLinear): Linear transformation for output. + cache_k (torch.Tensor): Cached keys for attention. + cache_v (torch.Tensor): Cached values for attention. + """ + super().__init__() + model_parallel_size = fs_init.get_model_parallel_world_size() + replication_factor = 1 + if model_parallel_size > 8: + replication_factor = model_parallel_size // MP_SCALE + + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + self.n_kv_heads *= replication_factor + + self.n_local_heads = args.n_heads // model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self.head_dim = args.dim // args.n_heads + self.max_seq_len = args.max_seq_len + + self.wq = ColumnParallelLinear( + args.dim, + args.n_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.wk = ColumnParallelLinear( + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.wv = ColumnParallelLinear( + args.dim, + self.n_kv_heads * self.head_dim, + bias=False, + gather_output=False, + init_method=lambda x: x, + ) + self.wo = RowParallelLinear( + args.n_heads * self.head_dim, + args.dim, + bias=False, + input_is_parallel=True, + init_method=lambda x: x, + ) + self.n_heads = args.n_heads + + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, Any], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: + if prefix + "wqkv.weight" in state_dict: + total_n_heads = self.n_heads + self.n_kv_heads * 2 + wqkv = state_dict.pop(prefix + "wqkv.weight") + head_dim = wqkv.shape[0] // total_n_heads + dim1 = head_dim * self.n_heads + dim2 = dim1 + head_dim * self.n_kv_heads + dim3 = dim1 + head_dim * self.n_kv_heads * 2 + + wq = wqkv[:dim1] + wk = wqkv[dim1:dim2] + wv = wqkv[dim2:dim3] + + state_dict[prefix + "wq.weight"] = wq + state_dict[prefix + "wk.weight"] = wk + state_dict[prefix + "wv.weight"] = wv + + def setup_cache(self, max_batch_size: int, dtype: torch.dtype): + cache_shape = ( + max_batch_size, + self.max_seq_len, + self.n_local_kv_heads, + self.head_dim, + ) + device = next(self.parameters()).device + self.register_buffer( + "key_cache", + torch.zeros( + cache_shape, + dtype=dtype, + device=device, + ), + persistent=False, + ) + self.register_buffer( + "value_cache", + torch.zeros( + cache_shape, + dtype=dtype, + device=device, + ), + persistent=False, + ) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + freqs_cis: torch.Tensor, + position_ids: torch.LongTensor, + ): + + xq, xk, xv = [ + F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight] + ] + + bs, slen, _ = xq.shape + + xq = xq.view(bs, slen, self.n_local_heads, self.head_dim) + xk = xk.view(bs, xk.shape[1], self.n_local_kv_heads, self.head_dim) + xv = xv.view(bs, xv.shape[1], self.n_local_kv_heads, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis) + + self.key_cache[:bs, position_ids, ...] = xk + self.value_cache[:bs, position_ids, ...] = xv + + # TODO: we can avoid slicing on first dimension by always padding to max_batch_size() + xk = self.key_cache[:bs, ...] + xv = self.value_cache[:bs, ...] + + xq, xk, xv = [tensor.transpose(1, 2) for tensor in (xq, xk, xv)] + + xk = xk.repeat_interleave(self.n_rep, dim=1) + xv = xv.repeat_interleave(self.n_rep, dim=1) + + attn_output = F.scaled_dot_product_attention( + xq, xk, xv, attn_mask=mask, dropout_p=0.0 + ) + + attn_output = attn_output.transpose(1, 2).contiguous().reshape(bs, slen, -1) + + out = F.linear(attn_output, self.wo.weight) + out = reduce_from_tensor_model_parallel_region(out) + return out + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + ): + """ + Initialize the FeedForward module. + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None. + Attributes: + w1 (ColumnParallelLinear): Linear transformation for the first layer. + w2 (RowParallelLinear): Linear transformation for the second layer. + w3 (ColumnParallelLinear): Linear transformation for the third layer. + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = ColumnParallelLinear( + dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x + ) + self.w2 = RowParallelLinear( + hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x + ) + self.w3 = ColumnParallelLinear( + dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x + ) + self._register_load_state_dict_pre_hook(self.load_hook) + + def forward(self, x): + x1, x3 = [F.linear(x, w) for w in [self.w1.weight, self.w3.weight]] + x1 = F.silu(x1) + x_in = x1 * x3 + out = F.linear(x_in, self.w2.weight) + out = reduce_from_tensor_model_parallel_region(out) + return out + + def load_hook( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + if prefix + "mlp.fc1_weight" in state_dict: + fc1_weight, fc3_weight = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2) + state_dict[prefix + "w1.weight"] = fc1_weight + state_dict[prefix + "w3.weight"] = fc3_weight + + if prefix + "mlp.fc2_weight" in state_dict: + fc2_weight = state_dict.pop(prefix + "mlp.fc2_weight") + state_dict[prefix + "w2.weight"] = fc2_weight + + +class TransformerBlock(nn.Module): + def __init__(self, layer_id: int, args: ModelArgs): + """ + Initialize a TransformerBlock. + Args: + layer_id (int): Identifier for the layer. + args (ModelArgs): Model configuration parameters. + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + """ + super().__init__() + self.n_heads = args.n_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = Attention(args) + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=4 * args.dim, + multiple_of=args.multiple_of, + ffn_dim_multiplier=args.ffn_dim_multiplier, + ) + self.layer_id = layer_id + self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, Any], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: + if prefix + "feed_forward.mlp.layer_norm_weight" in state_dict: + state_dict[prefix + "ffn_norm.weight"] = state_dict.pop( + prefix + "feed_forward.mlp.layer_norm_weight" + ) + if prefix + "attention.wqkv.layer_norm_weight" in state_dict: + state_dict[prefix + "attention_norm.weight"] = state_dict.pop( + prefix + "attention.wqkv.layer_norm_weight" + ) + + def setup_cache(self, max_batch_size: int, dtype: torch.dtype): + self.attention.setup_cache(max_batch_size, dtype) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + mask: torch.Tensor, + position_ids: torch.LongTensor, + ) -> torch.Tensor: + """ + Perform a forward pass through the TransformerBlock. + Args: + x (torch.Tensor): Input tensor. + start_pos (int): Starting position for attention caching. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None. + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + """ + h = self.attention.forward( + x=self.attention_norm(x), + freqs_cis=freqs_cis, + mask=mask, + position_ids=position_ids, + ) + h = h + x + out = h + self.feed_forward.forward(self.ffn_norm(h)) + return out + + +class TilePositionEmbedding(nn.Module): + def __init__( + self, + num_tiles: int, + width: int, + gated: bool = False, + ): + super().__init__() + self.num_tiles = num_tiles + self.width = width + self.embedding = nn.Parameter( + torch.randn(num_tiles, num_tiles, 1, width) / math.sqrt(width) + ) + self.gated = gated + if gated: + self.gate = nn.Parameter(torch.zeros(1)) + + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + # load the weights from the checkpoint + embed = state_dict.get(prefix + "embedding") + if embed is not None: + # reshape the weights to the correct shape + nt_old, nt_old, _, w = embed.shape + logging.info( + f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}" + ) + embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles) + # assign the weights to the module + state_dict[prefix + "embedding"] = embed_new + + @staticmethod + def _dynamic_resize(embed: torch.Tensor, num_tiles: int): + nt_old, nt_old, _, w = embed.shape + embed = embed.permute(2, 3, 0, 1) + + embed_new = F.interpolate( + embed, + size=(num_tiles, num_tiles), + mode="bilinear", + align_corners=True, + ) + # reshape the weights to the correct shape + embed_new = embed_new.permute(2, 3, 0, 1) + return embed_new + + def forward(self, x: torch.Tensor, ar: torch.Tensor, num_tiles: int = None): + embed = self.embedding + if num_tiles is None: + num_tiles = self.num_tiles + elif num_tiles > self.num_tiles: + embed = TilePositionEmbedding._dynamic_resize(self.embedding, num_tiles) + out_pos_embed = torch.zeros( + x.shape[0], num_tiles, 1, self.width, device=x.device, dtype=x.dtype + ) + for idx, arx in enumerate(ar): + w, h = arx + out_pos_embed[idx, : w * h] = embed[:w, :h].reshape(w * h, 1, self.width) + if self.gated: + out_pos_embed = out_pos_embed * self.gate.tanh() + x = x + out_pos_embed + return x + + +def _noinit(x): + return x + + +class CrossAttention(torch.nn.Module): + """Cross attention layer with model-parallel attention layers.""" + + def __init__( + self, + dim: int, + head_dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + ): + super().__init__() + self.model_parallel_size = fs_init.get_model_parallel_world_size() + replication_factor = 1 + if self.model_parallel_size > 8: + replication_factor = self.model_parallel_size // MP_SCALE + n_kv_heads *= replication_factor + + assert n_heads % n_kv_heads == 0 + + self.wq = ColumnParallelLinear( + dim, + n_heads * head_dim, + bias=False, + gather_output=False, + init_method=_noinit, + ) + + self.wk = ColumnParallelLinear( + dim, + n_kv_heads * head_dim, + bias=False, + gather_output=False, + init_method=_noinit, + ) + self.wv = ColumnParallelLinear( + dim, + n_kv_heads * head_dim, + bias=False, + gather_output=False, + init_method=_noinit, + ) + self.wo = RowParallelLinear( + n_heads * head_dim, + dim, + bias=False, + input_is_parallel=True, + init_method=_noinit, + ) + + self.n_heads = n_heads + self.head_dim = head_dim + self.n_kv_heads = n_kv_heads + + self.q_norm = RMSNorm( + self.head_dim, + eps=norm_eps, + ) + self.k_norm = RMSNorm( + self.head_dim, + eps=norm_eps, + ) + + # cross-attention heads are model parallel similar to + # self-attention, and we also use the identical KV head + # combination to ensure parity with the corresponding + # trunk LLM (i.e., group query attention) -- @dubeya + # local heads + assert self.n_heads % self.n_kv_heads == 0 + assert self.n_heads % self.model_parallel_size == 0 + assert self.n_kv_heads % self.model_parallel_size == 0 + self.n_local_heads = self.n_heads // self.model_parallel_size + self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size + self.n_rep = self.n_local_heads // self.n_local_kv_heads + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, Any], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: + if prefix + "inner_attention.q_norm.weight" in state_dict: + q_weight = state_dict.pop(prefix + "inner_attention.q_norm.weight") + state_dict[prefix + "q_norm.weight"] = q_weight + if prefix + "inner_attention.k_norm.weight" in state_dict: + k_weight = state_dict.pop(prefix + "inner_attention.k_norm.weight") + state_dict[prefix + "k_norm.weight"] = k_weight + if prefix + "wkv.weight" in state_dict: + wk, wv = state_dict.pop(prefix + "wkv.weight").chunk(2) + state_dict[prefix + "wk.weight"] = wk + state_dict[prefix + "wv.weight"] = wv + + def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: + bsz = xattn_tokens.shape[0] + xk = self.wk(xattn_tokens) + xv = self.wv(xattn_tokens) + + _, seqlen_y, _ = xk.shape + + xk = xk.view(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim) + xv = xv.view(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim) + + xk, xv = [tensor.transpose(1, 2) for tensor in (xk, xv)] + + # repeat k/v heads if n_kv_heads < n_heads + xk = xk.repeat_interleave(self.n_rep, dim=1) + xv = xv.repeat_interleave(self.n_rep, dim=1) + + xk = self.k_norm(xk) + + return torch.stack([xk, xv]) + + def compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: + return self._compute_xattn_kv_cache(xattn_tokens) + + def forward( + self, + x: torch.Tensor, + xattn_mask: torch.Tensor, + full_text_row_masked_out_mask: torch.Tensor, + xattn_cache: torch.Tensor, + ) -> torch.Tensor: + xq = F.linear(x, self.wq.weight) + bsz, seqlen, _ = x.shape + + xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xq = self.q_norm(xq) + xq = xq.transpose(1, 2) + + xk, xv = xattn_cache + + output = F.scaled_dot_product_attention( + xq, xk, xv, attn_mask=xattn_mask, dropout_p=0.0 + ) + output = output * full_text_row_masked_out_mask + output = output.transpose(1, 2).contiguous().reshape(bsz, seqlen, -1) + + out = F.linear(output, self.wo.weight) + out = reduce_from_tensor_model_parallel_region(out) + return out + + +class CrossAttentionTransformerBlock(torch.nn.Module): + """Cross-attention transformer block with tanh-gated attention and feedforward.""" + + def __init__( + self, + args: ModelArgs, + layer_id: int, + no_ffn: bool = False, + ) -> None: + super().__init__() + self.layer_id = layer_id + self.n_heads = args.n_heads + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.attention = CrossAttention( + dim=args.dim, + head_dim=self.head_dim, + n_heads=self.n_heads, + n_kv_heads=self.n_kv_heads, + norm_eps=args.norm_eps, + ) + + self.attention_norm = RMSNorm( + args.dim, + eps=args.norm_eps, + ) + self.gate_attn = torch.nn.Parameter(torch.zeros(1)) + + self.feed_forward = FeedForward( + dim=args.dim, + hidden_dim=4 * args.dim, + ffn_dim_multiplier=args.ffn_dim_multiplier, + multiple_of=args.multiple_of, + ) + self.ffn_norm = RMSNorm( + args.dim, + eps=args.norm_eps, + ) + self.gate_ffwd = torch.nn.Parameter(torch.zeros(1)) + + self._register_load_state_dict_pre_hook(self.load_hook) + self.no_ffn = no_ffn + + def load_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, Any], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: + if prefix + "gate_attn" in state_dict: + attn_gate = state_dict.pop(prefix + "gate_attn") + if attn_gate.dim() == 1: + attn_gate = attn_gate[0].view(1) + if attn_gate.dim() == 3: + attn_gate = attn_gate.view(1) + state_dict[prefix + "gate_attn"] = attn_gate + if prefix + "gate_ffwd" in state_dict: + ffn_gate = state_dict.pop(prefix + "gate_ffwd") + if ffn_gate.dim() == 1: + ffn_gate = ffn_gate[0].view(1) + if ffn_gate.dim() == 3: + ffn_gate = ffn_gate.view(1) + state_dict[prefix + "gate_ffwd"] = ffn_gate + if prefix + "feed_forward.mlp.layer_norm_weight" in state_dict: + state_dict[prefix + "ffn_norm.weight"] = state_dict.pop( + prefix + "feed_forward.mlp.layer_norm_weight" + ) + if prefix + "attention.wq.layer_norm_weight" in state_dict: + state_dict[prefix + "attention_norm.weight"] = state_dict.pop( + prefix + "attention.wq.layer_norm_weight" + ) + + def compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: + return self.attention.compute_xattn_kv_cache(xattn_tokens) + + def forward( + self, + x: torch.Tensor, + xattn_mask: torch.Tensor, + full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], + xattn_cache: torch.Tensor, + ) -> torch.Tensor: + _attn_out = self.attention( + x=self.attention_norm(x), + xattn_mask=xattn_mask, + xattn_cache=xattn_cache, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + ) + h = x + self.gate_attn.tanh() * _attn_out + _ffn = self.feed_forward(self.ffn_norm(h)) + _ffn = full_text_row_masked_out_mask[:, 0] * _ffn # type: ignore + h = h + self.gate_ffwd.tanh() * _ffn * float(not self.no_ffn) + return h + + +class DummyCrossAttentionTransformerBlock: + """Dummy cross-attention transformer block with tanh-gated attention and feedforward.""" + + def __call__( + self, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + return x + + +class DummySelfAttentionTransformerBlock: + """Dummy self-attention transformer block""" + + def __call__( + self, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + return x + + +class CrossAttentionTransformerVision(torch.nn.Module): + def __init__(self, args: ModelArgs) -> None: + super().__init__() + return_intermediate = "3,7,15,23,30" + self.vision_input_dim = 1280 + self.image_res = args.vision_chunk_size + self.max_num_chunks = args.vision_max_num_chunks + if return_intermediate is not None: + return_intermediate = [int(l) for l in return_intermediate.split(",")] + self.vision_input_dim = ( + len(return_intermediate) + 1 + ) * self.vision_input_dim + self.patch_size = 14 + self.vision_encoder = VisionEncoder( + max_num_tiles=4, + image_size=args.vision_chunk_size, + patch_size=self.patch_size, + n_global_layers=8, + global_model=True, + return_intermediate=return_intermediate, + ) + # vision token projection + self.vision_projection = ColumnParallelLinear( + self.vision_input_dim, + args.dim, + bias=True, + init_method=lambda x: x, + ) + + def forward( + self, images: torch.Tensor, aspect_ratios: torch.Tensor + ) -> torch.Tensor: + # vision_tokens: (B, T, D) + # aspect_ratios: (B, T) + # h: (B, T, D) + vision_tokens = self.vision_encoder( + images.to(dtype=torch.bfloat16), aspect_ratios + ) + + vision_tokens = F.linear(vision_tokens, self.vision_projection.weight) + vision_tokens = gather_from_tensor_model_parallel_region(vision_tokens) + return vision_tokens + + +class CrossAttentionTransformerText(torch.nn.Module): + INFERENCE_IMAGE_TOKEN_ID = 128010 + + def __init__(self, args: ModelArgs) -> None: + super().__init__() + self.model_parallel_size = fs_init.get_model_parallel_world_size() + assert args.vocab_size > 0 + self.vocab_size = args.vocab_size + self.n_layers = args.n_layers + self.dim = args.dim + self.head_dim = args.dim // args.n_heads + self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size + assert self.vocab_size % self.model_parallel_size == 0 + self.tok_embeddings = VocabParallelEmbedding( + args.vocab_size, args.dim, init_method=lambda x: x + ) + self.pos_embeddings = None + # final norm layer (not necessary for post-norm) + self.norm = RMSNorm(args.dim, eps=args.norm_eps) + + # output layer + self.output = ColumnParallelLinear( + args.dim, args.vocab_size, bias=False, init_method=lambda x: x + ) + + self.n_llama_layers = args.n_layers + self.model_dim = args.dim + + # BLOCKS + + self.fusion_schedule = self._init_fusion_schedule( + args.vision_num_cross_attention_layers + ) + self.learnable_embedding = VocabParallelEmbedding( + max(fs_init.get_model_parallel_world_size(), 8), + args.dim, + init_method=lambda x: x, + ) + self.num_frozen_embeddings = self.tok_embeddings.num_embeddings + self._thresh = self.num_frozen_embeddings - 1 + + # transformer blocks + self.layers = torch.nn.ModuleList() + self.cross_attention_layers = torch.nn.ModuleList() + for i in range(args.n_layers): + layer_id = i + block = TransformerBlock(args=args, layer_id=layer_id) + self.layers.append(block) + if layer_id in self.fusion_schedule: + xa_layer_id = self.fusion_schedule.index(layer_id) + args.n_layers + block = CrossAttentionTransformerBlock( + args, + layer_id=xa_layer_id, + ) + self.cross_attention_layers.append(block) + + # add xattn and dummy layers to avoid conditionals in forward() + self.text_and_xattn_layers = [] + + for idx, layer in enumerate(self.layers): + if idx in self.fusion_schedule: + xattn_layer_idx = self.fusion_schedule.index(idx) + xattn_layer = self.cross_attention_layers[xattn_layer_idx] + else: + xattn_layer_idx = 0 + xattn_layer = DummyCrossAttentionTransformerBlock() + + self.text_and_xattn_layers.append( + ( + layer, + xattn_layer, + xattn_layer_idx, + ) + ) + self.freqs_cis = precompute_freqs_cis( + args.dim // args.n_heads, + args.max_seq_len * 2, + args.rope_theta, + args.use_scaled_rope, + ) + + self._register_load_state_dict_pre_hook(self.load_hook) + + self.args = args + self.cache_is_setup = False + self.max_seq_len = args.max_seq_len + + def _init_fusion_schedule( + self, + num_layers: int, + ) -> List[int]: + llama_layers = list(range(self.n_llama_layers)) + + # uniformly spread the layers + k = math.ceil(len(llama_layers) / num_layers) + return llama_layers[::-1][::k][:num_layers][::-1] + + def get_partially_trainable_embedding(self, x): + xz = torch.zeros_like(x, device=x.device) + oz = torch.ones_like(x, device=x.device) + x_orig = torch.minimum(x, torch.tensor(self._thresh, device=x.device)) + x_new = ( + torch.maximum(x, torch.tensor(self._thresh + 1, device=x.device)) + - self.num_frozen_embeddings + ) + + mask_orig = torch.where(x >= self.num_frozen_embeddings, xz, oz).unsqueeze(-1) + mask_new = torch.where(x < self.num_frozen_embeddings, xz, oz).unsqueeze(-1) + + x_orig = self.tok_embeddings(x_orig) + x_new = self.learnable_embedding(x_new).type_as(x_orig) + return x_orig * mask_orig.type_as(x_orig) + x_new * mask_new.type_as(x_new) + + def load_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, Any], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: + if "rope.freqs" in state_dict: + del state_dict["rope.freqs"] + + def forward( + self, + position_ids: torch.LongTensor, + h: torch.Tensor, + xattn_mask: torch.Tensor, + full_text_row_masked_out_mask: torch.Tensor, + xattn_caches: torch.Tensor, + ): + assert self.cache_is_setup, "Please set up cache before calling forward" + mask = self.mask_cache.index_select(2, position_ids) + freqs_cis = self.freqs_cis.index_select(0, position_ids) + + for idx, ( + layer, + xattn_layer, + xattn_layer_idx, + ) in enumerate(self.text_and_xattn_layers): + h = xattn_layer( + x=h, + xattn_mask=xattn_mask, + xattn_cache=xattn_caches[xattn_layer_idx], + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + ) + h = layer( + x=h, + mask=mask, + freqs_cis=freqs_cis, + position_ids=position_ids, + ) + + h = self.norm(h) + + output = F.linear(h, self.output.weight) + output = gather_from_tensor_model_parallel_region(output) + return output.float() + + def setup_cache(self, max_batch_size: int, dtype=torch.bfloat16): + # Set up the text kv caches + device = next(self.parameters()).device + ones = torch.ones( + (self.max_seq_len, self.max_seq_len), + dtype=torch.bool, + device=device, + ) + self.register_buffer( + "mask_cache", + torch.tril( + ones, + ) + .unsqueeze(0) + .unsqueeze(0), + persistent=False, + ) + for layer in self.layers: + layer.setup_cache(max_batch_size, dtype=dtype) + self.cache_is_setup = True + + def _get_xattn_mask( + self, + num_tokens, + text_device, + text_dtype, + vision_tokens, + cross_attention_masks, + ) -> Tuple[Tensor, Tensor]: + assert vision_tokens is not None, "Vision tokens must be provided" + vision_seqlen = vision_tokens.shape[3] + assert ( + vision_tokens.shape[1] == cross_attention_masks.shape[2] + ), f"Mismatch in number of images given and number of masks given {vision_tokens.shape} {cross_attention_masks.shape}" + assert ( + vision_tokens.shape[2] == cross_attention_masks.shape[3] + ), f"Vision tokens shape {vision_tokens.shape} mismatch with xattn shape {cross_attention_masks.shape}" + assert ( + num_tokens == cross_attention_masks.shape[1] + ), f"Mismatch in text sequence length and cross attention mask sequence length {num_tokens} {cross_attention_masks.shape}" + _, _, _, num_image_tokens, image_token_dim = tuple(vision_tokens.shape) + bsz, ntext, nimg, nchunks = cross_attention_masks.shape + cross_attention_masks = ( + cross_attention_masks.repeat_interleave(vision_seqlen, dim=2) + .view(bsz, ntext, -1) + .unsqueeze(1) + ) + full_text_row_masked_out_mask = _get_full_row_masked_out_mask( + cross_attention_masks, + get_negative_inf_value(cross_attention_masks.dtype), + ) + cross_attention_masks *= full_text_row_masked_out_mask + + return ( + cross_attention_masks.to(device=text_device, dtype=text_dtype), + full_text_row_masked_out_mask, + ) + + +class CrossAttentionTransformer(torch.nn.Module): + def __init__(self, args: ModelArgs) -> None: + super().__init__() + self.params = args + + self.model_dim = args.dim + self.vision_model = CrossAttentionTransformerVision(args) + self.text_model = CrossAttentionTransformerText(args) + self.image_res = args.vision_chunk_size + self.max_num_chunks = args.vision_max_num_chunks + self.image_transform = partial( + VariableSizeImageTransform(size=args.vision_chunk_size), + max_num_chunks=args.vision_max_num_chunks, + ) + + def setup_cache(self, max_batch_size: int, dtype: torch.dtype): + self.text_model.setup_cache(max_batch_size, dtype) + + def compute_vision_tokens_masks( + self, + batch_images: List[List[PIL_Image.Image]], + batch_masks: List[List[List[int]]], + total_len: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + skip_vision_encoder = False + + assert len(batch_images) == len( + batch_masks + ), "Images and masks must have the same length" + + max_num_images = max(len(x) for x in batch_images) + bsz = len(batch_images) + + if max_num_images == 0: + num_chunks = [[self.max_num_chunks] for _ in batch_images] + skip_vision_encoder = True + else: + images_and_aspect_ratios = [ + [self.image_transform(im) for im in row] for row in batch_images + ] + transformed_images = [ + [x[0] for x in row] for row in images_and_aspect_ratios + ] + + aspect_ratios = torch.ones(bsz, max_num_images, 2, dtype=torch.int64) + for i, row in enumerate(images_and_aspect_ratios): + if len(row) > 0: + aspect_ratios[i, : len(row)] = torch.stack( + [torch.tensor(x[1]) for x in row] + ) + + stacked_images, num_chunks = _stack_images( + transformed_images, + max_num_chunks=self.max_num_chunks, + image_res=self.params.vision_chunk_size, + max_num_images=max_num_images, + ) + + if skip_vision_encoder: + vision_tokens = torch.zeros( + ( + bsz, + max_num_images, + self.max_num_chunks, + int( + (self.vision_model.image_res / self.vision_model.patch_size) + ** 2 + + 1 + ), + self.model_dim, + ), + ) + else: + vision_tokens = self.vision_model(stacked_images, aspect_ratios) + + vision_tokens = vision_tokens.to("cuda") + + bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape) + xattn_caches = torch.stack( + [ + layer.compute_xattn_kv_cache( + vision_tokens.view(bsz, -1, image_token_dim) + ) + for layer in self.text_model.cross_attention_layers + ] + ) + padded_masks = _pad_masks( + batch_masks, + num_chunks, + total_len, + self.max_num_chunks, + ) + + cross_attention_masks, full_text_row_masked_out_mask = ( + self.text_model._get_xattn_mask( + num_tokens=total_len, + text_device="cuda", + text_dtype=next(self.text_model.parameters()).dtype, + vision_tokens=vision_tokens, + cross_attention_masks=padded_masks, + ) + ) + + return (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask) + + def forward( + self, + position_ids: torch.Tensor, + tokens: torch.Tensor, + cross_attention_masks: torch.Tensor, + full_text_row_masked_out_mask: torch.Tensor, + xattn_caches: torch.Tensor, + ) -> torch.Tensor: + h = self.text_model.get_partially_trainable_embedding(tokens[:, position_ids]) + logits = self.text_model.forward( + position_ids=position_ids, + h=h, + xattn_mask=cross_attention_masks[:, :, position_ids], + full_text_row_masked_out_mask=full_text_row_masked_out_mask[ + :, :, position_ids + ], + xattn_caches=xattn_caches, + ) + return logits + + +def _stack_images( + images: List[List[PIL_Image.Image]], + max_num_chunks: int, + image_res: int, + max_num_images: int, +) -> Tuple[torch.Tensor, List[int]]: + """ + Takes a list of list of images and stacks them into a tensor. + This function is needed since images can be of completely + different resolutions and aspect ratios. + """ + out_images, out_num_chunks = [], [] + for imgs_sample in images: + out_images_i = torch.zeros( + max_num_images, + max_num_chunks, + 3, + image_res, + image_res, + ) + _num_chunks = [] + for j, chunks_image in enumerate(imgs_sample): + out_images_i[j, : chunks_image.shape[0]] = chunks_image + _num_chunks.append(chunks_image.shape[0]) + out_images.append(out_images_i) + out_num_chunks.append(_num_chunks) + return torch.stack(out_images), out_num_chunks + + +def _pad_masks( + all_masks: List[List[List[int]]], + all_num_chunks: List[List[int]], + total_len: int, + max_num_chunks: int, +) -> torch.Tensor: + dtype = torch.bfloat16 + inf_value = get_negative_inf_value(dtype) + + bsz = len(all_masks) + max_num_media = max([len(m) for m in all_masks]) + + out_masks = torch.full( + (bsz, total_len, max_num_media, max_num_chunks), + inf_value, + dtype=dtype, + ) + + for idx, (mask, num_chunks) in enumerate(zip(all_masks, all_num_chunks)): + for mask_idx, (mask_elem, mask_num_chunks) in enumerate(zip(mask, num_chunks)): + if len(mask_elem) == 2: + mask_elem[1] = min(mask_elem[1], total_len) + if mask_elem[1] == -1: + mask_elem[1] = total_len + out_masks[ + idx, mask_elem[0] : mask_elem[1], mask_idx, :mask_num_chunks + ].fill_(0.0) + + return out_masks + + @MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llama_image_tokens) class LlamaVLForCausalLM(nn.Module, SupportsMultiModal): @@ -46,4 +1618,4 @@ def __init__(self, config, def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, weight in weights: print(name, weight.shape) - + From 5233e2d857d2b719c52ae74c316c820ad15a0678 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 2 Sep 2024 00:10:55 -0700 Subject: [PATCH 04/75] add LlamaVLConfig --- examples/offline_inference_vision_language.py | 2 +- vllm/transformers_utils/config.py | 3 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/llamavl.py | 39 +++++++++++++++++++ 4 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 vllm/transformers_utils/configs/llamavl.py diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 76b1bed9a421..2ecf40f10b2d 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -162,7 +162,7 @@ def run_blip2(question): def run_llama(question, size: str): checkpoint_dir = "/data/zhang-chen/llama/checkpoints" # update checkpoint path here - llm = LLM(model=f"{checkpoint_dir}/Meta-Llama-3.2-{size}-Vision-Early/") + llm = LLM(model=f"{checkpoint_dir}/Meta-Llama-3.2-{size}-Vision-Early/", enforce_eager=True) raise NotImplementedError model_example_map = { diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index c2276b075c1d..3953eaa33b92 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -15,7 +15,7 @@ JAISConfig, MedusaConfig, MLPSpeculatorConfig, MPTConfig, NemotronConfig, RWConfig, - UltravoxConfig) + UltravoxConfig, LlamaVLConfig) if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -37,6 +37,7 @@ "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, "ultravox": UltravoxConfig, + "llamavl": LlamaVLConfig, } for name, cls in _CONFIG_REGISTRY.items(): diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index dc2fd6a859e3..8b1d555f33b6 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -12,6 +12,7 @@ from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig +from vllm.transformers_utils.configs.llamavl import LlamaVLConfig __all__ = [ "ChatGLMConfig", @@ -22,6 +23,7 @@ "JAISConfig", "MedusaConfig", "EAGLEConfig", + "LlamaVLConfig", "MLPSpeculatorConfig", "NemotronConfig", "UltravoxConfig", diff --git a/vllm/transformers_utils/configs/llamavl.py b/vllm/transformers_utils/configs/llamavl.py new file mode 100644 index 000000000000..7af55b1b720a --- /dev/null +++ b/vllm/transformers_utils/configs/llamavl.py @@ -0,0 +1,39 @@ +from transformers import PretrainedConfig +from typing import Optional + + +class LlamaVLConfig(PretrainedConfig): + model_type = "llamavl" + + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: Optional[int] = None + vocab_size: int = -1 + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: Optional[float] = None + norm_eps: float = 1e-5 + rope_theta: float = 500000 + use_scaled_rope: bool = False + + max_batch_size: int = 32 + max_seq_len: int = 2048 + + # vision model params + vision_chunk_size: int = -1 # image resolution for image models + vision_max_num_chunks: int = 4 + vision_num_cross_attention_layers: int = -1 + + model_type: str = "llamavl" + architectures: list[str] = ["LlamaVLForCausalLM"] + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if hasattr(self, k): + setattr(self, k, v) + + if self.n_kv_heads is None: + self.n_kv_heads = self.n_heads + assert self.n_kv_heads <= self.n_heads + assert self.n_heads % self.n_kv_heads == 0 + assert self.dim % self.n_heads == 0 From 72b9a8a903cd0c8446f98da00681ca0f703e4c0b Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 2 Sep 2024 14:45:21 -0700 Subject: [PATCH 05/75] can load weight, attention is ignored --- vllm/model_executor/models/llamavl.py | 965 +++++++++++++++++--------- 1 file changed, 651 insertions(+), 314 deletions(-) diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py index 25af74715f3f..4181d9b1e429 100644 --- a/vllm/model_executor/models/llamavl.py +++ b/vllm/model_executor/models/llamavl.py @@ -1,10 +1,16 @@ +from dataclasses import dataclass +from functools import partial import itertools +import collections +import math from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, - TypedDict, Union) + TypedDict, Union, Callable, Dict, Any) import torch import torch.nn as nn -# from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig +import torch.nn.functional as F +import torchvision.transforms as tv +from PIL import Image from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig @@ -13,12 +19,24 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.logger import init_logger from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsMultiModal +from .llama import LlamaAttention +from vllm.model_executor.layers.layernorm import RMSNorm +# from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, +# QKVParallelLinear, +# RowParallelLinear, +# ColumnParallelLinear) +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) + logger = init_logger(__name__) +MP_SCALE = 8 +IMAGE_RES = 224 def get_max_llama_image_tokens(ctx: InputContext) -> int: logger.warning("need further check on max llama image tokens") @@ -26,6 +44,232 @@ def get_max_llama_image_tokens(ctx: InputContext) -> int: print(ctx) return 1025 * 2 + +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + +def resize_local_position_embedding(orig_pos_embed, grid_size): + """ + Resize position embedding for vision encoder. + Original position embedding is [n_tiles * n_tiles + 1, dim] + New position embedding will be [grid_size[0] * grid_size[1] + 1, dim] + """ + new_grid_size = to_2tuple(grid_size) + orig_grid_size = to_2tuple(int(math.sqrt(len(orig_pos_embed) - 1))) + new_seq_len = new_grid_size[0] * new_grid_size[1] + 1 + + new_pos_emb_tok, new_pos_emb_img = ( + orig_pos_embed[:1], + orig_pos_embed[1:], + ) + logger.info( + f"resizing position embedding grid-size from {orig_grid_size} to {new_grid_size}" + ) + + new_pos_emb_img = new_pos_emb_img.reshape( + 1, orig_grid_size[0], orig_grid_size[1], -1 + ).permute(0, 3, 1, 2) + + new_pos_emb_img = F.interpolate( + new_pos_emb_img, + size=new_grid_size, + mode="bilinear", + align_corners=True, + ) + new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1).reshape( + 1, new_grid_size[0] * new_grid_size[1], -1 + )[0] + new_pos_embed = torch.cat([new_pos_emb_tok, new_pos_emb_img], dim=0) + return new_pos_embed + + +def initialize_global_position_embedding_from_local( + pos_and_cls_embed, grid_size, x_scale, y_scale +): + """ + Takes a local position embedding for vision encoder and uses it + to initialize the global position embedding. + Input: local position embedding of shape [grid_size[0] * grid_size[1] + 1, dim] + Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim] + Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively. + """ + pos_embed = pos_and_cls_embed[1:] + cls_embed = pos_and_cls_embed[0].view(1, 1, 1, -1) + grid_size = to_2tuple(grid_size) + new_pos_emb_img = pos_embed.reshape(1, grid_size[0], grid_size[1], -1).permute( + 0, 3, 1, 2 + ) + new_grid_size = (x_scale * grid_size[0], y_scale * grid_size[1]) + new_pos_emb_img = F.interpolate( + new_pos_emb_img, + size=new_grid_size, + mode="bilinear", + align_corners=True, + ) + new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1) + new_pos_emb_img = new_pos_emb_img.view( + x_scale, grid_size[0], y_scale, grid_size[1], -1 + ) + new_pos_emb_img = new_pos_emb_img.permute(0, 2, 1, 3, 4).contiguous() + new_pos_emb_img = new_pos_emb_img.reshape( + x_scale, y_scale, grid_size[0] * grid_size[1], -1 + ) + cls_embed = cls_embed.expand(x_scale, y_scale, -1, -1) + pos_and_cls_embed = torch.cat([cls_embed, new_pos_emb_img], dim=2) + return pos_and_cls_embed + + +def resize_global_position_embedding(pos_and_cls_embed, grid_size, x_scale, y_scale): + """ + Takes a global position embedding for vision encoder and resizes it to new size. + Input: global position embedding of shape [x_old, y_old, old_grid_size[0] * old_grid_size[1] + 1, dim] + Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim] + Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively. + """ + # first remove cls token + pos_embed = pos_and_cls_embed[:, :, 1:] + cls_embed = pos_and_cls_embed[:, :, 0].unsqueeze(2) + + xs_old, ys_old, ntok, dim = pos_embed.shape + old_grid_size = int(math.sqrt(ntok)) + + # move to correct form for interpolation + pos_embed = pos_embed.view(xs_old, ys_old, old_grid_size, old_grid_size, dim) + pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous() + pos_embed = pos_embed.view(xs_old * old_grid_size, ys_old * old_grid_size, dim) + pos_embed = pos_embed.unsqueeze(0) + + # interpolate + new_size = (grid_size[0] * x_scale, grid_size[1] * y_scale) + pos_embed = pos_embed.permute(0, 3, 1, 2) + pos_embed_resized = F.interpolate( + pos_embed, + size=new_size, + mode="bilinear", + align_corners=True, + ) + pos_embed = pos_embed_resized.permute(0, 2, 3, 1)[0] + + # move it back in place + pos_embed = pos_embed.view(x_scale, grid_size[0], y_scale, grid_size[1], dim) + pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous() + pos_embed = pos_embed.view(x_scale, y_scale, grid_size[0] * grid_size[1], dim) + + # interpolate cls token + cls_embed = cls_embed.permute(2, 3, 0, 1) + cls_embed_resized = F.interpolate( + cls_embed, + size=(x_scale, y_scale), + mode="bilinear", + align_corners=True, + ) + cls_embed = cls_embed_resized.permute(2, 3, 0, 1) + # add cls token back in + pos_and_cls_embed = torch.cat([cls_embed, pos_embed], dim=2) + + return pos_and_cls_embed + + +def build_encoder_attention_mask( + x: torch.Tensor, + ar: torch.Tensor, + ntok: int, + num_chunks: int, + n_heads: int, +): + """ + Build vision encoder attention mask that omits padding tokens. + """ + masks = [] + for arx in ar: + mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype) + mask_i[: arx[0] * arx[1], :ntok] = 0 + mask_i = mask_i.view(num_chunks * x.shape[2], -1) + mask_i = mask_i @ mask_i.T * torch.finfo(x.dtype).min + mask_i = mask_i.unsqueeze(0) + masks.append(mask_i) + masks = torch.stack(masks).to(x.device).expand(-1, n_heads, -1, -1) + return masks + + +def expand_num_tokens_to_mult8(x): + num_pad_tokens = 8 - (x.shape[-2] % 8) + if num_pad_tokens == 0: + return x, 0 + else: + return ( + torch.cat( + [ + x, + torch.zeros( + (x.shape[0], x.shape[1], num_pad_tokens, x.shape[-1]), + dtype=x.dtype, + device=x.device, + ), + ], + dim=-2, + ), + num_pad_tokens, + ) + + +def contract_num_tokens_from_mult8(x, num_pad_tokens): + if num_pad_tokens == 0: + return x + return x[:, :, :-num_pad_tokens] + +def apply_scaling(freqs: torch.Tensor): + # Values obtained from grid search + scale_factor = 8 + low_freq_factor = 1 + high_freq_factor = 4 + old_context_len = 8192 # original llama3 length + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) + + +def precompute_freqs_cis( + dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False +): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + if use_scaled: + freqs = apply_scaling(freqs) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + +def _get_full_row_masked_out_mask( + attn_bias, + negative_inf_value, +): + """ + attn_bias should be a 4D tensor of shape [B, H, S1, S2] + where B is the batch size, H is the number of heads, + and S1/S2 are the sequence lengths. This returns + a 4D tensor of shape [B, H, S1, 1] which stores boolean + values which are 0 if the a full row in the last dimension + contains negative infinity values, otherwise it's 1. + """ + return (attn_bias != negative_inf_value).any(dim=-1).type_as(attn_bias)[..., None] + # Image encoder for inference class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16.""" @@ -54,13 +298,18 @@ def __init__( out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], - bias: Optional[bool] = False, + bias: bool = False, ) -> None: super().__init__() if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride) - self._linear = ColumnParallelLinear( + # self._linear = ColumnParallelLinear( + # in_channels * kernel_size[0] * kernel_size[1], + # out_channels, + # bias=bias, + # ) + self._linear = nn.Linear( in_channels * kernel_size[0] * kernel_size[1], out_channels, bias=bias, @@ -69,8 +318,9 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: x = self._unfold(x) x = x.permute(0, 2, 1) - x = F.linear(x, self._linear.weight) - x = gather_from_tensor_model_parallel_region(x) + x = self._linear(x) + # x = F.linear(x, self._linear.weight) + # x = gather_from_tensor_model_parallel_region(x) return x @@ -84,29 +334,33 @@ def __init__( ): super().__init__() # layers - self.c_fc = ColumnParallelLinear( - dim, - hidden_dim, - bias=True, - gather_output=False, - init_method=lambda x: x, - ) - self.c_proj = RowParallelLinear( - hidden_dim, - dim, - bias=True, - input_is_parallel=True, - init_method=lambda x: x, - ) + self.c_fc = nn.Linear(dim, hidden_dim, bias=True) + # self.c_fc = ColumnParallelLinear( + # dim, + # hidden_dim, + # bias=True, + # gather_output=False, + # init_method=lambda x: x, + # ) + self.c_proj = nn.Linear(hidden_dim, dim, bias=True) + # self.c_proj = RowParallelLinear( + # hidden_dim, + # dim, + # bias=True, + # input_is_parallel=True, + # init_method=lambda x: x, + # ) self.non_linearity = act_layer() self.dropout = dropout def forward(self, x): - hidden = F.linear(x, self.c_fc.weight, self.c_fc.bias) + hidden = self.c_fc(x) + # hidden = F.linear(x, self.c_fc.weight, self.c_fc.bias) hidden = self.non_linearity(hidden) - hidden = F.linear(hidden, self.c_proj.weight) - hidden = reduce_from_tensor_model_parallel_region(hidden) - hidden += self.c_proj.bias + hidden = self.c_proj(hidden) + # hidden = F.linear(hidden, self.c_proj.weight) + # hidden = reduce_from_tensor_model_parallel_region(hidden) + # hidden += self.c_proj.bias return hidden @@ -118,7 +372,7 @@ def __init__( n_heads, ): super().__init__() - model_parallel_size = fs_init.get_model_parallel_world_size() + model_parallel_size = get_tensor_model_parallel_world_size() qkvo_replication = 1 if model_parallel_size > 16: qkvo_replication = model_parallel_size // 8 @@ -131,34 +385,40 @@ def __init__( self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = dim // n_heads - self.wq = ColumnParallelLinear( - dim, - qkvo_replication * n_heads * self.head_dim, - bias=True, - gather_output=False, - init_method=lambda x: x, - ) - self.wk = ColumnParallelLinear( - dim, - qkvo_replication * self.n_kv_heads * self.head_dim, - bias=True, - gather_output=False, - init_method=lambda x: x, - ) - self.wv = ColumnParallelLinear( - dim, - qkvo_replication * self.n_kv_heads * self.head_dim, - bias=True, - gather_output=False, - init_method=lambda x: x, - ) - self.wo = RowParallelLinear( - qkvo_replication * n_heads * self.head_dim, - dim, - bias=True, - input_is_parallel=True, - init_method=lambda x: x, - ) + # The model provided by llama is with bias=True, but the weight does not contain bias + # During runtime, the llama executor set bias to zero. We use bias=False here to match the behavior + self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False) + self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False) + # self.wq = ColumnParallelLinear( + # dim, + # qkvo_replication * n_heads * self.head_dim, + # bias=True, + # gather_output=False, + # init_method=lambda x: x, + # ) + # self.wk = ColumnParallelLinear( + # dim, + # qkvo_replication * self.n_kv_heads * self.head_dim, + # bias=True, + # gather_output=False, + # init_method=lambda x: x, + # ) + # self.wv = ColumnParallelLinear( + # dim, + # qkvo_replication * self.n_kv_heads * self.head_dim, + # bias=True, + # gather_output=False, + # init_method=lambda x: x, + # ) + # self.wo = RowParallelLinear( + # qkvo_replication * n_heads * self.head_dim, + # dim, + # bias=True, + # input_is_parallel=True, + # init_method=lambda x: x, + # ) self.qkvo_replication = qkvo_replication def forward( @@ -194,7 +454,7 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous().reshape(bs, slen, -1) out = F.linear(attn_output, self.wo.weight) - out = reduce_from_tensor_model_parallel_region(out) + # out = reduce_from_tensor_model_parallel_region(out) out = out / self.qkvo_replication out += self.wo.bias return out @@ -284,7 +544,7 @@ class VisionEncoder(nn.Module): def __init__( self, max_num_tiles: int, - ckpt_path: str = None, + # ckpt_path: str = None, image_size: int = 224, patch_size: int = 14, width: int = 1280, @@ -293,7 +553,7 @@ def __init__( mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU, in_channels: int = 3, - load_ckpt: bool = False, + # load_ckpt: bool = False, n_global_layers: int = 2, global_model: bool = False, return_intermediate=None, @@ -484,170 +744,6 @@ def forward(self, images: torch.Tensor, ar: torch.Tensor) -> torch.Tensor: return x -class Attention(nn.Module): - """Multi-head attention module.""" - - def __init__(self, args: ModelArgs): - """ - Initialize the Attention module. - Args: - args (ModelArgs): Model configuration parameters. - Attributes: - n_kv_heads (int): Number of key and value heads. - n_local_heads (int): Number of local query heads. - n_local_kv_heads (int): Number of local key and value heads. - n_rep (int): Number of repetitions for local heads. - head_dim (int): Dimension size of each attention head. - wq (ColumnParallelLinear): Linear transformation for queries. - wk (ColumnParallelLinear): Linear transformation for keys. - wv (ColumnParallelLinear): Linear transformation for values. - wo (RowParallelLinear): Linear transformation for output. - cache_k (torch.Tensor): Cached keys for attention. - cache_v (torch.Tensor): Cached values for attention. - """ - super().__init__() - model_parallel_size = fs_init.get_model_parallel_world_size() - replication_factor = 1 - if model_parallel_size > 8: - replication_factor = model_parallel_size // MP_SCALE - - self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads - self.n_kv_heads *= replication_factor - - self.n_local_heads = args.n_heads // model_parallel_size - self.n_local_kv_heads = self.n_kv_heads // model_parallel_size - self.n_rep = self.n_local_heads // self.n_local_kv_heads - self.head_dim = args.dim // args.n_heads - self.max_seq_len = args.max_seq_len - - self.wq = ColumnParallelLinear( - args.dim, - args.n_heads * self.head_dim, - bias=False, - gather_output=False, - init_method=lambda x: x, - ) - self.wk = ColumnParallelLinear( - args.dim, - self.n_kv_heads * self.head_dim, - bias=False, - gather_output=False, - init_method=lambda x: x, - ) - self.wv = ColumnParallelLinear( - args.dim, - self.n_kv_heads * self.head_dim, - bias=False, - gather_output=False, - init_method=lambda x: x, - ) - self.wo = RowParallelLinear( - args.n_heads * self.head_dim, - args.dim, - bias=False, - input_is_parallel=True, - init_method=lambda x: x, - ) - self.n_heads = args.n_heads - - self._register_load_state_dict_pre_hook(self.load_hook) - - def load_hook( - self, - state_dict: Dict[str, Any], - prefix: str, - local_metadata: Dict[str, Any], - strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], - ) -> None: - if prefix + "wqkv.weight" in state_dict: - total_n_heads = self.n_heads + self.n_kv_heads * 2 - wqkv = state_dict.pop(prefix + "wqkv.weight") - head_dim = wqkv.shape[0] // total_n_heads - dim1 = head_dim * self.n_heads - dim2 = dim1 + head_dim * self.n_kv_heads - dim3 = dim1 + head_dim * self.n_kv_heads * 2 - - wq = wqkv[:dim1] - wk = wqkv[dim1:dim2] - wv = wqkv[dim2:dim3] - - state_dict[prefix + "wq.weight"] = wq - state_dict[prefix + "wk.weight"] = wk - state_dict[prefix + "wv.weight"] = wv - - def setup_cache(self, max_batch_size: int, dtype: torch.dtype): - cache_shape = ( - max_batch_size, - self.max_seq_len, - self.n_local_kv_heads, - self.head_dim, - ) - device = next(self.parameters()).device - self.register_buffer( - "key_cache", - torch.zeros( - cache_shape, - dtype=dtype, - device=device, - ), - persistent=False, - ) - self.register_buffer( - "value_cache", - torch.zeros( - cache_shape, - dtype=dtype, - device=device, - ), - persistent=False, - ) - - def forward( - self, - x: torch.Tensor, - mask: torch.Tensor, - freqs_cis: torch.Tensor, - position_ids: torch.LongTensor, - ): - - xq, xk, xv = [ - F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight] - ] - - bs, slen, _ = xq.shape - - xq = xq.view(bs, slen, self.n_local_heads, self.head_dim) - xk = xk.view(bs, xk.shape[1], self.n_local_kv_heads, self.head_dim) - xv = xv.view(bs, xv.shape[1], self.n_local_kv_heads, self.head_dim) - - xq, xk = apply_rotary_emb(xq, xk, freqs_cis) - - self.key_cache[:bs, position_ids, ...] = xk - self.value_cache[:bs, position_ids, ...] = xv - - # TODO: we can avoid slicing on first dimension by always padding to max_batch_size() - xk = self.key_cache[:bs, ...] - xv = self.value_cache[:bs, ...] - - xq, xk, xv = [tensor.transpose(1, 2) for tensor in (xq, xk, xv)] - - xk = xk.repeat_interleave(self.n_rep, dim=1) - xv = xv.repeat_interleave(self.n_rep, dim=1) - - attn_output = F.scaled_dot_product_attention( - xq, xk, xv, attn_mask=mask, dropout_p=0.0 - ) - - attn_output = attn_output.transpose(1, 2).contiguous().reshape(bs, slen, -1) - - out = F.linear(attn_output, self.wo.weight) - out = reduce_from_tensor_model_parallel_region(out) - return out - - class FeedForward(nn.Module): def __init__( self, @@ -675,15 +771,18 @@ def __init__( hidden_dim = int(ffn_dim_multiplier * hidden_dim) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - self.w1 = ColumnParallelLinear( - dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x - ) - self.w2 = RowParallelLinear( - hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x - ) - self.w3 = ColumnParallelLinear( - dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x - ) + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + # self.w1 = ColumnParallelLinear( + # dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x + # ) + # self.w2 = RowParallelLinear( + # hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x + # ) + # self.w3 = ColumnParallelLinear( + # dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x + # ) self._register_load_state_dict_pre_hook(self.load_hook) def forward(self, x): @@ -691,7 +790,7 @@ def forward(self, x): x1 = F.silu(x1) x_in = x1 * x3 out = F.linear(x_in, self.w2.weight) - out = reduce_from_tensor_model_parallel_region(out) + # out = reduce_from_tensor_model_parallel_region(out) return out def load_hook( @@ -715,7 +814,7 @@ def load_hook( class TransformerBlock(nn.Module): - def __init__(self, layer_id: int, args: ModelArgs): + def __init__(self, layer_id: int, args): """ Initialize a TransformerBlock. Args: @@ -735,7 +834,8 @@ def __init__(self, layer_id: int, args: ModelArgs): self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.dim // args.n_heads - self.attention = Attention(args) + # self.attention = Attention(args) + logger.warning("skip attention") self.feed_forward = FeedForward( dim=args.dim, hidden_dim=4 * args.dim, @@ -786,13 +886,15 @@ def forward( Returns: torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = self.attention.forward( - x=self.attention_norm(x), - freqs_cis=freqs_cis, - mask=mask, - position_ids=position_ids, - ) - h = h + x + # h = self.attention.forward( + # x=self.attention_norm(x), + # freqs_cis=freqs_cis, + # mask=mask, + # position_ids=position_ids, + # ) + # h = h + x + h = x + logger.warning("skip attention") out = h + self.feed_forward.forward(self.ffn_norm(h)) return out @@ -831,7 +933,7 @@ def load_hook( if embed is not None: # reshape the weights to the correct shape nt_old, nt_old, _, w = embed.shape - logging.info( + logger.info( f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}" ) embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles) @@ -887,7 +989,7 @@ def __init__( norm_eps: float, ): super().__init__() - self.model_parallel_size = fs_init.get_model_parallel_world_size() + self.model_parallel_size = get_tensor_model_parallel_world_size() replication_factor = 1 if self.model_parallel_size > 8: replication_factor = self.model_parallel_size // MP_SCALE @@ -895,35 +997,40 @@ def __init__( assert n_heads % n_kv_heads == 0 - self.wq = ColumnParallelLinear( - dim, - n_heads * head_dim, - bias=False, - gather_output=False, - init_method=_noinit, - ) - self.wk = ColumnParallelLinear( - dim, - n_kv_heads * head_dim, - bias=False, - gather_output=False, - init_method=_noinit, - ) - self.wv = ColumnParallelLinear( - dim, - n_kv_heads * head_dim, - bias=False, - gather_output=False, - init_method=_noinit, - ) - self.wo = RowParallelLinear( - n_heads * head_dim, - dim, - bias=False, - input_is_parallel=True, - init_method=_noinit, - ) + self.wq = nn.Linear(dim, n_heads * head_dim, bias=False) + self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.wo = nn.Linear(n_heads * head_dim, dim, bias=False) + # self.wq = ColumnParallelLinear( + # dim, + # n_heads * head_dim, + # bias=False, + # gather_output=False, + # init_method=_noinit, + # ) + + # self.wk = ColumnParallelLinear( + # dim, + # n_kv_heads * head_dim, + # bias=False, + # gather_output=False, + # init_method=_noinit, + # ) + # self.wv = ColumnParallelLinear( + # dim, + # n_kv_heads * head_dim, + # bias=False, + # gather_output=False, + # init_method=_noinit, + # ) + # self.wo = RowParallelLinear( + # n_heads * head_dim, + # dim, + # bias=False, + # input_is_parallel=True, + # init_method=_noinit, + # ) self.n_heads = n_heads self.head_dim = head_dim @@ -1018,7 +1125,7 @@ def forward( output = output.transpose(1, 2).contiguous().reshape(bsz, seqlen, -1) out = F.linear(output, self.wo.weight) - out = reduce_from_tensor_model_parallel_region(out) + # out = reduce_from_tensor_model_parallel_region(out) return out @@ -1027,7 +1134,7 @@ class CrossAttentionTransformerBlock(torch.nn.Module): def __init__( self, - args: ModelArgs, + args, layer_id: int, no_ffn: bool = False, ) -> None: @@ -1063,6 +1170,7 @@ def __init__( ) self.gate_ffwd = torch.nn.Parameter(torch.zeros(1)) + logger.warning("todo hook") self._register_load_state_dict_pre_hook(self.load_hook) self.no_ffn = no_ffn @@ -1147,7 +1255,7 @@ def __call__( class CrossAttentionTransformerVision(torch.nn.Module): - def __init__(self, args: ModelArgs) -> None: + def __init__(self, args) -> None: super().__init__() return_intermediate = "3,7,15,23,30" self.vision_input_dim = 1280 @@ -1168,12 +1276,17 @@ def __init__(self, args: ModelArgs) -> None: return_intermediate=return_intermediate, ) # vision token projection - self.vision_projection = ColumnParallelLinear( + self.vision_projection = nn.Linear( self.vision_input_dim, args.dim, bias=True, - init_method=lambda x: x, ) + # self.vision_projection = ColumnParallelLinear( + # self.vision_input_dim, + # args.dim, + # bias=True, + # init_method=lambda x: x, + # ) def forward( self, images: torch.Tensor, aspect_ratios: torch.Tensor @@ -1186,16 +1299,16 @@ def forward( ) vision_tokens = F.linear(vision_tokens, self.vision_projection.weight) - vision_tokens = gather_from_tensor_model_parallel_region(vision_tokens) + # vision_tokens = gather_from_tensor_model_parallel_region(vision_tokens) return vision_tokens class CrossAttentionTransformerText(torch.nn.Module): INFERENCE_IMAGE_TOKEN_ID = 128010 - def __init__(self, args: ModelArgs) -> None: + def __init__(self, args) -> None: super().__init__() - self.model_parallel_size = fs_init.get_model_parallel_world_size() + self.model_parallel_size = get_tensor_model_parallel_world_size() assert args.vocab_size > 0 self.vocab_size = args.vocab_size self.n_layers = args.n_layers @@ -1205,16 +1318,18 @@ def __init__(self, args: ModelArgs) -> None: self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size assert self.vocab_size % self.model_parallel_size == 0 self.tok_embeddings = VocabParallelEmbedding( - args.vocab_size, args.dim, init_method=lambda x: x + args.vocab_size, args.dim, + padding_size=self.model_parallel_size, ) self.pos_embeddings = None # final norm layer (not necessary for post-norm) self.norm = RMSNorm(args.dim, eps=args.norm_eps) # output layer - self.output = ColumnParallelLinear( - args.dim, args.vocab_size, bias=False, init_method=lambda x: x - ) + self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + # self.output = ColumnParallelLinear( + # args.dim, args.vocab_size, bias=False, init_method=lambda x: x + # ) self.n_llama_layers = args.n_layers self.model_dim = args.dim @@ -1225,9 +1340,9 @@ def __init__(self, args: ModelArgs) -> None: args.vision_num_cross_attention_layers ) self.learnable_embedding = VocabParallelEmbedding( - max(fs_init.get_model_parallel_world_size(), 8), + max(get_tensor_model_parallel_world_size(), 8), args.dim, - init_method=lambda x: x, + padding_size=self.model_parallel_size, ) self.num_frozen_embeddings = self.tok_embeddings.num_embeddings self._thresh = self.num_frozen_embeddings - 1 @@ -1350,7 +1465,7 @@ def forward( h = self.norm(h) output = F.linear(h, self.output.weight) - output = gather_from_tensor_model_parallel_region(output) + # output = gather_from_tensor_model_parallel_region(output) return output.float() def setup_cache(self, max_batch_size: int, dtype=torch.bfloat16): @@ -1381,7 +1496,7 @@ def _get_xattn_mask( text_dtype, vision_tokens, cross_attention_masks, - ) -> Tuple[Tensor, Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: assert vision_tokens is not None, "Vision tokens must be provided" vision_seqlen = vision_tokens.shape[3] assert ( @@ -1402,7 +1517,7 @@ def _get_xattn_mask( ) full_text_row_masked_out_mask = _get_full_row_masked_out_mask( cross_attention_masks, - get_negative_inf_value(cross_attention_masks.dtype), + torch.finfo(cross_attention_masks.dtype).min, ) cross_attention_masks *= full_text_row_masked_out_mask @@ -1412,11 +1527,251 @@ def _get_xattn_mask( ) -class CrossAttentionTransformer(torch.nn.Module): - def __init__(self, args: ModelArgs) -> None: +class VariableSizeImageTransform(object): + """ + The variable size image transform will resize the image dynamically + based on the image aspect ratio and the number of image chunks we allow. + The algorithm will not upsample low-res images to fit a certain aspect + ratio, because that leads to a significant degradation in image quality. + For example, if an input image is of size 300x800, and we want to allow + a maximum of 16 image chunks, it will find the closest aspect ratio that + is allowed within 16 image chunks, i.e., 2:5 = 2 horizontal patches and + 5 vertical patches, giving a total of 10 chunks. + The image will then be resized to products of the base size (default is + 224px because MetaCLIP takes that), so in this case it will be resized to + 2*224:5*224 = 448:1120, where we maintain the original aspect ratio and + pad with the mean value for the rest. This approach minimizes the amount + of padding required for any arbitrary resolution. + The final output will therefore be of shape (11, 3, 224, 224), where 10 + patches are coming from the resizing and chunking, and the first patch + is a downsampled version of the image that preserves aspect ratios. + """ + + def __init__(self, size: int = IMAGE_RES) -> None: + self.size = size + self.to_tensor = tv.ToTensor() + self._mean = (0.48145466, 0.4578275, 0.40821073) + self._std = (0.26862954, 0.26130258, 0.27577711) + self.normalize = tv.Normalize( + mean=self._mean, + std=self._std, + inplace=True, + ) + + @staticmethod + def _factors(n: int): + """Return all factors of a number.""" + return set( + reduce( + list.__add__, + ([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0), + ) + ) + + def _find_supported_aspect_ratios(self, num_chunks: int): + """ + This function computes all the allowed aspect ratios for a fixed + number of input chunks. + For example, with `num_chunks=5`, it will return: + { + 0.2: [(1, 5)], + 5.0: [(5, 1)], + 0.25: [(1, 4)], + 1.0: [(2, 2), (1, 1)], + 4.0: [(4, 1)], + 0.3333333333333333: [(1, 3)], + 3.0: [(3, 1)], + 0.5: [(1, 2)], + 2.0: [(2, 1)] + } + """ + asp_dict = {} + for chunk_size in range(num_chunks, 0, -1): + _factors = sorted(VariableSizeImageTransform._factors(chunk_size)) + _asp_ratios = [(x, chunk_size // x) for x in _factors] + for ratio in _asp_ratios: + k = ratio[0] / ratio[1] + if k not in asp_dict: + asp_dict[k] = [ratio] + else: + asp_dict[k].append(ratio) + return asp_dict + + def _find_closest_aspect_ratio( + self, num_chunks: int, img_width: int, img_height: int + ) -> Tuple: + """ + Given an image width, height and target number of chunks + this function will find the closest supported aspect ratio. + """ + tgt_ar = img_width / img_height + asp_dict = self._find_supported_aspect_ratios(num_chunks) + cl_d, cl_p = 1e23, None + if tgt_ar >= 1: + cl_p = min( + [k for k in asp_dict.keys() if k <= tgt_ar], + key=lambda x: abs(x - tgt_ar), + ) + v = asp_dict[cl_p] + # select width + widths = [(idx, self.size * vv[0]) for idx, vv in enumerate(v)] + tgt_idx = max(widths, key=lambda x: x[1])[0] + else: + cl_p = min( + [k for k in asp_dict.keys() if k > tgt_ar], + key=lambda x: abs(1 / x - 1 / tgt_ar), + ) + v = asp_dict[cl_p] + # select height + heights = [(idx, self.size * vv[1]) for idx, vv in enumerate(v)] + tgt_idx = max(heights, key=lambda x: x[1])[0] + out = v[tgt_idx] + return out + + def _resize( + self, image: Image.Image, target_width: int, target_height: int + ) -> Image.Image: + # Resize longer edge to given size. + w, h = image.size + scale = w / h + + if scale > 1.0: + # width > height + new_w = target_width + new_h = math.floor(new_w / scale) + else: + # height >= width + new_h = target_height + new_w = math.floor(new_h * scale) + + image = F.resize(image, (new_h, new_w)) + return image + + def _resize_max_side_to_size( + self, + image: Image.Image, + ) -> Image.Image: + # Resize longer edge to given size. + w, h = image.size + scale = w / h + + if scale > 1.0: + # width > height + new_w = max(self.size, w) + new_h = math.floor(new_w / scale) + else: + # height >= width + new_h = max(self.size, h) + new_w = math.floor(new_h * scale) + + image = F.resize(image, (new_h, new_w)) + return image + + def _pad(self, image: Image.Image, new_width: int, new_height: int) -> Image.Image: + mean_per_channel = tuple( + np.clip(np.array(image).mean(axis=(0, 1)), 0, 255).astype(np.uint8) + ) + new_im = Image.new(mode="RGB", size=(new_height, new_width), color=(0, 0, 0)) # type: ignore + new_im.paste(image) + return new_im + + def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor: + # Split image into number of required tiles (width x height) + num_channels, height, width = image.size() + image = image.view(num_channels, nch, height // nch, ncw, width // ncw) + # Permute dimensions to reorder the axes + image = image.permute(1, 3, 0, 2, 4).contiguous() + # Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2) + image = image.view(ncw * nch, num_channels, height // nch, width // ncw) + return image + + def _fit_image_to_canvas( + self, num_chunks: int, img_width: int, img_height: int + ) -> Any: + """ + Given an image width, height and target number of chunks this function will see if the image + can be fit into any of the canvases that can be build from arranging the tiles in a grid. + If the image can be fit onto several canvases, it will return the canvas where the shorter edge + of the image will be largest. + """ + # Initialize the optimal canvas to None. If no canvas is found where image fits, function returns None. + optimal_canvas = None + optimal_image_width_height = None + + scale = img_width / img_height + + # Gather all potential supported image resolutions and iterate through them to find best match + potential_arrangements = [ + item + for sublist in self._find_supported_aspect_ratios(num_chunks).values() + for item in sublist + ] + current_gap = 1e23 + for n_w, n_h in potential_arrangements: + # Compute the canvas size + canvas_width, canvas_height = n_w * self.size, n_h * self.size + + # Check if image can fit into the canvas without downsampling + if canvas_width >= img_width and canvas_height >= img_height: + # If we did not find a good canvas yet, we will use the current one + if optimal_canvas is None: + # Set optimal canvas and determine the actual image height and width in the canvas with aspect ratio preserving resampling + optimal_canvas = (n_w, n_h) + optimal_image_width_height = (n_w * self.size, n_h * self.size) + else: + # Find closest fit based on gap + image_width_height = (n_w * self.size, n_h * self.size) + gap = abs(img_width - image_width_height[0]) + abs( + img_height - image_width_height[1] + ) + if gap < current_gap: + # If the gap is smaller than the previous one, we will update our optimal canvas and image width height + optimal_canvas = (n_w, n_h) + optimal_image_width_height = image_width_height + current_gap = gap + return optimal_canvas + + def __call__(self, image: Image.Image, max_num_chunks: int) -> Tuple[Any, Any]: + assert max_num_chunks > 0 + assert isinstance(image, Image.Image), type(image) + w, h = image.size + # Check if the image can be fit to the canvas without downsampling + ar = self._fit_image_to_canvas( + num_chunks=max_num_chunks, img_width=w, img_height=h + ) + if ar is None: + # If we did not find a canvas, we have to find the closest aspect ratio and downsample the image + ar = self._find_closest_aspect_ratio( + num_chunks=max_num_chunks, img_width=w, img_height=h + ) + image = self._resize(image, ar[0] * self.size, ar[1] * self.size) + else: + image = self._resize_max_side_to_size(image) + image = self._pad(image, ar[1] * self.size, ar[0] * self.size) + image = self.to_tensor(image) + image = self.normalize(image) + image = self._split(image, ar[0], ar[1]) # type: ignore + return image, ar + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llama_image_tokens) +class LlamaVLForCausalLM(nn.Module, SupportsMultiModal): + def __init__(self, config, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.params = args + print("config", type(config)) + print(config) + print("multimodal_config", type(multimodal_config)) + print(multimodal_config) + print("cache_config", type(cache_config)) + print(cache_config) + print("quant_config", type(quant_config)) + print(quant_config) + # self.params = args + args = config self.model_dim = args.dim self.vision_model = CrossAttentionTransformerVision(args) self.text_model = CrossAttentionTransformerText(args) @@ -1432,7 +1787,7 @@ def setup_cache(self, max_batch_size: int, dtype: torch.dtype): def compute_vision_tokens_masks( self, - batch_images: List[List[PIL_Image.Image]], + batch_images: List[List[Image.Image]], batch_masks: List[List[List[int]]], total_len: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -1517,6 +1872,10 @@ def compute_vision_tokens_masks( return (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + state_dict = {name: weight for name, weight in weights} + self.load_state_dict(state_dict, strict=False) + def forward( self, position_ids: torch.Tensor, @@ -1539,7 +1898,7 @@ def forward( def _stack_images( - images: List[List[PIL_Image.Image]], + images: List[List[Image.Image]], max_num_chunks: int, image_res: int, max_num_images: int, @@ -1574,7 +1933,7 @@ def _pad_masks( max_num_chunks: int, ) -> torch.Tensor: dtype = torch.bfloat16 - inf_value = get_negative_inf_value(dtype) + inf_value = torch.finfo(dtype).min bsz = len(all_masks) max_num_media = max([len(m) for m in all_masks]) @@ -1597,25 +1956,3 @@ def _pad_masks( return out_masks - -@MULTIMODAL_REGISTRY.register_image_input_mapper() -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llama_image_tokens) -class LlamaVLForCausalLM(nn.Module, SupportsMultiModal): - def __init__(self, config, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): - super().__init__() - print("config", type(config)) - print(config) - print("multimodal_config", type(multimodal_config)) - print(multimodal_config) - print("cache_config", type(cache_config)) - print(cache_config) - print("quant_config", type(quant_config)) - print(quant_config) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - for name, weight in weights: - print(name, weight.shape) - From 2dd36f5669889a8e8765521daccf4fb4902e48e8 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 2 Sep 2024 16:29:17 -0700 Subject: [PATCH 06/75] skip profile run by hardcode, can start model execution --- vllm/transformers_utils/configs/llamavl.py | 9 +++++++++ vllm/worker/worker.py | 14 ++++++++------ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/vllm/transformers_utils/configs/llamavl.py b/vllm/transformers_utils/configs/llamavl.py index 7af55b1b720a..f04a20ce4eaa 100644 --- a/vllm/transformers_utils/configs/llamavl.py +++ b/vllm/transformers_utils/configs/llamavl.py @@ -27,6 +27,13 @@ class LlamaVLConfig(PretrainedConfig): model_type: str = "llamavl" architectures: list[str] = ["LlamaVLForCausalLM"] + attribute_map = { + "num_hidden_layers": "n_layers", + "hidden_size": "dim", + "num_attention_heads": "n_heads", + "num_key_value_heads": "n_kv_heads", + } + def __init__(self, **kwargs): for k, v in kwargs.items(): if hasattr(self, k): @@ -37,3 +44,5 @@ def __init__(self, **kwargs): assert self.n_kv_heads <= self.n_heads assert self.n_heads % self.n_kv_heads == 0 assert self.dim % self.n_heads == 0 + + super().__init__(**kwargs) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 7ed609c3b447..44c237e23352 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -219,12 +219,14 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Execute a forward pass with dummy inputs to profile the memory usage # of the model. - self.model_runner.profile_run() - - # Calculate the number of blocks that can be allocated with the - # profiled peak memory. - torch.cuda.synchronize() - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + # self.model_runner.profile_run() + + # # Calculate the number of blocks that can be allocated with the + # # profiled peak memory. + # torch.cuda.synchronize() + # free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + free_gpu_memory = 40 * 1024 * 1024 * 1024 + total_gpu_memory = 80 * 1024 * 1024 * 1024 # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. peak_memory = self.init_gpu_memory - free_gpu_memory From affa9bafd4e0d8f1a24bf00b71d91d9587e3f192 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 2 Sep 2024 18:00:47 -0700 Subject: [PATCH 07/75] can run text tokenizer now --- vllm/transformers_utils/tokenizers/llamavl.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/vllm/transformers_utils/tokenizers/llamavl.py b/vllm/transformers_utils/tokenizers/llamavl.py index 9ae273da1db7..c81e1adf598f 100644 --- a/vllm/transformers_utils/tokenizers/llamavl.py +++ b/vllm/transformers_utils/tokenizers/llamavl.py @@ -32,8 +32,8 @@ # of max consecutive non-whitespace or whitespace characters. MAX_NO_WHITESPACES_CHARS = 25_000 - -class LlamaVLTokenizer(PreTrainedTokenizer): +# TODO: this class is with some hack. need toreplace with official release +class LlamaVLTokenizer: """ Tokenizing and encoding/decoding text using the Tiktoken tokenizer. """ @@ -52,7 +52,7 @@ def __init__(self, model_path: str): model_path (str): The path to the Tiktoken model file. """ assert os.path.isfile(model_path), model_path - + mergeable_ranks = load_tiktoken_bpe(model_path) num_base_tokens = len(mergeable_ranks) special_tokens = [ @@ -99,14 +99,14 @@ def __init__(self, model_path: str): self.special_tokens["<|eot_id|>"], ] + self.bos_token_id = self.bos_id + self.eos_token_id = self.eos_id + print("need to replace tokenizer with official release") + print("warning: recheck add bos and add eos of encode function") + def encode( self, s: str, - *, - bos: bool, - eos: bool, - allowed_special: Optional[Union[Literal["all"], AbstractSet[str]]] = None, - disallowed_special: Union[Literal["all"], Collection[str]] = (), ) -> List[int]: """ Encodes a string into a list of token IDs. @@ -129,8 +129,8 @@ def encode( - Setting `allowed_special` to "all" will treat all text corresponding to special tokens to be encoded as special tokens. """ - if allowed_special is None: - allowed_special = set() + bos = True + eos = False assert type(s) is str substrs = ( @@ -145,14 +145,15 @@ def encode( t.extend( self.model.encode( substr, - allowed_special=allowed_special, - disallowed_special=disallowed_special, + allowed_special=set(), + disallowed_special=set(), ) ) if bos: t.insert(0, self.bos_id) if eos: t.append(self.eos_id) + print("t:", t) return t def decode(self, t: Sequence[int]) -> str: From f633de555895593b7e8087241b991e4cee80c5bc Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 3 Sep 2024 16:45:29 -0700 Subject: [PATCH 08/75] finish image preprocessor --- vllm/transformers_utils/configs/llamavl.py | 2 + vllm/transformers_utils/image_processor.py | 4 + .../multimodal_processors/llamavl.py | 365 ++++++++++++++++++ vllm/transformers_utils/tokenizers/llamavl.py | 1 - 4 files changed, 371 insertions(+), 1 deletion(-) create mode 100644 vllm/transformers_utils/multimodal_processors/llamavl.py diff --git a/vllm/transformers_utils/configs/llamavl.py b/vllm/transformers_utils/configs/llamavl.py index f04a20ce4eaa..c5de04a35b29 100644 --- a/vllm/transformers_utils/configs/llamavl.py +++ b/vllm/transformers_utils/configs/llamavl.py @@ -27,6 +27,8 @@ class LlamaVLConfig(PretrainedConfig): model_type: str = "llamavl" architectures: list[str] = ["LlamaVLForCausalLM"] + torch_dtype: str = "bfloat16" + attribute_map = { "num_hidden_layers": "n_layers", "hidden_size": "dim", diff --git a/vllm/transformers_utils/image_processor.py b/vllm/transformers_utils/image_processor.py index c7d9eabd06f0..b792665fd9cd 100644 --- a/vllm/transformers_utils/image_processor.py +++ b/vllm/transformers_utils/image_processor.py @@ -14,6 +14,10 @@ def get_image_processor( from transformers.image_processing_utils import BaseImageProcessor try: + print("processor_name", processor_name) + if "Vision-Early" in processor_name: + from .multimodal_processors.llamavl import LlamaVLImageProcessor + return LlamaVLImageProcessor(processor_name, *args, **kwargs) processor = AutoImageProcessor.from_pretrained( processor_name, *args, diff --git a/vllm/transformers_utils/multimodal_processors/llamavl.py b/vllm/transformers_utils/multimodal_processors/llamavl.py new file mode 100644 index 000000000000..997c0e755c44 --- /dev/null +++ b/vllm/transformers_utils/multimodal_processors/llamavl.py @@ -0,0 +1,365 @@ +from transformers.image_processing_base import BatchFeature +from transformers.image_processing_utils import BaseImageProcessor + +import torch +from typing import List, Tuple +from PIL import Image +from functools import partial + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import math +from functools import reduce +from typing import Any, Tuple + +import numpy as np +import torch +import torchvision.transforms as tv +from PIL import Image +from torchvision.transforms import functional as F + +IMAGE_RES = 224 + +class TorchBF16Context: + + def __enter__(self): + self.prev_dtype = torch.get_default_dtype() + if torch.cuda.is_bf16_supported(): + torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) + else: + torch.set_default_tensor_type(torch.cuda.HalfTensor) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.prev_dtype == torch.float32: + torch.set_default_tensor_type(torch.FloatTensor) + else: + raise ValueError("Unsupported dtype") + +class VariableSizeImageTransform(object): + """ + The variable size image transform will resize the image dynamically + based on the image aspect ratio and the number of image chunks we allow. + The algorithm will not upsample low-res images to fit a certain aspect + ratio, because that leads to a significant degradation in image quality. + For example, if an input image is of size 300x800, and we want to allow + a maximum of 16 image chunks, it will find the closest aspect ratio that + is allowed within 16 image chunks, i.e., 2:5 = 2 horizontal patches and + 5 vertical patches, giving a total of 10 chunks. + The image will then be resized to products of the base size (default is + 224px because MetaCLIP takes that), so in this case it will be resized to + 2*224:5*224 = 448:1120, where we maintain the original aspect ratio and + pad with the mean value for the rest. This approach minimizes the amount + of padding required for any arbitrary resolution. + The final output will therefore be of shape (11, 3, 224, 224), where 10 + patches are coming from the resizing and chunking, and the first patch + is a downsampled version of the image that preserves aspect ratios. + """ + + def __init__(self, size: int = IMAGE_RES) -> None: + self.size = size + self.to_tensor = tv.ToTensor() + self._mean = (0.48145466, 0.4578275, 0.40821073) + self._std = (0.26862954, 0.26130258, 0.27577711) + self.normalize = tv.Normalize( + mean=self._mean, + std=self._std, + inplace=True, + ) + + @staticmethod + def _factors(n: int): + """Return all factors of a number.""" + return set( + reduce( + list.__add__, + ([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0), + ) + ) + + def _find_supported_aspect_ratios(self, num_chunks: int): + """ + This function computes all the allowed aspect ratios for a fixed + number of input chunks. + For example, with `num_chunks=5`, it will return: + { + 0.2: [(1, 5)], + 5.0: [(5, 1)], + 0.25: [(1, 4)], + 1.0: [(2, 2), (1, 1)], + 4.0: [(4, 1)], + 0.3333333333333333: [(1, 3)], + 3.0: [(3, 1)], + 0.5: [(1, 2)], + 2.0: [(2, 1)] + } + """ + asp_dict = {} + for chunk_size in range(num_chunks, 0, -1): + _factors = sorted(VariableSizeImageTransform._factors(chunk_size)) + _asp_ratios = [(x, chunk_size // x) for x in _factors] + for ratio in _asp_ratios: + k = ratio[0] / ratio[1] + if k not in asp_dict: + asp_dict[k] = [ratio] + else: + asp_dict[k].append(ratio) + return asp_dict + + def _find_closest_aspect_ratio( + self, num_chunks: int, img_width: int, img_height: int + ) -> Tuple: + """ + Given an image width, height and target number of chunks + this function will find the closest supported aspect ratio. + """ + tgt_ar = img_width / img_height + asp_dict = self._find_supported_aspect_ratios(num_chunks) + cl_d, cl_p = 1e23, None + if tgt_ar >= 1: + cl_p = min( + [k for k in asp_dict.keys() if k <= tgt_ar], + key=lambda x: abs(x - tgt_ar), + ) + v = asp_dict[cl_p] + # select width + widths = [(idx, self.size * vv[0]) for idx, vv in enumerate(v)] + tgt_idx = max(widths, key=lambda x: x[1])[0] + else: + cl_p = min( + [k for k in asp_dict.keys() if k > tgt_ar], + key=lambda x: abs(1 / x - 1 / tgt_ar), + ) + v = asp_dict[cl_p] + # select height + heights = [(idx, self.size * vv[1]) for idx, vv in enumerate(v)] + tgt_idx = max(heights, key=lambda x: x[1])[0] + out = v[tgt_idx] + return out + + def _resize( + self, image: Image.Image, target_width: int, target_height: int + ) -> Image.Image: + # Resize longer edge to given size. + w, h = image.size + scale = w / h + + if scale > 1.0: + # width > height + new_w = target_width + new_h = math.floor(new_w / scale) + else: + # height >= width + new_h = target_height + new_w = math.floor(new_h * scale) + + image = F.resize(image, (new_h, new_w)) + return image + + def _resize_max_side_to_size( + self, + image: Image.Image, + ) -> Image.Image: + # Resize longer edge to given size. + w, h = image.size + scale = w / h + + if scale > 1.0: + # width > height + new_w = max(self.size, w) + new_h = math.floor(new_w / scale) + else: + # height >= width + new_h = max(self.size, h) + new_w = math.floor(new_h * scale) + + image = F.resize(image, (new_h, new_w)) + return image + + def _pad(self, image: Image.Image, new_width: int, new_height: int) -> Image.Image: + mean_per_channel = tuple( + np.clip(np.array(image).mean(axis=(0, 1)), 0, 255).astype(np.uint8) + ) + new_im = Image.new(mode="RGB", size=(new_height, new_width), color=(0, 0, 0)) # type: ignore + new_im.paste(image) + return new_im + + def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor: + # Split image into number of required tiles (width x height) + num_channels, height, width = image.size() + image = image.view(num_channels, nch, height // nch, ncw, width // ncw) + # Permute dimensions to reorder the axes + image = image.permute(1, 3, 0, 2, 4).contiguous() + # Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2) + image = image.view(ncw * nch, num_channels, height // nch, width // ncw) + return image + + def _fit_image_to_canvas( + self, num_chunks: int, img_width: int, img_height: int + ) -> Any: + """ + Given an image width, height and target number of chunks this function will see if the image + can be fit into any of the canvases that can be build from arranging the tiles in a grid. + If the image can be fit onto several canvases, it will return the canvas where the shorter edge + of the image will be largest. + """ + # Initialize the optimal canvas to None. If no canvas is found where image fits, function returns None. + optimal_canvas = None + optimal_image_width_height = None + + scale = img_width / img_height + + # Gather all potential supported image resolutions and iterate through them to find best match + potential_arrangements = [ + item + for sublist in self._find_supported_aspect_ratios(num_chunks).values() + for item in sublist + ] + current_gap = 1e23 + for n_w, n_h in potential_arrangements: + # Compute the canvas size + canvas_width, canvas_height = n_w * self.size, n_h * self.size + + # Check if image can fit into the canvas without downsampling + if canvas_width >= img_width and canvas_height >= img_height: + # If we did not find a good canvas yet, we will use the current one + if optimal_canvas is None: + # Set optimal canvas and determine the actual image height and width in the canvas with aspect ratio preserving resampling + optimal_canvas = (n_w, n_h) + optimal_image_width_height = (n_w * self.size, n_h * self.size) + else: + # Find closest fit based on gap + image_width_height = (n_w * self.size, n_h * self.size) + gap = abs(img_width - image_width_height[0]) + abs( + img_height - image_width_height[1] + ) + if gap < current_gap: + # If the gap is smaller than the previous one, we will update our optimal canvas and image width height + optimal_canvas = (n_w, n_h) + optimal_image_width_height = image_width_height + current_gap = gap + return optimal_canvas + + def __call__(self, image: Image.Image, max_num_chunks: int) -> Tuple[Any, Any]: + assert max_num_chunks > 0 + assert isinstance(image, Image.Image), type(image) + + import numpy as np + w, h = image.size + # Check if the image can be fit to the canvas without downsampling + ar = self._fit_image_to_canvas( + num_chunks=max_num_chunks, img_width=w, img_height=h + ) + if ar is None: + # If we did not find a canvas, we have to find the closest aspect ratio and downsample the image + ar = self._find_closest_aspect_ratio( + num_chunks=max_num_chunks, img_width=w, img_height=h + ) + image = self._resize(image, ar[0] * self.size, ar[1] * self.size) + else: + image = self._resize_max_side_to_size(image) + + arr = np.array(image) + + image = self._pad(image, ar[1] * self.size, ar[0] * self.size) + image = self.to_tensor(image) + image = self.normalize(image) + image = self._split(image, ar[0], ar[1]) # type: ignore + return image, ar + + +def _stack_images( + images: List[List[Image.Image]], + max_num_chunks: int, + image_res: int, + max_num_images: int, +) -> Tuple[torch.Tensor, List[int]]: + """ + Takes a list of list of images and stacks them into a tensor. + This function is needed since images can be of completely + different resolutions and aspect ratios. + """ + out_images, out_num_chunks = [], [] + for imgs_sample in images: + out_images_i = torch.zeros( + max_num_images, + max_num_chunks, + 3, + image_res, + image_res, + ) + _num_chunks = [] + for j, chunks_image in enumerate(imgs_sample): + out_images_i[j, : chunks_image.shape[0]] = chunks_image + _num_chunks.append(chunks_image.shape[0]) + out_images.append(out_images_i) + out_num_chunks.append(_num_chunks) + return torch.stack(out_images), out_num_chunks + +class LlamaVLImageProcessor(BaseImageProcessor): + def __init__(self, name, *args, **kwargs): + if "11B" in name: + self.vision_chunk_size = 448 + elif "90B" in name: + self.vision_chunk_size = 560 + else: + raise ValueError(f"Unknown model name: {name}") + self.vision_max_num_chunks = 4 + self.max_num_chunks = self.vision_max_num_chunks + self.image_transform = partial( + VariableSizeImageTransform(size=self.vision_chunk_size), + max_num_chunks=self.vision_max_num_chunks, + ) + def preprocess(self, images, **kwargs) -> BatchFeature: + with TorchBF16Context(): + print("[warning] mask unsupported due to lack of example, replace with official release in the future") + # assert len(images) == len( + # batch_masks + # ), "Images and masks must have the same length" + + # preprocess is called for each batch now, so add batch dimension here. + images = [images] + + max_num_images = max(len(x) for x in images) + bsz = len(images) + + if max_num_images == 0: + data = {'pixel_values': None} + else: + images_and_aspect_ratios = [ + [self.image_transform(im) for im in row] for row in images + ] + transformed_images = [ + [x[0] for x in row] for row in images_and_aspect_ratios + ] + + aspect_ratios = torch.ones(bsz, max_num_images, 2, dtype=torch.int64) + for i, row in enumerate(images_and_aspect_ratios): + if len(row) > 0: + aspect_ratios[i, : len(row)] = torch.stack( + [torch.tensor(x[1]) for x in row] + ) + data = { + 'pixel_values': transformed_images, + 'aspect_ratios': aspect_ratios, + } + # print("transformed_images", transformed_images) + # for i, row in enumerate(transformed_images): + # for j, x in enumerate(row): + # print(i, j, x.shape) + # print("aspect_ratios", aspect_ratios) + # stacked_images, num_chunks = _stack_images( + # transformed_images, + # self.vision_max_num_chunks, + # self.vision_chunk_size, + # max_num_images, + # ) + # print("stacked_images", stacked_images.shape) + # print("num_chunks", num_chunks) + return BatchFeature(data, tensor_type=None) \ No newline at end of file diff --git a/vllm/transformers_utils/tokenizers/llamavl.py b/vllm/transformers_utils/tokenizers/llamavl.py index c81e1adf598f..7e9d2d26101e 100644 --- a/vllm/transformers_utils/tokenizers/llamavl.py +++ b/vllm/transformers_utils/tokenizers/llamavl.py @@ -153,7 +153,6 @@ def encode( t.insert(0, self.bos_id) if eos: t.append(self.eos_id) - print("t:", t) return t def decode(self, t: Sequence[int]) -> str: From de8bbaddad06764cd2b2a9b43fac8aa6b2f092f1 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 4 Sep 2024 15:28:26 -0700 Subject: [PATCH 09/75] can run vision encoder now --- vllm/model_executor/models/llamavl.py | 207 +++++++++--------- .../multimodal_processors/llamavl.py | 5 +- 2 files changed, 109 insertions(+), 103 deletions(-) diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py index 4181d9b1e429..10be1fb9f619 100644 --- a/vllm/model_executor/models/llamavl.py +++ b/vllm/model_executor/models/llamavl.py @@ -22,7 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.logger import init_logger -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal from .llama import LlamaAttention from vllm.model_executor.layers.layernorm import RMSNorm @@ -38,6 +38,18 @@ MP_SCALE = 8 IMAGE_RES = 224 +class LlamaImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size, max_num_image, max_num_chunk, num_channels, height, width)`""" + aspect_ratios: torch.Tensor + """Shape: `(batch_size, max_num_image, 2)`""" + num_chunks: List[List[int]] + +# TODO: support LlamaImageEmbeddingInputs + +LlavaImageInputs = LlamaImagePixelInputs + def get_max_llama_image_tokens(ctx: InputContext) -> int: logger.warning("need further check on max llama image tokens") print("ctx", type(ctx)) @@ -456,7 +468,7 @@ def forward( out = F.linear(attn_output, self.wo.weight) # out = reduce_from_tensor_model_parallel_region(out) out = out / self.qkvo_replication - out += self.wo.bias + # out += self.wo.bias return out @@ -1785,115 +1797,108 @@ def __init__(self, config, def setup_cache(self, max_batch_size: int, dtype: torch.dtype): self.text_model.setup_cache(max_batch_size, dtype) - def compute_vision_tokens_masks( - self, - batch_images: List[List[Image.Image]], - batch_masks: List[List[List[int]]], - total_len: int, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - skip_vision_encoder = False - - assert len(batch_images) == len( - batch_masks - ), "Images and masks must have the same length" - - max_num_images = max(len(x) for x in batch_images) - bsz = len(batch_images) - - if max_num_images == 0: - num_chunks = [[self.max_num_chunks] for _ in batch_images] - skip_vision_encoder = True - else: - images_and_aspect_ratios = [ - [self.image_transform(im) for im in row] for row in batch_images - ] - transformed_images = [ - [x[0] for x in row] for row in images_and_aspect_ratios - ] - - aspect_ratios = torch.ones(bsz, max_num_images, 2, dtype=torch.int64) - for i, row in enumerate(images_and_aspect_ratios): - if len(row) > 0: - aspect_ratios[i, : len(row)] = torch.stack( - [torch.tensor(x[1]) for x in row] - ) - stacked_images, num_chunks = _stack_images( - transformed_images, - max_num_chunks=self.max_num_chunks, - image_res=self.params.vision_chunk_size, - max_num_images=max_num_images, - ) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + state_dict = {name: weight for name, weight in weights} + self.load_state_dict(state_dict, strict=False) - if skip_vision_encoder: - vision_tokens = torch.zeros( - ( - bsz, - max_num_images, - self.max_num_chunks, - int( - (self.vision_model.image_res / self.vision_model.patch_size) - ** 2 - + 1 - ), - self.model_dim, - ), + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[LlavaImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + aspect_ratios = kwargs.pop("aspect_ratios", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None and image_embeds is not None: + raise ValueError("Both pixel values and image embeds are provided.") + + if pixel_values is not None: + print("pixel shapes", [x.shape for x in pixel_values]) + # tensor with the same shape will be batched together by MultiModalInputs.batch, so pixel_values here can be: + # - List[List[torch.Tensor]]: with shape (num_chunks, 3, image_res, image_res) + # - List[torch.Tensor]: with shape (num_image_in_batch, num_chunks, 3, image_res, image_res) + # - torch.Tensor: with shape (bs, num_image_in_batch, num_chunks, 3, image_res, image_res) + # the best choice is to remove MultiModalInputs.batch + pixel_values_unpacked = [] + for b in range(len(pixel_values)): + pixel_values_unpacked_b = [] + for i in range(len(pixel_values[b])): + pixel_values_unpacked_b.append(pixel_values[b][i]) + pixel_values_unpacked.append(pixel_values_unpacked_b) + + max_num_images = max([len(x) for x in pixel_values_unpacked]) + max_num_chunks = max(max([len(x) for x in y]) for y in pixel_values_unpacked) + bsz = len(pixel_values_unpacked) + out_num_chunks = [] + out_images = torch.zeros( + bsz, + max_num_images, + max_num_chunks, + 3, + self.image_res, + self.image_res ) - else: - vision_tokens = self.vision_model(stacked_images, aspect_ratios) - - vision_tokens = vision_tokens.to("cuda") - - bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape) - xattn_caches = torch.stack( - [ - layer.compute_xattn_kv_cache( - vision_tokens.view(bsz, -1, image_token_dim) - ) - for layer in self.text_model.cross_attention_layers - ] - ) - padded_masks = _pad_masks( - batch_masks, - num_chunks, - total_len, - self.max_num_chunks, - ) - - cross_attention_masks, full_text_row_masked_out_mask = ( - self.text_model._get_xattn_mask( - num_tokens=total_len, - text_device="cuda", - text_dtype=next(self.text_model.parameters()).dtype, - vision_tokens=vision_tokens, - cross_attention_masks=padded_masks, + out_ar = torch.ones(bsz, max_num_images, 2, dtype=torch.int64) + for b in range(len(pixel_values_unpacked)): + _num_chunks = [] + for i in range(len(pixel_values_unpacked[b])): + img = pixel_values_unpacked[b][i] + out_images[b, i, :img.shape[0]] = img + out_ar[b, i] = aspect_ratios[b][i] + _num_chunks.append(img.shape[0]) + out_num_chunks.append(_num_chunks) + + return LlamaImagePixelInputs( + type="pixel_values", + data=out_images, + num_chunks=out_num_chunks, + aspect_ratios=out_ar, ) - ) - return (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask) + if image_embeds is not None: + raise NotImplementedError - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - state_dict = {name: weight for name, weight in weights} - self.load_state_dict(state_dict, strict=False) + raise AssertionError("This line should be unreachable.") def forward( self, - position_ids: torch.Tensor, - tokens: torch.Tensor, - cross_attention_masks: torch.Tensor, - full_text_row_masked_out_mask: torch.Tensor, - xattn_caches: torch.Tensor, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, ) -> torch.Tensor: - h = self.text_model.get_partially_trainable_embedding(tokens[:, position_ids]) - logits = self.text_model.forward( - position_ids=position_ids, - h=h, - xattn_mask=cross_attention_masks[:, :, position_ids], - full_text_row_masked_out_mask=full_text_row_masked_out_mask[ - :, :, position_ids - ], - xattn_caches=xattn_caches, - ) + print("input_ids", input_ids) + print("positions", positions) + print("kv_caches", len(kv_caches), kv_caches[0].shape) + print("attn_metadata", attn_metadata) + print("intermediate_tensors", intermediate_tensors) + print("kwargs", kwargs) + image = self._parse_and_validate_image_input(**kwargs) + if image is None: + raise ValueError("No images provided") + else: + # llama's reference implementation runs the vision model on CPU + cuda_images = image['data'].cuda() + cuda_aspect_ratios = image['aspect_ratios'].cuda() + vision_tokens = self.vision_model(cuda_images, cuda_aspect_ratios) + print("vision_tokens", vision_tokens.shape, vision_tokens) + # pixel_values = kwargs.pop("pixel_values", None) + # if pixel_values is not None: + + # h = self.text_model.get_partially_trainable_embedding(tokens[:, position_ids]) + # logits = self.text_model.forward( + # position_ids=position_ids, + # h=h, + # xattn_mask=cross_attention_masks[:, :, position_ids], + # full_text_row_masked_out_mask=full_text_row_masked_out_mask[ + # :, :, position_ids + # ], + # xattn_caches=xattn_caches, + # ) return logits diff --git a/vllm/transformers_utils/multimodal_processors/llamavl.py b/vllm/transformers_utils/multimodal_processors/llamavl.py index 997c0e755c44..e9e024b37af2 100644 --- a/vllm/transformers_utils/multimodal_processors/llamavl.py +++ b/vllm/transformers_utils/multimodal_processors/llamavl.py @@ -345,9 +345,10 @@ def preprocess(self, images, **kwargs) -> BatchFeature: aspect_ratios[i, : len(row)] = torch.stack( [torch.tensor(x[1]) for x in row] ) + assert bsz == 1, "the below code is not for batched images" data = { - 'pixel_values': transformed_images, - 'aspect_ratios': aspect_ratios, + 'pixel_values': transformed_images[0], + 'aspect_ratios': aspect_ratios[0], } # print("transformed_images", transformed_images) # for i, row in enumerate(transformed_images): From 30239ad126ef4cf8b552f10e66c658cb3e138930 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 5 Sep 2024 10:52:26 -0700 Subject: [PATCH 10/75] run prefill self attention --- vllm/model_executor/models/llamavl.py | 310 ++++++++++++------ vllm/transformers_utils/configs/llamavl.py | 4 +- vllm/transformers_utils/tokenizers/llamavl.py | 5 +- 3 files changed, 218 insertions(+), 101 deletions(-) diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py index 10be1fb9f619..c17c0332f71a 100644 --- a/vllm/model_executor/models/llamavl.py +++ b/vllm/model_executor/models/llamavl.py @@ -12,11 +12,13 @@ import torchvision.transforms as tv from PIL import Image -from vllm.attention import AttentionMetadata +from vllm.attention import Attention, AttentionMetadata +from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding @@ -26,10 +28,10 @@ from .interfaces import SupportsMultiModal from .llama import LlamaAttention from vllm.model_executor.layers.layernorm import RMSNorm -# from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, -# QKVParallelLinear, -# RowParallelLinear, -# ColumnParallelLinear) +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, + ColumnParallelLinear) from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -824,9 +826,41 @@ def load_hook( fc2_weight = state_dict.pop(prefix + "mlp.fc2_weight") state_dict[prefix + "w2.weight"] = fc2_weight +class LlamaVLAttention(LlamaAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._register_load_state_dict_pre_hook(self.load_hook) + + + def load_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + local_metadata: Dict[str, Any], + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ) -> None: + # print("state_dict", state_dict.keys()) + # print("params", [x[0] for x in self.named_parameters()]) + if prefix + "wqkv.weight" in state_dict: + state_dict[prefix + "qkv_proj.weight"] = state_dict.pop(prefix + "wqkv.weight") + if prefix + "wo.weight" in state_dict: + state_dict[prefix + "o_proj.weight"] = state_dict.pop(prefix + "wo.weight") + # raise NotImplementedError + # if prefix + "feed_forward.mlp.layer_norm_weight" in state_dict: + # state_dict[prefix + "ffn_norm.weight"] = state_dict.pop( + # prefix + "feed_forward.mlp.layer_norm_weight" + # ) + # if prefix + "attention.wqkv.layer_norm_weight" in state_dict: + # state_dict[prefix + "attention_norm.weight"] = state_dict.pop( + # prefix + "attention.wqkv.layer_norm_weight" + # ) + class TransformerBlock(nn.Module): - def __init__(self, layer_id: int, args): + def __init__(self, layer_id: int, args, cache_config: Optional[CacheConfig] = None): """ Initialize a TransformerBlock. Args: @@ -846,8 +880,21 @@ def __init__(self, layer_id: int, args): self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.dim // args.n_heads - # self.attention = Attention(args) - logger.warning("skip attention") + # TODO: remove "use_scaled_rope" from args + self.attention = LlamaVLAttention( + config=args, + hidden_size=args.dim, + num_heads=self.n_heads, + num_kv_heads=args.n_kv_heads, + rope_theta=args.rope_theta, + rope_scaling=args.rope_scaling, + max_position_embeddings=512, + quant_config=None, + bias=False, + cache_config=cache_config, + prefix=f"tb.{layer_id}.self_attn", + ) + # logger.warning("skip attention") self.feed_forward = FeedForward( dim=args.dim, hidden_dim=4 * args.dim, @@ -878,15 +925,12 @@ def load_hook( prefix + "attention.wqkv.layer_norm_weight" ) - def setup_cache(self, max_batch_size: int, dtype: torch.dtype): - self.attention.setup_cache(max_batch_size, dtype) - def forward( self, x: torch.Tensor, - freqs_cis: torch.Tensor, - mask: torch.Tensor, - position_ids: torch.LongTensor, + positions: torch.LongTensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: """ Perform a forward pass through the TransformerBlock. @@ -898,15 +942,14 @@ def forward( Returns: torch.Tensor: Output tensor after applying attention and feedforward layers. """ - # h = self.attention.forward( - # x=self.attention_norm(x), - # freqs_cis=freqs_cis, - # mask=mask, - # position_ids=position_ids, - # ) - # h = h + x - h = x - logger.warning("skip attention") + # TODO: need to compute qkv and then do attention + h = self.attention.forward( + positions=positions, + hidden_states=self.attention_norm(x), + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + h = h + x out = h + self.feed_forward.forward(self.ffn_norm(h)) return out @@ -1182,7 +1225,7 @@ def __init__( ) self.gate_ffwd = torch.nn.Parameter(torch.zeros(1)) - logger.warning("todo hook") + logger.warning("todo put hook in correct place") self._register_load_state_dict_pre_hook(self.load_hook) self.no_ffn = no_ffn @@ -1318,7 +1361,7 @@ def forward( class CrossAttentionTransformerText(torch.nn.Module): INFERENCE_IMAGE_TOKEN_ID = 128010 - def __init__(self, args) -> None: + def __init__(self, args, cache_config:Optional[CacheConfig]) -> None: super().__init__() self.model_parallel_size = get_tensor_model_parallel_world_size() assert args.vocab_size > 0 @@ -1364,7 +1407,7 @@ def __init__(self, args) -> None: self.cross_attention_layers = torch.nn.ModuleList() for i in range(args.n_layers): layer_id = i - block = TransformerBlock(args=args, layer_id=layer_id) + block = TransformerBlock(args=args, layer_id=layer_id, cache_config=cache_config) self.layers.append(block) if layer_id in self.fusion_schedule: xa_layer_id = self.fusion_schedule.index(layer_id) + args.n_layers @@ -1446,21 +1489,24 @@ def load_hook( def forward( self, - position_ids: torch.LongTensor, + positions: torch.LongTensor, h: torch.Tensor, xattn_mask: torch.Tensor, full_text_row_masked_out_mask: torch.Tensor, xattn_caches: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, ): - assert self.cache_is_setup, "Please set up cache before calling forward" - mask = self.mask_cache.index_select(2, position_ids) - freqs_cis = self.freqs_cis.index_select(0, position_ids) + # assert self.cache_is_setup, "Please set up cache before calling forward" + # mask = self.mask_cache.index_select(2, positions) + # freqs_cis = self.freqs_cis.index_select(0, positions) for idx, ( layer, xattn_layer, xattn_layer_idx, ) in enumerate(self.text_and_xattn_layers): + print("running layer", type(layer), type(xattn_layer)) h = xattn_layer( x=h, xattn_mask=xattn_mask, @@ -1469,38 +1515,20 @@ def forward( ) h = layer( x=h, - mask=mask, - freqs_cis=freqs_cis, - position_ids=position_ids, + # mask=mask, + # freqs_cis=freqs_cis, + positions=positions, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, ) h = self.norm(h) + exit(1) output = F.linear(h, self.output.weight) # output = gather_from_tensor_model_parallel_region(output) return output.float() - def setup_cache(self, max_batch_size: int, dtype=torch.bfloat16): - # Set up the text kv caches - device = next(self.parameters()).device - ones = torch.ones( - (self.max_seq_len, self.max_seq_len), - dtype=torch.bool, - device=device, - ) - self.register_buffer( - "mask_cache", - torch.tril( - ones, - ) - .unsqueeze(0) - .unsqueeze(0), - persistent=False, - ) - for layer in self.layers: - layer.setup_cache(max_batch_size, dtype=dtype) - self.cache_is_setup = True - def _get_xattn_mask( self, num_tokens, @@ -1786,7 +1814,7 @@ def __init__(self, config, args = config self.model_dim = args.dim self.vision_model = CrossAttentionTransformerVision(args) - self.text_model = CrossAttentionTransformerText(args) + self.text_model = CrossAttentionTransformerText(args, cache_config=cache_config) self.image_res = args.vision_chunk_size self.max_num_chunks = args.vision_max_num_chunks self.image_transform = partial( @@ -1794,13 +1822,10 @@ def __init__(self, config, max_num_chunks=args.vision_max_num_chunks, ) - def setup_cache(self, max_batch_size: int, dtype: torch.dtype): - self.text_model.setup_cache(max_batch_size, dtype) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): state_dict = {name: weight for name, weight in weights} - self.load_state_dict(state_dict, strict=False) + state_dict.pop('text_model.rope.freqs') + self.load_state_dict(state_dict, strict=True) def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaImageInputs]: @@ -1871,24 +1896,79 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> torch.Tensor: - print("input_ids", input_ids) - print("positions", positions) - print("kv_caches", len(kv_caches), kv_caches[0].shape) + # print("input_ids", input_ids) + # print("positions", positions) + # print("kv_caches", len(kv_caches), kv_caches[0].shape) + # print("attn_metadata", attn_metadata) + # print("intermediate_tensors", intermediate_tensors) + # print("kwargs", kwargs) print("attn_metadata", attn_metadata) - print("intermediate_tensors", intermediate_tensors) - print("kwargs", kwargs) - image = self._parse_and_validate_image_input(**kwargs) - if image is None: + image_inputs = self._parse_and_validate_image_input(**kwargs) + if image_inputs is None: raise ValueError("No images provided") else: # llama's reference implementation runs the vision model on CPU - cuda_images = image['data'].cuda() - cuda_aspect_ratios = image['aspect_ratios'].cuda() + cuda_images = image_inputs['data'].cuda() + cuda_aspect_ratios = image_inputs['aspect_ratios'].cuda() vision_tokens = self.vision_model(cuda_images, cuda_aspect_ratios) - print("vision_tokens", vision_tokens.shape, vision_tokens) + batch_masks = [] + # TODO: get the sequence of each query without hack? 1) better attn metadata 2) better input processor to create vision mask during preprocess + # assert isinstance(attn_metadata, PagedAttentionMetadata) + start_pos = 0 + for seq_len in attn_metadata.seq_lens_tensor: + end_pos = start_pos + seq_len + batch_masks.append(create_vision_mask(input_ids[start_pos:end_pos])) + start_pos = end_pos + print("batch_masks", batch_masks) + # print("vision_tokens", vision_tokens.shape, vision_tokens) + + bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape) + xattn_caches = torch.stack( + [ + layer.compute_xattn_kv_cache( + vision_tokens.view(bsz, -1, image_token_dim) + ) + for layer in self.text_model.cross_attention_layers + ] + ) + # TODO: remove this hardcode + total_len = 512 + padded_masks = _pad_masks( + batch_masks, + image_inputs['num_chunks'], + total_len, + self.max_num_chunks, + ) + + cross_attention_masks, full_text_row_masked_out_mask = ( + self.text_model._get_xattn_mask( + num_tokens=total_len, + text_device="cuda", + text_dtype=next(self.text_model.parameters()).dtype, + vision_tokens=vision_tokens, + cross_attention_masks=padded_masks, + ) + ) + print("cross_attention_masks", cross_attention_masks.shape, cross_attention_masks) + print("full_text_row_masked_out_mask", full_text_row_masked_out_mask.shape, full_text_row_masked_out_mask) + + h = self.text_model.get_partially_trainable_embedding(input_ids) + print("h", h.shape, h) + print("positions", positions.shape, positions) + logits = self.text_model.forward( + positions=positions, + h=h, + xattn_mask=cross_attention_masks, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + xattn_caches=xattn_caches, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + print("prefill logits", logits.shape, logits) + exit(1) + # pixel_values = kwargs.pop("pixel_values", None) # if pixel_values is not None: - # h = self.text_model.get_partially_trainable_embedding(tokens[:, position_ids]) # logits = self.text_model.forward( # position_ids=position_ids, @@ -1902,33 +1982,40 @@ def forward( return logits -def _stack_images( - images: List[List[Image.Image]], - max_num_chunks: int, - image_res: int, - max_num_images: int, -) -> Tuple[torch.Tensor, List[int]]: - """ - Takes a list of list of images and stacks them into a tensor. - This function is needed since images can be of completely - different resolutions and aspect ratios. - """ - out_images, out_num_chunks = [], [] - for imgs_sample in images: - out_images_i = torch.zeros( - max_num_images, - max_num_chunks, - 3, - image_res, - image_res, - ) - _num_chunks = [] - for j, chunks_image in enumerate(imgs_sample): - out_images_i[j, : chunks_image.shape[0]] = chunks_image - _num_chunks.append(chunks_image.shape[0]) - out_images.append(out_images_i) - out_num_chunks.append(_num_chunks) - return torch.stack(out_images), out_num_chunks +def create_vision_mask( + tokens: List[int], + vision_token: int=128256, +) -> List[List[int]]: + # import pdb; pdb.set_trace() +# (Pdb) p tokens +# [128011, 128011, 128000, 644, 264, 11914, 11, 1521, 1403, 5448, 6308] + print("tokens", tokens) + vision_token_locations = [ + i for i, token in enumerate(tokens) if token == vision_token + ] + if len(vision_token_locations) == 0: + return [] + + if len(vision_token_locations) == 1: + # only one image present, unmask until end of sequence + return [[vision_token_locations[0], -1]] + vision_masks = [ + [loc1, loc2] + for loc1, loc2 in zip(vision_token_locations[:-1], vision_token_locations[1:]) + ] + # last image will attend to all subsequent text + vision_masks.append([vision_token_locations[-1], len(tokens)]) + + # if there are two or more consecutive vision tokens, + # they should all attend to all subsequent + # text present + last_mask_end = vision_masks[-1][1] + for vision_mask in vision_masks[::-1]: + if vision_mask[0] == vision_mask[1] - 1: + vision_mask[1] = last_mask_end + last_mask_end = vision_mask[1] + return vision_masks + def _pad_masks( @@ -1961,3 +2048,30 @@ def _pad_masks( return out_masks + + +# def _encode_content( +# self, content: InterleavedTextAttachment, bos: bool = False +# ) -> Tuple[List[int], List[PIL_Image.Image]]: +# tokens = [] +# images = [] + +# added_bos = False + +# def _process(c): +# nonlocal added_bos + +# if isinstance(c, str): +# tokens.extend( +# self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False) +# ) +# added_bos = True +# elif isinstance(c, ImageMedia): +# tokens.append(self.vision_token) +# images.append(c.image) + +# if isinstance(content, str): +# _process(content) +# elif isinstance(content, list): +# for c in content: +# _process(c) diff --git a/vllm/transformers_utils/configs/llamavl.py b/vllm/transformers_utils/configs/llamavl.py index c5de04a35b29..d186ddac2e32 100644 --- a/vllm/transformers_utils/configs/llamavl.py +++ b/vllm/transformers_utils/configs/llamavl.py @@ -1,5 +1,5 @@ from transformers import PretrainedConfig -from typing import Optional +from typing import Optional, Any class LlamaVLConfig(PretrainedConfig): @@ -29,6 +29,8 @@ class LlamaVLConfig(PretrainedConfig): torch_dtype: str = "bfloat16" + rope_scaling: Optional[dict[str, Any]] = None + attribute_map = { "num_hidden_layers": "n_layers", "hidden_size": "dim", diff --git a/vllm/transformers_utils/tokenizers/llamavl.py b/vllm/transformers_utils/tokenizers/llamavl.py index 7e9d2d26101e..5d4a2541edbf 100644 --- a/vllm/transformers_utils/tokenizers/llamavl.py +++ b/vllm/transformers_utils/tokenizers/llamavl.py @@ -78,6 +78,7 @@ def __init__(self, model_path: str): self.special_tokens = { token: num_base_tokens + i for i, token in enumerate(special_tokens) } + self.special_tokens["<|image|>"] = 128256 self.model = tiktoken.Encoding( name=Path(model_path).name, pat_str=self.pat_str, @@ -129,7 +130,7 @@ def encode( - Setting `allowed_special` to "all" will treat all text corresponding to special tokens to be encoded as special tokens. """ - bos = True + bos = False eos = False assert type(s) is str @@ -145,7 +146,7 @@ def encode( t.extend( self.model.encode( substr, - allowed_special=set(), + allowed_special="all", disallowed_special=set(), ) ) From 6972cbf9cb252ab45cf660c000713eda7a741eed Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 6 Sep 2024 16:49:09 -0700 Subject: [PATCH 11/75] run prefill crossattention --- vllm/model_executor/models/llamavl.py | 117 +++++++++++++------------- 1 file changed, 59 insertions(+), 58 deletions(-) diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py index c17c0332f71a..8c8692ec5c78 100644 --- a/vllm/model_executor/models/llamavl.py +++ b/vllm/model_executor/models/llamavl.py @@ -52,10 +52,13 @@ class LlamaImagePixelInputs(TypedDict): LlavaImageInputs = LlamaImagePixelInputs + +def input_processor_for_llamavl(ctx: InputContext, llm_inputs: LLMInputs): + # TODO: move image preprocessing here + return llm_inputs + def get_max_llama_image_tokens(ctx: InputContext) -> int: logger.warning("need further check on max llama image tokens") - print("ctx", type(ctx)) - print(ctx) return 1025 * 2 @@ -842,22 +845,10 @@ def load_hook( unexpected_keys: List[str], error_msgs: List[str], ) -> None: - # print("state_dict", state_dict.keys()) - # print("params", [x[0] for x in self.named_parameters()]) if prefix + "wqkv.weight" in state_dict: state_dict[prefix + "qkv_proj.weight"] = state_dict.pop(prefix + "wqkv.weight") if prefix + "wo.weight" in state_dict: state_dict[prefix + "o_proj.weight"] = state_dict.pop(prefix + "wo.weight") - # raise NotImplementedError - # if prefix + "feed_forward.mlp.layer_norm_weight" in state_dict: - # state_dict[prefix + "ffn_norm.weight"] = state_dict.pop( - # prefix + "feed_forward.mlp.layer_norm_weight" - # ) - # if prefix + "attention.wqkv.layer_norm_weight" in state_dict: - # state_dict[prefix + "attention_norm.weight"] = state_dict.pop( - # prefix + "attention.wqkv.layer_norm_weight" - # ) - class TransformerBlock(nn.Module): def __init__(self, layer_id: int, args, cache_config: Optional[CacheConfig] = None): @@ -1157,30 +1148,60 @@ def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: def compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: return self._compute_xattn_kv_cache(xattn_tokens) + def unpack_value(self, x: torch.Tensor, positions: torch.LongTensor, attn_metadata: AttentionMetadata, xattn_mask: torch.Tensor, full_text_row_masked_out_mask: torch.Tensor): + x_unpacked = torch.zeros(attn_metadata.num_prefills, attn_metadata.max_query_len, x.shape[-1], device=x.device, dtype=x.dtype) + positions_unpacked = torch.zeros(attn_metadata.num_prefills, attn_metadata.max_query_len, device=positions.device, dtype=positions.dtype) + xattn_mask = xattn_mask[:, :, :attn_metadata.max_query_len] + # position + start_pos = 0 + for i, seq_len in enumerate(attn_metadata.seq_lens_tensor): + end_pos = start_pos + seq_len + x_unpacked[i, :seq_len] = x[start_pos:end_pos] + positions_unpacked[i, :seq_len] = positions[start_pos:end_pos] + xattn_mask[i, 0, seq_len:] = torch.finfo(xattn_mask.dtype).min + start_pos = end_pos + # xattn_mask = xattn_mask[:, :, :attn_metadata.max_query_len] + # full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, :attn_metadata.max_query_len] + return x_unpacked, positions_unpacked, xattn_mask, full_text_row_masked_out_mask + + def pack_value(self, x:torch.Tensor, attn_metadata: AttentionMetadata): + x_packed = torch.zeros(attn_metadata.num_prefill_tokens, x.shape[-1], device=x.device, dtype=x.dtype) + start_pos = 0 + for i, seq_len in enumerate(attn_metadata.seq_lens_tensor): + end_pos = start_pos + seq_len + x_packed[start_pos:end_pos] = x[i, :seq_len] + start_pos = end_pos + return x_packed + def forward( self, x: torch.Tensor, xattn_mask: torch.Tensor, full_text_row_masked_out_mask: torch.Tensor, xattn_cache: torch.Tensor, + positions: torch.LongTensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: xq = F.linear(x, self.wq.weight) - bsz, seqlen, _ = x.shape + n_token = xq.shape[0] + xq, positions, xattn_mask, full_text_row_masked_out_mask = self.unpack_value(xq, positions, attn_metadata, xattn_mask, full_text_row_masked_out_mask) + bsz, seqlen, _ = xq.shape xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) xq = self.q_norm(xq) - xq = xq.transpose(1, 2) + xq = xq.transpose(1, 2) # [bs, n_head, seq_len, head_dim] xk, xv = xattn_cache output = F.scaled_dot_product_attention( xq, xk, xv, attn_mask=xattn_mask, dropout_p=0.0 ) - output = output * full_text_row_masked_out_mask - output = output.transpose(1, 2).contiguous().reshape(bsz, seqlen, -1) + + output = output.transpose(1, 2).reshape(bsz, seqlen, -1).contiguous() + output = self.pack_value(output, attn_metadata) + output = output * full_text_row_masked_out_mask out = F.linear(output, self.wo.weight) - # out = reduce_from_tensor_model_parallel_region(out) return out @@ -1269,18 +1290,22 @@ def forward( self, x: torch.Tensor, xattn_mask: torch.Tensor, - full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], + full_text_row_masked_out_mask: torch.Tensor, xattn_cache: torch.Tensor, + positions: torch.LongTensor, + attn_metadata: AttentionMetadata, ) -> torch.Tensor: _attn_out = self.attention( x=self.attention_norm(x), xattn_mask=xattn_mask, - xattn_cache=xattn_cache, full_text_row_masked_out_mask=full_text_row_masked_out_mask, + xattn_cache=xattn_cache, + positions=positions, + attn_metadata=attn_metadata ) h = x + self.gate_attn.tanh() * _attn_out _ffn = self.feed_forward(self.ffn_norm(h)) - _ffn = full_text_row_masked_out_mask[:, 0] * _ffn # type: ignore + _ffn = full_text_row_masked_out_mask * _ffn # type: ignore h = h + self.gate_ffwd.tanh() * _ffn * float(not self.no_ffn) return h @@ -1506,12 +1531,13 @@ def forward( xattn_layer, xattn_layer_idx, ) in enumerate(self.text_and_xattn_layers): - print("running layer", type(layer), type(xattn_layer)) h = xattn_layer( x=h, xattn_mask=xattn_mask, - xattn_cache=xattn_caches[xattn_layer_idx], full_text_row_masked_out_mask=full_text_row_masked_out_mask, + xattn_cache=xattn_caches[xattn_layer_idx], + positions=positions, + attn_metadata=attn_metadata, ) h = layer( x=h, @@ -1523,8 +1549,6 @@ def forward( ) h = self.norm(h) - exit(1) - output = F.linear(h, self.output.weight) # output = gather_from_tensor_model_parallel_region(output) return output.float() @@ -1795,6 +1819,7 @@ def __call__(self, image: Image.Image, max_num_chunks: int) -> Tuple[Any, Any]: @MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llama_image_tokens) +@INPUT_REGISTRY.register_input_processor(input_processor_for_llamavl) class LlamaVLForCausalLM(nn.Module, SupportsMultiModal): def __init__(self, config, multimodal_config: MultiModalConfig, @@ -1840,7 +1865,6 @@ def _parse_and_validate_image_input( raise ValueError("Both pixel values and image embeds are provided.") if pixel_values is not None: - print("pixel shapes", [x.shape for x in pixel_values]) # tensor with the same shape will be batched together by MultiModalInputs.batch, so pixel_values here can be: # - List[List[torch.Tensor]]: with shape (num_chunks, 3, image_res, image_res) # - List[torch.Tensor]: with shape (num_image_in_batch, num_chunks, 3, image_res, image_res) @@ -1896,13 +1920,6 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> torch.Tensor: - # print("input_ids", input_ids) - # print("positions", positions) - # print("kv_caches", len(kv_caches), kv_caches[0].shape) - # print("attn_metadata", attn_metadata) - # print("intermediate_tensors", intermediate_tensors) - # print("kwargs", kwargs) - print("attn_metadata", attn_metadata) image_inputs = self._parse_and_validate_image_input(**kwargs) if image_inputs is None: raise ValueError("No images provided") @@ -1919,9 +1936,7 @@ def forward( end_pos = start_pos + seq_len batch_masks.append(create_vision_mask(input_ids[start_pos:end_pos])) start_pos = end_pos - print("batch_masks", batch_masks) - # print("vision_tokens", vision_tokens.shape, vision_tokens) - + bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape) xattn_caches = torch.stack( [ @@ -1949,12 +1964,15 @@ def forward( cross_attention_masks=padded_masks, ) ) - print("cross_attention_masks", cross_attention_masks.shape, cross_attention_masks) - print("full_text_row_masked_out_mask", full_text_row_masked_out_mask.shape, full_text_row_masked_out_mask) + full_text_row_masked_out_mask_plain = torch.zeros(attn_metadata.num_prefill_tokens, 1, dtype=full_text_row_masked_out_mask.dtype) + start_pos = 0 + for i, seq_len in enumerate(attn_metadata.seq_lens_tensor): + end_pos = start_pos + seq_len + full_text_row_masked_out_mask_plain[start_pos:end_pos, 0] = full_text_row_masked_out_mask[i, 0, :seq_len, 0] + start_pos = end_pos + full_text_row_masked_out_mask = full_text_row_masked_out_mask_plain.cuda() h = self.text_model.get_partially_trainable_embedding(input_ids) - print("h", h.shape, h) - print("positions", positions.shape, positions) logits = self.text_model.forward( positions=positions, h=h, @@ -1964,24 +1982,8 @@ def forward( kv_caches=kv_caches, attn_metadata=attn_metadata, ) - print("prefill logits", logits.shape, logits) - exit(1) - - # pixel_values = kwargs.pop("pixel_values", None) - # if pixel_values is not None: - # h = self.text_model.get_partially_trainable_embedding(tokens[:, position_ids]) - # logits = self.text_model.forward( - # position_ids=position_ids, - # h=h, - # xattn_mask=cross_attention_masks[:, :, position_ids], - # full_text_row_masked_out_mask=full_text_row_masked_out_mask[ - # :, :, position_ids - # ], - # xattn_caches=xattn_caches, - # ) return logits - def create_vision_mask( tokens: List[int], vision_token: int=128256, @@ -1989,7 +1991,6 @@ def create_vision_mask( # import pdb; pdb.set_trace() # (Pdb) p tokens # [128011, 128011, 128000, 644, 264, 11914, 11, 1521, 1403, 5448, 6308] - print("tokens", tokens) vision_token_locations = [ i for i, token in enumerate(tokens) if token == vision_token ] From 4e1344bf92615d5d087825407c57c405fe8d2e78 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 6 Sep 2024 22:31:36 -0700 Subject: [PATCH 12/75] can generate the first token :) --- vllm/model_executor/models/llamavl.py | 68 ++++++++++--------- vllm/transformers_utils/tokenizers/llamavl.py | 51 ++++++++++---- 2 files changed, 72 insertions(+), 47 deletions(-) diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py index 8c8692ec5c78..92ee43a49a9d 100644 --- a/vllm/model_executor/models/llamavl.py +++ b/vllm/model_executor/models/llamavl.py @@ -21,8 +21,11 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal @@ -1406,7 +1409,7 @@ def __init__(self, args, cache_config:Optional[CacheConfig]) -> None: self.norm = RMSNorm(args.dim, eps=args.norm_eps) # output layer - self.output = nn.Linear(args.dim, args.vocab_size, bias=False) + # self.output = nn.Linear(args.dim, args.vocab_size, bias=False) # self.output = ColumnParallelLinear( # args.dim, args.vocab_size, bias=False, init_method=lambda x: x # ) @@ -1549,9 +1552,10 @@ def forward( ) h = self.norm(h) - output = F.linear(h, self.output.weight) + return h + # output = F.linear(h, self.output.weight) # output = gather_from_tensor_model_parallel_region(output) - return output.float() + # return output.float() def _get_xattn_mask( self, @@ -1846,10 +1850,20 @@ def __init__(self, config, VariableSizeImageTransform(size=args.vision_chunk_size), max_num_chunks=args.vision_max_num_chunks, ) + self.lm_head = ParallelLMHead( + args.vocab_size, + args.dim, + org_num_embeddings=args.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(args.dim, args.vocab_size) + self.sampler = Sampler() def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): state_dict = {name: weight for name, weight in weights} state_dict.pop('text_model.rope.freqs') + state_dict['lm_head.weight'] = state_dict.pop('text_model.output.weight') self.load_state_dict(state_dict, strict=True) def _parse_and_validate_image_input( @@ -1911,6 +1925,24 @@ def _parse_and_validate_image_input( raise AssertionError("This line should be unreachable.") + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def forward( self, input_ids: torch.Tensor, @@ -2048,31 +2080,3 @@ def _pad_masks( ].fill_(0.0) return out_masks - - - -# def _encode_content( -# self, content: InterleavedTextAttachment, bos: bool = False -# ) -> Tuple[List[int], List[PIL_Image.Image]]: -# tokens = [] -# images = [] - -# added_bos = False - -# def _process(c): -# nonlocal added_bos - -# if isinstance(c, str): -# tokens.extend( -# self.tokenizer.encode(c, bos=False if added_bos else bos, eos=False) -# ) -# added_bos = True -# elif isinstance(c, ImageMedia): -# tokens.append(self.vision_token) -# images.append(c.image) - -# if isinstance(content, str): -# _process(content) -# elif isinstance(content, list): -# for c in content: -# _process(c) diff --git a/vllm/transformers_utils/tokenizers/llamavl.py b/vllm/transformers_utils/tokenizers/llamavl.py index 5d4a2541edbf..532fa269bb8c 100644 --- a/vllm/transformers_utils/tokenizers/llamavl.py +++ b/vllm/transformers_utils/tokenizers/llamavl.py @@ -12,10 +12,11 @@ Optional, Sequence, Union, + Any, ) from transformers.tokenization_utils import PreTrainedTokenizer - +# TODO: now use tiktoken, but I believe it should be replaced with tokenizer in huggingface import tiktoken from tiktoken.load import load_tiktoken_bpe @@ -105,6 +106,16 @@ def __init__(self, model_path: str): print("need to replace tokenizer with official release") print("warning: recheck add bos and add eos of encode function") + # the following attributes are set to fit VLLM's design (copied from MistralTokenizer) + self.is_fast = False + self.chat_template = True + self.all_special_ids: List[Any] = [] + self.all_special_tokens: List[Any] = [] + self.all_special_tokens_extended: List[Any] = [] + + def get_added_vocab(self) -> List[str]: + return [] + def encode( self, s: str, @@ -155,19 +166,26 @@ def encode( if eos: t.append(self.eos_id) return t - - def decode(self, t: Sequence[int]) -> str: - """ - Decodes a list of token IDs into a string. - - Args: - t (List[int]): The list of token IDs to be decoded. - - Returns: - str: The decoded string. - """ - # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. - return self.model.decode(cast(List[int], t)) + + def convert_ids_to_tokens( + self, + tokens: List[int], + skip_special_tokens: Optional[bool] = True) -> List[str]: + # TODO(Patrick) - potentially allow special tokens to not be skipped + assert ( + skip_special_tokens + ), "Skipping special tokens is not supported for Mistral tokenizers." + + # assert isinstance(self.tokenizer, + # (Tekkenizer, SentencePieceTokenizer)), type( + # self.tokenizer) + + # TODO: handle skip_special_tokens + # TODO: self.model.decode returns a string, but the interface expects a list of words + return [self.model.decode(tokens)] + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + return "".join(tokens) @staticmethod def _split_whitespaces_or_nonwhitespaces( @@ -197,4 +215,7 @@ def _split_whitespaces_or_nonwhitespaces( @classmethod def from_pretrained(cls, model_path: str) -> "LlamaVLTokenizer": - return cls(os.path.join(model_path, "tokenizer.model")) \ No newline at end of file + return cls(os.path.join(model_path, "tokenizer.model")) + + def __len__(self): + return self.n_words \ No newline at end of file From f3d869d99f08736ac3d6e816cb1621d239fee86a Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 6 Sep 2024 22:47:20 -0700 Subject: [PATCH 13/75] can perform offline e2e run without decode crossattn, but wrong answer --- vllm/model_executor/models/llamavl.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py index 92ee43a49a9d..090435949b0f 100644 --- a/vllm/model_executor/models/llamavl.py +++ b/vllm/model_executor/models/llamavl.py @@ -1534,14 +1534,16 @@ def forward( xattn_layer, xattn_layer_idx, ) in enumerate(self.text_and_xattn_layers): - h = xattn_layer( - x=h, - xattn_mask=xattn_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - xattn_cache=xattn_caches[xattn_layer_idx], - positions=positions, - attn_metadata=attn_metadata, - ) + # TODO: a hack now. skip decode cross attention + if xattn_mask is not None: + h = xattn_layer( + x=h, + xattn_mask=xattn_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + xattn_cache=xattn_caches[xattn_layer_idx], + positions=positions, + attn_metadata=attn_metadata, + ) h = layer( x=h, # mask=mask, @@ -1954,7 +1956,9 @@ def forward( ) -> torch.Tensor: image_inputs = self._parse_and_validate_image_input(**kwargs) if image_inputs is None: - raise ValueError("No images provided") + cross_attention_masks = None + full_text_row_masked_out_mask = None + xattn_caches = None else: # llama's reference implementation runs the vision model on CPU cuda_images = image_inputs['data'].cuda() From 6f26a3be1cd1384359e8d023da7edf3047a36234 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 7 Sep 2024 17:36:35 -0700 Subject: [PATCH 14/75] pass mm data in encoder-decoder --- vllm/engine/llm_engine.py | 12 ++++++++---- vllm/inputs/data.py | 6 ++++++ vllm/model_executor/models/llamavl.py | 25 +++++++++++++++++++++++-- vllm/multimodal/image.py | 5 +++++ vllm/sequence.py | 5 ++++- vllm/worker/enc_dec_model_runner.py | 6 +++++- vllm/worker/utils.py | 4 ---- 7 files changed, 51 insertions(+), 12 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1eab83f3b988..1ba678f72668 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -885,8 +885,8 @@ def _build_enc_dec_llm_inputs( encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps - if encoder_mm_data is not None or decoder_mm_data is not None: - raise ValueError("Multi-modal encoder-decoder models are " + if decoder_mm_data is not None: + raise ValueError("Multi-modality decoder inputs of encoder-decoder models are " "not supported yet") decoder_prompt_ids = ( @@ -895,8 +895,10 @@ def _build_enc_dec_llm_inputs( return EncoderDecoderLLMInputs( prompt_token_ids=decoder_prompt_ids, prompt=decoder_prompt, + multi_modal_data=decoder_mm_data, encoder_prompt_token_ids=encoder_prompt_ids, encoder_prompt=encoder_prompt, + encoder_multi_modal_data = encoder_mm_data, ) def _process_encoder_decoder_prompt( @@ -1098,7 +1100,6 @@ def add_request( "not enabled!") if arrival_time is None: arrival_time = time.time() - processed_inputs = self.process_model_inputs( inputs, request_id=request_id, @@ -2011,7 +2012,10 @@ def is_embedding_model(self): def _validate_model_inputs(self, inputs: Union[LLMInputs, EncoderDecoderLLMInputs]): - if self.is_encoder_decoder_model(): + if self.model_config.is_multimodal_model: + # For encoder-decoder multimodal models, the max_prompt_len restricts the decoder prompt length + prompt_ids = inputs.get("prompt_token_ids") + elif self.is_encoder_decoder_model(): prompt_ids = inputs.get("encoder_prompt_token_ids") else: prompt_ids = inputs.get("prompt_token_ids") diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 75ab0c770155..a71e9a7b5db6 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -139,6 +139,12 @@ class EncoderDecoderLLMInputs(LLMInputs): available. """ + encoder_multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] + """ + Optional multi-modal data to pass to the encoder model, + if the model supports it. + """ + _T1 = TypeVar("_T1", bound=SingletonPromptInputs, diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py index 090435949b0f..32284916dc3c 100644 --- a/vllm/model_executor/models/llamavl.py +++ b/vllm/model_executor/models/llamavl.py @@ -38,6 +38,7 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.transformers_utils.multimodal_processors.llamavl import LlamaVLImageProcessor logger = init_logger(__name__) MP_SCALE = 8 @@ -54,10 +55,28 @@ class LlamaImagePixelInputs(TypedDict): # TODO: support LlamaImageEmbeddingInputs LlavaImageInputs = LlamaImagePixelInputs - +image_processor = None def input_processor_for_llamavl(ctx: InputContext, llm_inputs: LLMInputs): - # TODO: move image preprocessing here + multi_modal_data = llm_inputs.get("encoder_multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + global image_processor + if image_processor is None: + image_processor = LlamaVLImageProcessor(ctx.model_config.model) + + processed_image = image_processor(multi_modal_data["image"]) + llm_inputs["encoder_multi_modal_data"]["image"] = processed_image + + num_chunks = int(processed_image["aspect_ratios"].sum()) + assert ctx.model_config.hf_config.vision_chunk_size % 14 == 0, "chunk size should be multiple of 14" + token_per_chunk = (ctx.model_config.hf_config.vision_chunk_size // 14) ** 2 + 1 + num_tokens = num_chunks * token_per_chunk + llm_inputs["encoder_prompt"] = "<|image|>" * num_tokens + llm_inputs["encoder_prompt_token_ids"] = [128256] * num_tokens + + assert "decoder_multi_modal_data" not in llm_inputs, "multi-modal data should be put in encoder message of LLaMA Vision" + return llm_inputs def get_max_llama_image_tokens(ctx: InputContext) -> int: @@ -1954,6 +1973,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> torch.Tensor: + # import pdb; pdb.set_trace() image_inputs = self._parse_and_validate_image_input(**kwargs) if image_inputs is None: cross_attention_masks = None @@ -2018,6 +2038,7 @@ def forward( kv_caches=kv_caches, attn_metadata=attn_metadata, ) + # import pdb; pdb.set_trace() return logits def create_vision_mask( diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 6cdde949bc2b..5b10e199c562 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.transformers_utils.image_processor import get_image_processor from vllm.utils import is_list_of +from transformers.image_processing_base import BatchFeature from .base import MultiModalData, MultiModalInputs, MultiModalPlugin @@ -33,6 +34,10 @@ def _default_input_mapper( data: MultiModalData[object], ) -> MultiModalInputs: model_config = ctx.model_config + + # Processed by input processor + if isinstance(data, BatchFeature): + return MultiModalInputs(data.data) # PIL image if isinstance(data, Image.Image) or is_list_of(data, Image.Image): diff --git a/vllm/sequence.py b/vllm/sequence.py index 87b3d21fa7ae..2a6f6952cf84 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -440,7 +440,10 @@ def prompt_token_ids(self) -> List[int]: @property def multi_modal_data(self) -> "MultiModalDataDict": - return self.inputs.get("multi_modal_data") or {} + if self.inputs.get("multi_modal_data") and self.inputs.get( + "encoder_multi_modal_data"): + raise ValueError("Multi-modal data in both encoder and decoder is not supported yet.") + return self.inputs.get("multi_modal_data") or self.inputs.get("encoder_multi_modal_data") or {} @property def lora_int_id(self) -> int: diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index d6189d82d51d..11fb2376d496 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -17,7 +17,7 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry, MultiModalInputs from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceGroupMetadata) @@ -184,6 +184,8 @@ def execute_model( "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } if self.has_seqlen_agnostic else {} + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -192,6 +194,8 @@ def execute_model( kv_caches=kv_caches, attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), **seqlen_agnostic_kwargs) logits = self.model.compute_logits(hidden_or_intermediate_states, diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py index d73023e8e172..287ed14f68f9 100644 --- a/vllm/worker/utils.py +++ b/vllm/worker/utils.py @@ -39,10 +39,6 @@ def assert_enc_dec_mr_supported_scenario( raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP']) - if enc_dec_mr.model_config.is_multimodal_model: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM']) - if enc_dec_mr.scheduler_config.num_lookahead_slots > 0: raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC']) From fa0912ebf061659470bcc141604645edaac300d1 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 10 Sep 2024 22:41:29 -0700 Subject: [PATCH 15/75] prefill result matches now. Model is speaking human words. --- tests/models/test_llamavl.py | 85 ++++++++++++++++++++++++++++++++++++ vllm/engine/llm_engine.py | 6 ++- 2 files changed, 89 insertions(+), 2 deletions(-) create mode 100644 tests/models/test_llamavl.py diff --git a/tests/models/test_llamavl.py b/tests/models/test_llamavl.py new file mode 100644 index 000000000000..dc2f28563d45 --- /dev/null +++ b/tests/models/test_llamavl.py @@ -0,0 +1,85 @@ +from transformers import AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.assets.image import ImageAsset +from vllm.utils import FlexibleArgumentParser + +from functools import partial +from PIL import Image as PIL_Image + + +if __name__ == "__main__": + model_size_map = { + "llama-3.2-11b": "11B", + "llama-3.2-90b": "90B", + } + parser = FlexibleArgumentParser( + description='Demo on using vLLM for offline inference with ' + 'vision language models') + parser.add_argument('--model-type', + '-m', + type=str, + default="llama-3.2-11b", + choices=model_size_map.keys(), + help='Huggingface "model_type".') + + args = parser.parse_args() + + size = model_size_map[args.model_type] + checkpoint_dir = "/data/zhang-chen/llama/checkpoints" # update checkpoint path here + llm = LLM(model=f"{checkpoint_dir}/Meta-Llama-3.2-{size}-Vision-Early/", enforce_eager=True, limit_mm_per_prompt={"image": 2}, + # load_format="dummy" + ) + + resource_dir = "/home/eecs/zhang-chen/venv/vllm-multimodal/lib/python3.10/site-packages/llama_models/scripts/resources/" + # Input image and question + with open(f"{resource_dir}/dog.jpg", "rb") as f: + image = PIL_Image.open(f).convert("RGB") + with open(f"{resource_dir}/pasta.jpeg", "rb") as f: + image2 = PIL_Image.open(f).convert("RGB") + + # inputs = [ + # { + # "prompt": "<|image|><|image|><|begin_of_text|>In a sentence, these two images paint", + # "multi_modal_data": { + # "image": [image, image2], + # } + # }, + # { + # "prompt": "<|image|><|begin_of_text|>If I had to write a haiku for this one", + # "multi_modal_data": { + # "image": [image] + # } + # }, + # # { + # # "prompt": "he color of the sky is blue but sometimes it can also be", + # # }, + # ] + inputs = [ + { + "encoder_prompt": { + "prompt": "", + "multi_modal_data": { + "image": [image, image2], + } + }, + "decoder_prompt": "<|image|><|image|><|begin_of_text|>In a sentence, these two images paint", + }, + { + "encoder_prompt":{ + "prompt": "", + "multi_modal_data": { + "image": [image] + } + }, + "decoder_prompt": "<|image|><|begin_of_text|>If I had to write a haiku for this one", + }, + # { + # "prompt": "he color of the sky is blue but sometimes it can also be", + # }, + ] + outputs = llm.generate(inputs, SamplingParams(temperature=0.6, top_p=0.9, max_tokens=512)) + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1ba678f72668..994d18478a78 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -718,6 +718,7 @@ def stop_remote_worker_execution_loop(self) -> None: def _prepare_decoder_input_ids_for_generation( self, decoder_input_ids: Optional[List[int]], + force_bos: bool = False, ) -> List[int]: """ Prepares `decoder_input_ids` for generation with encoder-decoder models. @@ -747,7 +748,7 @@ def _prepare_decoder_input_ids_for_generation( # use decoder_start_token_id as decoder_input_ids decoder_input_ids = self._get_default_enc_dec_decoder_prompt() - if (len(decoder_input_ids) == 0 + if force_bos and (len(decoder_input_ids) == 0 or decoder_input_ids[0] != decoder_start_token_id): decoder_input_ids = [decoder_start_token_id] + decoder_input_ids @@ -889,8 +890,9 @@ def _build_enc_dec_llm_inputs( raise ValueError("Multi-modality decoder inputs of encoder-decoder models are " "not supported yet") + # For Multi-Modal models, the start token can be the image token decoder_prompt_ids = ( - self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) + self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids, force_bos=(encoder_mm_data is None and decoder_mm_data is None))) return EncoderDecoderLLMInputs( prompt_token_ids=decoder_prompt_ids, From 46634ffcf59f619c12b5d64151cb7f996a01d883 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 11 Sep 2024 21:25:42 -0700 Subject: [PATCH 16/75] generate correct result for single image --- vllm/model_executor/models/llamavl.py | 197 +++++++++++--------------- 1 file changed, 85 insertions(+), 112 deletions(-) diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py index 32284916dc3c..7f2026120be3 100644 --- a/vllm/model_executor/models/llamavl.py +++ b/vllm/model_executor/models/llamavl.py @@ -12,7 +12,7 @@ import torchvision.transforms as tv from PIL import Image -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs @@ -393,13 +393,11 @@ def __init__( self.dropout = dropout def forward(self, x): - hidden = self.c_fc(x) - # hidden = F.linear(x, self.c_fc.weight, self.c_fc.bias) + hidden = F.linear(x, self.c_fc.weight, self.c_fc.bias) hidden = self.non_linearity(hidden) - hidden = self.c_proj(hidden) - # hidden = F.linear(hidden, self.c_proj.weight) + hidden = F.linear(hidden, self.c_proj.weight) # hidden = reduce_from_tensor_model_parallel_region(hidden) - # hidden += self.c_proj.bias + hidden += self.c_proj.bias return hidden @@ -854,6 +852,15 @@ def load_hook( class LlamaVLAttention(LlamaAttention): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + rope_scaling = kwargs.get("rope_scaling", None) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False, # force to use neox=False + ) self._register_load_state_dict_pre_hook(self.load_hook) @@ -867,8 +874,6 @@ def load_hook( unexpected_keys: List[str], error_msgs: List[str], ) -> None: - if prefix + "wqkv.weight" in state_dict: - state_dict[prefix + "qkv_proj.weight"] = state_dict.pop(prefix + "wqkv.weight") if prefix + "wo.weight" in state_dict: state_dict[prefix + "o_proj.weight"] = state_dict.pop(prefix + "wo.weight") @@ -1112,6 +1117,13 @@ def __init__( self.head_dim, eps=norm_eps, ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention( + self.n_heads, + self.head_dim, + self.scaling, + self.n_kv_heads, + ) # cross-attention heads are model parallel similar to # self-attention, and we also use the identical KV head @@ -1157,12 +1169,6 @@ def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: xk = xk.view(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim) xv = xv.view(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim) - xk, xv = [tensor.transpose(1, 2) for tensor in (xk, xv)] - - # repeat k/v heads if n_kv_heads < n_heads - xk = xk.repeat_interleave(self.n_rep, dim=1) - xv = xv.repeat_interleave(self.n_rep, dim=1) - xk = self.k_norm(xk) return torch.stack([xk, xv]) @@ -1170,59 +1176,25 @@ def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: def compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: return self._compute_xattn_kv_cache(xattn_tokens) - def unpack_value(self, x: torch.Tensor, positions: torch.LongTensor, attn_metadata: AttentionMetadata, xattn_mask: torch.Tensor, full_text_row_masked_out_mask: torch.Tensor): - x_unpacked = torch.zeros(attn_metadata.num_prefills, attn_metadata.max_query_len, x.shape[-1], device=x.device, dtype=x.dtype) - positions_unpacked = torch.zeros(attn_metadata.num_prefills, attn_metadata.max_query_len, device=positions.device, dtype=positions.dtype) - xattn_mask = xattn_mask[:, :, :attn_metadata.max_query_len] - # position - start_pos = 0 - for i, seq_len in enumerate(attn_metadata.seq_lens_tensor): - end_pos = start_pos + seq_len - x_unpacked[i, :seq_len] = x[start_pos:end_pos] - positions_unpacked[i, :seq_len] = positions[start_pos:end_pos] - xattn_mask[i, 0, seq_len:] = torch.finfo(xattn_mask.dtype).min - start_pos = end_pos - # xattn_mask = xattn_mask[:, :, :attn_metadata.max_query_len] - # full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, :attn_metadata.max_query_len] - return x_unpacked, positions_unpacked, xattn_mask, full_text_row_masked_out_mask - - def pack_value(self, x:torch.Tensor, attn_metadata: AttentionMetadata): - x_packed = torch.zeros(attn_metadata.num_prefill_tokens, x.shape[-1], device=x.device, dtype=x.dtype) - start_pos = 0 - for i, seq_len in enumerate(attn_metadata.seq_lens_tensor): - end_pos = start_pos + seq_len - x_packed[start_pos:end_pos] = x[i, :seq_len] - start_pos = end_pos - return x_packed - def forward( self, x: torch.Tensor, - xattn_mask: torch.Tensor, - full_text_row_masked_out_mask: torch.Tensor, + # xattn_mask: torch.Tensor, + # full_text_row_masked_out_mask: torch.Tensor, xattn_cache: torch.Tensor, - positions: torch.LongTensor, + kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: xq = F.linear(x, self.wq.weight) - n_token = xq.shape[0] - xq, positions, xattn_mask, full_text_row_masked_out_mask = self.unpack_value(xq, positions, attn_metadata, xattn_mask, full_text_row_masked_out_mask) - bsz, seqlen, _ = xq.shape - xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) + xq = xq.view(-1, self.n_local_heads, self.head_dim) xq = self.q_norm(xq) - xq = xq.transpose(1, 2) # [bs, n_head, seq_len, head_dim] - - xk, xv = xattn_cache - output = F.scaled_dot_product_attention( - xq, xk, xv, attn_mask=xattn_mask, dropout_p=0.0 - ) - - output = output.transpose(1, 2).reshape(bsz, seqlen, -1).contiguous() - output = self.pack_value(output, attn_metadata) - - output = output * full_text_row_masked_out_mask + if xattn_cache is not None: + xk, xv = xattn_cache + else: + xk, xv = None, None + output = self.attn(xq, xk, xv, kv_cache, attn_metadata, attn_type=AttentionType.ENCODER_DECODER) out = F.linear(output, self.wo.weight) return out @@ -1311,23 +1283,23 @@ def compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: def forward( self, x: torch.Tensor, - xattn_mask: torch.Tensor, - full_text_row_masked_out_mask: torch.Tensor, + # xattn_mask: torch.Tensor, + # full_text_row_masked_out_mask: torch.Tensor, xattn_cache: torch.Tensor, - positions: torch.LongTensor, + kv_cache: torch.LongTensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: _attn_out = self.attention( x=self.attention_norm(x), - xattn_mask=xattn_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, + # xattn_mask=xattn_mask, + # full_text_row_masked_out_mask=full_text_row_masked_out_mask, xattn_cache=xattn_cache, - positions=positions, + kv_cache=kv_cache, attn_metadata=attn_metadata ) h = x + self.gate_attn.tanh() * _attn_out _ffn = self.feed_forward(self.ffn_norm(h)) - _ffn = full_text_row_masked_out_mask * _ffn # type: ignore + # _ffn = full_text_row_masked_out_mask * _ffn # type: ignore h = h + self.gate_ffwd.tanh() * _ffn * float(not self.no_ffn) return h @@ -1399,7 +1371,6 @@ def forward( vision_tokens = self.vision_encoder( images.to(dtype=torch.bfloat16), aspect_ratios ) - vision_tokens = F.linear(vision_tokens, self.vision_projection.weight) # vision_tokens = gather_from_tensor_model_parallel_region(vision_tokens) return vision_tokens @@ -1538,8 +1509,8 @@ def forward( self, positions: torch.LongTensor, h: torch.Tensor, - xattn_mask: torch.Tensor, - full_text_row_masked_out_mask: torch.Tensor, + # xattn_mask: torch.Tensor, + # full_text_row_masked_out_mask: torch.Tensor, xattn_caches: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, @@ -1553,16 +1524,14 @@ def forward( xattn_layer, xattn_layer_idx, ) in enumerate(self.text_and_xattn_layers): - # TODO: a hack now. skip decode cross attention - if xattn_mask is not None: - h = xattn_layer( - x=h, - xattn_mask=xattn_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - xattn_cache=xattn_caches[xattn_layer_idx], - positions=positions, - attn_metadata=attn_metadata, - ) + h = xattn_layer( + x=h, + # xattn_mask=xattn_mask, + # full_text_row_masked_out_mask=full_text_row_masked_out_mask, + xattn_cache=xattn_caches[xattn_layer_idx] if xattn_caches is not None else None, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + ) h = layer( x=h, # mask=mask, @@ -1574,9 +1543,6 @@ def forward( h = self.norm(h) return h - # output = F.linear(h, self.output.weight) - # output = gather_from_tensor_model_parallel_region(output) - # return output.float() def _get_xattn_mask( self, @@ -1885,7 +1851,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): state_dict = {name: weight for name, weight in weights} state_dict.pop('text_model.rope.freqs') state_dict['lm_head.weight'] = state_dict.pop('text_model.output.weight') - self.load_state_dict(state_dict, strict=True) + param_dict = {k: v for k, v in self.named_parameters()} + for i in range(self.text_model.n_layers): + module = self.text_model.layers[i].attention.qkv_proj + param_name = f"text_model.layers.{i}.attention.qkv_proj.weight" + weight_name = f"text_model.layers.{i}.attention.wqkv.weight" + param = param_dict[param_name] + weight = state_dict.pop(weight_name) + module.weight_loader(param, weight) + self.load_state_dict(state_dict, strict=False) + def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaImageInputs]: @@ -1973,7 +1948,6 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> torch.Tensor: - # import pdb; pdb.set_trace() image_inputs = self._parse_and_validate_image_input(**kwargs) if image_inputs is None: cross_attention_masks = None @@ -2003,43 +1977,42 @@ def forward( ] ) # TODO: remove this hardcode - total_len = 512 - padded_masks = _pad_masks( - batch_masks, - image_inputs['num_chunks'], - total_len, - self.max_num_chunks, - ) - - cross_attention_masks, full_text_row_masked_out_mask = ( - self.text_model._get_xattn_mask( - num_tokens=total_len, - text_device="cuda", - text_dtype=next(self.text_model.parameters()).dtype, - vision_tokens=vision_tokens, - cross_attention_masks=padded_masks, - ) - ) - - full_text_row_masked_out_mask_plain = torch.zeros(attn_metadata.num_prefill_tokens, 1, dtype=full_text_row_masked_out_mask.dtype) - start_pos = 0 - for i, seq_len in enumerate(attn_metadata.seq_lens_tensor): - end_pos = start_pos + seq_len - full_text_row_masked_out_mask_plain[start_pos:end_pos, 0] = full_text_row_masked_out_mask[i, 0, :seq_len, 0] - start_pos = end_pos - full_text_row_masked_out_mask = full_text_row_masked_out_mask_plain.cuda() + # total_len = 512 + # padded_masks = _pad_masks( + # batch_masks, + # image_inputs['num_chunks'], + # total_len, + # self.max_num_chunks, + # ) + + # cross_attention_masks, full_text_row_masked_out_mask = ( + # self.text_model._get_xattn_mask( + # num_tokens=total_len, + # text_device="cuda", + # text_dtype=next(self.text_model.parameters()).dtype, + # vision_tokens=vision_tokens, + # cross_attention_masks=padded_masks, + # ) + # ) + + # full_text_row_masked_out_mask_plain = torch.zeros(attn_metadata.num_prefill_tokens, 1, dtype=full_text_row_masked_out_mask.dtype) + # start_pos = 0 + # for i, seq_len in enumerate(attn_metadata.seq_lens_tensor): + # end_pos = start_pos + seq_len + # full_text_row_masked_out_mask_plain[start_pos:end_pos, 0] = full_text_row_masked_out_mask[i, 0, :seq_len, 0] + # start_pos = end_pos + # full_text_row_masked_out_mask = full_text_row_masked_out_mask_plain.cuda() h = self.text_model.get_partially_trainable_embedding(input_ids) - logits = self.text_model.forward( + h = self.text_model.forward( positions=positions, h=h, - xattn_mask=cross_attention_masks, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, + # xattn_mask=cross_attention_masks, + # full_text_row_masked_out_mask=full_text_row_masked_out_mask, xattn_caches=xattn_caches, kv_caches=kv_caches, attn_metadata=attn_metadata, ) - # import pdb; pdb.set_trace() - return logits + return h def create_vision_mask( tokens: List[int], From 6b73f4d9fc1a05956f1a49f525bb5f1ab363ba99 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 12 Sep 2024 15:50:13 -0700 Subject: [PATCH 17/75] can support arbitary number of image, need better mask for image_cnt<>1 --- vllm/model_executor/models/llamavl.py | 59 ++++++++++++++----- .../multimodal_processors/llamavl.py | 4 +- 2 files changed, 45 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py index 7f2026120be3..3dc7ebdc30dd 100644 --- a/vllm/model_executor/models/llamavl.py +++ b/vllm/model_executor/models/llamavl.py @@ -40,6 +40,17 @@ from vllm.transformers_utils.multimodal_processors.llamavl import LlamaVLImageProcessor +step_name = "prefill" +pt_dir = "" + +def check(tensor, file_name): pass + # with open(f"{pt_dir}{file_name}", "rb") as f: + # data = torch.load(f) + + # tensor_flat = tensor.cpu().reshape(-1) + # data_flat = data.cpu().reshape(-1) + # print("check:", file_name, torch.allclose(tensor_flat, data_flat), torch.max(torch.abs(tensor_flat-data_flat)), tensor_flat.shape, data_flat.shape) + logger = init_logger(__name__) MP_SCALE = 8 IMAGE_RES = 224 @@ -79,10 +90,11 @@ def input_processor_for_llamavl(ctx: InputContext, llm_inputs: LLMInputs): return llm_inputs -def get_max_llama_image_tokens(ctx: InputContext) -> int: - logger.warning("need further check on max llama image tokens") - return 1025 * 2 +def get_max_llama_image_tokens(ctx: InputContext) -> int: + hf_config = ctx.model_config.hf_config + token_per_chunk = (hf_config.vision_chunk_size // 14) ** 2 + 1 + return hf_config.max_num_image * hf_config.vision_max_num_chunks * token_per_chunk def to_2tuple(x): if isinstance(x, collections.abc.Iterable): @@ -103,7 +115,7 @@ def resize_local_position_embedding(orig_pos_embed, grid_size): orig_pos_embed[:1], orig_pos_embed[1:], ) - logger.info( + logger.debug( f"resizing position embedding grid-size from {orig_grid_size} to {new_grid_size}" ) @@ -690,7 +702,7 @@ def load_hook( self.max_num_tiles, self.max_num_tiles, ) - logger.info( + logger.debug( f"Resized global positional embedding from {state_dict[prefix + 'gated_positional_embedding'].size()} to {global_pos_embed.size()}" ) state_dict[prefix + "gated_positional_embedding"] = global_pos_embed @@ -1164,10 +1176,10 @@ def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: xk = self.wk(xattn_tokens) xv = self.wv(xattn_tokens) - _, seqlen_y, _ = xk.shape + # _, seqlen_y, _ = xk.shape - xk = xk.view(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim) - xv = xv.view(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim) + xk = xk.view(-1, self.n_local_kv_heads, self.head_dim) + xv = xv.view(-1, self.n_local_kv_heads, self.head_dim) xk = self.k_norm(xk) @@ -1532,6 +1544,7 @@ def forward( kv_cache=kv_caches[idx], attn_metadata=attn_metadata, ) + # check(h, f"layer_{idx}_xh_{step_name}.pt") h = layer( x=h, # mask=mask, @@ -1540,8 +1553,10 @@ def forward( kv_cache=kv_caches[idx], attn_metadata=attn_metadata, ) + # check(h, f"layer_{idx}_h_{step_name}.pt") h = self.norm(h) + # check(h, f"finalh_{step_name}.pt") return h def _get_xattn_mask( @@ -1810,6 +1825,7 @@ def __call__(self, image: Image.Image, max_num_chunks: int) -> Tuple[Any, Any]: @MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llama_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llamavl) @INPUT_REGISTRY.register_input_processor(input_processor_for_llamavl) class LlamaVLForCausalLM(nn.Module, SupportsMultiModal): def __init__(self, config, @@ -1948,6 +1964,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> torch.Tensor: + # import pdb; pdb.set_trace() image_inputs = self._parse_and_validate_image_input(**kwargs) if image_inputs is None: cross_attention_masks = None @@ -1958,6 +1975,17 @@ def forward( cuda_images = image_inputs['data'].cuda() cuda_aspect_ratios = image_inputs['aspect_ratios'].cuda() vision_tokens = self.vision_model(cuda_images, cuda_aspect_ratios) + # import pdb; pdb.set_trace() + bsz, _, _, _, image_token_dim = tuple(vision_tokens.shape) + vision_tokens = vision_tokens.view(bsz, -1, image_token_dim) + + vision_tokens_flat = torch.zeros(sum(attn_metadata.encoder_seq_lens), image_token_dim, device=vision_tokens.device, dtype=vision_tokens.dtype) + start_pos = 0 + for seq_len, vision_token_in_batch in zip(attn_metadata.encoder_seq_lens, vision_tokens): + end_pos = start_pos + seq_len + vision_tokens_flat[start_pos:end_pos] = vision_token_in_batch[:seq_len] + start_pos = end_pos + batch_masks = [] # TODO: get the sequence of each query without hack? 1) better attn metadata 2) better input processor to create vision mask during preprocess # assert isinstance(attn_metadata, PagedAttentionMetadata) @@ -1967,12 +1995,9 @@ def forward( batch_masks.append(create_vision_mask(input_ids[start_pos:end_pos])) start_pos = end_pos - bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape) xattn_caches = torch.stack( [ - layer.compute_xattn_kv_cache( - vision_tokens.view(bsz, -1, image_token_dim) - ) + layer.compute_xattn_kv_cache(vision_tokens_flat) for layer in self.text_model.cross_attention_layers ] ) @@ -2002,7 +2027,12 @@ def forward( # full_text_row_masked_out_mask_plain[start_pos:end_pos, 0] = full_text_row_masked_out_mask[i, 0, :seq_len, 0] # start_pos = end_pos # full_text_row_masked_out_mask = full_text_row_masked_out_mask_plain.cuda() + # print("input_ids", input_ids) + if positions.numel() == 1: + global step_name + step_name = f"decode_{positions.item()}" h = self.text_model.get_partially_trainable_embedding(input_ids) + # check(h, f"h_{step_name}.pt") h = self.text_model.forward( positions=positions, h=h, @@ -2012,15 +2042,14 @@ def forward( kv_caches=kv_caches, attn_metadata=attn_metadata, ) + # if positions.numel() == 1 and positions.item() == 20: + # exit(0) return h def create_vision_mask( tokens: List[int], vision_token: int=128256, ) -> List[List[int]]: - # import pdb; pdb.set_trace() -# (Pdb) p tokens -# [128011, 128011, 128000, 644, 264, 11914, 11, 1521, 1403, 5448, 6308] vision_token_locations = [ i for i, token in enumerate(tokens) if token == vision_token ] diff --git a/vllm/transformers_utils/multimodal_processors/llamavl.py b/vllm/transformers_utils/multimodal_processors/llamavl.py index e9e024b37af2..8d3537b457f5 100644 --- a/vllm/transformers_utils/multimodal_processors/llamavl.py +++ b/vllm/transformers_utils/multimodal_processors/llamavl.py @@ -318,7 +318,6 @@ def __init__(self, name, *args, **kwargs): ) def preprocess(self, images, **kwargs) -> BatchFeature: with TorchBF16Context(): - print("[warning] mask unsupported due to lack of example, replace with official release in the future") # assert len(images) == len( # batch_masks # ), "Images and masks must have the same length" @@ -328,9 +327,8 @@ def preprocess(self, images, **kwargs) -> BatchFeature: max_num_images = max(len(x) for x in images) bsz = len(images) - if max_num_images == 0: - data = {'pixel_values': None} + data = None else: images_and_aspect_ratios = [ [self.image_transform(im) for im in row] for row in images From fb10a70eb0b7cb8dec7737d3b670a0fae5a0f8fb Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 12 Sep 2024 16:10:23 -0700 Subject: [PATCH 18/75] temp save for profile run --- tests/models/test_llamavl.py | 12 ++++++++---- vllm/model_executor/models/llamavl.py | 8 +++++++- vllm/worker/enc_dec_model_runner.py | 7 ++++--- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/tests/models/test_llamavl.py b/tests/models/test_llamavl.py index dc2f28563d45..26749d66ec65 100644 --- a/tests/models/test_llamavl.py +++ b/tests/models/test_llamavl.py @@ -74,12 +74,16 @@ }, "decoder_prompt": "<|image|><|begin_of_text|>If I had to write a haiku for this one", }, - # { - # "prompt": "he color of the sky is blue but sometimes it can also be", - # }, + { + "encoder_prompt":{ + "prompt": "", + }, + "decoder_prompt": "The color of the sky is blue but sometimes it can also be", + }, ] - outputs = llm.generate(inputs, SamplingParams(temperature=0.6, top_p=0.9, max_tokens=512)) + outputs = llm.generate(inputs, SamplingParams(temperature=0, top_p=0.9, max_tokens=512)) for o in outputs: generated_text = o.outputs[0].text print(generated_text) + print("==================================") diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py index 3dc7ebdc30dd..b8cac6bdca6d 100644 --- a/vllm/model_executor/models/llamavl.py +++ b/vllm/model_executor/models/llamavl.py @@ -90,6 +90,13 @@ def input_processor_for_llamavl(ctx: InputContext, llm_inputs: LLMInputs): return llm_inputs +def dummy_data_for_llamavl(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): + # seq_len: 16 + # mm_counts: {'image': 2, 'audio': 1} + hf_config = ctx.model_config.hf_config + token_per_chunk = (hf_config.vision_chunk_size // 14) ** 2 + 1 + num_tokens = hf_config.max_num_image * hf_config.vision_max_num_chunks * token_per_chunk + import pdb; pdb.set_trace() def get_max_llama_image_tokens(ctx: InputContext) -> int: hf_config = ctx.model_config.hf_config @@ -1964,7 +1971,6 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> torch.Tensor: - # import pdb; pdb.set_trace() image_inputs = self._parse_and_validate_image_input(**kwargs) if image_inputs is None: cross_attention_masks = None diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 11fb2376d496..7a1b5f45fb2d 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -281,8 +281,8 @@ def profile_run(self) -> None: max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( self.model_config) if max_mm_tokens > 0: - raise NotImplementedError( - "Multi-modal encoder-decoder models are not supported yet") + logger.warning( + "profile run for multi-modal models") batch_size = 0 for group_id in range(max_num_seqs): @@ -290,7 +290,7 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - seq_data, _ = self.input_registry \ + seq_data, dummy_multi_modal_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, seq_len, self.mm_registry) @@ -308,6 +308,7 @@ def profile_run(self) -> None: block_tables=None, encoder_seq_data=seq_data, cross_block_table=None, + multi_modal_data=dummy_multi_modal_data, ) seqs.append(seq) From 718f87929e1803e188a777e8d698dc189ed25700 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 12 Sep 2024 23:22:02 -0700 Subject: [PATCH 19/75] can run tp, but wrong answer --- tests/models/test_llamavl.py | 5 +- vllm/model_executor/models/llamavl.py | 373 +++++++++++++++++--------- vllm/worker/enc_dec_model_runner.py | 1 + 3 files changed, 248 insertions(+), 131 deletions(-) diff --git a/tests/models/test_llamavl.py b/tests/models/test_llamavl.py index 26749d66ec65..d8c3eb6f13e8 100644 --- a/tests/models/test_llamavl.py +++ b/tests/models/test_llamavl.py @@ -27,7 +27,10 @@ size = model_size_map[args.model_type] checkpoint_dir = "/data/zhang-chen/llama/checkpoints" # update checkpoint path here - llm = LLM(model=f"{checkpoint_dir}/Meta-Llama-3.2-{size}-Vision-Early/", enforce_eager=True, limit_mm_per_prompt={"image": 2}, + llm = LLM(model=f"{checkpoint_dir}/Meta-Llama-3.2-{size}-Vision-Early/", + enforce_eager=True, + limit_mm_per_prompt={"image": 2}, + tensor_parallel_size=1 # load_format="dummy" ) diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py index b8cac6bdca6d..47f94ea78b13 100644 --- a/vllm/model_executor/models/llamavl.py +++ b/vllm/model_executor/models/llamavl.py @@ -29,7 +29,7 @@ from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal -from .llama import LlamaAttention +from .llama import LlamaAttention, LlamaMLP from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, @@ -428,7 +428,8 @@ def __init__( n_heads, ): super().__init__() - model_parallel_size = get_tensor_model_parallel_world_size() + model_parallel_size = 1 # skip TP for image now + # model_parallel_size = get_tensor_model_parallel_world_size() qkvo_replication = 1 if model_parallel_size > 16: qkvo_replication = model_parallel_size // 8 @@ -493,7 +494,7 @@ def forward( ] bs, slen, _ = xq.shape - + # print("xq.shape", xq.shape) xq = xq.view(bs, slen, self.n_local_heads, self.head_dim) xk = xk.view(bs, xk.shape[1], self.n_local_kv_heads, self.head_dim) xv = xv.view(bs, xv.shape[1], self.n_local_kv_heads, self.head_dim) @@ -880,21 +881,21 @@ def __init__(self, *args, **kwargs): rope_scaling=rope_scaling, is_neox_style=False, # force to use neox=False ) - self._register_load_state_dict_pre_hook(self.load_hook) - - - def load_hook( - self, - state_dict: Dict[str, Any], - prefix: str, - local_metadata: Dict[str, Any], - strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], - ) -> None: - if prefix + "wo.weight" in state_dict: - state_dict[prefix + "o_proj.weight"] = state_dict.pop(prefix + "wo.weight") + # self._register_load_state_dict_pre_hook(self.load_hook) + + + # def load_hook( + # self, + # state_dict: Dict[str, Any], + # prefix: str, + # local_metadata: Dict[str, Any], + # strict: bool, + # missing_keys: List[str], + # unexpected_keys: List[str], + # error_msgs: List[str], + # ) -> None: + # if prefix + "wo.weight" in state_dict: + # state_dict[prefix + "o_proj.weight"] = state_dict.pop(prefix + "wo.weight") class TransformerBlock(nn.Module): def __init__(self, layer_id: int, args, cache_config: Optional[CacheConfig] = None): @@ -932,11 +933,18 @@ def __init__(self, layer_id: int, args, cache_config: Optional[CacheConfig] = No prefix=f"tb.{layer_id}.self_attn", ) # logger.warning("skip attention") - self.feed_forward = FeedForward( - dim=args.dim, - hidden_dim=4 * args.dim, - multiple_of=args.multiple_of, - ffn_dim_multiplier=args.ffn_dim_multiplier, + + hidden_dim = args.dim * 4 + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if args.ffn_dim_multiplier is not None: + hidden_dim = int(args.ffn_dim_multiplier * hidden_dim) + hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of) + + self.feed_forward = LlamaMLP( + hidden_size=args.dim, + intermediate_size=hidden_dim, + hidden_act="silu", ) self.layer_id = layer_id self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) @@ -1089,40 +1097,39 @@ def __init__( assert n_heads % n_kv_heads == 0 + # TODO: change to Q/KV seperate linear after #7448 is merged + self.qkv_proj = QKVParallelLinear( + dim, + head_dim, + n_heads, + n_kv_heads, + bias=False, + ) - self.wq = nn.Linear(dim, n_heads * head_dim, bias=False) - self.wk = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.wv = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.wo = nn.Linear(n_heads * head_dim, dim, bias=False) - # self.wq = ColumnParallelLinear( + # self.wqkv = ColumnParallelLinear( # dim, # n_heads * head_dim, # bias=False, # gather_output=False, - # init_method=_noinit, # ) - # self.wk = ColumnParallelLinear( # dim, # n_kv_heads * head_dim, # bias=False, # gather_output=False, - # init_method=_noinit, # ) # self.wv = ColumnParallelLinear( # dim, # n_kv_heads * head_dim, # bias=False, # gather_output=False, - # init_method=_noinit, - # ) - # self.wo = RowParallelLinear( - # n_heads * head_dim, - # dim, - # bias=False, - # input_is_parallel=True, - # init_method=_noinit, # ) + self.wo = RowParallelLinear( + n_heads * head_dim, + dim, + bias=False, + input_is_parallel=True, + ) self.n_heads = n_heads self.head_dim = head_dim @@ -1137,12 +1144,6 @@ def __init__( eps=norm_eps, ) self.scaling = self.head_dim**-0.5 - self.attn = Attention( - self.n_heads, - self.head_dim, - self.scaling, - self.n_kv_heads, - ) # cross-attention heads are model parallel similar to # self-attention, and we also use the identical KV head @@ -1155,66 +1156,59 @@ def __init__( self.n_local_heads = self.n_heads // self.model_parallel_size self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size self.n_rep = self.n_local_heads // self.n_local_kv_heads - self._register_load_state_dict_pre_hook(self.load_hook) + self.q_local_size = self.n_local_heads * self.head_dim + self.kv_local_size = self.n_local_kv_heads * self.head_dim - def load_hook( - self, - state_dict: Dict[str, Any], - prefix: str, - local_metadata: Dict[str, Any], - strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], - ) -> None: - if prefix + "inner_attention.q_norm.weight" in state_dict: - q_weight = state_dict.pop(prefix + "inner_attention.q_norm.weight") - state_dict[prefix + "q_norm.weight"] = q_weight - if prefix + "inner_attention.k_norm.weight" in state_dict: - k_weight = state_dict.pop(prefix + "inner_attention.k_norm.weight") - state_dict[prefix + "k_norm.weight"] = k_weight - if prefix + "wkv.weight" in state_dict: - wk, wv = state_dict.pop(prefix + "wkv.weight").chunk(2) - state_dict[prefix + "wk.weight"] = wk - state_dict[prefix + "wv.weight"] = wv + self.attn = Attention( + self.n_local_heads, + self.head_dim, + self.scaling, + self.n_local_kv_heads, + ) - def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: - bsz = xattn_tokens.shape[0] - xk = self.wk(xattn_tokens) - xv = self.wv(xattn_tokens) + # def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: + # bsz = xattn_tokens.shape[0] + # xk, _ = self.wk(xattn_tokens) + # xv, _ = self.wv(xattn_tokens) - # _, seqlen_y, _ = xk.shape + # # _, seqlen_y, _ = xk.shape - xk = xk.view(-1, self.n_local_kv_heads, self.head_dim) - xv = xv.view(-1, self.n_local_kv_heads, self.head_dim) + # xk = xk.view(-1, self.n_local_kv_heads, self.head_dim) + # xv = xv.view(-1, self.n_local_kv_heads, self.head_dim) - xk = self.k_norm(xk) + # xk = self.k_norm(xk) - return torch.stack([xk, xv]) + # return torch.stack([xk, xv]) - def compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: - return self._compute_xattn_kv_cache(xattn_tokens) + # def compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: + # return self._compute_xattn_kv_cache(xattn_tokens) def forward( self, - x: torch.Tensor, + decoder_hidden_states: torch.Tensor, # xattn_mask: torch.Tensor, # full_text_row_masked_out_mask: torch.Tensor, - xattn_cache: torch.Tensor, + # xattn_cache: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor] = None, ) -> torch.Tensor: - xq = F.linear(x, self.wq.weight) + qkv_dec, _ = self.qkv_proj(decoder_hidden_states) + q, _, _ = qkv_dec.split([self.q_local_size, self.kv_local_size, self.kv_local_size], + dim=-1) + if encoder_hidden_states is None: + k = None + v = None + else: + qkv_enc, _ = self.qkv_proj(encoder_hidden_states) + _, k, v = qkv_enc.split([self.q_local_size, self.kv_local_size, self.kv_local_size], + dim=-1) - xq = xq.view(-1, self.n_local_heads, self.head_dim) - xq = self.q_norm(xq) + q = q.view(-1, self.n_local_heads, self.head_dim) + q = self.q_norm(q) - if xattn_cache is not None: - xk, xv = xattn_cache - else: - xk, xv = None, None - output = self.attn(xq, xk, xv, kv_cache, attn_metadata, attn_type=AttentionType.ENCODER_DECODER) - out = F.linear(output, self.wo.weight) + output = self.attn(q, k, v, kv_cache, attn_metadata, attn_type=AttentionType.ENCODER_DECODER) + out, _ = self.wo(output) return out @@ -1247,12 +1241,19 @@ def __init__( ) self.gate_attn = torch.nn.Parameter(torch.zeros(1)) - self.feed_forward = FeedForward( - dim=args.dim, - hidden_dim=4 * args.dim, - ffn_dim_multiplier=args.ffn_dim_multiplier, - multiple_of=args.multiple_of, + hidden_dim = args.dim * 4 + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if args.ffn_dim_multiplier is not None: + hidden_dim = int(args.ffn_dim_multiplier * hidden_dim) + hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of) + + self.feed_forward = LlamaMLP( + hidden_size=args.dim, + intermediate_size=hidden_dim, + hidden_act="silu", ) + self.ffn_norm = RMSNorm( args.dim, eps=args.norm_eps, @@ -1296,25 +1297,27 @@ def load_hook( prefix + "attention.wq.layer_norm_weight" ) - def compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: - return self.attention.compute_xattn_kv_cache(xattn_tokens) + # def compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: + # return self.attention.compute_xattn_kv_cache(xattn_tokens) def forward( self, x: torch.Tensor, # xattn_mask: torch.Tensor, # full_text_row_masked_out_mask: torch.Tensor, - xattn_cache: torch.Tensor, + # xattn_cache: torch.Tensor, kv_cache: torch.LongTensor, attn_metadata: AttentionMetadata, + vision_hidden_states: Optional[torch.Tensor], ) -> torch.Tensor: _attn_out = self.attention( - x=self.attention_norm(x), + decoder_hidden_states=self.attention_norm(x), # xattn_mask=xattn_mask, # full_text_row_masked_out_mask=full_text_row_masked_out_mask, - xattn_cache=xattn_cache, + # xattn_cache=xattn_cache, kv_cache=kv_cache, - attn_metadata=attn_metadata + attn_metadata=attn_metadata, + encoder_hidden_states=vision_hidden_states, ) h = x + self.gate_attn.tanh() * _attn_out _ffn = self.feed_forward(self.ffn_norm(h)) @@ -1530,9 +1533,10 @@ def forward( h: torch.Tensor, # xattn_mask: torch.Tensor, # full_text_row_masked_out_mask: torch.Tensor, - xattn_caches: torch.Tensor, + # xattn_caches: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + vision_hidden_states: Optional[torch.Tensor], ): # assert self.cache_is_setup, "Please set up cache before calling forward" # mask = self.mask_cache.index_select(2, positions) @@ -1547,9 +1551,10 @@ def forward( x=h, # xattn_mask=xattn_mask, # full_text_row_masked_out_mask=full_text_row_masked_out_mask, - xattn_cache=xattn_caches[xattn_layer_idx] if xattn_caches is not None else None, + # xattn_cache=xattn_caches[xattn_layer_idx] if xattn_caches is not None else None, kv_cache=kv_caches[idx], attn_metadata=attn_metadata, + vision_hidden_states=vision_hidden_states, ) # check(h, f"layer_{idx}_xh_{step_name}.pt") h = layer( @@ -1874,16 +1879,122 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): state_dict = {name: weight for name, weight in weights} state_dict.pop('text_model.rope.freqs') state_dict['lm_head.weight'] = state_dict.pop('text_model.output.weight') - param_dict = {k: v for k, v in self.named_parameters()} - for i in range(self.text_model.n_layers): - module = self.text_model.layers[i].attention.qkv_proj - param_name = f"text_model.layers.{i}.attention.qkv_proj.weight" - weight_name = f"text_model.layers.{i}.attention.wqkv.weight" - param = param_dict[param_name] - weight = state_dict.pop(weight_name) - module.weight_loader(param, weight) - self.load_state_dict(state_dict, strict=False) - + load_succ = True + + def load_weight(param, weight): + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight) + + for param_name, param in self.named_parameters(): + # print("loading", param_name) + if param_name.startswith("text_model.layers"): + layer_id = int(param_name.split(".")[2]) + if param_name.endswith("attention.qkv_proj.weight"): + # "text_model.layers.{i}.attention.qkv_proj.weight" + weight_name = f"text_model.layers.{layer_id}.attention.wqkv.weight" + weight = state_dict.pop(weight_name) + module = self.text_model.layers[layer_id].attention.qkv_proj + module.weight_loader(param, weight) + continue + elif param_name.endswith("attention.o_proj.weight"): + # "text_model.layers.{i}.attention.o_proj.weight" + weight_name = f"text_model.layers.{layer_id}.attention.wo.weight" + weight = state_dict.pop(weight_name) + module = self.text_model.layers[layer_id].attention.o_proj + module.weight_loader(param, weight) + continue + elif param_name.endswith("feed_forward.gate_up_proj.weight"): + # "text_model.layers.{i}.feed_forward.mlp.fc1_weight" + weight_name = f"text_model.layers.{layer_id}.feed_forward.mlp.fc1_weight" + weight = state_dict.pop(weight_name) + module = self.text_model.layers[layer_id].feed_forward.gate_up_proj + module.weight_loader(param, weight) + continue + elif param_name.endswith("feed_forward.down_proj.weight"): + # "text_model.layers.{i}.feed_forward.mlp.fc2_weight" + weight_name = f"text_model.layers.{layer_id}.feed_forward.mlp.fc2_weight" + weight = state_dict.pop(weight_name) + module = self.text_model.layers[layer_id].feed_forward.down_proj + module.weight_loader(param, weight) + continue + elif param_name.endswith("attention_norm.weight"): + # "text_model.layers.{i}.attention_norm.weight" + weight_name = f"text_model.layers.{layer_id}.attention.wqkv.layer_norm_weight" + weight = state_dict.pop(weight_name) + load_weight(param, weight) + continue + elif param_name.endswith("ffn_norm.weight"): + # "text_model.layers.{i}.ffn_norm.weight" + weight_name = f"text_model.layers.{layer_id}.feed_forward.mlp.layer_norm_weight" + weight = state_dict.pop(weight_name) + load_weight(param, weight) + continue + if param_name.startswith("text_model.cross_attention_layers"): + layer_id = int(param_name.split(".")[2]) + if param_name.endswith('gate_attn'): + attn_gate = state_dict.pop(param_name) + if attn_gate.dim() == 1: + attn_gate = attn_gate[0].view(1) + if attn_gate.dim() == 3: + attn_gate = attn_gate.view(1) + load_weight(param, attn_gate) + continue + if param_name.endswith('gate_ffwd'): + ffn_gate = state_dict.pop(param_name) + if ffn_gate.dim() == 1: + ffn_gate = ffn_gate[0].view(1) + if ffn_gate.dim() == 3: + ffn_gate = ffn_gate.view(1) + load_weight(param, ffn_gate) + continue + if param_name.endswith('ffn_norm.weight'): + load_weight(param, state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.feed_forward.mlp.layer_norm_weight")) + continue + if param_name.endswith('attention_norm.weight'): + load_weight(param, state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.attention.wq.layer_norm_weight")) + continue + # if param_name.endswith('attention.wk.weight') or param_name.endswith('attention.wv.weight'): + # if f"text_model.cross_attention_layers.{layer_id}.attention.wkv.weight" in state_dict: + # weight = state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.attention.wkv.weight") + # state_dict[f"text_model.cross_attention_layers.{layer_id}.attention.wkv.weight.1"] = weight + # else: + # weight = state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.attention.wkv.weight.1") + # if param_name.endswith('attention.wk.weight'): + # weight = weight.chunk(2)[0] + # else: + # weight = weight.chunk(2)[1] + # load_weight(param, weight) + # continue + if param_name.endswith('attention.qkv_proj.weight'): + q_weight = state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.attention.wq.weight") + kv_weight = state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.attention.wkv.weight") + qkv_weight = torch.cat([q_weight, kv_weight], dim=0) + module = self.text_model.cross_attention_layers[layer_id].attention.qkv_proj + module.weight_loader(param, qkv_weight) + continue + if param_name.endswith('attention.q_norm.weight'): + load_weight(param, state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.attention.inner_attention.q_norm.weight")) + continue + if param_name.endswith('attention.k_norm.weight'): + load_weight(param, state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.attention.inner_attention.k_norm.weight")) + continue + if param_name.endswith('feed_forward.gate_up_proj.weight'): + load_weight(param, state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.feed_forward.mlp.fc1_weight")) + continue + if param_name.endswith('feed_forward.down_proj.weight'): + load_weight(param, state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.feed_forward.mlp.fc2_weight")) + continue + if param_name in state_dict: + loaded_weight = state_dict.pop(param_name) + load_weight(param, loaded_weight) + continue + + raise ValueError(f"Unexpected parameter {param_name}") + + if len(state_dict) > 0: + raise ValueError(f"unused keys: {state_dict.keys()}") + def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaImageInputs]: @@ -1976,6 +2087,7 @@ def forward( cross_attention_masks = None full_text_row_masked_out_mask = None xattn_caches = None + vision_tokens = None else: # llama's reference implementation runs the vision model on CPU cuda_images = image_inputs['data'].cuda() @@ -1991,22 +2103,23 @@ def forward( end_pos = start_pos + seq_len vision_tokens_flat[start_pos:end_pos] = vision_token_in_batch[:seq_len] start_pos = end_pos + vision_tokens = vision_tokens_flat - batch_masks = [] - # TODO: get the sequence of each query without hack? 1) better attn metadata 2) better input processor to create vision mask during preprocess - # assert isinstance(attn_metadata, PagedAttentionMetadata) - start_pos = 0 - for seq_len in attn_metadata.seq_lens_tensor: - end_pos = start_pos + seq_len - batch_masks.append(create_vision_mask(input_ids[start_pos:end_pos])) - start_pos = end_pos + # batch_masks = [] + # # TODO: get the sequence of each query without hack? 1) better attn metadata 2) better input processor to create vision mask during preprocess + # # assert isinstance(attn_metadata, PagedAttentionMetadata) + # start_pos = 0 + # for seq_len in attn_metadata.seq_lens_tensor: + # end_pos = start_pos + seq_len + # batch_masks.append(create_vision_mask(input_ids[start_pos:end_pos])) + # start_pos = end_pos - xattn_caches = torch.stack( - [ - layer.compute_xattn_kv_cache(vision_tokens_flat) - for layer in self.text_model.cross_attention_layers - ] - ) + # xattn_caches = torch.stack( + # [ + # layer.compute_xattn_kv_cache(vision_tokens_flat) + # for layer in self.text_model.cross_attention_layers + # ] + # ) # TODO: remove this hardcode # total_len = 512 # padded_masks = _pad_masks( @@ -2033,10 +2146,10 @@ def forward( # full_text_row_masked_out_mask_plain[start_pos:end_pos, 0] = full_text_row_masked_out_mask[i, 0, :seq_len, 0] # start_pos = end_pos # full_text_row_masked_out_mask = full_text_row_masked_out_mask_plain.cuda() - # print("input_ids", input_ids) - if positions.numel() == 1: - global step_name - step_name = f"decode_{positions.item()}" + # print("input_ids", input_ids, vision_tokens is None) + # if positions.numel() == 1: + # global step_name + # step_name = f"decode_{positions.item()}" h = self.text_model.get_partially_trainable_embedding(input_ids) # check(h, f"h_{step_name}.pt") h = self.text_model.forward( @@ -2044,7 +2157,7 @@ def forward( h=h, # xattn_mask=cross_attention_masks, # full_text_row_masked_out_mask=full_text_row_masked_out_mask, - xattn_caches=xattn_caches, + vision_hidden_states=vision_tokens, kv_caches=kv_caches, attn_metadata=attn_metadata, ) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 7a1b5f45fb2d..2330975b54ac 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -50,6 +50,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "virtual_engine": self.virtual_engine, "request_ids_to_seq_ids": self.request_ids_to_seq_ids, "finished_requests_ids": self.finished_requests_ids, + "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, From 264434906ba5acee00260092c475710df8937b3b Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 13 Sep 2024 00:51:10 -0700 Subject: [PATCH 20/75] can run tp for small model with correct result --- vllm/model_executor/models/llamavl.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py index 47f94ea78b13..a93d0a6b3d8c 100644 --- a/vllm/model_executor/models/llamavl.py +++ b/vllm/model_executor/models/llamavl.py @@ -30,7 +30,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal from .llama import LlamaAttention, LlamaMLP -from vllm.model_executor.layers.layernorm import RMSNorm +# from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, @@ -46,10 +46,12 @@ def check(tensor, file_name): pass # with open(f"{pt_dir}{file_name}", "rb") as f: # data = torch.load(f) - # tensor_flat = tensor.cpu().reshape(-1) # data_flat = data.cpu().reshape(-1) - # print("check:", file_name, torch.allclose(tensor_flat, data_flat), torch.max(torch.abs(tensor_flat-data_flat)), tensor_flat.shape, data_flat.shape) + # if tensor_flat.shape != data_flat.shape: + # print("check:", file_name, "shape missmatch", tensor_flat.shape, data_flat.shape) + # return + # print("check:", file_name, torch.allclose(tensor_flat, data_flat), torch.max(torch.abs(tensor_flat-data_flat)), tensor.shape, data.shape) logger = init_logger(__name__) MP_SCALE = 8 @@ -328,6 +330,21 @@ def _get_full_row_masked_out_mask( """ return (attn_bias != negative_inf_value).any(dim=-1).type_as(attn_bias)[..., None] +# use float RMSNorm to make result closer to reference impl. +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * self.weight + + # Image encoder for inference class LayerNorm(nn.LayerNorm): """Subclass torch's LayerNorm to handle fp16.""" @@ -1203,7 +1220,9 @@ def forward( qkv_enc, _ = self.qkv_proj(encoder_hidden_states) _, k, v = qkv_enc.split([self.q_local_size, self.kv_local_size, self.kv_local_size], dim=-1) - + k = k.view(-1, self.n_local_kv_heads, self.head_dim) + v = v.view(-1, self.n_local_kv_heads, self.head_dim) + k = self.k_norm(k) q = q.view(-1, self.n_local_heads, self.head_dim) q = self.q_norm(q) From ec4cb9c680045601b0a126ccf95e77603981ba3e Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 13 Sep 2024 17:43:21 -0700 Subject: [PATCH 21/75] tp for vision encoder --- vllm/model_executor/models/llamavl.py | 418 ++++++-------------------- 1 file changed, 86 insertions(+), 332 deletions(-) diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py index a93d0a6b3d8c..2a059bc302fa 100644 --- a/vllm/model_executor/models/llamavl.py +++ b/vllm/model_executor/models/llamavl.py @@ -39,6 +39,7 @@ get_tensor_model_parallel_world_size) from vllm.transformers_utils.multimodal_processors.llamavl import LlamaVLImageProcessor +import vllm.distributed.parallel_state as ps step_name = "prefill" pt_dir = "" @@ -379,12 +380,7 @@ def __init__( if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size) self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride) - # self._linear = ColumnParallelLinear( - # in_channels * kernel_size[0] * kernel_size[1], - # out_channels, - # bias=bias, - # ) - self._linear = nn.Linear( + self._linear = ColumnParallelLinear( in_channels * kernel_size[0] * kernel_size[1], out_channels, bias=bias, @@ -393,9 +389,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: x = self._unfold(x) x = x.permute(0, 2, 1) - x = self._linear(x) - # x = F.linear(x, self._linear.weight) - # x = gather_from_tensor_model_parallel_region(x) + x, _ = self._linear(x) return x @@ -409,31 +403,26 @@ def __init__( ): super().__init__() # layers - self.c_fc = nn.Linear(dim, hidden_dim, bias=True) - # self.c_fc = ColumnParallelLinear( - # dim, - # hidden_dim, - # bias=True, - # gather_output=False, - # init_method=lambda x: x, - # ) - self.c_proj = nn.Linear(hidden_dim, dim, bias=True) - # self.c_proj = RowParallelLinear( - # hidden_dim, - # dim, - # bias=True, - # input_is_parallel=True, - # init_method=lambda x: x, - # ) + self.c_fc = ColumnParallelLinear( + dim, + hidden_dim, + bias=True, + ) + self.c_proj = RowParallelLinear( + hidden_dim, + dim, + bias=True, + input_is_parallel=True, + skip_bias_add=True, # add bias explicitly for precision concern + ) self.non_linearity = act_layer() self.dropout = dropout def forward(self, x): - hidden = F.linear(x, self.c_fc.weight, self.c_fc.bias) + hidden, _ = self.c_fc(x) hidden = self.non_linearity(hidden) - hidden = F.linear(hidden, self.c_proj.weight) - # hidden = reduce_from_tensor_model_parallel_region(hidden) - hidden += self.c_proj.bias + hidden, bias = self.c_proj(hidden) # skip_bias_add=True + hidden += bias return hidden @@ -441,96 +430,58 @@ class ImageAttention(nn.Module): def __init__( self, dim, - head_dim, n_heads, ): super().__init__() - model_parallel_size = 1 # skip TP for image now - # model_parallel_size = get_tensor_model_parallel_world_size() - qkvo_replication = 1 - if model_parallel_size > 16: - qkvo_replication = model_parallel_size // 8 - + model_parallel_size = get_tensor_model_parallel_world_size() + self.n_heads = n_heads self.n_kv_heads = n_heads - self.n_local_heads = n_heads * qkvo_replication // model_parallel_size + self.n_local_heads = n_heads // model_parallel_size self.n_local_kv_heads = ( - self.n_kv_heads * qkvo_replication // model_parallel_size + self.n_kv_heads // model_parallel_size ) self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = dim // n_heads + self.q_size = self.n_local_heads * self.head_dim + self.kv_size = self.n_local_kv_heads * self.head_dim + assert self.n_heads % self.n_kv_heads == 0 + assert self.n_heads % model_parallel_size == 0 + assert self.n_kv_heads % model_parallel_size == 0 # The model provided by llama is with bias=True, but the weight does not contain bias # During runtime, the llama executor set bias to zero. We use bias=False here to match the behavior - self.wq = nn.Linear(dim, n_heads * self.head_dim, bias=False) - self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias=False) - self.wo = nn.Linear(n_heads * self.head_dim, dim, bias=False) - # self.wq = ColumnParallelLinear( - # dim, - # qkvo_replication * n_heads * self.head_dim, - # bias=True, - # gather_output=False, - # init_method=lambda x: x, - # ) - # self.wk = ColumnParallelLinear( - # dim, - # qkvo_replication * self.n_kv_heads * self.head_dim, - # bias=True, - # gather_output=False, - # init_method=lambda x: x, - # ) - # self.wv = ColumnParallelLinear( - # dim, - # qkvo_replication * self.n_kv_heads * self.head_dim, - # bias=True, - # gather_output=False, - # init_method=lambda x: x, - # ) - # self.wo = RowParallelLinear( - # qkvo_replication * n_heads * self.head_dim, - # dim, - # bias=True, - # input_is_parallel=True, - # init_method=lambda x: x, - # ) - self.qkvo_replication = qkvo_replication + self.qkv_proj = QKVParallelLinear( + dim, + self.head_dim, + n_heads, + bias=False, + ) + self.wo = RowParallelLinear( + n_heads * self.head_dim, + dim, + bias=False, + input_is_parallel=True, + ) def forward( self, x: torch.Tensor, mask: torch.Tensor = None, ): + qkv, _ = self.qkv_proj(x) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(q.shape[0], q.shape[1], self.n_local_heads, self.head_dim).transpose(1, 2) + k = k.view(k.shape[0], k.shape[1], self.n_local_kv_heads, self.head_dim).transpose(1, 2) + v = v.view(v.shape[0], v.shape[1], self.n_local_kv_heads, self.head_dim).transpose(1, 2) - xq, xk, xv = [ - F.linear(x, w, b) - for (w, b) in [ - (self.wq.weight, self.wq.bias), - (self.wk.weight, self.wk.bias), - (self.wv.weight, self.wv.bias), - ] - ] - - bs, slen, _ = xq.shape - # print("xq.shape", xq.shape) - xq = xq.view(bs, slen, self.n_local_heads, self.head_dim) - xk = xk.view(bs, xk.shape[1], self.n_local_kv_heads, self.head_dim) - xv = xv.view(bs, xv.shape[1], self.n_local_kv_heads, self.head_dim) - - xq, xk, xv = [tensor.transpose(1, 2) for tensor in (xq, xk, xv)] - - xk = xk.repeat_interleave(self.n_rep, dim=1) - xv = xv.repeat_interleave(self.n_rep, dim=1) - + # TODO: remove padding in image encoder attn_output = F.scaled_dot_product_attention( - xq, xk, xv, attn_mask=mask, dropout_p=0.0 + q, k, v, attn_mask=mask, dropout_p=0.0 ) - attn_output = attn_output.transpose(1, 2).contiguous().reshape(bs, slen, -1) - - out = F.linear(attn_output, self.wo.weight) - # out = reduce_from_tensor_model_parallel_region(out) - out = out / self.qkvo_replication - # out += self.wo.bias + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(attn_output.shape[0], attn_output.shape[1], -1) + out, _ = self.wo(attn_output) return out @@ -549,7 +500,6 @@ def __init__( self.head_dim = d_model // self.n_heads self.attn = ImageAttention( dim=d_model, - head_dim=self.head_dim, n_heads=self.n_heads, ) self.ln_1 = LayerNorm(d_model) @@ -685,55 +635,6 @@ def __init__( ) self.gated_positional_embedding_gate = nn.Parameter(torch.zeros(1)) - self._register_load_state_dict_pre_hook(self.load_hook) - - def load_hook( - self, - state_dict: Dict[str, Any], - prefix: str, - local_metadata: Dict[str, Any], - strict: bool = True, - missing_keys: List[str] = None, - unexpected_keys: List[str] = None, - error_msgs: List[str] = None, - return_state_dict: bool = False, - ) -> None: - orig_pos_embed = state_dict.get(prefix + "positional_embedding") - if orig_pos_embed is not None: - new_pos_embed = resize_local_position_embedding( - orig_pos_embed, self.grid_size - ) - state_dict[prefix + "positional_embedding"] = new_pos_embed - if hasattr(self, "gated_positional_embedding"): - if prefix + "gated_positional_embedding" not in state_dict: - # resize positional_embedding to fit the new grid size - global_pos_embed = initialize_global_position_embedding_from_local( - new_pos_embed, - self.grid_size, - self.max_num_tiles, - self.max_num_tiles, - ) - state_dict[prefix + "gated_positional_embedding"] = global_pos_embed - state_dict[prefix + "gated_positional_embedding_gate"] = torch.zeros( - 1, dtype=global_pos_embed.dtype - ) - logger.info( - f"Initialized global positional embedding with size {global_pos_embed.size()}" - ) - else: - global_pos_embed = resize_global_position_embedding( - state_dict[prefix + "gated_positional_embedding"], - self.grid_size, - self.max_num_tiles, - self.max_num_tiles, - ) - logger.debug( - f"Resized global positional embedding from {state_dict[prefix + 'gated_positional_embedding'].size()} to {global_pos_embed.size()}" - ) - state_dict[prefix + "gated_positional_embedding"] = global_pos_embed - if return_state_dict: - return state_dict - def apply_positional_embedding(self, x, ar): out = [] # apply regular position embedding @@ -777,11 +678,12 @@ def forward(self, images: torch.Tensor, ar: torch.Tensor) -> torch.Tensor: # patch embedding x = images.reshape(bsz * num_concurrent_media * num_chunks, nch, w, h) x = self.conv1(x) + x = ps.get_tp_group().all_gather(x) _, ntok, dim = x.shape x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim) # tile embeddings - x = self.pre_tile_pos_embed(x, ar) + x = self.pre_tile_pos_embed(x, ar) # call all_gather here, dim will change x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim) # apply cls token @@ -818,74 +720,6 @@ def forward(self, images: torch.Tensor, ar: torch.Tensor) -> torch.Tensor: return x -class FeedForward(nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int, - ffn_dim_multiplier: Optional[float], - ): - """ - Initialize the FeedForward module. - Args: - dim (int): Input dimension. - hidden_dim (int): Hidden dimension of the feedforward layer. - multiple_of (int): Value to ensure hidden dimension is a multiple of this value. - ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None. - Attributes: - w1 (ColumnParallelLinear): Linear transformation for the first layer. - w2 (RowParallelLinear): Linear transformation for the second layer. - w3 (ColumnParallelLinear): Linear transformation for the third layer. - """ - super().__init__() - hidden_dim = int(2 * hidden_dim / 3) - # custom dim factor multiplier - if ffn_dim_multiplier is not None: - hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - - self.w1 = nn.Linear(dim, hidden_dim, bias=False) - self.w2 = nn.Linear(hidden_dim, dim, bias=False) - self.w3 = nn.Linear(dim, hidden_dim, bias=False) - # self.w1 = ColumnParallelLinear( - # dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x - # ) - # self.w2 = RowParallelLinear( - # hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x - # ) - # self.w3 = ColumnParallelLinear( - # dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x - # ) - self._register_load_state_dict_pre_hook(self.load_hook) - - def forward(self, x): - x1, x3 = [F.linear(x, w) for w in [self.w1.weight, self.w3.weight]] - x1 = F.silu(x1) - x_in = x1 * x3 - out = F.linear(x_in, self.w2.weight) - # out = reduce_from_tensor_model_parallel_region(out) - return out - - def load_hook( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): - if prefix + "mlp.fc1_weight" in state_dict: - fc1_weight, fc3_weight = state_dict.pop(prefix + "mlp.fc1_weight").chunk(2) - state_dict[prefix + "w1.weight"] = fc1_weight - state_dict[prefix + "w3.weight"] = fc3_weight - - if prefix + "mlp.fc2_weight" in state_dict: - fc2_weight = state_dict.pop(prefix + "mlp.fc2_weight") - state_dict[prefix + "w2.weight"] = fc2_weight - class LlamaVLAttention(LlamaAttention): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -898,21 +732,6 @@ def __init__(self, *args, **kwargs): rope_scaling=rope_scaling, is_neox_style=False, # force to use neox=False ) - # self._register_load_state_dict_pre_hook(self.load_hook) - - - # def load_hook( - # self, - # state_dict: Dict[str, Any], - # prefix: str, - # local_metadata: Dict[str, Any], - # strict: bool, - # missing_keys: List[str], - # unexpected_keys: List[str], - # error_msgs: List[str], - # ) -> None: - # if prefix + "wo.weight" in state_dict: - # state_dict[prefix + "o_proj.weight"] = state_dict.pop(prefix + "wo.weight") class TransformerBlock(nn.Module): def __init__(self, layer_id: int, args, cache_config: Optional[CacheConfig] = None): @@ -966,26 +785,6 @@ def __init__(self, layer_id: int, args, cache_config: Optional[CacheConfig] = No self.layer_id = layer_id self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) - self._register_load_state_dict_pre_hook(self.load_hook) - - def load_hook( - self, - state_dict: Dict[str, Any], - prefix: str, - local_metadata: Dict[str, Any], - strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], - ) -> None: - if prefix + "feed_forward.mlp.layer_norm_weight" in state_dict: - state_dict[prefix + "ffn_norm.weight"] = state_dict.pop( - prefix + "feed_forward.mlp.layer_norm_weight" - ) - if prefix + "attention.wqkv.layer_norm_weight" in state_dict: - state_dict[prefix + "attention_norm.weight"] = state_dict.pop( - prefix + "attention.wqkv.layer_norm_weight" - ) def forward( self, @@ -1033,30 +832,6 @@ def __init__( if gated: self.gate = nn.Parameter(torch.zeros(1)) - self._register_load_state_dict_pre_hook(self.load_hook) - - def load_hook( - self, - state_dict, - prefix, - local_metadata, - strict, - missing_keys, - unexpected_keys, - error_msgs, - ): - # load the weights from the checkpoint - embed = state_dict.get(prefix + "embedding") - if embed is not None: - # reshape the weights to the correct shape - nt_old, nt_old, _, w = embed.shape - logger.info( - f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}" - ) - embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles) - # assign the weights to the module - state_dict[prefix + "embedding"] = embed_new - @staticmethod def _dynamic_resize(embed: torch.Tensor, num_tiles: int): nt_old, nt_old, _, w = embed.shape @@ -1279,46 +1054,8 @@ def __init__( ) self.gate_ffwd = torch.nn.Parameter(torch.zeros(1)) - logger.warning("todo put hook in correct place") - self._register_load_state_dict_pre_hook(self.load_hook) self.no_ffn = no_ffn - def load_hook( - self, - state_dict: Dict[str, Any], - prefix: str, - local_metadata: Dict[str, Any], - strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], - ) -> None: - if prefix + "gate_attn" in state_dict: - attn_gate = state_dict.pop(prefix + "gate_attn") - if attn_gate.dim() == 1: - attn_gate = attn_gate[0].view(1) - if attn_gate.dim() == 3: - attn_gate = attn_gate.view(1) - state_dict[prefix + "gate_attn"] = attn_gate - if prefix + "gate_ffwd" in state_dict: - ffn_gate = state_dict.pop(prefix + "gate_ffwd") - if ffn_gate.dim() == 1: - ffn_gate = ffn_gate[0].view(1) - if ffn_gate.dim() == 3: - ffn_gate = ffn_gate.view(1) - state_dict[prefix + "gate_ffwd"] = ffn_gate - if prefix + "feed_forward.mlp.layer_norm_weight" in state_dict: - state_dict[prefix + "ffn_norm.weight"] = state_dict.pop( - prefix + "feed_forward.mlp.layer_norm_weight" - ) - if prefix + "attention.wq.layer_norm_weight" in state_dict: - state_dict[prefix + "attention_norm.weight"] = state_dict.pop( - prefix + "attention.wq.layer_norm_weight" - ) - - # def compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: - # return self.attention.compute_xattn_kv_cache(xattn_tokens) - def forward( self, x: torch.Tensor, @@ -1395,7 +1132,7 @@ def __init__(self, args) -> None: self.vision_input_dim, args.dim, bias=True, - ) + ) # ORZZZZZZZZZZ # self.vision_projection = ColumnParallelLinear( # self.vision_input_dim, # args.dim, @@ -1501,8 +1238,6 @@ def __init__(self, args, cache_config:Optional[CacheConfig]) -> None: args.use_scaled_rope, ) - self._register_load_state_dict_pre_hook(self.load_hook) - self.args = args self.cache_is_setup = False self.max_seq_len = args.max_seq_len @@ -1533,19 +1268,6 @@ def get_partially_trainable_embedding(self, x): x_new = self.learnable_embedding(x_new).type_as(x_orig) return x_orig * mask_orig.type_as(x_orig) + x_new * mask_new.type_as(x_new) - def load_hook( - self, - state_dict: Dict[str, Any], - prefix: str, - local_metadata: Dict[str, Any], - strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], - ) -> None: - if "rope.freqs" in state_dict: - del state_dict["rope.freqs"] - def forward( self, positions: torch.LongTensor, @@ -1895,10 +1617,15 @@ def __init__(self, config, self.sampler = Sampler() def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # my_rank = get_tensor_model_parallel_rank() state_dict = {name: weight for name, weight in weights} + # if my_rank == 0: + # with open("weight_shape_map.log", "w") as f: + # for name, weight in state_dict.items(): + # f.write(f"{name}-{tuple(weight.shape)}-{weight.dtype}\n") + state_dict.pop('text_model.rope.freqs') state_dict['lm_head.weight'] = state_dict.pop('text_model.output.weight') - load_succ = True def load_weight(param, weight): weight_loader = getattr(param, "weight_loader", @@ -2004,6 +1731,34 @@ def load_weight(param, weight): if param_name.endswith('feed_forward.down_proj.weight'): load_weight(param, state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.feed_forward.mlp.fc2_weight")) continue + if param_name.startswith("vision_model.vision_encoder"): + if param_name == 'vision_model.vision_encoder.conv1._linear.weight': + module = self.vision_model.vision_encoder.conv1._linear + weight = state_dict.pop('vision_model.vision_encoder.conv1._linear.weight') + module.weight_loader(param, weight) + continue + if param_name.startswith("vision_model.vision_encoder.transformer.resblocks") or param_name.startswith("vision_model.vision_encoder.global_transformer.resblocks"): + layer_id = int(param_name.split(".")[4]) + if param_name.startswith('vision_model.vision_encoder.transformer.resblocks'): + prefix = 'vision_model.vision_encoder.transformer.resblocks' + transformer_block: ImageTransformerBlock = self.vision_model.vision_encoder.transformer.resblocks[layer_id] + else: + prefix = 'vision_model.vision_encoder.global_transformer.resblocks' + transformer_block = self.vision_model.vision_encoder.global_transformer.resblocks[layer_id] + if param_name.endswith("mlp.c_fc.weight"): + module = transformer_block.mlp.c_fc + weight = state_dict.pop(f"{prefix}.{layer_id}.mlp.c_fc.weight") + module.weight_loader(param, weight) + continue + if param_name.endswith("attn.qkv_proj.weight"): + module = transformer_block.attn.qkv_proj + q_weight = state_dict.pop(f"{prefix}.{layer_id}.attn.wq.weight") + k_weight = state_dict.pop(f"{prefix}.{layer_id}.attn.wk.weight") + v_weight = state_dict.pop(f"{prefix}.{layer_id}.attn.wv.weight") + # import pdb; pdb.set_trace() + qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) + module.weight_loader(param, qkv_weight) + continue if param_name in state_dict: loaded_weight = state_dict.pop(param_name) load_weight(param, loaded_weight) @@ -2013,7 +1768,6 @@ def load_weight(param, weight): if len(state_dict) > 0: raise ValueError(f"unused keys: {state_dict.keys()}") - def _parse_and_validate_image_input( self, **kwargs: object) -> Optional[LlavaImageInputs]: From fc012665788ce1cec40e78dfa7c8cc809f4a17cd Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 15 Sep 2024 00:23:42 -0700 Subject: [PATCH 22/75] update image preprocessor --- tests/models/test_llamavl.py | 30 +- vllm/entrypoints/chat_utils.py | 2 + vllm/model_executor/models/llamavl.py | 516 +++++++++++++++++--------- 3 files changed, 345 insertions(+), 203 deletions(-) diff --git a/tests/models/test_llamavl.py b/tests/models/test_llamavl.py index d8c3eb6f13e8..5283d52e802c 100644 --- a/tests/models/test_llamavl.py +++ b/tests/models/test_llamavl.py @@ -30,7 +30,7 @@ llm = LLM(model=f"{checkpoint_dir}/Meta-Llama-3.2-{size}-Vision-Early/", enforce_eager=True, limit_mm_per_prompt={"image": 2}, - tensor_parallel_size=1 + tensor_parallel_size=1, # load_format="dummy" ) @@ -40,34 +40,8 @@ image = PIL_Image.open(f).convert("RGB") with open(f"{resource_dir}/pasta.jpeg", "rb") as f: image2 = PIL_Image.open(f).convert("RGB") - - # inputs = [ - # { - # "prompt": "<|image|><|image|><|begin_of_text|>In a sentence, these two images paint", - # "multi_modal_data": { - # "image": [image, image2], - # } - # }, - # { - # "prompt": "<|image|><|begin_of_text|>If I had to write a haiku for this one", - # "multi_modal_data": { - # "image": [image] - # } - # }, - # # { - # # "prompt": "he color of the sky is blue but sometimes it can also be", - # # }, - # ] + inputs = [ - { - "encoder_prompt": { - "prompt": "", - "multi_modal_data": { - "image": [image, image2], - } - }, - "decoder_prompt": "<|image|><|image|><|begin_of_text|>In a sentence, these two images paint", - }, { "encoder_prompt":{ "prompt": "", diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c70c6d9330b1..9e1b77de8425 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -135,6 +135,8 @@ def add(self, modality: Literal["image", "audio"], self._model_config.hf_config.image_token_index) if model_type in ("chameleon", "internvl_chat"): return "" + if model_type == "llamavl": + return "<|image|>" raise TypeError(f"Unknown model type: {model_type}") elif modality == "audio": diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py index 2a059bc302fa..be59bfc21988 100644 --- a/vllm/model_executor/models/llamavl.py +++ b/vllm/model_executor/models/llamavl.py @@ -4,7 +4,7 @@ import collections import math from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, - TypedDict, Union, Callable, Dict, Any) + TypedDict, Union, Callable, Dict, Any, Set) import torch import torch.nn as nn @@ -42,7 +42,7 @@ import vllm.distributed.parallel_state as ps step_name = "prefill" -pt_dir = "" +pt_dir = "/home/eecs/zhang-chen/MultiModal/scripts/" def check(tensor, file_name): pass # with open(f"{pt_dir}{file_name}", "rb") as f: @@ -666,6 +666,7 @@ def apply_class_embedding(self, x): return x def forward(self, images: torch.Tensor, ar: torch.Tensor) -> torch.Tensor: + # TODO: run tp in this function if images.ndim == 5: num_concurrent_media = 1 bsz, num_chunks, nch, w, h = images.shape @@ -857,8 +858,8 @@ def forward(self, x: torch.Tensor, ar: torch.Tensor, num_tiles: int = None): x.shape[0], num_tiles, 1, self.width, device=x.device, dtype=x.dtype ) for idx, arx in enumerate(ar): - w, h = arx - out_pos_embed[idx, : w * h] = embed[:w, :h].reshape(w * h, 1, self.width) + h, w = arx + out_pos_embed[idx, : w * h] = embed[:h, :w].reshape(w * h, 1, self.width) if self.gated: out_pos_embed = out_pos_embed * self.gate.tanh() x = x + out_pos_embed @@ -1149,7 +1150,7 @@ def forward( vision_tokens = self.vision_encoder( images.to(dtype=torch.bfloat16), aspect_ratios ) - vision_tokens = F.linear(vision_tokens, self.vision_projection.weight) + vision_tokens = self.vision_projection(vision_tokens) # vision_tokens = gather_from_tensor_model_parallel_region(vision_tokens) return vision_tokens @@ -1334,7 +1335,7 @@ def _get_xattn_mask( _, _, _, num_image_tokens, image_token_dim = tuple(vision_tokens.shape) bsz, ntext, nimg, nchunks = cross_attention_masks.shape cross_attention_masks = ( - cross_attention_masks.repeat_interleave(vision_seqlen, dim=2) + cross_attention_masks.repeat_interleave(vision_seqlen, dim=3) .view(bsz, ntext, -1) .unsqueeze(1) ) @@ -1352,26 +1353,42 @@ def _get_xattn_mask( class VariableSizeImageTransform(object): """ - The variable size image transform will resize the image dynamically + This class accepts images of any size and dynamically resize, pads and chunks it based on the image aspect ratio and the number of image chunks we allow. - The algorithm will not upsample low-res images to fit a certain aspect - ratio, because that leads to a significant degradation in image quality. - For example, if an input image is of size 300x800, and we want to allow - a maximum of 16 image chunks, it will find the closest aspect ratio that - is allowed within 16 image chunks, i.e., 2:5 = 2 horizontal patches and - 5 vertical patches, giving a total of 10 chunks. - The image will then be resized to products of the base size (default is - 224px because MetaCLIP takes that), so in this case it will be resized to - 2*224:5*224 = 448:1120, where we maintain the original aspect ratio and - pad with the mean value for the rest. This approach minimizes the amount - of padding required for any arbitrary resolution. - The final output will therefore be of shape (11, 3, 224, 224), where 10 - patches are coming from the resizing and chunking, and the first patch - is a downsampled version of the image that preserves aspect ratios. + + The algorithm will NOT distort the image fit a certain aspect ratio, because + that leads to a significant degradation in image quality. + + It can be summarized in 6 steps: + 1. Find all possible canvas combinations of max_num_chunks; + 2. Find the best canvas to fit the image; + 3. Resize without distortion + 4. Pad + 5. Normalize + 6. Chunk + + For example, if an input image is of size 300x800, patch_size of 224, + and max_num_chunks = 8, it will find the closest aspect ratio that + is allowed within 8 image chunks, with some restrictions. + In this case, 2:4 = 2 horizontal patches and 4 vertical patches, + giving a total of 8 chunks. + + If resize_to_max_canvas, the image will be resized (without distortion), + to the largest possible resolution. In this case, 388:896, and padded to 448:896, + where we maintain the original aspect ratio and pad with zeros value for the rest. + This approach minimizes the amount of padding required for any arbitrary resolution. + + However, if limit_upscaling_to_patch_size is set to True, + the upscaling will be limited to the patch size. In the example above, + the image would remain 300x800 (no upscaling), and then padded to 448:896. + + The final output will therefore be of shape (8, 3, 224, 224), where 2x4 + patches are coming from the resizing and chunking. """ def __init__(self, size: int = IMAGE_RES) -> None: self.size = size + logger.info(f"VariableSizeImageTransform size: {self.size}") self.to_tensor = tv.ToTensor() self._mean = (0.48145466, 0.4578275, 0.40821073) self._std = (0.26862954, 0.26130258, 0.27577711) @@ -1380,121 +1397,118 @@ def __init__(self, size: int = IMAGE_RES) -> None: std=self._std, inplace=True, ) + self.resample = tv.InterpolationMode.BILINEAR @staticmethod - def _factors(n: int): - """Return all factors of a number.""" - return set( - reduce( - list.__add__, - ([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0), - ) - ) + def get_factors(n: int) -> Set[int]: + """ + Calculate all factors of a given number, i.e. a dividor that leaves + no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}. + + Args: + n (int): The number to find factors for. - def _find_supported_aspect_ratios(self, num_chunks: int): + Returns: + set: A set containing all factors of the number. + """ + factors_set = set() + + for i in range(1, int(n**0.5) + 1): + if n % i == 0: + factors_set.add(i) + factors_set.add(n // i) + return factors_set + + def find_supported_resolutions( + self, max_num_chunks: int, patch_size: int + ) -> torch.Tensor: """ - This function computes all the allowed aspect ratios for a fixed - number of input chunks. - For example, with `num_chunks=5`, it will return: - { - 0.2: [(1, 5)], - 5.0: [(5, 1)], + Computes all of the allowed resoltuions for a fixed number of chunks + and patch_size. Useful for when dividing an image into chunks. + + Args: + max_num_chunks (int): Maximum number of chunks for processing. + patch_size (int): Size of the side of the patch. + + Returns: + torch.Tensor: List of possible resolutions as tuples (height, width). + + Example: + >>> max_num_chunks = 5 + >>> patch_size = 224 + >>> find_supported_resolutions(max_num_chunks, patch_size) + tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672), + (672, 224), (224, 448), (448, 224)]) + + Given max_num_chunks=4, patch_size=224, it will create a dictionary: + { 0.25: [(1, 4)], 1.0: [(2, 2), (1, 1)], 4.0: [(4, 1)], - 0.3333333333333333: [(1, 3)], + 0.33: [(1, 3)], 3.0: [(3, 1)], 0.5: [(1, 2)], 2.0: [(2, 1)] - } - """ - asp_dict = {} - for chunk_size in range(num_chunks, 0, -1): - _factors = sorted(VariableSizeImageTransform._factors(chunk_size)) - _asp_ratios = [(x, chunk_size // x) for x in _factors] - for ratio in _asp_ratios: - k = ratio[0] / ratio[1] - if k not in asp_dict: - asp_dict[k] = [ratio] - else: - asp_dict[k].append(ratio) - return asp_dict + } - def _find_closest_aspect_ratio( - self, num_chunks: int, img_width: int, img_height: int - ) -> Tuple: - """ - Given an image width, height and target number of chunks - this function will find the closest supported aspect ratio. + and return the resolutions multiplied by the patch_size: + [(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)] """ - tgt_ar = img_width / img_height - asp_dict = self._find_supported_aspect_ratios(num_chunks) - cl_d, cl_p = 1e23, None - if tgt_ar >= 1: - cl_p = min( - [k for k in asp_dict.keys() if k <= tgt_ar], - key=lambda x: abs(x - tgt_ar), - ) - v = asp_dict[cl_p] - # select width - widths = [(idx, self.size * vv[0]) for idx, vv in enumerate(v)] - tgt_idx = max(widths, key=lambda x: x[1])[0] - else: - cl_p = min( - [k for k in asp_dict.keys() if k > tgt_ar], - key=lambda x: abs(1 / x - 1 / tgt_ar), - ) - v = asp_dict[cl_p] - # select height - heights = [(idx, self.size * vv[1]) for idx, vv in enumerate(v)] - tgt_idx = max(heights, key=lambda x: x[1])[0] - out = v[tgt_idx] - return out + asp_dict = collections.defaultdict(list) + for chunk_size in range(max_num_chunks, 0, -1): + _factors = sorted(self.get_factors(chunk_size)) + _asp_ratios = [(factor, chunk_size // factor) for factor in _factors] + for height, width in _asp_ratios: + ratio_float = height / width + asp_dict[ratio_float].append((height, width)) + + # get the resolutions multiplied by the patch_size + possible_resolutions = [] + for key, value in asp_dict.items(): + for height, depth in value: + possible_resolutions.append((height * patch_size, depth * patch_size)) + + return possible_resolutions - def _resize( - self, image: Image.Image, target_width: int, target_height: int - ) -> Image.Image: - # Resize longer edge to given size. - w, h = image.size - scale = w / h + @staticmethod + def get_max_res_without_distortion( + image_size: Tuple[int, int], + target_size: Tuple[int, int], + ) -> Tuple[int, int]: + """ + Determines the maximum resolution to which an image can be resized to without distorting its + aspect ratio, based on the target resolution. - if scale > 1.0: - # width > height - new_w = target_width - new_h = math.floor(new_w / scale) - else: - # height >= width - new_h = target_height - new_w = math.floor(new_h * scale) + Args: + image_size (Tuple[int, int]): The original resolution of the image (height, width). + target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width). + Returns: + Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized. + Example: + >>> _get_max_res_without_distortion([200, 300], target_size = [450, 200]) + (134, 200) + >>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300]) + (450, 338) + """ - image = F.resize(image, (new_h, new_w)) - return image + original_width, original_height = image_size + target_width, target_height = target_size - def _resize_max_side_to_size( - self, - image: Image.Image, - ) -> Image.Image: - # Resize longer edge to given size. - w, h = image.size - scale = w / h + scale_w = target_width / original_width + scale_h = target_height / original_height - if scale > 1.0: - # width > height - new_w = max(self.size, w) - new_h = math.floor(new_w / scale) + if scale_w < scale_h: + new_width = target_width + new_height = min(math.floor(original_height * scale_w), target_height) else: - # height >= width - new_h = max(self.size, h) - new_w = math.floor(new_h * scale) + new_height = target_height + new_width = min(math.floor(original_width * scale_h), target_width) - image = F.resize(image, (new_h, new_w)) - return image + return new_width, new_height - def _pad(self, image: Image.Image, new_width: int, new_height: int) -> Image.Image: - mean_per_channel = tuple( - np.clip(np.array(image).mean(axis=(0, 1)), 0, 255).astype(np.uint8) - ) - new_im = Image.new(mode="RGB", size=(new_height, new_width), color=(0, 0, 0)) # type: ignore + def _pad(self, image: Image.Image, target_size) -> Image.Image: + new_width, new_height = target_size + new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore new_im.paste(image) return new_im @@ -1508,72 +1522,224 @@ def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor: image = image.view(ncw * nch, num_channels, height // nch, width // ncw) return image - def _fit_image_to_canvas( - self, num_chunks: int, img_width: int, img_height: int - ) -> Any: + def resize_without_distortion( + self, + image: torch.Tensor, + target_size: Tuple[int, int], + max_upscaling_size: Optional[int], + ) -> torch.Tensor: """ - Given an image width, height and target number of chunks this function will see if the image - can be fit into any of the canvases that can be build from arranging the tiles in a grid. - If the image can be fit onto several canvases, it will return the canvas where the shorter edge - of the image will be largest. + Used to resize an image to target_resolution, without distortion. + + If target_size requires upscaling the image, the user can set max_upscaling_size to + limit the upscaling to a maximum size. In this case, since we rescale without distortion, + modifying target_size works as a boundary for the image's largest side. + + Args: + resample (str): Resampling method used when resizing images. + Supports "nearest", "nearest_exact", "bilinear", "bicubic". + max_upscaling_size (int): The maximum size to upscale the image to. + If None, there is no limit. + Examples: + >>> target_size = (1000, 1200) + >>> max_upscaling_size = 600 + >>> image_size = (400, 200) + >>> resize_without_distortion(image_size, target_size, max_upscaling_size) + (600, 300) # new_size_without_distortion + + >>> target_size = (1000, 1200) + >>> max_upscaling_size = 600 + >>> image_size = (2000, 200) + >>> resize_without_distortion(image_size, target_size, max_upscaling_size) + (1000, 100) # new_size_without_distortion + + >>> target_size = (1000, 1200) + >>> max_upscaling_size = 2000 + >>> image_size = (400, 200) + >>> resize_without_distortion(image_size, target_size, max_upscaling_size) + (1000, 500) # new_size_without_distortion + + >>> target_size = (1000, 1200) + >>> max_upscaling_size = None + >>> image_size = (400, 200) + >>> resize_without_distortion(image_size, target_size, max_upscaling_size) + (1000, 500) # new_size_without_distortion + """ + + image_width, image_height = image.size + image_size = (image_width, image_height) + + # If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size + if max_upscaling_size is not None: + new_target_width = min(max(image_width, max_upscaling_size), target_size[0]) + new_target_height = min( + max(image_height, max_upscaling_size), target_size[1] + ) + target_size = (new_target_width, new_target_height) + + # resize to target_size while preserving aspect ratio + new_size_without_distortion = self.get_max_res_without_distortion( + image_size, target_size + ) + + image = F.resize( + image, + (new_size_without_distortion[1], new_size_without_distortion[0]), + interpolation=self.resample, + ) + + return image + + def get_best_fit( + self, + image_size: Tuple[int, int], + possible_resolutions: torch.Tensor, + resize_to_max_canvas: bool = False, + ) -> Tuple[int, int]: + """ + Determines the best canvas possible from a list of possible resolutions to, without distortion, + resize an image to. + + For each possible resolution, calculates the scaling factors for + width and height, and selects the smallest one, which is the limiting side. + E.g. to match the canvas you can upscale height by 2x, and width by 1.5x, + therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5. + + If upscaling is possible (any of the scaling factors is greater than 1), + then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True. + + If upscaling is not possible, then picks the largest scaling factor <= 1, i.e. + reduce downscaling as much as possible. + + If there are multiple resolutions with the same max scale, we pick the one with the lowest area, + to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter + has more padding. + + Args: + image_size (Tuple[int, int]): A tuple containing the height and width of the image. + possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each + row represents a possible resolution (height, width). + use_max_upscaling (bool): If True, will return the largest upscaling resolution. + + Returns: + List[int]: The best resolution [height, width] for the given image. + + Example: + >>> image_size = (200, 300) + >>> possible_resolutions = torch.tensor([[224, 672], + ... [672, 224], + ... [224, 448], + ... [448, 224], + ... [224, 224]]) + >>> _get_smallest_upscaling_possibility(image_size, possible_resolutions) + [224, 448] + + We have: + scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467]) + scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200]) + scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467]) + Only one of the scales > 1: + upscaling_possible = tensor([1.1200, 1.1200]) + smallest_rescale = tensor(1.1200) + So we pick the resolution with the smallest smallest area: + areas = tensor([150528, 100352]) # [672, 224], [224, 448] + optimal_canvas = tensor([224, 448]) + """ + + original_width, original_height = image_size + + # get all possible resolutions heights/widths + target_widths, target_heights = ( + possible_resolutions[:, 0], + possible_resolutions[:, 1], + ) + + # get scaling factors to resize the image without distortion + scale_w = target_widths / original_width + scale_h = target_heights / original_height + + # get the min scale between width and height (limiting side -> no distortion) + scales = torch.where(scale_w > scale_h, scale_h, scale_w) + + # filter only scales that allow upscaling + upscaling_options = scales[scales >= 1] + if len(upscaling_options) > 0: + if resize_to_max_canvas: + selected_scale = torch.max(upscaling_options) + else: + selected_scale = torch.min(upscaling_options) + else: + # no upscaling possible, + # get the minimum downscaling (max scale for scales<1) + downscaling_options = scales[scales < 1] + selected_scale = torch.max(downscaling_options) + + # get all resolutions that support this scaling factor, + # e.g. you can upscale to 224x224, 224x448, 224x672 without distortion + chosen_canvas = possible_resolutions[scales == selected_scale] + + # if there are multiple resolutions, + # get the one with minimum area to reduce padding + if len(chosen_canvas) > 1: + areas = chosen_canvas[:, 0] * chosen_canvas[:, 1] + optimal_idx = torch.argmin(areas) + optimal_canvas = chosen_canvas[optimal_idx] + else: + optimal_canvas = chosen_canvas[0] + + return tuple(optimal_canvas.tolist()) + + def __call__( + self, + image: Image.Image, + max_num_chunks: int, + normalize_img: bool = True, + resize_to_max_canvas: bool = False, + ) -> Tuple[Any, Any]: + """ + Args: + image (PIL.Image): Image to be resized. + max_num_chunks (int): Maximum number of chunks to split the image into. + normalize_img (bool): Whether to normalize the image. + resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size. + If True, picks the canvas the allows the largest resizing without distortion. + If False, downsample as little as possible, including no resizing at all, + but never upsample, unless the image is smaller than the patch size. """ - # Initialize the optimal canvas to None. If no canvas is found where image fits, function returns None. - optimal_canvas = None - optimal_image_width_height = None - - scale = img_width / img_height - - # Gather all potential supported image resolutions and iterate through them to find best match - potential_arrangements = [ - item - for sublist in self._find_supported_aspect_ratios(num_chunks).values() - for item in sublist - ] - current_gap = 1e23 - for n_w, n_h in potential_arrangements: - # Compute the canvas size - canvas_width, canvas_height = n_w * self.size, n_h * self.size - - # Check if image can fit into the canvas without downsampling - if canvas_width >= img_width and canvas_height >= img_height: - # If we did not find a good canvas yet, we will use the current one - if optimal_canvas is None: - # Set optimal canvas and determine the actual image height and width in the canvas with aspect ratio preserving resampling - optimal_canvas = (n_w, n_h) - optimal_image_width_height = (n_w * self.size, n_h * self.size) - else: - # Find closest fit based on gap - image_width_height = (n_w * self.size, n_h * self.size) - gap = abs(img_width - image_width_height[0]) + abs( - img_height - image_width_height[1] - ) - if gap < current_gap: - # If the gap is smaller than the previous one, we will update our optimal canvas and image width height - optimal_canvas = (n_w, n_h) - optimal_image_width_height = image_width_height - current_gap = gap - return optimal_canvas - - def __call__(self, image: Image.Image, max_num_chunks: int) -> Tuple[Any, Any]: assert max_num_chunks > 0 assert isinstance(image, Image.Image), type(image) w, h = image.size - # Check if the image can be fit to the canvas without downsampling - ar = self._fit_image_to_canvas( - num_chunks=max_num_chunks, img_width=w, img_height=h + + possible_resolutions = self.find_supported_resolutions( + max_num_chunks=max_num_chunks, patch_size=self.size ) - if ar is None: - # If we did not find a canvas, we have to find the closest aspect ratio and downsample the image - ar = self._find_closest_aspect_ratio( - num_chunks=max_num_chunks, img_width=w, img_height=h - ) - image = self._resize(image, ar[0] * self.size, ar[1] * self.size) - else: - image = self._resize_max_side_to_size(image) - image = self._pad(image, ar[1] * self.size, ar[0] * self.size) + possible_resolutions = torch.tensor(possible_resolutions) + + best_resolution = self.get_best_fit( + image_size=(w, h), + possible_resolutions=possible_resolutions, + resize_to_max_canvas=resize_to_max_canvas, + ) + + max_upscaling_size = None if resize_to_max_canvas else self.size + image = self.resize_without_distortion( + image, best_resolution, max_upscaling_size + ) + image = self._pad(image, best_resolution) + image = self.to_tensor(image) - image = self.normalize(image) - image = self._split(image, ar[0], ar[1]) # type: ignore + + if normalize_img: + image = self.normalize(image) + + ratio_w, ratio_h = ( + best_resolution[0] // self.size, + best_resolution[1] // self.size, + ) + + image = self._split(image, ratio_w, ratio_h) # type: ignore + + ar = (ratio_h, ratio_w) return image, ar @MULTIMODAL_REGISTRY.register_image_input_mapper() From 3e1d249d45ec61a9677dfa4c51c6bf7f03490888 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 15 Sep 2024 00:45:18 -0700 Subject: [PATCH 23/75] support text-only input --- vllm/model_executor/models/llamavl.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py index be59bfc21988..cf207b71cbd8 100644 --- a/vllm/model_executor/models/llamavl.py +++ b/vllm/model_executor/models/llamavl.py @@ -1066,6 +1066,7 @@ def forward( kv_cache: torch.LongTensor, attn_metadata: AttentionMetadata, vision_hidden_states: Optional[torch.Tensor], + run_xattn_mask: torch.Tensor, ) -> torch.Tensor: _attn_out = self.attention( decoder_hidden_states=self.attention_norm(x), @@ -1076,10 +1077,11 @@ def forward( attn_metadata=attn_metadata, encoder_hidden_states=vision_hidden_states, ) - h = x + self.gate_attn.tanh() * _attn_out + # import pdb; pdb.set_trace() + h = x + self.gate_attn.tanh() * _attn_out * run_xattn_mask _ffn = self.feed_forward(self.ffn_norm(h)) # _ffn = full_text_row_masked_out_mask * _ffn # type: ignore - h = h + self.gate_ffwd.tanh() * _ffn * float(not self.no_ffn) + h = h + self.gate_ffwd.tanh() * _ffn * float(not self.no_ffn) * run_xattn_mask return h @@ -1279,6 +1281,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, vision_hidden_states: Optional[torch.Tensor], + run_xattn_mask: torch.Tensor, ): # assert self.cache_is_setup, "Please set up cache before calling forward" # mask = self.mask_cache.index_select(2, positions) @@ -1297,6 +1300,7 @@ def forward( kv_cache=kv_caches[idx], attn_metadata=attn_metadata, vision_hidden_states=vision_hidden_states, + run_xattn_mask=run_xattn_mask, ) # check(h, f"layer_{idx}_xh_{step_name}.pt") h = layer( @@ -2021,10 +2025,12 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> torch.Tensor: + if attn_metadata.num_prefill_tokens > 0 and attn_metadata.num_decode_tokens > 0: + raise ValueError("Chunk prefill not supported") image_inputs = self._parse_and_validate_image_input(**kwargs) if image_inputs is None: cross_attention_masks = None - full_text_row_masked_out_mask = None + run_xattn_mask = (attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).cuda() xattn_caches = None vision_tokens = None else: @@ -2044,6 +2050,13 @@ def forward( start_pos = end_pos vision_tokens = vision_tokens_flat + run_xattn_mask = torch.ones((attn_metadata.num_prefill_tokens, 1), dtype=torch.bool, device=vision_tokens.device) + start_pos = 0 + for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens_tensor, attn_metadata.encoder_seq_lens): + if encoder_seq_len == 0: + run_xattn_mask[start_pos:start_pos+seq_len] = False + start_pos += seq_len + # batch_masks = [] # # TODO: get the sequence of each query without hack? 1) better attn metadata 2) better input processor to create vision mask during preprocess # # assert isinstance(attn_metadata, PagedAttentionMetadata) @@ -2099,6 +2112,7 @@ def forward( vision_hidden_states=vision_tokens, kv_caches=kv_caches, attn_metadata=attn_metadata, + run_xattn_mask=run_xattn_mask, ) # if positions.numel() == 1 and positions.item() == 20: # exit(0) From cac19d5ab06fc6d2ad42f22374c7f020ca6ef1c7 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 15 Sep 2024 19:20:11 -0700 Subject: [PATCH 24/75] enable profile run --- tests/models/test_llamavl.py | 1 + vllm/model_executor/models/llamavl.py | 33 +++++++++++++++++++++------ vllm/worker/worker.py | 9 ++++---- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/tests/models/test_llamavl.py b/tests/models/test_llamavl.py index 5283d52e802c..c221b15e97f0 100644 --- a/tests/models/test_llamavl.py +++ b/tests/models/test_llamavl.py @@ -30,6 +30,7 @@ llm = LLM(model=f"{checkpoint_dir}/Meta-Llama-3.2-{size}-Vision-Early/", enforce_eager=True, limit_mm_per_prompt={"image": 2}, + max_num_seqs=16, tensor_parallel_size=1, # load_format="dummy" ) diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py index cf207b71cbd8..84084651c762 100644 --- a/vllm/model_executor/models/llamavl.py +++ b/vllm/model_executor/models/llamavl.py @@ -1,3 +1,4 @@ +from array import array from dataclasses import dataclass from functools import partial import itertools @@ -40,6 +41,7 @@ from vllm.transformers_utils.multimodal_processors.llamavl import LlamaVLImageProcessor import vllm.distributed.parallel_state as ps +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData step_name = "prefill" pt_dir = "/home/eecs/zhang-chen/MultiModal/scripts/" @@ -57,6 +59,7 @@ def check(tensor, file_name): pass logger = init_logger(__name__) MP_SCALE = 8 IMAGE_RES = 224 +LLAMA_IMAGE_TOKEN_ID = 128256 class LlamaImagePixelInputs(TypedDict): type: Literal["pixel_values"] @@ -87,19 +90,35 @@ def input_processor_for_llamavl(ctx: InputContext, llm_inputs: LLMInputs): token_per_chunk = (ctx.model_config.hf_config.vision_chunk_size // 14) ** 2 + 1 num_tokens = num_chunks * token_per_chunk llm_inputs["encoder_prompt"] = "<|image|>" * num_tokens - llm_inputs["encoder_prompt_token_ids"] = [128256] * num_tokens + llm_inputs["encoder_prompt_token_ids"] = [LLAMA_IMAGE_TOKEN_ID] * num_tokens assert "decoder_multi_modal_data" not in llm_inputs, "multi-modal data should be put in encoder message of LLaMA Vision" return llm_inputs + +def dummy_seq_data( + seq_len: int, + num_images: int +): + assert seq_len >= num_images, "seq_len should be greater than or equal to num_images" + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [LLAMA_IMAGE_TOKEN_ID]) * num_images + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - num_images) + return SequenceData(token_ids) + + +def dummy_image( + num_images: int, +): + width = height = 512 + image = Image.new("RGB", (width, height), color=0) + return {"image": image if num_images == 1 else [image] * num_images} + def dummy_data_for_llamavl(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): - # seq_len: 16 - # mm_counts: {'image': 2, 'audio': 1} - hf_config = ctx.model_config.hf_config - token_per_chunk = (hf_config.vision_chunk_size // 14) ** 2 + 1 - num_tokens = hf_config.max_num_image * hf_config.vision_max_num_chunks * token_per_chunk - import pdb; pdb.set_trace() + num_images = mm_counts["image"] + return dummy_seq_data(seq_len, num_images), dummy_image(num_images) def get_max_llama_image_tokens(ctx: InputContext) -> int: hf_config = ctx.model_config.hf_config diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index ba57a9312443..7907f322cdcf 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -220,14 +220,13 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Execute a forward pass with dummy inputs to profile the memory usage # of the model. - # self.model_runner.profile_run() + self.model_runner.profile_run() # # Calculate the number of blocks that can be allocated with the # # profiled peak memory. - # torch.cuda.synchronize() - # free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() - free_gpu_memory = 40 * 1024 * 1024 * 1024 - total_gpu_memory = 80 * 1024 * 1024 * 1024 + torch.cuda.synchronize() + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. peak_memory = self.init_gpu_memory - free_gpu_memory From 7e5eadd89867b67722dbab22e753441fbee36642 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 17 Sep 2024 13:17:40 -0700 Subject: [PATCH 25/75] copy mllama from transformer --- vllm/entrypoints/chat_utils.py | 2 +- vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/mllama.py | 1849 ++++++++++++++++++++ vllm/transformers_utils/image_processor.py | 2 +- vllm/transformers_utils/tokenizer.py | 2 +- 5 files changed, 1853 insertions(+), 3 deletions(-) create mode 100644 vllm/model_executor/models/mllama.py diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 6a075cd08fe0..d243718e7ca9 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -159,7 +159,7 @@ def _placeholder_str(self, modality: ModalityStr, hf_config.image_token_index) if model_type in ("chameleon", "internvl_chat"): return "" - if model_type == "llamavl": + if model_type == "llamavl" or model_type == "mllama": return "<|image|>" if model_type == "qwen2_vl": return "<|vision_start|><|image_pad|><|vision_end|>" diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 9e7df18060da..94d7854f6ae7 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -97,6 +97,7 @@ "Qwen2VLForConditionalGeneration"), "UltravoxModel": ("ultravox", "UltravoxModel"), "LlamaVLForCausalLM": ("llamavl", "LlamaVLForCausalLM"), + "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), } _CONDITIONAL_GENERATION_MODELS = { "BartModel": ("bart", "BartForConditionalGeneration"), diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py new file mode 100644 index 000000000000..12250048f47f --- /dev/null +++ b/vllm/model_executor/models/mllama.py @@ -0,0 +1,1849 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mllama model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from transformers import PreTrainedModel +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.models.mllama.configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MllamaConfig" + + +# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position +def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + min_dtype: float, + cache_position: torch.Tensor, + batch_size: int, +): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + min_dtype (`float`): + The minimum value representable with the dtype `dtype`. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + +def _prepare_cross_attention_mask( + cross_attention_mask: torch.Tensor, + past_key_values: Cache, + num_vision_tokens: int, + cross_attention_states: torch.Tensor, + cross_attention_layers: List[int], + device: str, + dtype: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + if cross_attention_mask is None: + # should we raise error or prepare a full attn mask with all ones? + return None, None + else: + # reshape so it can be used by attn module + batch_size, text_total_length, *_ = cross_attention_mask.shape + cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3) + cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) + cross_attention_mask = cross_attention_mask.unsqueeze(1) + + # invert the mask + inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) + cross_attention_mask = inverted_cross_attn_mask.masked_fill( + inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min + ) + + # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's + # last dimension contains negative infinity values, otherwise it's 1 + negative_inf_value = torch.finfo(dtype).min + full_text_row_masked_out_mask = ( + (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] + ) + cross_attention_mask *= full_text_row_masked_out_mask + + # In case we receive a new image but already have previous cross-attention key/values in cache, + # then we need to extend the attention-mask and add previous images' lengths + if ( + past_key_values is not None + and cross_attention_states is not None + and past_key_values.get_seq_length(cross_attention_layers[0]) != 0 + ): + # make all zeros mask for cross-attn-mask from previuos cached hidden_states, all zeros right? + # i.e. extend current cross-attn-mask on image-seq-length dimension to account for past_seen_tokens + past_cross_attn_kv_length = past_key_values.get_seq_length(cross_attention_layers[0]) + past_cross_attn_mask = torch.zeros( + (*cross_attention_mask.shape[:-1], past_cross_attn_kv_length), dtype=dtype, device=device + ) + # concatenate both on image-seq-length dimension + cross_attention_mask = torch.cat([past_cross_attn_mask, cross_attention_mask], dim=-1) + + return cross_attention_mask, full_text_row_masked_out_mask + + +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) + attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1) + attention_mask = attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + +class MllamaPrecomputedAspectRatioEmbedding(nn.Module): + def __init__(self, config: MllamaVisionConfig, is_gated: bool = True): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.is_gated = is_gated + + self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1, self.max_num_tiles * self.hidden_size) + if is_gated: + self.gate = nn.Parameter(torch.zeros(1)) + + def forward(self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + embeddings = self.embedding(aspect_ratio_ids) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) + + if self.is_gated: + embeddings = embeddings * self.gate.tanh() + + hidden_state = hidden_state + embeddings + return hidden_state + + +class MllamaPrecomputedPositionEmbedding(nn.Module): + def __init__(self, config: MllamaVisionConfig): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.num_patches = (config.image_size // config.patch_size) ** 2 + 1 + self.hidden_size = config.hidden_size + self.scale = config.hidden_size**-0.5 + + self.gate = nn.Parameter(torch.zeros(1)) + + # position embedding + position_embedding = torch.randn(self.num_patches, self.hidden_size) + self.embedding = nn.Parameter(self.scale * position_embedding) + + # tile position embedding + self.tile_embedding = nn.Embedding( + self.max_aspect_ratio_id + 1, self.max_num_tiles * self.num_patches * self.hidden_size + ) + + def forward(self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + # position embeddings + gated_position_embedding = (1 - self.gate.tanh()) * self.embedding + hidden_state = hidden_state + gated_position_embedding.view(1, 1, self.num_patches, self.hidden_size) + + # precomputed tile position embeddings + tile_position_embedding = self.tile_embedding(aspect_ratio_ids) + batch_size = hidden_state.shape[0] + tile_position_embedding = tile_position_embedding.reshape( + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size + ) + gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding + hidden_state = hidden_state + gated_tile_position_embedding + + return hidden_state + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision +class MllamaVisionMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class MllamaVisionSdpaAttention(nn.Module): + def __init__(self, config: MllamaVisionConfig): + super().__init__() + + self.embed_dim = config.hidden_size + self.num_heads = config.attention_heads + self.head_dim = config.hidden_size // config.attention_heads + + self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=False) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = None, + ) -> torch.Tensor: + query = self.q_proj(hidden_state) + key = self.k_proj(hidden_state) + value = self.v_proj(hidden_state) + + batch_size, q_seq_len, _ = query.shape + _, kv_seq_len, _ = key.shape + + query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim) + key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) + value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_seq_len, -1) + + output = self.o_proj(attn_output) + + return output + + +class MllamaVisionEncoderLayer(nn.Module): + def __init__(self, config, is_gated: bool = False): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.attention_heads + self.is_gated = is_gated + self.intermediate_size = config.intermediate_size + + self.self_attn = MllamaVisionSdpaAttention(config) + self.mlp = MllamaVisionMLP(config) + + self.input_layernorm = nn.LayerNorm(self.hidden_size) + self.post_attention_layernorm = nn.LayerNorm(self.hidden_size) + + # there used to be an if else here, no code path + if is_gated: + self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4) + self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = None, + ): + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask) + gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() + hidden_state = residual + gate_attn * hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() + hidden_state = residual + gate_ffn * hidden_state + + return hidden_state + + +class MllamaVisionEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a + [`MllamaEncoderLayer`]. + + Args: + config: MllamaConfig + """ + + def __init__(self, config: MllamaVisionConfig, num_layers=32, is_gated=False): + super().__init__() + self.config = config + self.layers = nn.ModuleList([MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers)]) + self.gradient_checkpointing = False + self.config = config + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + hidden_states = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + hidden_states = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + # SDPA never returns attn weights, so the kwarg isn't used at all + # TODO: fix this + # if output_attentions: + # all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class MllamaVisionModel(PreTrainedModel): + config_class = MllamaVisionConfig + base_model_prefix = "vision_encoder" + _no_split_modules = ["MllamaVisionSdpaAttention"] + _supports_sdpa = True + + def __init__(self, config: MllamaVisionConfig): + super().__init__(config) + self.image_size = config.image_size + self.patch_size = config.patch_size + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.in_channels = config.in_channels + self.intermediate_layers_indices = config.intermediate_layers_indices + + self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = nn.Conv2d( + in_channels=config.in_channels, + out_channels=self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + bias=False, + ) + + self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size)) + self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(config) + + self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) + self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) + + # layer norms + self.layernorm_pre = nn.LayerNorm(self.hidden_size) + self.layernorm_post = nn.LayerNorm(self.hidden_size) + + # encoders + self.transformer = MllamaVisionEncoder(config, config.num_hidden_layers, is_gated=False) + self.global_transformer = MllamaVisionEncoder(config, config.num_global_layers, is_gated=True) + + def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward( + self, pixel_values: torch.Tensor, aspect_ratio_ids: torch.Tensor, attention_mask: torch.Tensor + ) -> torch.Tensor: + batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape + + pixel_values = pixel_values.reshape(batch_size * num_concurrent_media * num_tiles, num_channels, height, width) + aspect_ratio_ids = aspect_ratio_ids.reshape(batch_size * num_concurrent_media, -1) + + # patch embedding + patch_embeds = self.patch_embedding(pixel_values.to(self.dtype).to(self.device)) + hidden_state = patch_embeds.flatten(2).transpose(1, 2) + + # tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, -1, dim) + hidden_state = self.pre_tile_positional_embedding(hidden_state, aspect_ratio_ids) + + # apply cls token + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media * num_tiles, num_patches, dim) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # apply position embeddings + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, num_patches, dim) + hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) + + # apply encoder + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = (0, 0, 0, num_padding_patches) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + if attention_mask is not None: + attention_mask = attention_mask.reshape(batch_size * num_concurrent_media, -1) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.dtype, + ) + + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) + output = self.transformer( + hidden_state, + attention_mask=attention_mask, + output_hidden_states=True, + ) + hidden_state, all_intermediate_hidden_states = output[0], output[1] + intermediate_hidden_states = [ + hidden_state + for idx, hidden_state in enumerate(all_intermediate_hidden_states) + if idx in self.intermediate_layers_indices + ] + intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1) + + # apply global encoder + hidden_state = self.layernorm_post(hidden_state) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim + ) + hidden_state = self.post_tile_positional_embedding(hidden_state, aspect_ratio_ids) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles * (num_patches + num_padding_patches), dim + ) + hidden_state = self.global_transformer(hidden_state, attention_mask=attention_mask)[0] + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim + ) + hidden_state = hidden_state[:, :, :slice_index] + + # adding intermediate layer outputs + hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, num_tiles, num_patches, dim) + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, -1 + ) + intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1 + ) + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) + return hidden_state + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText +class MllamaTextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MllamaTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class MllamaTextCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Optional[MllamaTextConfig] = None, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.config = config + self.num_heads = self.config.num_attention_heads + self.num_key_value_heads = self.config.num_key_value_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.head_dim = config.hidden_size // self.num_heads + self.layer_idx = layer_idx + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query_states = self.q_norm(query_states) + + if cross_attention_states is not None: + key_states = self.k_proj(cross_attention_states) + value_states = self.v_proj(cross_attention_states) + key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + key_states = self.k_norm(key_states) + if past_key_value is not None: + # if we have a new image + new tokens, we only computed key_states on that new image + # we still update the cross key states, past_image, new_image. And use it! + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + elif past_key_value.get_seq_length(self.layer_idx) != 0: + key_states, value_states = ( + past_key_value.key_cache[self.layer_idx], + past_key_value.value_cache[self.layer_idx], + ) + else: + raise ValueError( + "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" + ) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class MllamaTextSelfAttention(nn.Module): + def __init__( + self, + config: Optional[MllamaTextConfig] = None, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.config = config + self.num_heads = config.num_attention_heads + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.hidden_size // self.num_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.layer_idx = layer_idx + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + position_embeddings: torch.Tensor, + output_attentions: bool = False, + use_cache: bool = False, + past_key_value=None, + cache_position=None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + return attn_output, None, past_key_value + + +MLLAMA_TEXT_CROSS_ATTENTION_CLASSES = {"eager": MllamaTextCrossAttention, "sdpa": MllamaTextCrossAttention} +MLLAMA_TEXT_ATTENTION_CLASSES = {"eager": MllamaTextSelfAttention, "sdpa": MllamaTextSelfAttention} + + +# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText +class MllamaTextMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_activation] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LlamaDecoder->MllamaSelfAttentionDecoder, Llama->MllamaText, LLAMA->MLLAMA_TEXT +class MllamaSelfAttentionDecoderLayer(nn.Module): + def __init__(self, config: MllamaTextConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MLLAMA_TEXT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = MllamaTextMLP(config) + self.input_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Ignore copy + self.layer_idx = layer_idx + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class MllamaCrossAttentionDecoderLayer(torch.nn.Module): + """Cross-attention transformer block with tanh-gated attention and feedforward.""" + + def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: + super().__init__() + self.layer_idx = layer_idx + self.cross_attn = MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + + self.input_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) + + self.mlp = MllamaTextMLP(config) + self.post_attention_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + cross_attention_mask: torch.Tensor, + attention_mask: torch.Tensor, + full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states, attn_weights, past_key_value = self.cross_attn( + hidden_states=hidden_states, + attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + if full_text_row_masked_out_mask is not None: + hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore + hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + return outputs + + +class MllamaRotaryEmbedding(nn.Module): + def __init__( + self, + config: Optional[MllamaTextConfig] = None, + device=None, + ): + super().__init__() + self.rope_type = config.rope_scaling["rope_type"] + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +MLLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MllamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +class MllamaPreTrainedModel(PreTrainedModel): + config_class = MllamaConfig + base_model_prefix = "model" + _no_split_modules = ["MllamaSdpaCrossAttention"] + _supports_cache_class = True + _supports_static_cache = True + _supports_sdpa = True + _supports_quantized_cache = True + + def _init_weights(self, module): + std = self.config.get_text_config().initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.Parameter): + module.data.normal_(mean=0.0, std=std) + + +class MllamaTextModel(MllamaPreTrainedModel): + config_class = MllamaTextConfig + base_model_prefix = "model" + _no_split_modules = ["MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"] + + def __init__(self, config: MllamaTextConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size + 8, config.hidden_size, self.padding_idx) + self.cross_attention_layers = config.cross_attention_layers + + layers = [] + for layer_idx in range(config.num_hidden_layers): + if layer_idx in self.cross_attention_layers: + layers.append(MllamaCrossAttentionDecoderLayer(config, layer_idx)) + else: + layers.append(MllamaSelfAttentionDecoderLayer(config, layer_idx)) + + self.layers = nn.ModuleList(layers) + self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = MllamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.FloatTensor] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if ( + idx in self.cross_attention_layers + and cross_attention_states is None + and ( + past_key_values is None + or (past_key_values is not None and past_key_values.get_seq_length(idx) == 0) + ) + ): + continue + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + cross_attention_states, + cross_attention_mask, + causal_mask, + full_text_row_masked_out_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + attention_mask=causal_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + # TODO: we have only SDPA currently and there's a bug when attn-bias is passed. Need to add eager attn and return the line + # self.config._attn_implementation == "sdpa" and + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class MllamaForCausalLM(MllamaPreTrainedModel): + config_class = MllamaTextConfig + base_model_prefix = "language_model" + _no_split_modules = ["MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"] + + def __init__(self, config): + super().__init__(config) + self.vocab_size = config.vocab_size + self.model = MllamaTextModel._from_config(config, attn_implementation=config._attn_implementation) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + cross_attention_states: Optional[torch.LongTensor] = None, + cross_attention_mask: Optional[torch.LongTensor] = None, + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MllamaForCausalLM + + >>> model = MllamaForCausalLM.from_pretrained("meta-llama/Llama-3.2-11b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-11b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + cross_attention_states=cross_attention_states, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + dtype = self.lm_head.weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + +MLLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, max_num_images, max_num_tiles, channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`MllamaImageProcessor.__call__`] for details ([]`MllamaProcessor`] uses + [`MllamaImageProcessor`] for processing images). + aspect_ratio_mask: Optional[List[List[int]]] = None, # TODO + aspect_ratio_ids: Optional[torch.Tensor] = None, # TODO + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + cross_attention_mask: Optional[torch.Tensor] = None, # TODO + cross_attention_states: Optional[torch.Tensor] = None, # TODO + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + """The MLLAMA model which consists of a vision backbone and a language model.""", + MLLAMA_START_DOCSTRING, +) +class MllamaForConditionalGeneration(MllamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.vocab_size = config.text_config.vocab_size + self.hidden_size = self.config.text_config.hidden_size + self.max_num_tiles = config.vision_config.max_num_tiles + self.vision_output_dim = config.vision_config.vision_output_dim + self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + + self.vision_model = MllamaVisionModel._from_config( + config.vision_config, attn_implementation=config._attn_implementation + ) + self.language_model = MllamaForCausalLM._from_config( + config.text_config, attn_implementation=config._attn_implementation + ) + self.multi_modal_projector = nn.Linear( + config.vision_config.vision_output_dim, + config.text_config.hidden_size, + bias=True, + ) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.language_model.set_output_embeddings(new_embeddings) + + def set_decoder(self, decoder): + self.language_model.set_decoder(decoder) + + def get_decoder(self): + return self.language_model.get_decoder() + + def tie_weights(self): + return self.language_model.tie_weights() + + @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + pixel_values: Optional[torch.FloatTensor] = None, + aspect_ratio_mask: Optional[List[List[int]]] = None, + aspect_ratio_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[List[List[List[int]]]] = None, + cross_attention_mask: Optional[torch.Tensor] = None, + cross_attention_states: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, MllamaForConditionalGeneration + + >>> model = MllamaForConditionalGeneration.from_pretrained("") + >>> processor = AutoProcessor.from_pretrained("") + + >>> prompt = "<|image|><|begin_of_text|>If I had to write a haiku for this one" + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> inputs = processor(text=prompt, images=image, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_new_tokens=15) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "TODO: fill this out" + ```""" + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + ) + + if pixel_values is not None and cross_attention_states is not None: + raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously") + + if pixel_values is not None: + if aspect_ratio_ids is None: + raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") + # get vision tokens from vision model + cross_attention_states = self.vision_model(pixel_values, aspect_ratio_ids, aspect_ratio_mask) + cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( + -1, cross_attention_states.shape[-2], self.hidden_size + ) + + cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( + cross_attention_mask, + past_key_values=past_key_values, + num_vision_tokens=self.vision_model.num_patches, + cross_attention_layers=self.language_model.model.cross_attention_layers, + cross_attention_states=cross_attention_states, + device=self.device, + dtype=self.dtype, + ) + + if cross_attention_mask is not None and cache_position is not None: + cross_attention_mask = cross_attention_mask[:, :, cache_position] + full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] + + outputs = self.language_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + past_key_values=past_key_values, + use_cache=use_cache, + inputs_embeds=inputs_embeds, + labels=labels, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, + num_logits_to_keep=num_logits_to_keep, + ) + + return outputs + + def prepare_inputs_for_generation( + self, + input_ids=None, + inputs_embeds=None, + attention_mask=None, + position_ids=None, + pixel_values=None, + aspect_ratio_ids=None, + aspect_ratio_mask=None, + cross_attention_mask=None, + past_key_values=None, + use_cache=False, + cache_position=None, + num_logits_to_keep=None, + **kwargs, + ): + # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens + # Exception 1: when passing input_embeds, input_ids may be missing entries + # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here + if past_key_values is not None: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + position_ids = position_ids.clone(memory_format=torch.contiguous_format) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and cache_position[0] == 0: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + # The clone here is for the same reason as for `position_ids`. + model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} + + if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: + if model_inputs["inputs_embeds"] is not None: + batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape + device = model_inputs["inputs_embeds"].device + else: + batch_size, sequence_length = model_inputs["input_ids"].shape + device = model_inputs["input_ids"].device + + dtype = self.get_output_embeddings().weight.dtype + min_dtype = torch.finfo(dtype).min + + attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=past_key_values.get_max_length(), + dtype=dtype, + device=device, + min_dtype=min_dtype, + cache_position=cache_position, + batch_size=batch_size, + ) + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "cross_attention_mask": cross_attention_mask, + } + ) + + # If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios + # to compute image hidden states, otherwise they are cached within each cross attn layer + if (input_ids == self.config.image_token_index).any(): + model_inputs["pixel_values"] = pixel_values + model_inputs["aspect_ratio_ids"] = aspect_ratio_ids + model_inputs["aspect_ratio_mask"] = aspect_ratio_mask + + return model_inputs + + def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): + cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None) + model_kwargs = super()._update_model_kwargs_for_generation( + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + **kwargs, + ) + + # add cross-attn mask for new token + if cross_attention_mask_prev is not None: + model_kwargs["cross_attention_mask"] = torch.cat( + [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1 + ) + return model_kwargs diff --git a/vllm/transformers_utils/image_processor.py b/vllm/transformers_utils/image_processor.py index afbc52788ba1..ec39379c3b0b 100644 --- a/vllm/transformers_utils/image_processor.py +++ b/vllm/transformers_utils/image_processor.py @@ -42,7 +42,7 @@ def get_image_processor( try: print("processor_name", processor_name) - if "Vision-Early" in processor_name: + if "Vision-Early" in processor_name and "checkpoints" not in processor_name: from .multimodal_processors.llamavl import LlamaVLImageProcessor return LlamaVLImageProcessor(processor_name, *args, **kwargs) processor = AutoImageProcessor.from_pretrained( diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 21653a395b89..e06a3dab2cdc 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -116,7 +116,7 @@ def get_tokenizer( if tokenizer_mode == "mistral": tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name), revision=revision) - elif "Meta-Llama-3.2-11B-Vision-Early" in str(tokenizer_name) or "Meta-Llama-3.2-90B-Vision-Early" in str(tokenizer_name): + elif ("Meta-Llama-3.2-11B-Vision-Early" in str(tokenizer_name) or "Meta-Llama-3.2-90B-Vision-Early" in str(tokenizer_name)) and "checkpoints" not in str(tokenizer_name): tokenizer = LlamaVLTokenizer.from_pretrained(str(tokenizer_name)) else: try: From 7e3fb1e86220733a66bb3523c55550c769ae86b4 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 17 Sep 2024 15:19:21 -0700 Subject: [PATCH 26/75] can init model from vllm --- vllm/model_executor/models/mllama.py | 450 +++++++++++++++------------ 1 file changed, 247 insertions(+), 203 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 12250048f47f..7abec8bd0a53 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Mllama model.""" - +from array import array import math -from typing import List, Optional, Tuple, Union +from PIL import Image +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union, Callable, Dict, Any, Set) import torch import torch.nn.functional as F @@ -36,10 +38,108 @@ replace_return_docstrings, ) from transformers.models.mllama.configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig +from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention.ops.paged_attn import PagedAttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.logger import init_logger +from vllm.sequence import IntermediateTensors +from .interfaces import SupportsMultiModal +from .llama import LlamaAttention, LlamaMLP +# from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, + ColumnParallelLinear) +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) + +from vllm.transformers_utils.multimodal_processors.llamavl import LlamaVLImageProcessor +import vllm.distributed.parallel_state as ps +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData logger = logging.get_logger(__name__) +logger = init_logger(__name__) +MP_SCALE = 8 +IMAGE_RES = 224 +LLAMA_IMAGE_TOKEN_ID = 128256 + +class MllamaImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size, max_num_image, max_num_chunk, num_channels, height, width)`""" + aspect_ratios: torch.Tensor + """Shape: `(batch_size, max_num_image, 2)`""" + num_chunks: List[List[int]] + +# TODO: support LlamaImageEmbeddingInputs + +image_processor = None + +def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("encoder_multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + global image_processor + if image_processor is None: + image_processor = LlamaVLImageProcessor(ctx.model_config.model) + + processed_image = image_processor(multi_modal_data["image"]) + llm_inputs["encoder_multi_modal_data"]["image"] = processed_image + + num_chunks = int(processed_image["aspect_ratios"].sum()) + assert ctx.model_config.hf_config.vision_chunk_size % 14 == 0, "chunk size should be multiple of 14" + token_per_chunk = (ctx.model_config.hf_config.vision_chunk_size // 14) ** 2 + 1 + num_tokens = num_chunks * token_per_chunk + llm_inputs["encoder_prompt"] = "<|image|>" * num_tokens + llm_inputs["encoder_prompt_token_ids"] = [LLAMA_IMAGE_TOKEN_ID] * num_tokens + + assert "decoder_multi_modal_data" not in llm_inputs, "multi-modal data should be put in encoder message of LLaMA Vision" + + return llm_inputs + + +def dummy_seq_data( + seq_len: int, + num_images: int +): + assert seq_len >= num_images, "seq_len should be greater than or equal to num_images" + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [LLAMA_IMAGE_TOKEN_ID]) * num_images + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - num_images) + return SequenceData(token_ids) + + +def dummy_image( + num_images: int, +): + width = height = 512 + image = Image.new("RGB", (width, height), color=0) + return {"image": image if num_images == 1 else [image] * num_images} + +def dummy_data_for_mllama(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): + num_images = mm_counts["image"] + return dummy_seq_data(seq_len, num_images), dummy_image(num_images) + +def get_max_mllama_image_tokens(ctx: InputContext) -> int: + hf_config = ctx.model_config.hf_config + token_per_chunk = (hf_config.vision_chunk_size // 14) ** 2 + 1 + return hf_config.max_num_image * hf_config.vision_max_num_chunks * token_per_chunk + + _CONFIG_FOR_DOC = "MllamaConfig" @@ -176,6 +276,45 @@ def _prepare_aspect_ratio_attention_mask( return attention_mask +class ColumnParallelConv2dPatch(torch.nn.Module): + """Conv2D Patching layer with model parallelism. + Column parallel over unfolded input. + Arguments: + in_channels: Input channels. + out_channels: Output channels. + kernel_size: Size of convolution kernel. + stride (default 1): Stride for convolution. + bias (default False): Use bias in Conv2d. + Input: (bsz, in_channels, width, height) + Output: (bsz, num_tokens, out_channels) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + bias: bool = False, + ) -> None: + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride) + self._linear = ColumnParallelLinear( + in_channels * kernel_size[0] * kernel_size[1], + out_channels, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._unfold(x) + x = x.permute(0, 2, 1) + x, _ = self._linear(x) + return x + + + class MllamaPrecomputedAspectRatioEmbedding(nn.Module): def __init__(self, config: MllamaVisionConfig, is_gated: bool = True): super().__init__() @@ -242,8 +381,8 @@ def __init__(self, config): super().__init__() self.config = config self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) - self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + self.fc1 = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=True) + self.fc2 = RowParallelLinear(config.intermediate_size, config.hidden_size, bias=True) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.fc1(hidden_states) @@ -260,39 +399,38 @@ def __init__(self, config: MllamaVisionConfig): self.num_heads = config.attention_heads self.head_dim = config.hidden_size // config.attention_heads - self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=False) + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=False, + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.embed_dim, + bias=False, + input_is_parallel=True, + ) def forward( self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = None, ) -> torch.Tensor: - query = self.q_proj(hidden_state) - key = self.k_proj(hidden_state) - value = self.v_proj(hidden_state) - - batch_size, q_seq_len, _ = query.shape - _, kv_seq_len, _ = key.shape - - query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim) - key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) - value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim) - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask) + qkv, _ = self.qkv_proj(hidden_state) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(q.shape[0], q.shape[1], self.n_local_heads, self.head_dim).transpose(1, 2) + k = k.view(k.shape[0], k.shape[1], self.n_local_kv_heads, self.head_dim).transpose(1, 2) + v = v.view(v.shape[0], v.shape[1], self.n_local_kv_heads, self.head_dim).transpose(1, 2) + + # TODO: remove padding in image encoder + attn_output = F.scaled_dot_product_attention( + q, k, v, attn_mask=attention_mask, dropout_p=0.0 + ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_seq_len, -1) - - output = self.o_proj(attn_output) - + attn_output = attn_output.reshape(attn_output.shape[0], attn_output.shape[1], -1) + output, _ = self.o_proj(attn_output) return output @@ -444,15 +582,14 @@ def __init__(self, config: MllamaVisionConfig): self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 self.scale = config.hidden_size**-0.5 - self.patch_embedding = nn.Conv2d( + self.patch_embedding = ColumnParallelConv2dPatch( in_channels=config.in_channels, out_channels=self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size, - padding="valid", bias=False, ) - + self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size)) self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(config) @@ -600,10 +737,20 @@ def __init__( self.layer_idx = layer_idx self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + # TODO(heheda12345): change to Q/KV seperate linear after #7448 is merged + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.num_heads, + self.num_key_value_heads, + bias=False, + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + ) self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) @@ -618,6 +765,7 @@ def forward( use_cache: bool = None, cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + raise NotImplementedError """Input shape: Batch x Time x Channel""" bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -717,119 +865,39 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class MllamaTextSelfAttention(nn.Module): - def __init__( - self, - config: Optional[MllamaTextConfig] = None, - layer_idx: Optional[int] = None, - ): - super().__init__() - self.config = config - self.num_heads = config.num_attention_heads - self.dropout = config.dropout - self.hidden_size = config.hidden_size - self.num_key_value_heads = config.num_key_value_heads - self.head_dim = config.hidden_size // self.num_heads - self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.layer_idx = layer_idx - - self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - position_embeddings: torch.Tensor, - output_attentions: bool = False, - use_cache: bool = False, - past_key_value=None, - cache_position=None, - **kwargs, - ): - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - causal_mask = attention_mask - if attention_mask is not None: - causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] - - # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, - # Reference: https://github.com/pytorch/pytorch/issues/112577. - if query_states.device.type == "cuda" and causal_mask is not None: - query_states = query_states.contiguous() - key_states = key_states.contiguous() - value_states = value_states.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if causal_mask is None and q_len > 1 else False - - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=causal_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, +class MllamaTextSelfAttention(LlamaAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + rope_scaling = kwargs.get("rope_scaling", None) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False, # force to use neox=False ) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(bsz, q_len, -1) - - attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value - - -MLLAMA_TEXT_CROSS_ATTENTION_CLASSES = {"eager": MllamaTextCrossAttention, "sdpa": MllamaTextCrossAttention} -MLLAMA_TEXT_ATTENTION_CLASSES = {"eager": MllamaTextSelfAttention, "sdpa": MllamaTextSelfAttention} - - -# Copied from transformers.models.gemma2.modeling_gemma2.Gemma2MLP with Gemma2->MllamaText -class MllamaTextMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) - self.act_fn = ACT2FN[config.hidden_activation] - - def forward(self, x): - return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) - # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LlamaDecoder->MllamaSelfAttentionDecoder, Llama->MllamaText, LLAMA->MLLAMA_TEXT class MllamaSelfAttentionDecoderLayer(nn.Module): - def __init__(self, config: MllamaTextConfig, layer_idx: int): + def __init__(self, config: MllamaTextConfig, layer_idx: int, cache_config: Optional[CacheConfig] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MLLAMA_TEXT_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.self_attn = MllamaTextSelfAttention( + config=config, + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=config.rope_theta, + rope_scaling=config.rope_scaling, + max_position_embeddings=config.max_position_embeddings, + quant_config=None, + bias=False, + cache_config=cache_config) - self.mlp = MllamaTextMLP(config) + self.mlp = LlamaMLP(config.hidden_size, config.intermediate_size, hidden_act=config.hidden_activation) self.input_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -911,12 +979,19 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: super().__init__() self.layer_idx = layer_idx - self.cross_attn = MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + self.cross_attn = MllamaTextCrossAttention( + config=config, + layer_idx=layer_idx, + ) self.input_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) - self.mlp = MllamaTextMLP(config) + self.mlp = LlamaMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_activation, + ) self.post_attention_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) @@ -1068,11 +1143,11 @@ class MllamaTextModel(MllamaPreTrainedModel): base_model_prefix = "model" _no_split_modules = ["MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"] - def __init__(self, config: MllamaTextConfig): + def __init__(self, config: MllamaTextConfig, cache_config:Optional[CacheConfig]): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size + 8, config.hidden_size, self.padding_idx) + self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, config.hidden_size, self.padding_idx) self.cross_attention_layers = config.cross_attention_layers layers = [] @@ -1080,7 +1155,7 @@ def __init__(self, config: MllamaTextConfig): if layer_idx in self.cross_attention_layers: layers.append(MllamaCrossAttentionDecoderLayer(config, layer_idx)) else: - layers.append(MllamaSelfAttentionDecoderLayer(config, layer_idx)) + layers.append(MllamaSelfAttentionDecoderLayer(config, layer_idx, cache_config=cache_config)) self.layers = nn.ModuleList(layers) self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -1296,32 +1371,11 @@ class MllamaForCausalLM(MllamaPreTrainedModel): base_model_prefix = "language_model" _no_split_modules = ["MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"] - def __init__(self, config): + def __init__(self, config: MllamaConfig, cache_config:Optional[CacheConfig]): super().__init__(config) self.vocab_size = config.vocab_size - self.model = MllamaTextModel._from_config(config, attn_implementation=config._attn_implementation) - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.model.embed_tokens + self.model = MllamaTextModel(config, cache_config) - def set_input_embeddings(self, value): - self.model.embed_tokens = value - - def get_output_embeddings(self): - return self.lm_head - - def set_output_embeddings(self, new_embeddings): - self.lm_head = new_embeddings - - def set_decoder(self, decoder): - self.model = decoder - - def get_decoder(self): - return self.model def forward( self, @@ -1340,7 +1394,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, - ) -> Union[Tuple, CausalLMOutputWithPast]: + ) -> Tuple: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1395,34 +1449,7 @@ def forward( ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() - - loss = None - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) + return hidden_states def prepare_inputs_for_generation( self, @@ -1586,8 +1613,15 @@ def prepare_inputs_for_generation( """The MLLAMA model which consists of a vision backbone and a language model.""", MLLAMA_START_DOCSTRING, ) -class MllamaForConditionalGeneration(MllamaPreTrainedModel): - def __init__(self, config): +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_mllama) +@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) +class MllamaForConditionalGeneration(MllamaPreTrainedModel, SupportsMultiModal): + def __init__(self, config: MllamaConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): super().__init__(config) self.vocab_size = config.text_config.vocab_size self.hidden_size = self.config.text_config.hidden_size @@ -1595,17 +1629,27 @@ def __init__(self, config): self.vision_output_dim = config.vision_config.vision_output_dim self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - self.vision_model = MllamaVisionModel._from_config( - config.vision_config, attn_implementation=config._attn_implementation + self.vision_model = MllamaVisionModel( + config.vision_config, ) - self.language_model = MllamaForCausalLM._from_config( - config.text_config, attn_implementation=config._attn_implementation + self.language_model = MllamaForCausalLM( + config.text_config, + cache_config=cache_config, ) self.multi_modal_projector = nn.Linear( config.vision_config.vision_output_dim, config.text_config.hidden_size, bias=True, ) + self.lm_head = ParallelLMHead( + config.text_config.vocab_size, + config.text_config.hidden_size, + org_num_embeddings=config.text_config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config.output_hidden_states, config.text_config.vocab_size) + self.sampler = Sampler() self.post_init() def get_input_embeddings(self): From 49b05d623ecc65278f1036f70619a4a72c806712 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 17 Sep 2024 16:27:37 -0700 Subject: [PATCH 27/75] weight loader --- tests/models/test_llamavl.py | 6 +- vllm/model_executor/models/mllama.py | 277 +++++++-------------------- 2 files changed, 69 insertions(+), 214 deletions(-) diff --git a/tests/models/test_llamavl.py b/tests/models/test_llamavl.py index c221b15e97f0..fcf6c5c48d8a 100644 --- a/tests/models/test_llamavl.py +++ b/tests/models/test_llamavl.py @@ -26,8 +26,9 @@ args = parser.parse_args() size = model_size_map[args.model_type] - checkpoint_dir = "/data/zhang-chen/llama/checkpoints" # update checkpoint path here - llm = LLM(model=f"{checkpoint_dir}/Meta-Llama-3.2-{size}-Vision-Early/", + # checkpoint_dir = "/data/zhang-chen/llama/checkpoints" # update checkpoint path here + model_id = "/data/zhang-chen/Llama-3.2-11B-Vision-Early" + llm = LLM(model=model_id, enforce_eager=True, limit_mm_per_prompt={"image": 2}, max_num_seqs=16, @@ -64,4 +65,3 @@ generated_text = o.outputs[0].text print(generated_text) print("==================================") - diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 7abec8bd0a53..fa5947771ebd 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -23,20 +23,12 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss -from transformers import PreTrainedModel from transformers.activations import ACT2FN from transformers.cache_utils import Cache, StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) from transformers.models.mllama.configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention.ops.paged_attn import PagedAttentionMetadata @@ -69,8 +61,6 @@ from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData -logger = logging.get_logger(__name__) - logger = init_logger(__name__) MP_SCALE = 8 IMAGE_RES = 224 @@ -140,9 +130,6 @@ def get_max_mllama_image_tokens(ctx: InputContext) -> int: return hf_config.max_num_image * hf_config.vision_max_num_chunks * token_per_chunk -_CONFIG_FOR_DOC = "MllamaConfig" - - # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, @@ -564,14 +551,14 @@ def forward( ) -class MllamaVisionModel(PreTrainedModel): +class MllamaVisionModel(nn.Module): config_class = MllamaVisionConfig base_model_prefix = "vision_encoder" _no_split_modules = ["MllamaVisionSdpaAttention"] _supports_sdpa = True def __init__(self, config: MllamaVisionConfig): - super().__init__(config) + super().__init__() self.image_size = config.image_size self.patch_size = config.patch_size self.max_num_tiles = config.max_num_tiles @@ -1098,55 +1085,16 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -MLLAMA_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`MllamaConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -class MllamaPreTrainedModel(PreTrainedModel): - config_class = MllamaConfig - base_model_prefix = "model" - _no_split_modules = ["MllamaSdpaCrossAttention"] - _supports_cache_class = True - _supports_static_cache = True - _supports_sdpa = True - _supports_quantized_cache = True - - def _init_weights(self, module): - std = self.config.get_text_config().initializer_range - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, nn.Parameter): - module.data.normal_(mean=0.0, std=std) - - -class MllamaTextModel(MllamaPreTrainedModel): +class MllamaTextModel(nn.Module): config_class = MllamaTextConfig base_model_prefix = "model" _no_split_modules = ["MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"] def __init__(self, config: MllamaTextConfig, cache_config:Optional[CacheConfig]): - super().__init__(config) + super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + print("vocab_size", self.vocab_size) self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, config.hidden_size, self.padding_idx) self.cross_attention_layers = config.cross_attention_layers @@ -1161,7 +1109,6 @@ def __init__(self, config: MllamaTextConfig, cache_config:Optional[CacheConfig]) self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = MllamaRotaryEmbedding(config=config) self.gradient_checkpointing = False - self.post_init() def get_input_embeddings(self): return self.embed_tokens @@ -1366,15 +1313,22 @@ def _update_causal_mask( return causal_mask -class MllamaForCausalLM(MllamaPreTrainedModel): +class MllamaForCausalLM(nn.Module): config_class = MllamaTextConfig base_model_prefix = "language_model" _no_split_modules = ["MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"] - def __init__(self, config: MllamaConfig, cache_config:Optional[CacheConfig]): - super().__init__(config) + def __init__(self, config: MllamaTextConfig, cache_config:Optional[CacheConfig], quant_config: Optional[QuantizationConfig]): + super().__init__() self.vocab_size = config.vocab_size self.model = MllamaTextModel(config, cache_config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + ) def forward( @@ -1395,36 +1349,6 @@ def forward( cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Tuple: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, MllamaForCausalLM - - >>> model = MllamaForCausalLM.from_pretrained("meta-llama/Llama-3.2-11b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-11b-hf") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1526,108 +1450,21 @@ def prepare_inputs_for_generation( return model_inputs -MLLAMA_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide - it. - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - [What are input IDs?](../glossary#input-ids) - pixel_values (`torch.FloatTensor` of shape `(batch_size, max_num_images, max_num_tiles, channels, image_size, image_size)): - The tensors corresponding to the input images. Pixel values can be obtained using - [`AutoImageProcessor`]. See [`MllamaImageProcessor.__call__`] for details ([]`MllamaProcessor`] uses - [`MllamaImageProcessor`] for processing images). - aspect_ratio_mask: Optional[List[List[int]]] = None, # TODO - aspect_ratio_ids: Optional[torch.Tensor] = None, # TODO - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. - - If `past_key_values` is used, optionally only the last `input_ids` have to be input (see - `past_key_values`). - - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] - and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more - information on the default strategy. - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - cross_attention_mask: Optional[torch.Tensor] = None, # TODO - cross_attention_states: Optional[torch.Tensor] = None, # TODO - position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): - Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention - blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` - returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. - - If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't - have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` - of shape `(batch_size, sequence_length)`. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert `input_ids` indices into associated vectors than the - model's internal embedding lookup matrix. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see - `past_key_values`). - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, - this tensor is not affected by padding. It is used to update the cache in the correct position and to infer - the complete sequence length. -""" - - -@add_start_docstrings( - """The MLLAMA model which consists of a vision backbone and a language model.""", - MLLAMA_START_DOCSTRING, -) @MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_mllama) @INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) -class MllamaForConditionalGeneration(MllamaPreTrainedModel, SupportsMultiModal): +class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, config: MllamaConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): - super().__init__(config) + super().__init__() self.vocab_size = config.text_config.vocab_size - self.hidden_size = self.config.text_config.hidden_size + self.hidden_size = config.text_config.hidden_size self.max_num_tiles = config.vision_config.max_num_tiles self.vision_output_dim = config.vision_config.vision_output_dim - self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self.pad_token_id = config.pad_token_id if config.pad_token_id is not None else -1 self.vision_model = MllamaVisionModel( config.vision_config, @@ -1635,46 +1472,34 @@ def __init__(self, config: MllamaConfig, self.language_model = MllamaForCausalLM( config.text_config, cache_config=cache_config, + quant_config=quant_config, ) self.multi_modal_projector = nn.Linear( config.vision_config.vision_output_dim, config.text_config.hidden_size, bias=True, ) - self.lm_head = ParallelLMHead( - config.text_config.vocab_size, - config.text_config.hidden_size, - org_num_embeddings=config.text_config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, - quant_config=quant_config, - ) self.logits_processor = LogitsProcessor(config.output_hidden_states, config.text_config.vocab_size) self.sampler = Sampler() - self.post_init() - - def get_input_embeddings(self): - return self.language_model.get_input_embeddings() - - def set_input_embeddings(self, value): - self.language_model.set_input_embeddings(value) - - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - def get_decoder(self): - return self.language_model.get_decoder() - def tie_weights(self): - return self.language_model.tie_weights() + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.language_model.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens - @add_start_docstrings_to_model_forward(MLLAMA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, @@ -1891,3 +1716,33 @@ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_ [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1 ) return model_kwargs + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + updated_params = set() + for name, loaded_weight in weights: + if 'patch_embedding.weight' in name: + name = name.replace('patch_embedding.weight', 'patch_embedding._linear.weight') + loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + updated_params.add(name) + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict.pop(name) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From 2e66a5dde12d99606e9348c2f5bec13ef3b5b689 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 18 Sep 2024 00:35:34 -0700 Subject: [PATCH 28/75] run image encoder now --- vllm/model_executor/models/mllama.py | 372 +++++++++++++-------------- vllm/multimodal/base.py | 6 + 2 files changed, 191 insertions(+), 187 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index fa5947771ebd..c834b07c2a5f 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -30,6 +30,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS from transformers.models.mllama.configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig +from transformers.models.mllama.image_processing_mllama import MllamaImageProcessor from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.config import CacheConfig, MultiModalConfig @@ -56,7 +57,6 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.transformers_utils.multimodal_processors.llamavl import LlamaVLImageProcessor import vllm.distributed.parallel_state as ps from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData @@ -70,29 +70,39 @@ class MllamaImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor """Shape: `(batch_size, max_num_image, max_num_chunk, num_channels, height, width)`""" - aspect_ratios: torch.Tensor - """Shape: `(batch_size, max_num_image, 2)`""" - num_chunks: List[List[int]] + aspect_ratio_ids: torch.Tensor + """Shape: `(batch_size, max_num_image)`""" + aspect_ratio_mask: torch.Tensor + """Shape: `(batch_size, max_num_image, max_num_tiles)`""" # TODO: support LlamaImageEmbeddingInputs image_processor = None +def recursive_sum(x): + if isinstance(x, torch.Tensor): + return x.sum() + if isinstance(x, (list, tuple)): + return sum(recursive_sum(v) for v in x) + return 0 + def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data = llm_inputs.get("encoder_multi_modal_data") + hf_config = ctx.model_config.hf_config if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs global image_processor if image_processor is None: - image_processor = LlamaVLImageProcessor(ctx.model_config.model) - + image_processor = MllamaImageProcessor( + ctx.model_config.model, + size={"height": hf_config.vision_config.image_size, "width": hf_config.vision_config.image_size}, + ) processed_image = image_processor(multi_modal_data["image"]) llm_inputs["encoder_multi_modal_data"]["image"] = processed_image - - num_chunks = int(processed_image["aspect_ratios"].sum()) - assert ctx.model_config.hf_config.vision_chunk_size % 14 == 0, "chunk size should be multiple of 14" - token_per_chunk = (ctx.model_config.hf_config.vision_chunk_size // 14) ** 2 + 1 - num_tokens = num_chunks * token_per_chunk + num_tiles = recursive_sum(processed_image["num_tiles"]) + assert hf_config.vision_config.image_size % 14 == 0, "chunk size should be multiple of 14" + token_per_chunk = (hf_config.vision_config.image_size // 14) ** 2 + 1 + num_tokens = num_tiles * token_per_chunk llm_inputs["encoder_prompt"] = "<|image|>" * num_tokens llm_inputs["encoder_prompt_token_ids"] = [LLAMA_IMAGE_TOKEN_ID] * num_tokens @@ -126,8 +136,8 @@ def dummy_data_for_mllama(ctx: InputContext, seq_len: int, mm_counts: Mapping[st def get_max_mllama_image_tokens(ctx: InputContext) -> int: hf_config = ctx.model_config.hf_config - token_per_chunk = (hf_config.vision_chunk_size // 14) ** 2 + 1 - return hf_config.max_num_image * hf_config.vision_max_num_chunks * token_per_chunk + token_per_chunk = (hf_config.vision_config.image_size // 14) ** 2 + 1 + return hf_config.vision_config.max_num_tiles * token_per_chunk # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position @@ -372,9 +382,9 @@ def __init__(self, config): self.fc2 = RowParallelLinear(config.intermediate_size, config.hidden_size, bias=True) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.fc1(hidden_states) + hidden_states, _ = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) - hidden_states = self.fc2(hidden_states) + hidden_states, _ = self.fc2(hidden_states) return hidden_states @@ -382,9 +392,13 @@ class MllamaVisionSdpaAttention(nn.Module): def __init__(self, config: MllamaVisionConfig): super().__init__() + model_parallel_size = get_tensor_model_parallel_world_size() self.embed_dim = config.hidden_size self.num_heads = config.attention_heads self.head_dim = config.hidden_size // config.attention_heads + self.num_local_heads = self.num_heads // model_parallel_size + self.q_size = self.num_local_heads * self.head_dim + self.kv_size = self.num_local_heads * self.head_dim self.qkv_proj = QKVParallelLinear( self.embed_dim, @@ -406,9 +420,9 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_state) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = q.view(q.shape[0], q.shape[1], self.n_local_heads, self.head_dim).transpose(1, 2) - k = k.view(k.shape[0], k.shape[1], self.n_local_kv_heads, self.head_dim).transpose(1, 2) - v = v.view(v.shape[0], v.shape[1], self.n_local_kv_heads, self.head_dim).transpose(1, 2) + q = q.view(q.shape[0], q.shape[1], self.num_local_heads, self.head_dim).transpose(1, 2) + k = k.view(k.shape[0], k.shape[1], self.num_local_heads, self.head_dim).transpose(1, 2) + v = v.view(v.shape[0], v.shape[1], self.num_local_heads, self.head_dim).transpose(1, 2) # TODO: remove padding in image encoder attn_output = F.scaled_dot_product_attention( @@ -606,8 +620,8 @@ def forward( aspect_ratio_ids = aspect_ratio_ids.reshape(batch_size * num_concurrent_media, -1) # patch embedding - patch_embeds = self.patch_embedding(pixel_values.to(self.dtype).to(self.device)) - hidden_state = patch_embeds.flatten(2).transpose(1, 2) + patch_embeds = self.patch_embedding(pixel_values.to(self.layernorm_pre.weight.dtype)) + hidden_state = patch_embeds # tile embeddings _, num_patches, dim = hidden_state.shape @@ -633,6 +647,7 @@ def forward( # Pad the tensor hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) slice_index = -num_padding_patches if num_padding_patches > 0 else None + # import pdb; pdb.set_trace() if attention_mask is not None: attention_mask = attention_mask.reshape(batch_size * num_concurrent_media, -1) @@ -640,7 +655,7 @@ def forward( aspect_ratio_mask=attention_mask, num_patches=self.num_patches, target_length=hidden_state.shape[2], - dtype=self.dtype, + dtype=self.layernorm_pre.weight.dtype, ) hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) @@ -1118,41 +1133,15 @@ def set_input_embeddings(self, value): def forward( self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.FloatTensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" - ) - - if self.gradient_checkpointing and self.training and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." - ) - use_cache = False - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - + input_ids: torch.LongTensor, + position_ids: Optional[torch.LongTensor], + cross_attention_states: Optional[torch.LongTensor], + cross_attention_mask: Optional[torch.LongTensor], + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds if cache_position is None: @@ -1333,46 +1322,23 @@ def __init__(self, config: MllamaTextConfig, cache_config:Optional[CacheConfig], def forward( self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.LongTensor] = None, - cross_attention_mask: Optional[torch.LongTensor] = None, - full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - ) -> Tuple: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + input_ids: torch.LongTensor, + position_ids: Optional[torch.LongTensor], + cross_attention_states: Optional[torch.LongTensor], + cross_attention_mask: Optional[torch.LongTensor], + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.model( input_ids=input_ids, - cross_attention_states=cross_attention_states, - attention_mask=attention_mask, position_ids=position_ids, + cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, + kv_caches=kv_caches, + attn_metadata=attn_metadata, ) - - hidden_states = outputs[0] return hidden_states def prepare_inputs_for_generation( @@ -1465,6 +1431,7 @@ def __init__(self, config: MllamaConfig, self.max_num_tiles = config.vision_config.max_num_tiles self.vision_output_dim = config.vision_config.vision_output_dim self.pad_token_id = config.pad_token_id if config.pad_token_id is not None else -1 + self.image_size = config.vision_config.image_size self.vision_model = MllamaVisionModel( config.vision_config, @@ -1499,116 +1466,147 @@ def sample( ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + + def _parse_and_validate_image_input( + self, **kwargs: object): + print("kwargs", kwargs.keys()) + # import pdb; pdb.set_trace() + # tensor with the same shape will be batched together by MultiModalInputs.batch, so pixel_values here can be: + # - List[List[torch.Tensor]]: with shape (num_tiles, 3, image_res, image_res) + # - List[torch.Tensor]: with shape (num_image_in_batch, num_tiles, 3, image_res, image_res) + # - torch.Tensor: with shape (bs, num_image_in_batch, num_tiles, 3, image_res, image_res) + pixel_values: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = kwargs.pop("pixel_values", None) + image_embeds: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = kwargs.pop("image_embeds", None) + aspect_ratio_ids: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = kwargs.pop("aspect_ratio_ids", None) + aspect_ratio_mask: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = kwargs.pop("aspect_ratio_mask", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None and image_embeds is not None: + raise ValueError("Both pixel values and image embeds are provided.") - def forward( - self, - input_ids: torch.LongTensor = None, - pixel_values: Optional[torch.FloatTensor] = None, - aspect_ratio_mask: Optional[List[List[int]]] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[List[List[List[int]]]] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - num_logits_to_keep: int = 0, - ) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - num_logits_to_keep (`int`, *optional*): - Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, MllamaForConditionalGeneration - - >>> model = MllamaForConditionalGeneration.from_pretrained("") - >>> processor = AutoProcessor.from_pretrained("") - - >>> prompt = "<|image|><|begin_of_text|>If I had to write a haiku for this one" - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(text=prompt, images=image, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_new_tokens=15) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "TODO: fill this out" - ```""" - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError( - "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + if pixel_values is not None: + assert aspect_ratio_ids is not None + assert aspect_ratio_mask is not None + max_num_images = max([len(x[0]) for x in pixel_values]) + if max_num_images == 0: + raise ValueError("No images provided.") + max_num_tiles = max(max([len(x) for x in y[0]]) for y in pixel_values) + device = self.multi_modal_projector.weight.device + bsz = len(pixel_values) + out_num_tiles = [] + out_images = torch.zeros( + bsz, + max_num_images, + max_num_tiles, + 3, + self.image_size, + self.image_size, + dtype=torch.float32, + device=device, ) - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" + out_ar_ids = torch.ones(bsz, max_num_images, dtype=torch.int64, device=device) + out_ar_mask = torch.zeros(bsz, max_num_images, max_num_tiles, dtype=torch.int64, device=device) + for b in range(len(pixel_values)): + _num_tiles = [] + for i in range(len(pixel_values[b][0])): + img = pixel_values[b][0][i] + out_images[b, i, :img.shape[0]] = img + out_ar_ids[b, i] = aspect_ratio_ids[b][0][i] + out_ar_mask[b, i] = aspect_ratio_mask[b][0][i] + _num_tiles.append(img.shape[0]) + out_num_tiles.append(_num_tiles) + + return MllamaImagePixelInputs( + type="pixel_values", + data=out_images, + aspect_ratio_ids=out_ar_ids, + aspect_ratio_mask=out_ar_mask, ) - if pixel_values is not None and cross_attention_states is not None: - raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously") + if image_embeds is not None: + raise NotImplementedError - if pixel_values is not None: - if aspect_ratio_ids is None: - raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") - # get vision tokens from vision model - cross_attention_states = self.vision_model(pixel_values, aspect_ratio_ids, aspect_ratio_mask) - cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( - -1, cross_attention_states.shape[-2], self.hidden_size - ) + raise AssertionError("This line should be unreachable.") - cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( - cross_attention_mask, - past_key_values=past_key_values, - num_vision_tokens=self.vision_model.num_patches, - cross_attention_layers=self.language_model.model.cross_attention_layers, - cross_attention_states=cross_attention_states, - device=self.device, - dtype=self.dtype, - ) + - if cross_attention_mask is not None and cache_position is not None: - cross_attention_mask = cross_attention_mask[:, :, cache_position] - full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> Union[Tuple, CausalLMOutputWithPast]: + if attn_metadata.num_prefill_tokens > 0 and attn_metadata.num_decode_tokens > 0: + raise ValueError("Chunk prefill not supported") + image_inputs = self._parse_and_validate_image_input(**kwargs) + if image_inputs is None: + cross_attention_masks = None + run_xattn_mask = (attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).cuda() + xattn_caches = None + vision_tokens = None + else: + # llama's reference implementation runs the vision model on CPU + pixel_values = image_inputs['data'] + aspect_ratio_ids = image_inputs['aspect_ratio_ids'] + aspect_ratio_mask = image_inputs['aspect_ratio_mask'] + cross_attention_states = self.vision_model(pixel_values, aspect_ratio_ids, aspect_ratio_mask) + import pdb; pdb.set_trace() + bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape) + cross_attention_states = cross_attention_states.view(bsz, -1, image_token_dim) + + cross_attention_states_flat = torch.zeros(sum(attn_metadata.encoder_seq_lens), image_token_dim, device=vision_tokens.device, dtype=vision_tokens.dtype) + start_pos = 0 + for seq_len, vision_token_in_batch in zip(attn_metadata.encoder_seq_lens, vision_tokens): + end_pos = start_pos + seq_len + cross_attention_states_flat[start_pos:end_pos] = vision_token_in_batch[:seq_len] + start_pos = end_pos + cross_attention_states = cross_attention_states_flat + cross_attention_mask = None # TODO + full_text_row_masked_out_mask = None # TODO + + # run_xattn_mask = torch.ones((attn_metadata.num_prefill_tokens, 1), dtype=torch.bool, device=cross_attention_states.device) + # start_pos = 0 + # for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens_tensor, attn_metadata.encoder_seq_lens): + # if encoder_seq_len == 0: + # run_xattn_mask[start_pos:start_pos+seq_len] = False + # start_pos += seq_len + + # if pixel_values is not None: + # if aspect_ratio_ids is None: + # raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") + # # get vision tokens from vision model + # cross_attention_states = self.vision_model(pixel_values, aspect_ratio_ids, aspect_ratio_mask) + # cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( + # -1, cross_attention_states.shape[-2], self.hidden_size + # ) + + # cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( + # cross_attention_mask, + # past_key_values=past_key_values, + # num_vision_tokens=self.vision_model.num_patches, + # cross_attention_layers=self.language_model.model.cross_attention_layers, + # cross_attention_states=cross_attention_states, + # device=self.device, + # dtype=self.dtype, + # ) + + # if cross_attention_mask is not None and cache_position is not None: + # cross_attention_mask = cross_attention_mask[:, :, cache_position] + # full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] outputs = self.language_model( input_ids=input_ids, - attention_mask=attention_mask, position_ids=position_ids, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, - past_key_values=past_key_values, - use_cache=use_cache, - inputs_embeds=inputs_embeds, - labels=labels, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, - cache_position=cache_position, - num_logits_to_keep=num_logits_to_keep, + kv_caches=kv_caches, + attn_metadata=attn_metadata, ) return outputs diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 032964fe0ac4..499380292b79 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -52,6 +52,12 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: """ if isinstance(nested_tensors, torch.Tensor): return nested_tensors + + if isinstance(nested_tensors, np.ndarray): + return torch.from_numpy(nested_tensors) + + if isinstance(nested_tensors, (int, float)): + return torch.tensor(nested_tensors) stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors] if not is_list_of(stacked, torch.Tensor, check="all"): From 9770d84d3693d3b07ff476e21930014695abba11 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Wed, 18 Sep 2024 16:37:00 -0700 Subject: [PATCH 29/75] Add API Server Support --- examples/openai_vision_api_client.py | 2 ++ examples/template_llama3.2.jinja | 25 +++++++++++++++++++ .../multimodal_processors/llamavl.py | 4 ++- 3 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 examples/template_llama3.2.jinja diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py index 1ba702ef019e..09854003d501 100644 --- a/examples/openai_vision_api_client.py +++ b/examples/openai_vision_api_client.py @@ -55,6 +55,8 @@ result = chat_completion_from_url.choices[0].message.content print("Chat completion output:", result) +print("remove me: testing done, exitting...") +import sys; sys.exit(0) ## Use base64 encoded image in the payload def encode_image_base64_from_url(image_url: str) -> str: diff --git a/examples/template_llama3.2.jinja b/examples/template_llama3.2.jinja new file mode 100644 index 000000000000..66a074be5610 --- /dev/null +++ b/examples/template_llama3.2.jinja @@ -0,0 +1,25 @@ +{% for message in messages %} + {% if loop.index0 == 0 %} + {{ bos_token }} + {% endif %} + + {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }} + + {% if message['content'] is string %} + {{ message['content'] }} + {% else %} + {% for content in message['content'] %} + {% if content['type'] == 'image' %} + {{ '<|image|>' }} + {% elif content['type'] == 'text' %} + {{ content['text'] }} + {% endif %} + {% endfor %} + {% endif %} + + {{ '<|eot_id|>' }} +{% endfor %} + +{% if add_generation_prompt %} + {{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{% endif %} \ No newline at end of file diff --git a/vllm/transformers_utils/multimodal_processors/llamavl.py b/vllm/transformers_utils/multimodal_processors/llamavl.py index 8d3537b457f5..2ac9fcdcb933 100644 --- a/vllm/transformers_utils/multimodal_processors/llamavl.py +++ b/vllm/transformers_utils/multimodal_processors/llamavl.py @@ -323,6 +323,8 @@ def preprocess(self, images, **kwargs) -> BatchFeature: # ), "Images and masks must have the same length" # preprocess is called for each batch now, so add batch dimension here. + if not isinstance(images, list): + images = [images] images = [images] max_num_images = max(len(x) for x in images) @@ -361,4 +363,4 @@ def preprocess(self, images, **kwargs) -> BatchFeature: # ) # print("stacked_images", stacked_images.shape) # print("num_chunks", num_chunks) - return BatchFeature(data, tensor_type=None) \ No newline at end of file + return BatchFeature(data, tensor_type=None) From c9d612b53319dd5f3bd87bd1be57b2a8693be422 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 18 Sep 2024 17:27:26 -0700 Subject: [PATCH 30/75] run single image reqeusts correctly --- vllm/model_executor/models/mllama.py | 364 +++++---------------- vllm/transformers_utils/image_processor.py | 2 +- 2 files changed, 85 insertions(+), 281 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index c834b07c2a5f..6f5742fd7de8 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -84,6 +84,8 @@ def recursive_sum(x): return x.sum() if isinstance(x, (list, tuple)): return sum(recursive_sum(v) for v in x) + if isinstance(x, (int, float)): + return x return 0 def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): @@ -731,13 +733,18 @@ def __init__( ): super().__init__() self.config = config + self.model_parallel_size = get_tensor_model_parallel_world_size() self.num_heads = self.config.num_attention_heads + self.num_local_heads = self.num_heads // self.model_parallel_size self.num_key_value_heads = self.config.num_key_value_heads + self.num_local_key_value_heads = self.num_key_value_heads // self.model_parallel_size self.dropout = config.dropout self.hidden_size = config.hidden_size self.head_dim = config.hidden_size // self.num_heads self.layer_idx = layer_idx self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.q_local_size = self.num_local_heads * self.head_dim + self.kv_local_size = self.num_local_key_value_heads * self.head_dim # TODO(heheda12345): change to Q/KV seperate linear after #7448 is merged self.qkv_proj = QKVParallelLinear( @@ -756,67 +763,43 @@ def __init__( self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.scaling = self.head_dim**-0.5 + + self.attn = Attention( + self.num_local_heads, + self.head_dim, + self.scaling, + self.num_local_key_value_heads, + ) def forward( self, hidden_states: torch.Tensor, - cross_attention_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - use_cache: bool = None, - cache_position: Optional[torch.LongTensor] = None, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - raise NotImplementedError + attention_mask: Optional[torch.Tensor], + cross_attention_states: Optional[torch.Tensor], + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: """Input shape: Batch x Time x Channel""" - bsz, q_len, _ = hidden_states.size() - query_states = self.q_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - query_states = self.q_norm(query_states) - - if cross_attention_states is not None: - key_states = self.k_proj(cross_attention_states) - value_states = self.v_proj(cross_attention_states) - key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - key_states = self.k_norm(key_states) - if past_key_value is not None: - # if we have a new image + new tokens, we only computed key_states on that new image - # we still update the cross key states, past_image, new_image. And use it! - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, {"cache_position": cache_position} - ) - elif past_key_value.get_seq_length(self.layer_idx) != 0: - key_states, value_states = ( - past_key_value.key_cache[self.layer_idx], - past_key_value.value_cache[self.layer_idx], - ) + qkv_dec, _ = self.qkv_proj(hidden_states) + q, _, _ = qkv_dec.split([self.q_local_size, self.kv_local_size, self.kv_local_size], + dim=-1) + if cross_attention_states is None: + k = None + v = None else: - raise ValueError( - "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!" - ) - - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - - if attention_mask is not None: # no matter the length, we just slice it - causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] - attn_weights = attn_weights + causal_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - + qkv_enc, _ = self.qkv_proj(cross_attention_states) + _, k, v = qkv_enc.split([self.q_local_size, self.kv_local_size, self.kv_local_size], + dim=-1) + k = k.view(-1, self.num_local_key_value_heads, self.head_dim) + v = v.view(-1, self.num_local_key_value_heads, self.head_dim) + k = self.k_norm(k) + q = q.view(-1, self.num_local_heads, self.head_dim) + q = self.q_norm(q) + + output = self.attn(q, k, v, kv_cache, attn_metadata, attn_type=AttentionType.ENCODER_DECODER) + out, _ = self.o_proj(output) + return out # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): @@ -909,52 +892,20 @@ def __init__(self, config: MllamaTextConfig, layer_idx: int, cache_config: Optio def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 - **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`, *optional*): - attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, - query_sequence_length, key_sequence_length)` if default attention is used. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - use_cache (`bool`, *optional*): - If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding - (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states - cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): - Indices depicting the position of the input sequence tokens in the sequence - position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): - Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, - with `head_dim` being the embedding dimension of each attention head. - kwargs (`dict`, *optional*): - Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code - into the model - """ + positions: Optional[torch.LongTensor], + kv_cache: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.FloatTensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states = self.self_attn( hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, + positions=positions, + kv_cache=kv_cache, + attn_metadata=attn_metadata, ) hidden_states = residual + hidden_states @@ -964,15 +915,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs + return hidden_states class MllamaCrossAttentionDecoderLayer(torch.nn.Module): @@ -1002,24 +945,19 @@ def forward( hidden_states: torch.Tensor, cross_attention_states: torch.Tensor, cross_attention_mask: torch.Tensor, - attention_mask: torch.Tensor, full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], - past_key_value: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, + kv_cache: List[torch.Tensor], + attn_metadata: AttentionMetadata, ) -> torch.Tensor: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) - hidden_states, attn_weights, past_key_value = self.cross_attn( + hidden_states = self.cross_attn( hidden_states=hidden_states, attention_mask=cross_attention_mask, cross_attention_states=cross_attention_states, - past_key_value=past_key_value, - output_attentions=output_attentions, - cache_position=cache_position, + kv_cache=kv_cache, + attn_metadata=attn_metadata, ) hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states @@ -1029,76 +967,7 @@ def forward( if full_text_row_masked_out_mask is not None: hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states - - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - if use_cache: - outputs += (past_key_value,) - - return outputs - - -class MllamaRotaryEmbedding(nn.Module): - def __init__( - self, - config: Optional[MllamaTextConfig] = None, - device=None, - ): - super().__init__() - self.rope_type = config.rope_scaling["rope_type"] - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - def _dynamic_frequency_update(self, position_ids, device): - """ - dynamic RoPE layers should recompute `inv_freq` in the following situations: - 1 - growing beyond the cached sequence length (allow scaling) - 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) - """ - seq_len = torch.max(position_ids) + 1 - if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn( - self.config, device, seq_len=seq_len, **self.rope_kwargs - ) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation - self.max_seq_len_cached = seq_len - - if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) - self.max_seq_len_cached = self.original_max_seq_len - - @torch.no_grad() - def forward(self, x, position_ids): - if "dynamic" in self.rope_type: - self._dynamic_frequency_update(position_ids, device=x.device) - - # Core RoPE block - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - position_ids_expanded = position_ids[:, None, :].float() - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) - device_type = x.device.type - device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention - cos = cos * self.attention_scaling - sin = sin * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - + return hidden_states class MllamaTextModel(nn.Module): config_class = MllamaTextConfig @@ -1109,7 +978,6 @@ def __init__(self, config: MllamaTextConfig, cache_config:Optional[CacheConfig]) super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - print("vocab_size", self.vocab_size) self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, config.hidden_size, self.padding_idx) self.cross_attention_layers = config.cross_attention_layers @@ -1122,7 +990,7 @@ def __init__(self, config: MllamaTextConfig, cache_config:Optional[CacheConfig]) self.layers = nn.ModuleList(layers) self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = MllamaRotaryEmbedding(config=config) + # self.rotary_emb = MllamaRotaryEmbedding(config=config) self.gradient_checkpointing = False def get_input_embeddings(self): @@ -1134,7 +1002,7 @@ def set_input_embeddings(self, value): def forward( self, input_ids: torch.LongTensor, - position_ids: Optional[torch.LongTensor], + positions: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor], full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], @@ -1144,94 +1012,29 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions - ) - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - for idx, decoder_layer in enumerate(self.layers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - if ( - idx in self.cross_attention_layers - and cross_attention_states is None - and ( - past_key_values is None - or (past_key_values is not None and past_key_values.get_seq_length(idx) == 0) - ) - ): - continue - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - cross_attention_states, - cross_attention_mask, - causal_mask, - full_text_row_masked_out_mask, - position_ids, - past_key_values, - output_attentions, - use_cache, - cache_position, - position_embeddings, - ) - else: - layer_outputs = decoder_layer( - hidden_states, + if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer): + hidden_states = decoder_layer( + hidden_states=hidden_states, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, - attention_mask=causal_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, + # xattn_cache=xattn_caches[xattn_layer_idx] if xattn_caches is not None else None, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] - - if output_attentions: - all_self_attns += (layer_outputs[1],) - + elif isinstance(decoder_layer, MllamaSelfAttentionDecoderLayer): + hidden_states = decoder_layer( + hidden_states=hidden_states, + positions=positions, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + ) + else: + raise ValueError(f"Unknown decoder layer type {type(decoder_layer)}") hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) + return hidden_states + def _update_causal_mask( self, @@ -1323,7 +1126,7 @@ def __init__(self, config: MllamaTextConfig, cache_config:Optional[CacheConfig], def forward( self, input_ids: torch.LongTensor, - position_ids: Optional[torch.LongTensor], + positions: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor], full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], @@ -1332,7 +1135,7 @@ def forward( ) -> torch.Tensor: hidden_states = self.model( input_ids=input_ids, - position_ids=position_ids, + positions=positions, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, @@ -1455,7 +1258,7 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.language_model.logits_processor(self.lm_head, hidden_states, + logits = self.logits_processor(self.language_model.lm_head, hidden_states, sampling_metadata) return logits @@ -1469,8 +1272,6 @@ def sample( def _parse_and_validate_image_input( self, **kwargs: object): - print("kwargs", kwargs.keys()) - # import pdb; pdb.set_trace() # tensor with the same shape will be batched together by MultiModalInputs.batch, so pixel_values here can be: # - List[List[torch.Tensor]]: with shape (num_tiles, 3, image_res, image_res) # - List[torch.Tensor]: with shape (num_image_in_batch, num_tiles, 3, image_res, image_res) @@ -1535,7 +1336,7 @@ def _parse_and_validate_image_input( def forward( self, input_ids: torch.Tensor, - position_ids: torch.Tensor, + positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, @@ -1545,23 +1346,26 @@ def forward( raise ValueError("Chunk prefill not supported") image_inputs = self._parse_and_validate_image_input(**kwargs) if image_inputs is None: - cross_attention_masks = None + cross_attention_mask = None run_xattn_mask = (attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).cuda() xattn_caches = None vision_tokens = None + cross_attention_states = None + full_text_row_masked_out_mask = None else: # llama's reference implementation runs the vision model on CPU pixel_values = image_inputs['data'] aspect_ratio_ids = image_inputs['aspect_ratio_ids'] aspect_ratio_mask = image_inputs['aspect_ratio_mask'] cross_attention_states = self.vision_model(pixel_values, aspect_ratio_ids, aspect_ratio_mask) - import pdb; pdb.set_trace() + cross_attention_states = self.multi_modal_projector(cross_attention_states) + bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape) cross_attention_states = cross_attention_states.view(bsz, -1, image_token_dim) - cross_attention_states_flat = torch.zeros(sum(attn_metadata.encoder_seq_lens), image_token_dim, device=vision_tokens.device, dtype=vision_tokens.dtype) + cross_attention_states_flat = torch.zeros(sum(attn_metadata.encoder_seq_lens), image_token_dim, device=cross_attention_states.device, dtype=cross_attention_states.dtype) start_pos = 0 - for seq_len, vision_token_in_batch in zip(attn_metadata.encoder_seq_lens, vision_tokens): + for seq_len, vision_token_in_batch in zip(attn_metadata.encoder_seq_lens, cross_attention_states): end_pos = start_pos + seq_len cross_attention_states_flat[start_pos:end_pos] = vision_token_in_batch[:seq_len] start_pos = end_pos @@ -1601,7 +1405,7 @@ def forward( outputs = self.language_model( input_ids=input_ids, - position_ids=position_ids, + positions=positions, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, diff --git a/vllm/transformers_utils/image_processor.py b/vllm/transformers_utils/image_processor.py index ec39379c3b0b..2bb167a3fc65 100644 --- a/vllm/transformers_utils/image_processor.py +++ b/vllm/transformers_utils/image_processor.py @@ -42,7 +42,7 @@ def get_image_processor( try: print("processor_name", processor_name) - if "Vision-Early" in processor_name and "checkpoints" not in processor_name: + if "Vision-Early" in processor_name and "checkpoints" in processor_name: from .multimodal_processors.llamavl import LlamaVLImageProcessor return LlamaVLImageProcessor(processor_name, *args, **kwargs) processor = AutoImageProcessor.from_pretrained( From 2f54ae395e12d71a1725d6b13075510e210fdbd5 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 18 Sep 2024 22:40:32 -0700 Subject: [PATCH 31/75] single image match huggingface result --- vllm/model_executor/models/mllama.py | 25 +++++++++---------------- vllm/transformers_utils/tokenizer.py | 3 --- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 6f5742fd7de8..43578d26a602 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -624,6 +624,7 @@ def forward( # patch embedding patch_embeds = self.patch_embedding(pixel_values.to(self.layernorm_pre.weight.dtype)) hidden_state = patch_embeds + hidden_state = ps.get_tp_group().all_gather(hidden_state) # tile embeddings _, num_patches, dim = hidden_state.shape @@ -649,7 +650,6 @@ def forward( # Pad the tensor hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) slice_index = -num_padding_patches if num_padding_patches > 0 else None - # import pdb; pdb.set_trace() if attention_mask is not None: attention_mask = attention_mask.reshape(batch_size * num_concurrent_media, -1) @@ -850,27 +850,13 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -class MllamaTextSelfAttention(LlamaAttention): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - rope_scaling = kwargs.get("rope_scaling", None) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=self.rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False, # force to use neox=False - ) - - # Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LlamaDecoder->MllamaSelfAttentionDecoder, Llama->MllamaText, LLAMA->MLLAMA_TEXT class MllamaSelfAttentionDecoderLayer(nn.Module): def __init__(self, config: MllamaTextConfig, layer_idx: int, cache_config: Optional[CacheConfig] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = MllamaTextSelfAttention( + self.self_attn = LlamaAttention( config=config, hidden_size=config.hidden_size, num_heads=config.num_attention_heads, @@ -1403,6 +1389,11 @@ def forward( # cross_attention_mask = cross_attention_mask[:, :, cache_position] # full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] + # print("input_ids", input_ids, cross_attention_states is None) + # if positions.numel() == 1: + # global step_name + # step_name = f"decode_{positions.item()}" + outputs = self.language_model( input_ids=input_ids, positions=positions, @@ -1412,6 +1403,8 @@ def forward( kv_caches=kv_caches, attn_metadata=attn_metadata, ) + # if positions.numel() == 1 and positions.item() == 20: + # exit(0) return outputs diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index e06a3dab2cdc..d904b035c7f5 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -112,12 +112,9 @@ def get_tokenizer( 'encoding and decoding.', FutureWarning, stacklevel=2) - print("get tokenizer, tokenizer_name:", tokenizer_name, "Meta-Llama-3.2-11B-Vision-Early" in tokenizer_name,) if tokenizer_mode == "mistral": tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name), revision=revision) - elif ("Meta-Llama-3.2-11B-Vision-Early" in str(tokenizer_name) or "Meta-Llama-3.2-90B-Vision-Early" in str(tokenizer_name)) and "checkpoints" not in str(tokenizer_name): - tokenizer = LlamaVLTokenizer.from_pretrained(str(tokenizer_name)) else: try: tokenizer = AutoTokenizer.from_pretrained( From 8f3989e794c0392a63d66c5c8a791ed635e84ccc Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 19 Sep 2024 00:15:34 -0700 Subject: [PATCH 32/75] small fix --- vllm/model_executor/models/mllama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 43578d26a602..669bb6f626ff 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -964,7 +964,7 @@ def __init__(self, config: MllamaTextConfig, cache_config:Optional[CacheConfig]) super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, config.hidden_size, self.padding_idx) + self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, config.hidden_size) self.cross_attention_layers = config.cross_attention_layers layers = [] From 01621a57e23c2cbb6c04dfc30db9554a1c9a1e63 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 19 Sep 2024 11:19:02 -0700 Subject: [PATCH 33/75] remove old code --- vllm/entrypoints/chat_utils.py | 2 +- vllm/model_executor/models/__init__.py | 1 - vllm/model_executor/models/llamavl.py | 2200 ----------------- vllm/transformers_utils/config.py | 3 +- vllm/transformers_utils/configs/__init__.py | 2 - vllm/transformers_utils/configs/llamavl.py | 52 - vllm/transformers_utils/image_processor.py | 4 - .../multimodal_processors/llamavl.py | 366 --- vllm/transformers_utils/tokenizer.py | 5 +- .../transformers_utils/tokenizers/__init__.py | 1 - vllm/transformers_utils/tokenizers/llamavl.py | 221 -- 11 files changed, 4 insertions(+), 2853 deletions(-) delete mode 100644 vllm/model_executor/models/llamavl.py delete mode 100644 vllm/transformers_utils/configs/llamavl.py delete mode 100644 vllm/transformers_utils/multimodal_processors/llamavl.py delete mode 100644 vllm/transformers_utils/tokenizers/llamavl.py diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index d243718e7ca9..a4d0f7c44437 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -159,7 +159,7 @@ def _placeholder_str(self, modality: ModalityStr, hf_config.image_token_index) if model_type in ("chameleon", "internvl_chat"): return "" - if model_type == "llamavl" or model_type == "mllama": + if model_type == "mllama": return "<|image|>" if model_type == "qwen2_vl": return "<|vision_start|><|image_pad|><|vision_end|>" diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 94d7854f6ae7..78afecd8e5ab 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -96,7 +96,6 @@ "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), "UltravoxModel": ("ultravox", "UltravoxModel"), - "LlamaVLForCausalLM": ("llamavl", "LlamaVLForCausalLM"), "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), } _CONDITIONAL_GENERATION_MODELS = { diff --git a/vllm/model_executor/models/llamavl.py b/vllm/model_executor/models/llamavl.py deleted file mode 100644 index 84084651c762..000000000000 --- a/vllm/model_executor/models/llamavl.py +++ /dev/null @@ -1,2200 +0,0 @@ -from array import array -from dataclasses import dataclass -from functools import partial -import itertools -import collections -import math -from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, - TypedDict, Union, Callable, Dict, Any, Set) - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision.transforms as tv -from PIL import Image - -from vllm.attention import Attention, AttentionMetadata, AttentionType -from vllm.attention.ops.paged_attn import PagedAttentionMetadata -from vllm.config import CacheConfig, MultiModalConfig -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs -from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.logger import init_logger -from vllm.sequence import IntermediateTensors -from .interfaces import SupportsMultiModal -from .llama import LlamaAttention, LlamaMLP -# from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear, - ColumnParallelLinear) -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) - -from vllm.transformers_utils.multimodal_processors.llamavl import LlamaVLImageProcessor -import vllm.distributed.parallel_state as ps -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData - -step_name = "prefill" -pt_dir = "/home/eecs/zhang-chen/MultiModal/scripts/" - -def check(tensor, file_name): pass - # with open(f"{pt_dir}{file_name}", "rb") as f: - # data = torch.load(f) - # tensor_flat = tensor.cpu().reshape(-1) - # data_flat = data.cpu().reshape(-1) - # if tensor_flat.shape != data_flat.shape: - # print("check:", file_name, "shape missmatch", tensor_flat.shape, data_flat.shape) - # return - # print("check:", file_name, torch.allclose(tensor_flat, data_flat), torch.max(torch.abs(tensor_flat-data_flat)), tensor.shape, data.shape) - -logger = init_logger(__name__) -MP_SCALE = 8 -IMAGE_RES = 224 -LLAMA_IMAGE_TOKEN_ID = 128256 - -class LlamaImagePixelInputs(TypedDict): - type: Literal["pixel_values"] - data: torch.Tensor - """Shape: `(batch_size, max_num_image, max_num_chunk, num_channels, height, width)`""" - aspect_ratios: torch.Tensor - """Shape: `(batch_size, max_num_image, 2)`""" - num_chunks: List[List[int]] - -# TODO: support LlamaImageEmbeddingInputs - -LlavaImageInputs = LlamaImagePixelInputs -image_processor = None - -def input_processor_for_llamavl(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("encoder_multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return llm_inputs - global image_processor - if image_processor is None: - image_processor = LlamaVLImageProcessor(ctx.model_config.model) - - processed_image = image_processor(multi_modal_data["image"]) - llm_inputs["encoder_multi_modal_data"]["image"] = processed_image - - num_chunks = int(processed_image["aspect_ratios"].sum()) - assert ctx.model_config.hf_config.vision_chunk_size % 14 == 0, "chunk size should be multiple of 14" - token_per_chunk = (ctx.model_config.hf_config.vision_chunk_size // 14) ** 2 + 1 - num_tokens = num_chunks * token_per_chunk - llm_inputs["encoder_prompt"] = "<|image|>" * num_tokens - llm_inputs["encoder_prompt_token_ids"] = [LLAMA_IMAGE_TOKEN_ID] * num_tokens - - assert "decoder_multi_modal_data" not in llm_inputs, "multi-modal data should be put in encoder message of LLaMA Vision" - - return llm_inputs - - -def dummy_seq_data( - seq_len: int, - num_images: int -): - assert seq_len >= num_images, "seq_len should be greater than or equal to num_images" - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [LLAMA_IMAGE_TOKEN_ID]) * num_images - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - num_images) - return SequenceData(token_ids) - - -def dummy_image( - num_images: int, -): - width = height = 512 - image = Image.new("RGB", (width, height), color=0) - return {"image": image if num_images == 1 else [image] * num_images} - -def dummy_data_for_llamavl(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): - num_images = mm_counts["image"] - return dummy_seq_data(seq_len, num_images), dummy_image(num_images) - -def get_max_llama_image_tokens(ctx: InputContext) -> int: - hf_config = ctx.model_config.hf_config - token_per_chunk = (hf_config.vision_chunk_size // 14) ** 2 + 1 - return hf_config.max_num_image * hf_config.vision_max_num_chunks * token_per_chunk - -def to_2tuple(x): - if isinstance(x, collections.abc.Iterable): - return x - return (x, x) - -def resize_local_position_embedding(orig_pos_embed, grid_size): - """ - Resize position embedding for vision encoder. - Original position embedding is [n_tiles * n_tiles + 1, dim] - New position embedding will be [grid_size[0] * grid_size[1] + 1, dim] - """ - new_grid_size = to_2tuple(grid_size) - orig_grid_size = to_2tuple(int(math.sqrt(len(orig_pos_embed) - 1))) - new_seq_len = new_grid_size[0] * new_grid_size[1] + 1 - - new_pos_emb_tok, new_pos_emb_img = ( - orig_pos_embed[:1], - orig_pos_embed[1:], - ) - logger.debug( - f"resizing position embedding grid-size from {orig_grid_size} to {new_grid_size}" - ) - - new_pos_emb_img = new_pos_emb_img.reshape( - 1, orig_grid_size[0], orig_grid_size[1], -1 - ).permute(0, 3, 1, 2) - - new_pos_emb_img = F.interpolate( - new_pos_emb_img, - size=new_grid_size, - mode="bilinear", - align_corners=True, - ) - new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1).reshape( - 1, new_grid_size[0] * new_grid_size[1], -1 - )[0] - new_pos_embed = torch.cat([new_pos_emb_tok, new_pos_emb_img], dim=0) - return new_pos_embed - - -def initialize_global_position_embedding_from_local( - pos_and_cls_embed, grid_size, x_scale, y_scale -): - """ - Takes a local position embedding for vision encoder and uses it - to initialize the global position embedding. - Input: local position embedding of shape [grid_size[0] * grid_size[1] + 1, dim] - Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim] - Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively. - """ - pos_embed = pos_and_cls_embed[1:] - cls_embed = pos_and_cls_embed[0].view(1, 1, 1, -1) - grid_size = to_2tuple(grid_size) - new_pos_emb_img = pos_embed.reshape(1, grid_size[0], grid_size[1], -1).permute( - 0, 3, 1, 2 - ) - new_grid_size = (x_scale * grid_size[0], y_scale * grid_size[1]) - new_pos_emb_img = F.interpolate( - new_pos_emb_img, - size=new_grid_size, - mode="bilinear", - align_corners=True, - ) - new_pos_emb_img = new_pos_emb_img.permute(0, 2, 3, 1) - new_pos_emb_img = new_pos_emb_img.view( - x_scale, grid_size[0], y_scale, grid_size[1], -1 - ) - new_pos_emb_img = new_pos_emb_img.permute(0, 2, 1, 3, 4).contiguous() - new_pos_emb_img = new_pos_emb_img.reshape( - x_scale, y_scale, grid_size[0] * grid_size[1], -1 - ) - cls_embed = cls_embed.expand(x_scale, y_scale, -1, -1) - pos_and_cls_embed = torch.cat([cls_embed, new_pos_emb_img], dim=2) - return pos_and_cls_embed - - -def resize_global_position_embedding(pos_and_cls_embed, grid_size, x_scale, y_scale): - """ - Takes a global position embedding for vision encoder and resizes it to new size. - Input: global position embedding of shape [x_old, y_old, old_grid_size[0] * old_grid_size[1] + 1, dim] - Returns: global position embedding of shape [x_scale, y_scale, grid_size[0] * grid_size[1] + 1, dim] - Here x_scale and y_scale are the number of tiles along x-axis and y-axis respectively. - """ - # first remove cls token - pos_embed = pos_and_cls_embed[:, :, 1:] - cls_embed = pos_and_cls_embed[:, :, 0].unsqueeze(2) - - xs_old, ys_old, ntok, dim = pos_embed.shape - old_grid_size = int(math.sqrt(ntok)) - - # move to correct form for interpolation - pos_embed = pos_embed.view(xs_old, ys_old, old_grid_size, old_grid_size, dim) - pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous() - pos_embed = pos_embed.view(xs_old * old_grid_size, ys_old * old_grid_size, dim) - pos_embed = pos_embed.unsqueeze(0) - - # interpolate - new_size = (grid_size[0] * x_scale, grid_size[1] * y_scale) - pos_embed = pos_embed.permute(0, 3, 1, 2) - pos_embed_resized = F.interpolate( - pos_embed, - size=new_size, - mode="bilinear", - align_corners=True, - ) - pos_embed = pos_embed_resized.permute(0, 2, 3, 1)[0] - - # move it back in place - pos_embed = pos_embed.view(x_scale, grid_size[0], y_scale, grid_size[1], dim) - pos_embed = pos_embed.permute(0, 2, 1, 3, 4).contiguous() - pos_embed = pos_embed.view(x_scale, y_scale, grid_size[0] * grid_size[1], dim) - - # interpolate cls token - cls_embed = cls_embed.permute(2, 3, 0, 1) - cls_embed_resized = F.interpolate( - cls_embed, - size=(x_scale, y_scale), - mode="bilinear", - align_corners=True, - ) - cls_embed = cls_embed_resized.permute(2, 3, 0, 1) - # add cls token back in - pos_and_cls_embed = torch.cat([cls_embed, pos_embed], dim=2) - - return pos_and_cls_embed - - -def build_encoder_attention_mask( - x: torch.Tensor, - ar: torch.Tensor, - ntok: int, - num_chunks: int, - n_heads: int, -): - """ - Build vision encoder attention mask that omits padding tokens. - """ - masks = [] - for arx in ar: - mask_i = torch.ones((num_chunks, x.shape[2], 1), dtype=x.dtype) - mask_i[: arx[0] * arx[1], :ntok] = 0 - mask_i = mask_i.view(num_chunks * x.shape[2], -1) - mask_i = mask_i @ mask_i.T * torch.finfo(x.dtype).min - mask_i = mask_i.unsqueeze(0) - masks.append(mask_i) - masks = torch.stack(masks).to(x.device).expand(-1, n_heads, -1, -1) - return masks - - -def expand_num_tokens_to_mult8(x): - num_pad_tokens = 8 - (x.shape[-2] % 8) - if num_pad_tokens == 0: - return x, 0 - else: - return ( - torch.cat( - [ - x, - torch.zeros( - (x.shape[0], x.shape[1], num_pad_tokens, x.shape[-1]), - dtype=x.dtype, - device=x.device, - ), - ], - dim=-2, - ), - num_pad_tokens, - ) - - -def contract_num_tokens_from_mult8(x, num_pad_tokens): - if num_pad_tokens == 0: - return x - return x[:, :, :-num_pad_tokens] - -def apply_scaling(freqs: torch.Tensor): - # Values obtained from grid search - scale_factor = 8 - low_freq_factor = 1 - high_freq_factor = 4 - old_context_len = 8192 # original llama3 length - - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor - new_freqs = [] - for freq in freqs: - wavelen = 2 * math.pi / freq - if wavelen < high_freq_wavelen: - new_freqs.append(freq) - elif wavelen > low_freq_wavelen: - new_freqs.append(freq / scale_factor) - else: - assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - low_freq_factor) / ( - high_freq_factor - low_freq_factor - ) - new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) - return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) - - -def precompute_freqs_cis( - dim: int, end: int, theta: float = 10000.0, use_scaled: bool = False -): - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device, dtype=torch.float32) - if use_scaled: - freqs = apply_scaling(freqs) - freqs = torch.outer(t, freqs) - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis - -def _get_full_row_masked_out_mask( - attn_bias, - negative_inf_value, -): - """ - attn_bias should be a 4D tensor of shape [B, H, S1, S2] - where B is the batch size, H is the number of heads, - and S1/S2 are the sequence lengths. This returns - a 4D tensor of shape [B, H, S1, 1] which stores boolean - values which are 0 if the a full row in the last dimension - contains negative infinity values, otherwise it's 1. - """ - return (attn_bias != negative_inf_value).any(dim=-1).type_as(attn_bias)[..., None] - -# use float RMSNorm to make result closer to reference impl. -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, eps: float = 1e-6): - super().__init__() - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def _norm(self, x): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) - - def forward(self, x): - output = self._norm(x.float()).type_as(x) - return output * self.weight - - -# Image encoder for inference -class LayerNorm(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16.""" - - def forward(self, x: torch.Tensor): - x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - return x - - -class ColumnParallelConv2dPatch(torch.nn.Module): - """Conv2D Patching layer with model parallelism. - Column parallel over unfolded input. - Arguments: - in_channels: Input channels. - out_channels: Output channels. - kernel_size: Size of convolution kernel. - stride (default 1): Stride for convolution. - bias (default False): Use bias in Conv2d. - Input: (bsz, in_channels, width, height) - Output: (bsz, num_tokens, out_channels) - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int, int]], - stride: Union[int, Tuple[int, int]], - bias: bool = False, - ) -> None: - super().__init__() - if isinstance(kernel_size, int): - kernel_size = (kernel_size, kernel_size) - self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride) - self._linear = ColumnParallelLinear( - in_channels * kernel_size[0] * kernel_size[1], - out_channels, - bias=bias, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self._unfold(x) - x = x.permute(0, 2, 1) - x, _ = self._linear(x) - return x - - -class ImageFeedForward(torch.nn.Module): - def __init__( - self, - dim: int, - hidden_dim: int, - dropout: float, - act_layer: Callable = nn.GELU, - ): - super().__init__() - # layers - self.c_fc = ColumnParallelLinear( - dim, - hidden_dim, - bias=True, - ) - self.c_proj = RowParallelLinear( - hidden_dim, - dim, - bias=True, - input_is_parallel=True, - skip_bias_add=True, # add bias explicitly for precision concern - ) - self.non_linearity = act_layer() - self.dropout = dropout - - def forward(self, x): - hidden, _ = self.c_fc(x) - hidden = self.non_linearity(hidden) - hidden, bias = self.c_proj(hidden) # skip_bias_add=True - hidden += bias - return hidden - - -class ImageAttention(nn.Module): - def __init__( - self, - dim, - n_heads, - ): - super().__init__() - model_parallel_size = get_tensor_model_parallel_world_size() - self.n_heads = n_heads - self.n_kv_heads = n_heads - self.n_local_heads = n_heads // model_parallel_size - self.n_local_kv_heads = ( - self.n_kv_heads // model_parallel_size - ) - self.n_rep = self.n_local_heads // self.n_local_kv_heads - self.head_dim = dim // n_heads - self.q_size = self.n_local_heads * self.head_dim - self.kv_size = self.n_local_kv_heads * self.head_dim - assert self.n_heads % self.n_kv_heads == 0 - assert self.n_heads % model_parallel_size == 0 - assert self.n_kv_heads % model_parallel_size == 0 - - # The model provided by llama is with bias=True, but the weight does not contain bias - # During runtime, the llama executor set bias to zero. We use bias=False here to match the behavior - self.qkv_proj = QKVParallelLinear( - dim, - self.head_dim, - n_heads, - bias=False, - ) - self.wo = RowParallelLinear( - n_heads * self.head_dim, - dim, - bias=False, - input_is_parallel=True, - ) - - def forward( - self, - x: torch.Tensor, - mask: torch.Tensor = None, - ): - qkv, _ = self.qkv_proj(x) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = q.view(q.shape[0], q.shape[1], self.n_local_heads, self.head_dim).transpose(1, 2) - k = k.view(k.shape[0], k.shape[1], self.n_local_kv_heads, self.head_dim).transpose(1, 2) - v = v.view(v.shape[0], v.shape[1], self.n_local_kv_heads, self.head_dim).transpose(1, 2) - - # TODO: remove padding in image encoder - attn_output = F.scaled_dot_product_attention( - q, k, v, attn_mask=mask, dropout_p=0.0 - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(attn_output.shape[0], attn_output.shape[1], -1) - out, _ = self.wo(attn_output) - return out - - -class ImageTransformerBlock(nn.Module): - def __init__( - self, - d_model: int, - n_head: int, - mlp_ratio: float = 4.0, - act_layer: Callable = nn.GELU, - gated: bool = False, - ): - super().__init__() - assert d_model % n_head == 0 - self.n_heads = n_head - self.head_dim = d_model // self.n_heads - self.attn = ImageAttention( - dim=d_model, - n_heads=self.n_heads, - ) - self.ln_1 = LayerNorm(d_model) - self.mlp = ImageFeedForward( - dim=d_model, - hidden_dim=int(mlp_ratio * d_model), - dropout=0.0, - act_layer=act_layer, - ) - self.ln_2 = LayerNorm(d_model) - self.gated = gated - if gated: - self.gate_attn = nn.Parameter(torch.zeros(1)) - self.gate_ffn = nn.Parameter(torch.zeros(1)) - - def forward( - self, - x: torch.Tensor, - mask: torch.Tensor = None, - ): - _gate_attn = 1 if not self.gated else self.gate_attn.tanh() - _gate_ffn = 1 if not self.gated else self.gate_ffn.tanh() - x = x + _gate_attn * self.attn(self.ln_1(x), mask=mask) - x = x + _gate_ffn * self.mlp(self.ln_2(x)) - return x - - -class ImageTransformer(nn.Module): - def __init__( - self, - width: int, - layers: int, - heads: int, - mlp_ratio: float = 4.0, - act_layer: Callable = nn.GELU, - gated: bool = False, - ): - super().__init__() - self.width = width - self.layers = layers - self.resblocks = nn.ModuleList( - [ - ImageTransformerBlock( - d_model=width, - n_head=heads, - mlp_ratio=mlp_ratio, - act_layer=act_layer, - gated=gated, - ) - for _ in range(self.layers) - ] - ) - - def forward(self, x: torch.Tensor, return_intermediate=None, mask=None): - out = [] - for idx, r in enumerate(self.resblocks): - if return_intermediate is not None and idx in return_intermediate: - out.append(x) - x = r(x, mask=mask) - if return_intermediate is not None: - return x, torch.stack(out, dim=-1) - return x - - -class VisionEncoder(nn.Module): - def __init__( - self, - max_num_tiles: int, - # ckpt_path: str = None, - image_size: int = 224, - patch_size: int = 14, - width: int = 1280, - layers: int = 32, - heads: int = 16, - mlp_ratio: float = 4.0, - act_layer: Callable = nn.GELU, - in_channels: int = 3, - # load_ckpt: bool = False, - n_global_layers: int = 2, - global_model: bool = False, - return_intermediate=None, - ): - super().__init__() - self.global_model = global_model - self.return_intermediate = return_intermediate - self.max_num_tiles = max_num_tiles - self.image_size = to_2tuple(image_size) - self.patch_size = to_2tuple(patch_size) - self.grid_size = ( - self.image_size[0] // self.patch_size[0], - self.image_size[1] // self.patch_size[1], - ) - self.conv1 = ColumnParallelConv2dPatch( - in_channels=in_channels, - out_channels=width, - kernel_size=patch_size, - stride=patch_size, - bias=False, - ) - scale = width**-0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter( - scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width) - ) - self.ln_post = LayerNorm(width) - self.ln_pre = LayerNorm(width) - self.transformer = ImageTransformer( - width, layers, heads, mlp_ratio, act_layer=act_layer - ) - # pre and post tile position embedding - self.global_transformer = ImageTransformer( - width, n_global_layers, heads, mlp_ratio, act_layer=act_layer, gated=True - ) - # pre and post tile position embedding - self.pre_tile_pos_embed = TilePositionEmbedding( - num_tiles=max_num_tiles, - width=width, - gated=True, - ) - self.post_tile_pos_embed = TilePositionEmbedding( - num_tiles=max_num_tiles, - width=width, - gated=True, - ) - self.gated_positional_embedding = nn.Parameter( - scale - * torch.randn( - max_num_tiles, - max_num_tiles, - self.grid_size[0] * self.grid_size[1] + 1, - width, - ) - ) - self.gated_positional_embedding_gate = nn.Parameter(torch.zeros(1)) - - def apply_positional_embedding(self, x, ar): - out = [] - # apply regular position embedding - bsz, num_chunks, num_tokens, dim = x.shape - x = x.view(bsz * num_chunks, num_tokens, dim) - x = x + self.positional_embedding * ( - 1 - self.gated_positional_embedding_gate.tanh() - ) - x = x.view(bsz, num_chunks, num_tokens, dim) - for idx, arx in enumerate(ar): - _pos_embed = self.gated_positional_embedding[: arx[0], : arx[1]] - _pos_embed = _pos_embed.reshape(arx[0] * arx[1], *_pos_embed.shape[2:]) - x[idx, : arx[0] * arx[1]] += ( - _pos_embed * self.gated_positional_embedding_gate.tanh() - ) - return x - - def apply_class_embedding(self, x): - x = torch.cat( - [ - self.class_embedding.to(x.dtype) - + torch.zeros( - x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device - ), - x, - ], - dim=1, - ) - return x - - def forward(self, images: torch.Tensor, ar: torch.Tensor) -> torch.Tensor: - # TODO: run tp in this function - if images.ndim == 5: - num_concurrent_media = 1 - bsz, num_chunks, nch, w, h = images.shape - else: - bsz, num_concurrent_media, num_chunks, nch, w, h = images.shape - - images = images.reshape(bsz * num_concurrent_media * num_chunks, nch, w, h) - ar = ar.reshape(bsz * num_concurrent_media, 2) - - # patch embedding - x = images.reshape(bsz * num_concurrent_media * num_chunks, nch, w, h) - x = self.conv1(x) - x = ps.get_tp_group().all_gather(x) - _, ntok, dim = x.shape - x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim) - - # tile embeddings - x = self.pre_tile_pos_embed(x, ar) # call all_gather here, dim will change - x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim) - - # apply cls token - x = self.apply_class_embedding(x) - ntok += 1 - - # apply position embeddings - x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim) - x = self.apply_positional_embedding(x, ar) - - x = self.ln_pre(x) - npad, attn_mask = 0, None - x, npad = expand_num_tokens_to_mult8(x) - attn_mask = build_encoder_attention_mask(x, ar, ntok, num_chunks, 1) - x = x.view(bsz * num_concurrent_media, -1, dim) - x, int_x = self.transformer( - x, return_intermediate=self.return_intermediate, mask=attn_mask - ) - - x = self.ln_post(x) - x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, dim) - x = self.post_tile_pos_embed(x, ar) - x = x.reshape(bsz * num_concurrent_media, num_chunks * (ntok + npad), dim) - x = self.global_transformer(x, mask=attn_mask) - x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, dim) - x = contract_num_tokens_from_mult8(x, npad) - - # adding back intermediate layer outputs - x = x.reshape(bsz, num_concurrent_media, num_chunks, ntok, dim) - int_x = int_x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, -1) - int_x = contract_num_tokens_from_mult8(int_x, npad) - int_x = int_x.reshape(bsz, num_concurrent_media, num_chunks, ntok, -1) - x = torch.cat([x, int_x], dim=-1) - return x - - -class LlamaVLAttention(LlamaAttention): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - rope_scaling = kwargs.get("rope_scaling", None) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=self.max_position_embeddings, - base=self.rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False, # force to use neox=False - ) - -class TransformerBlock(nn.Module): - def __init__(self, layer_id: int, args, cache_config: Optional[CacheConfig] = None): - """ - Initialize a TransformerBlock. - Args: - layer_id (int): Identifier for the layer. - args (ModelArgs): Model configuration parameters. - Attributes: - n_heads (int): Number of attention heads. - dim (int): Dimension size of the model. - head_dim (int): Dimension size of each attention head. - attention (Attention): Attention module. - feed_forward (FeedForward): FeedForward module. - layer_id (int): Identifier for the layer. - attention_norm (RMSNorm): Layer normalization for attention output. - ffn_norm (RMSNorm): Layer normalization for feedforward output. - """ - super().__init__() - self.n_heads = args.n_heads - self.dim = args.dim - self.head_dim = args.dim // args.n_heads - # TODO: remove "use_scaled_rope" from args - self.attention = LlamaVLAttention( - config=args, - hidden_size=args.dim, - num_heads=self.n_heads, - num_kv_heads=args.n_kv_heads, - rope_theta=args.rope_theta, - rope_scaling=args.rope_scaling, - max_position_embeddings=512, - quant_config=None, - bias=False, - cache_config=cache_config, - prefix=f"tb.{layer_id}.self_attn", - ) - # logger.warning("skip attention") - - hidden_dim = args.dim * 4 - hidden_dim = int(2 * hidden_dim / 3) - # custom dim factor multiplier - if args.ffn_dim_multiplier is not None: - hidden_dim = int(args.ffn_dim_multiplier * hidden_dim) - hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of) - - self.feed_forward = LlamaMLP( - hidden_size=args.dim, - intermediate_size=hidden_dim, - hidden_act="silu", - ) - self.layer_id = layer_id - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) - - def forward( - self, - x: torch.Tensor, - positions: torch.LongTensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - """ - Perform a forward pass through the TransformerBlock. - Args: - x (torch.Tensor): Input tensor. - start_pos (int): Starting position for attention caching. - freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. - mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None. - Returns: - torch.Tensor: Output tensor after applying attention and feedforward layers. - """ - # TODO: need to compute qkv and then do attention - h = self.attention.forward( - positions=positions, - hidden_states=self.attention_norm(x), - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - h = h + x - out = h + self.feed_forward.forward(self.ffn_norm(h)) - return out - - -class TilePositionEmbedding(nn.Module): - def __init__( - self, - num_tiles: int, - width: int, - gated: bool = False, - ): - super().__init__() - self.num_tiles = num_tiles - self.width = width - self.embedding = nn.Parameter( - torch.randn(num_tiles, num_tiles, 1, width) / math.sqrt(width) - ) - self.gated = gated - if gated: - self.gate = nn.Parameter(torch.zeros(1)) - - @staticmethod - def _dynamic_resize(embed: torch.Tensor, num_tiles: int): - nt_old, nt_old, _, w = embed.shape - embed = embed.permute(2, 3, 0, 1) - - embed_new = F.interpolate( - embed, - size=(num_tiles, num_tiles), - mode="bilinear", - align_corners=True, - ) - # reshape the weights to the correct shape - embed_new = embed_new.permute(2, 3, 0, 1) - return embed_new - - def forward(self, x: torch.Tensor, ar: torch.Tensor, num_tiles: int = None): - embed = self.embedding - if num_tiles is None: - num_tiles = self.num_tiles - elif num_tiles > self.num_tiles: - embed = TilePositionEmbedding._dynamic_resize(self.embedding, num_tiles) - out_pos_embed = torch.zeros( - x.shape[0], num_tiles, 1, self.width, device=x.device, dtype=x.dtype - ) - for idx, arx in enumerate(ar): - h, w = arx - out_pos_embed[idx, : w * h] = embed[:h, :w].reshape(w * h, 1, self.width) - if self.gated: - out_pos_embed = out_pos_embed * self.gate.tanh() - x = x + out_pos_embed - return x - - -def _noinit(x): - return x - - -class CrossAttention(torch.nn.Module): - """Cross attention layer with model-parallel attention layers.""" - - def __init__( - self, - dim: int, - head_dim: int, - n_heads: int, - n_kv_heads: int, - norm_eps: float, - ): - super().__init__() - self.model_parallel_size = get_tensor_model_parallel_world_size() - replication_factor = 1 - if self.model_parallel_size > 8: - replication_factor = self.model_parallel_size // MP_SCALE - n_kv_heads *= replication_factor - - assert n_heads % n_kv_heads == 0 - - # TODO: change to Q/KV seperate linear after #7448 is merged - self.qkv_proj = QKVParallelLinear( - dim, - head_dim, - n_heads, - n_kv_heads, - bias=False, - ) - - # self.wqkv = ColumnParallelLinear( - # dim, - # n_heads * head_dim, - # bias=False, - # gather_output=False, - # ) - # self.wk = ColumnParallelLinear( - # dim, - # n_kv_heads * head_dim, - # bias=False, - # gather_output=False, - # ) - # self.wv = ColumnParallelLinear( - # dim, - # n_kv_heads * head_dim, - # bias=False, - # gather_output=False, - # ) - self.wo = RowParallelLinear( - n_heads * head_dim, - dim, - bias=False, - input_is_parallel=True, - ) - - self.n_heads = n_heads - self.head_dim = head_dim - self.n_kv_heads = n_kv_heads - - self.q_norm = RMSNorm( - self.head_dim, - eps=norm_eps, - ) - self.k_norm = RMSNorm( - self.head_dim, - eps=norm_eps, - ) - self.scaling = self.head_dim**-0.5 - - # cross-attention heads are model parallel similar to - # self-attention, and we also use the identical KV head - # combination to ensure parity with the corresponding - # trunk LLM (i.e., group query attention) -- @dubeya - # local heads - assert self.n_heads % self.n_kv_heads == 0 - assert self.n_heads % self.model_parallel_size == 0 - assert self.n_kv_heads % self.model_parallel_size == 0 - self.n_local_heads = self.n_heads // self.model_parallel_size - self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size - self.n_rep = self.n_local_heads // self.n_local_kv_heads - self.q_local_size = self.n_local_heads * self.head_dim - self.kv_local_size = self.n_local_kv_heads * self.head_dim - - self.attn = Attention( - self.n_local_heads, - self.head_dim, - self.scaling, - self.n_local_kv_heads, - ) - - # def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: - # bsz = xattn_tokens.shape[0] - # xk, _ = self.wk(xattn_tokens) - # xv, _ = self.wv(xattn_tokens) - - # # _, seqlen_y, _ = xk.shape - - # xk = xk.view(-1, self.n_local_kv_heads, self.head_dim) - # xv = xv.view(-1, self.n_local_kv_heads, self.head_dim) - - # xk = self.k_norm(xk) - - # return torch.stack([xk, xv]) - - # def compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor: - # return self._compute_xattn_kv_cache(xattn_tokens) - - def forward( - self, - decoder_hidden_states: torch.Tensor, - # xattn_mask: torch.Tensor, - # full_text_row_masked_out_mask: torch.Tensor, - # xattn_cache: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - encoder_hidden_states: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qkv_dec, _ = self.qkv_proj(decoder_hidden_states) - q, _, _ = qkv_dec.split([self.q_local_size, self.kv_local_size, self.kv_local_size], - dim=-1) - if encoder_hidden_states is None: - k = None - v = None - else: - qkv_enc, _ = self.qkv_proj(encoder_hidden_states) - _, k, v = qkv_enc.split([self.q_local_size, self.kv_local_size, self.kv_local_size], - dim=-1) - k = k.view(-1, self.n_local_kv_heads, self.head_dim) - v = v.view(-1, self.n_local_kv_heads, self.head_dim) - k = self.k_norm(k) - q = q.view(-1, self.n_local_heads, self.head_dim) - q = self.q_norm(q) - - output = self.attn(q, k, v, kv_cache, attn_metadata, attn_type=AttentionType.ENCODER_DECODER) - out, _ = self.wo(output) - return out - - -class CrossAttentionTransformerBlock(torch.nn.Module): - """Cross-attention transformer block with tanh-gated attention and feedforward.""" - - def __init__( - self, - args, - layer_id: int, - no_ffn: bool = False, - ) -> None: - super().__init__() - self.layer_id = layer_id - self.n_heads = args.n_heads - self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads - self.dim = args.dim - self.head_dim = args.dim // args.n_heads - self.attention = CrossAttention( - dim=args.dim, - head_dim=self.head_dim, - n_heads=self.n_heads, - n_kv_heads=self.n_kv_heads, - norm_eps=args.norm_eps, - ) - - self.attention_norm = RMSNorm( - args.dim, - eps=args.norm_eps, - ) - self.gate_attn = torch.nn.Parameter(torch.zeros(1)) - - hidden_dim = args.dim * 4 - hidden_dim = int(2 * hidden_dim / 3) - # custom dim factor multiplier - if args.ffn_dim_multiplier is not None: - hidden_dim = int(args.ffn_dim_multiplier * hidden_dim) - hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of) - - self.feed_forward = LlamaMLP( - hidden_size=args.dim, - intermediate_size=hidden_dim, - hidden_act="silu", - ) - - self.ffn_norm = RMSNorm( - args.dim, - eps=args.norm_eps, - ) - self.gate_ffwd = torch.nn.Parameter(torch.zeros(1)) - - self.no_ffn = no_ffn - - def forward( - self, - x: torch.Tensor, - # xattn_mask: torch.Tensor, - # full_text_row_masked_out_mask: torch.Tensor, - # xattn_cache: torch.Tensor, - kv_cache: torch.LongTensor, - attn_metadata: AttentionMetadata, - vision_hidden_states: Optional[torch.Tensor], - run_xattn_mask: torch.Tensor, - ) -> torch.Tensor: - _attn_out = self.attention( - decoder_hidden_states=self.attention_norm(x), - # xattn_mask=xattn_mask, - # full_text_row_masked_out_mask=full_text_row_masked_out_mask, - # xattn_cache=xattn_cache, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - encoder_hidden_states=vision_hidden_states, - ) - # import pdb; pdb.set_trace() - h = x + self.gate_attn.tanh() * _attn_out * run_xattn_mask - _ffn = self.feed_forward(self.ffn_norm(h)) - # _ffn = full_text_row_masked_out_mask * _ffn # type: ignore - h = h + self.gate_ffwd.tanh() * _ffn * float(not self.no_ffn) * run_xattn_mask - return h - - -class DummyCrossAttentionTransformerBlock: - """Dummy cross-attention transformer block with tanh-gated attention and feedforward.""" - - def __call__( - self, - x: torch.Tensor, - *args, - **kwargs, - ) -> torch.Tensor: - return x - - -class DummySelfAttentionTransformerBlock: - """Dummy self-attention transformer block""" - - def __call__( - self, - x: torch.Tensor, - *args, - **kwargs, - ) -> torch.Tensor: - return x - - -class CrossAttentionTransformerVision(torch.nn.Module): - def __init__(self, args) -> None: - super().__init__() - return_intermediate = "3,7,15,23,30" - self.vision_input_dim = 1280 - self.image_res = args.vision_chunk_size - self.max_num_chunks = args.vision_max_num_chunks - if return_intermediate is not None: - return_intermediate = [int(l) for l in return_intermediate.split(",")] - self.vision_input_dim = ( - len(return_intermediate) + 1 - ) * self.vision_input_dim - self.patch_size = 14 - self.vision_encoder = VisionEncoder( - max_num_tiles=4, - image_size=args.vision_chunk_size, - patch_size=self.patch_size, - n_global_layers=8, - global_model=True, - return_intermediate=return_intermediate, - ) - # vision token projection - self.vision_projection = nn.Linear( - self.vision_input_dim, - args.dim, - bias=True, - ) # ORZZZZZZZZZZ - # self.vision_projection = ColumnParallelLinear( - # self.vision_input_dim, - # args.dim, - # bias=True, - # init_method=lambda x: x, - # ) - - def forward( - self, images: torch.Tensor, aspect_ratios: torch.Tensor - ) -> torch.Tensor: - # vision_tokens: (B, T, D) - # aspect_ratios: (B, T) - # h: (B, T, D) - vision_tokens = self.vision_encoder( - images.to(dtype=torch.bfloat16), aspect_ratios - ) - vision_tokens = self.vision_projection(vision_tokens) - # vision_tokens = gather_from_tensor_model_parallel_region(vision_tokens) - return vision_tokens - - -class CrossAttentionTransformerText(torch.nn.Module): - INFERENCE_IMAGE_TOKEN_ID = 128010 - - def __init__(self, args, cache_config:Optional[CacheConfig]) -> None: - super().__init__() - self.model_parallel_size = get_tensor_model_parallel_world_size() - assert args.vocab_size > 0 - self.vocab_size = args.vocab_size - self.n_layers = args.n_layers - self.dim = args.dim - self.head_dim = args.dim // args.n_heads - self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads - self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size - assert self.vocab_size % self.model_parallel_size == 0 - self.tok_embeddings = VocabParallelEmbedding( - args.vocab_size, args.dim, - padding_size=self.model_parallel_size, - ) - self.pos_embeddings = None - # final norm layer (not necessary for post-norm) - self.norm = RMSNorm(args.dim, eps=args.norm_eps) - - # output layer - # self.output = nn.Linear(args.dim, args.vocab_size, bias=False) - # self.output = ColumnParallelLinear( - # args.dim, args.vocab_size, bias=False, init_method=lambda x: x - # ) - - self.n_llama_layers = args.n_layers - self.model_dim = args.dim - - # BLOCKS - - self.fusion_schedule = self._init_fusion_schedule( - args.vision_num_cross_attention_layers - ) - self.learnable_embedding = VocabParallelEmbedding( - max(get_tensor_model_parallel_world_size(), 8), - args.dim, - padding_size=self.model_parallel_size, - ) - self.num_frozen_embeddings = self.tok_embeddings.num_embeddings - self._thresh = self.num_frozen_embeddings - 1 - - # transformer blocks - self.layers = torch.nn.ModuleList() - self.cross_attention_layers = torch.nn.ModuleList() - for i in range(args.n_layers): - layer_id = i - block = TransformerBlock(args=args, layer_id=layer_id, cache_config=cache_config) - self.layers.append(block) - if layer_id in self.fusion_schedule: - xa_layer_id = self.fusion_schedule.index(layer_id) + args.n_layers - block = CrossAttentionTransformerBlock( - args, - layer_id=xa_layer_id, - ) - self.cross_attention_layers.append(block) - - # add xattn and dummy layers to avoid conditionals in forward() - self.text_and_xattn_layers = [] - - for idx, layer in enumerate(self.layers): - if idx in self.fusion_schedule: - xattn_layer_idx = self.fusion_schedule.index(idx) - xattn_layer = self.cross_attention_layers[xattn_layer_idx] - else: - xattn_layer_idx = 0 - xattn_layer = DummyCrossAttentionTransformerBlock() - - self.text_and_xattn_layers.append( - ( - layer, - xattn_layer, - xattn_layer_idx, - ) - ) - self.freqs_cis = precompute_freqs_cis( - args.dim // args.n_heads, - args.max_seq_len * 2, - args.rope_theta, - args.use_scaled_rope, - ) - - self.args = args - self.cache_is_setup = False - self.max_seq_len = args.max_seq_len - - def _init_fusion_schedule( - self, - num_layers: int, - ) -> List[int]: - llama_layers = list(range(self.n_llama_layers)) - - # uniformly spread the layers - k = math.ceil(len(llama_layers) / num_layers) - return llama_layers[::-1][::k][:num_layers][::-1] - - def get_partially_trainable_embedding(self, x): - xz = torch.zeros_like(x, device=x.device) - oz = torch.ones_like(x, device=x.device) - x_orig = torch.minimum(x, torch.tensor(self._thresh, device=x.device)) - x_new = ( - torch.maximum(x, torch.tensor(self._thresh + 1, device=x.device)) - - self.num_frozen_embeddings - ) - - mask_orig = torch.where(x >= self.num_frozen_embeddings, xz, oz).unsqueeze(-1) - mask_new = torch.where(x < self.num_frozen_embeddings, xz, oz).unsqueeze(-1) - - x_orig = self.tok_embeddings(x_orig) - x_new = self.learnable_embedding(x_new).type_as(x_orig) - return x_orig * mask_orig.type_as(x_orig) + x_new * mask_new.type_as(x_new) - - def forward( - self, - positions: torch.LongTensor, - h: torch.Tensor, - # xattn_mask: torch.Tensor, - # full_text_row_masked_out_mask: torch.Tensor, - # xattn_caches: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - vision_hidden_states: Optional[torch.Tensor], - run_xattn_mask: torch.Tensor, - ): - # assert self.cache_is_setup, "Please set up cache before calling forward" - # mask = self.mask_cache.index_select(2, positions) - # freqs_cis = self.freqs_cis.index_select(0, positions) - - for idx, ( - layer, - xattn_layer, - xattn_layer_idx, - ) in enumerate(self.text_and_xattn_layers): - h = xattn_layer( - x=h, - # xattn_mask=xattn_mask, - # full_text_row_masked_out_mask=full_text_row_masked_out_mask, - # xattn_cache=xattn_caches[xattn_layer_idx] if xattn_caches is not None else None, - kv_cache=kv_caches[idx], - attn_metadata=attn_metadata, - vision_hidden_states=vision_hidden_states, - run_xattn_mask=run_xattn_mask, - ) - # check(h, f"layer_{idx}_xh_{step_name}.pt") - h = layer( - x=h, - # mask=mask, - # freqs_cis=freqs_cis, - positions=positions, - kv_cache=kv_caches[idx], - attn_metadata=attn_metadata, - ) - # check(h, f"layer_{idx}_h_{step_name}.pt") - - h = self.norm(h) - # check(h, f"finalh_{step_name}.pt") - return h - - def _get_xattn_mask( - self, - num_tokens, - text_device, - text_dtype, - vision_tokens, - cross_attention_masks, - ) -> Tuple[torch.Tensor, torch.Tensor]: - assert vision_tokens is not None, "Vision tokens must be provided" - vision_seqlen = vision_tokens.shape[3] - assert ( - vision_tokens.shape[1] == cross_attention_masks.shape[2] - ), f"Mismatch in number of images given and number of masks given {vision_tokens.shape} {cross_attention_masks.shape}" - assert ( - vision_tokens.shape[2] == cross_attention_masks.shape[3] - ), f"Vision tokens shape {vision_tokens.shape} mismatch with xattn shape {cross_attention_masks.shape}" - assert ( - num_tokens == cross_attention_masks.shape[1] - ), f"Mismatch in text sequence length and cross attention mask sequence length {num_tokens} {cross_attention_masks.shape}" - _, _, _, num_image_tokens, image_token_dim = tuple(vision_tokens.shape) - bsz, ntext, nimg, nchunks = cross_attention_masks.shape - cross_attention_masks = ( - cross_attention_masks.repeat_interleave(vision_seqlen, dim=3) - .view(bsz, ntext, -1) - .unsqueeze(1) - ) - full_text_row_masked_out_mask = _get_full_row_masked_out_mask( - cross_attention_masks, - torch.finfo(cross_attention_masks.dtype).min, - ) - cross_attention_masks *= full_text_row_masked_out_mask - - return ( - cross_attention_masks.to(device=text_device, dtype=text_dtype), - full_text_row_masked_out_mask, - ) - - -class VariableSizeImageTransform(object): - """ - This class accepts images of any size and dynamically resize, pads and chunks it - based on the image aspect ratio and the number of image chunks we allow. - - The algorithm will NOT distort the image fit a certain aspect ratio, because - that leads to a significant degradation in image quality. - - It can be summarized in 6 steps: - 1. Find all possible canvas combinations of max_num_chunks; - 2. Find the best canvas to fit the image; - 3. Resize without distortion - 4. Pad - 5. Normalize - 6. Chunk - - For example, if an input image is of size 300x800, patch_size of 224, - and max_num_chunks = 8, it will find the closest aspect ratio that - is allowed within 8 image chunks, with some restrictions. - In this case, 2:4 = 2 horizontal patches and 4 vertical patches, - giving a total of 8 chunks. - - If resize_to_max_canvas, the image will be resized (without distortion), - to the largest possible resolution. In this case, 388:896, and padded to 448:896, - where we maintain the original aspect ratio and pad with zeros value for the rest. - This approach minimizes the amount of padding required for any arbitrary resolution. - - However, if limit_upscaling_to_patch_size is set to True, - the upscaling will be limited to the patch size. In the example above, - the image would remain 300x800 (no upscaling), and then padded to 448:896. - - The final output will therefore be of shape (8, 3, 224, 224), where 2x4 - patches are coming from the resizing and chunking. - """ - - def __init__(self, size: int = IMAGE_RES) -> None: - self.size = size - logger.info(f"VariableSizeImageTransform size: {self.size}") - self.to_tensor = tv.ToTensor() - self._mean = (0.48145466, 0.4578275, 0.40821073) - self._std = (0.26862954, 0.26130258, 0.27577711) - self.normalize = tv.Normalize( - mean=self._mean, - std=self._std, - inplace=True, - ) - self.resample = tv.InterpolationMode.BILINEAR - - @staticmethod - def get_factors(n: int) -> Set[int]: - """ - Calculate all factors of a given number, i.e. a dividor that leaves - no remainder. For example, if n=12, it will return {1, 2, 3, 4, 6, 12}. - - Args: - n (int): The number to find factors for. - - Returns: - set: A set containing all factors of the number. - """ - factors_set = set() - - for i in range(1, int(n**0.5) + 1): - if n % i == 0: - factors_set.add(i) - factors_set.add(n // i) - return factors_set - - def find_supported_resolutions( - self, max_num_chunks: int, patch_size: int - ) -> torch.Tensor: - """ - Computes all of the allowed resoltuions for a fixed number of chunks - and patch_size. Useful for when dividing an image into chunks. - - Args: - max_num_chunks (int): Maximum number of chunks for processing. - patch_size (int): Size of the side of the patch. - - Returns: - torch.Tensor: List of possible resolutions as tuples (height, width). - - Example: - >>> max_num_chunks = 5 - >>> patch_size = 224 - >>> find_supported_resolutions(max_num_chunks, patch_size) - tensor([(224, 896), (448, 448), (224, 224), (896, 224), (224, 672), - (672, 224), (224, 448), (448, 224)]) - - Given max_num_chunks=4, patch_size=224, it will create a dictionary: - { - 0.25: [(1, 4)], - 1.0: [(2, 2), (1, 1)], - 4.0: [(4, 1)], - 0.33: [(1, 3)], - 3.0: [(3, 1)], - 0.5: [(1, 2)], - 2.0: [(2, 1)] - } - - and return the resolutions multiplied by the patch_size: - [(1*224, 4*224), (2*224, 2*224), ..., (2*224, 1*224)] - """ - asp_dict = collections.defaultdict(list) - for chunk_size in range(max_num_chunks, 0, -1): - _factors = sorted(self.get_factors(chunk_size)) - _asp_ratios = [(factor, chunk_size // factor) for factor in _factors] - for height, width in _asp_ratios: - ratio_float = height / width - asp_dict[ratio_float].append((height, width)) - - # get the resolutions multiplied by the patch_size - possible_resolutions = [] - for key, value in asp_dict.items(): - for height, depth in value: - possible_resolutions.append((height * patch_size, depth * patch_size)) - - return possible_resolutions - - @staticmethod - def get_max_res_without_distortion( - image_size: Tuple[int, int], - target_size: Tuple[int, int], - ) -> Tuple[int, int]: - """ - Determines the maximum resolution to which an image can be resized to without distorting its - aspect ratio, based on the target resolution. - - Args: - image_size (Tuple[int, int]): The original resolution of the image (height, width). - target_resolution (Tuple[int, int]): The desired resolution to fit the image into (height, width). - Returns: - Tuple[int, int]: The optimal dimensions (height, width) to which the image should be resized. - Example: - >>> _get_max_res_without_distortion([200, 300], target_size = [450, 200]) - (134, 200) - >>> _get_max_res_without_distortion([800, 600], target_size = [450, 1300]) - (450, 338) - """ - - original_width, original_height = image_size - target_width, target_height = target_size - - scale_w = target_width / original_width - scale_h = target_height / original_height - - if scale_w < scale_h: - new_width = target_width - new_height = min(math.floor(original_height * scale_w), target_height) - else: - new_height = target_height - new_width = min(math.floor(original_width * scale_h), target_width) - - return new_width, new_height - - def _pad(self, image: Image.Image, target_size) -> Image.Image: - new_width, new_height = target_size - new_im = Image.new(mode="RGB", size=(new_width, new_height), color=(0, 0, 0)) # type: ignore - new_im.paste(image) - return new_im - - def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor: - # Split image into number of required tiles (width x height) - num_channels, height, width = image.size() - image = image.view(num_channels, nch, height // nch, ncw, width // ncw) - # Permute dimensions to reorder the axes - image = image.permute(1, 3, 0, 2, 4).contiguous() - # Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2) - image = image.view(ncw * nch, num_channels, height // nch, width // ncw) - return image - - def resize_without_distortion( - self, - image: torch.Tensor, - target_size: Tuple[int, int], - max_upscaling_size: Optional[int], - ) -> torch.Tensor: - """ - Used to resize an image to target_resolution, without distortion. - - If target_size requires upscaling the image, the user can set max_upscaling_size to - limit the upscaling to a maximum size. In this case, since we rescale without distortion, - modifying target_size works as a boundary for the image's largest side. - - Args: - resample (str): Resampling method used when resizing images. - Supports "nearest", "nearest_exact", "bilinear", "bicubic". - max_upscaling_size (int): The maximum size to upscale the image to. - If None, there is no limit. - Examples: - >>> target_size = (1000, 1200) - >>> max_upscaling_size = 600 - >>> image_size = (400, 200) - >>> resize_without_distortion(image_size, target_size, max_upscaling_size) - (600, 300) # new_size_without_distortion - - >>> target_size = (1000, 1200) - >>> max_upscaling_size = 600 - >>> image_size = (2000, 200) - >>> resize_without_distortion(image_size, target_size, max_upscaling_size) - (1000, 100) # new_size_without_distortion - - >>> target_size = (1000, 1200) - >>> max_upscaling_size = 2000 - >>> image_size = (400, 200) - >>> resize_without_distortion(image_size, target_size, max_upscaling_size) - (1000, 500) # new_size_without_distortion - - >>> target_size = (1000, 1200) - >>> max_upscaling_size = None - >>> image_size = (400, 200) - >>> resize_without_distortion(image_size, target_size, max_upscaling_size) - (1000, 500) # new_size_without_distortion - """ - - image_width, image_height = image.size - image_size = (image_width, image_height) - - # If target_size requires upscaling, we might want to limit the upscaling to max_upscaling_size - if max_upscaling_size is not None: - new_target_width = min(max(image_width, max_upscaling_size), target_size[0]) - new_target_height = min( - max(image_height, max_upscaling_size), target_size[1] - ) - target_size = (new_target_width, new_target_height) - - # resize to target_size while preserving aspect ratio - new_size_without_distortion = self.get_max_res_without_distortion( - image_size, target_size - ) - - image = F.resize( - image, - (new_size_without_distortion[1], new_size_without_distortion[0]), - interpolation=self.resample, - ) - - return image - - def get_best_fit( - self, - image_size: Tuple[int, int], - possible_resolutions: torch.Tensor, - resize_to_max_canvas: bool = False, - ) -> Tuple[int, int]: - """ - Determines the best canvas possible from a list of possible resolutions to, without distortion, - resize an image to. - - For each possible resolution, calculates the scaling factors for - width and height, and selects the smallest one, which is the limiting side. - E.g. to match the canvas you can upscale height by 2x, and width by 1.5x, - therefore, the maximum upscaling you can do is min(2, 1.5) = 1.5. - - If upscaling is possible (any of the scaling factors is greater than 1), - then picks the smallest upscaling factor > 1, unless resize_to_max_canvas is True. - - If upscaling is not possible, then picks the largest scaling factor <= 1, i.e. - reduce downscaling as much as possible. - - If there are multiple resolutions with the same max scale, we pick the one with the lowest area, - to minimize padding. E.g., the same image can be upscaled to 224x224 and 224x448, but the latter - has more padding. - - Args: - image_size (Tuple[int, int]): A tuple containing the height and width of the image. - possible_resolutions (torch.Tensor): A tensor of shape (N, 2) where each - row represents a possible resolution (height, width). - use_max_upscaling (bool): If True, will return the largest upscaling resolution. - - Returns: - List[int]: The best resolution [height, width] for the given image. - - Example: - >>> image_size = (200, 300) - >>> possible_resolutions = torch.tensor([[224, 672], - ... [672, 224], - ... [224, 448], - ... [448, 224], - ... [224, 224]]) - >>> _get_smallest_upscaling_possibility(image_size, possible_resolutions) - [224, 448] - - We have: - scale_w = tensor([2.2400, 0.7467, 1.4933, 0.7467, 0.7467]) - scale_h = tensor([1.1200, 3.3600, 1.1200, 2.2400, 1.1200]) - scales = tensor([1.1200, 0.7467, 1.1200, 0.7467, 0.7467]) - Only one of the scales > 1: - upscaling_possible = tensor([1.1200, 1.1200]) - smallest_rescale = tensor(1.1200) - So we pick the resolution with the smallest smallest area: - areas = tensor([150528, 100352]) # [672, 224], [224, 448] - optimal_canvas = tensor([224, 448]) - """ - - original_width, original_height = image_size - - # get all possible resolutions heights/widths - target_widths, target_heights = ( - possible_resolutions[:, 0], - possible_resolutions[:, 1], - ) - - # get scaling factors to resize the image without distortion - scale_w = target_widths / original_width - scale_h = target_heights / original_height - - # get the min scale between width and height (limiting side -> no distortion) - scales = torch.where(scale_w > scale_h, scale_h, scale_w) - - # filter only scales that allow upscaling - upscaling_options = scales[scales >= 1] - if len(upscaling_options) > 0: - if resize_to_max_canvas: - selected_scale = torch.max(upscaling_options) - else: - selected_scale = torch.min(upscaling_options) - else: - # no upscaling possible, - # get the minimum downscaling (max scale for scales<1) - downscaling_options = scales[scales < 1] - selected_scale = torch.max(downscaling_options) - - # get all resolutions that support this scaling factor, - # e.g. you can upscale to 224x224, 224x448, 224x672 without distortion - chosen_canvas = possible_resolutions[scales == selected_scale] - - # if there are multiple resolutions, - # get the one with minimum area to reduce padding - if len(chosen_canvas) > 1: - areas = chosen_canvas[:, 0] * chosen_canvas[:, 1] - optimal_idx = torch.argmin(areas) - optimal_canvas = chosen_canvas[optimal_idx] - else: - optimal_canvas = chosen_canvas[0] - - return tuple(optimal_canvas.tolist()) - - def __call__( - self, - image: Image.Image, - max_num_chunks: int, - normalize_img: bool = True, - resize_to_max_canvas: bool = False, - ) -> Tuple[Any, Any]: - """ - Args: - image (PIL.Image): Image to be resized. - max_num_chunks (int): Maximum number of chunks to split the image into. - normalize_img (bool): Whether to normalize the image. - resize_to_max_canvas (bool): Whether to resize the image to the maximum canvas size. - If True, picks the canvas the allows the largest resizing without distortion. - If False, downsample as little as possible, including no resizing at all, - but never upsample, unless the image is smaller than the patch size. - """ - assert max_num_chunks > 0 - assert isinstance(image, Image.Image), type(image) - w, h = image.size - - possible_resolutions = self.find_supported_resolutions( - max_num_chunks=max_num_chunks, patch_size=self.size - ) - possible_resolutions = torch.tensor(possible_resolutions) - - best_resolution = self.get_best_fit( - image_size=(w, h), - possible_resolutions=possible_resolutions, - resize_to_max_canvas=resize_to_max_canvas, - ) - - max_upscaling_size = None if resize_to_max_canvas else self.size - image = self.resize_without_distortion( - image, best_resolution, max_upscaling_size - ) - image = self._pad(image, best_resolution) - - image = self.to_tensor(image) - - if normalize_img: - image = self.normalize(image) - - ratio_w, ratio_h = ( - best_resolution[0] // self.size, - best_resolution[1] // self.size, - ) - - image = self._split(image, ratio_w, ratio_h) # type: ignore - - ar = (ratio_h, ratio_w) - return image, ar - -@MULTIMODAL_REGISTRY.register_image_input_mapper() -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llama_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llamavl) -@INPUT_REGISTRY.register_input_processor(input_processor_for_llamavl) -class LlamaVLForCausalLM(nn.Module, SupportsMultiModal): - def __init__(self, config, - multimodal_config: MultiModalConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): - super().__init__() - print("config", type(config)) - print(config) - print("multimodal_config", type(multimodal_config)) - print(multimodal_config) - print("cache_config", type(cache_config)) - print(cache_config) - print("quant_config", type(quant_config)) - print(quant_config) - - # self.params = args - args = config - self.model_dim = args.dim - self.vision_model = CrossAttentionTransformerVision(args) - self.text_model = CrossAttentionTransformerText(args, cache_config=cache_config) - self.image_res = args.vision_chunk_size - self.max_num_chunks = args.vision_max_num_chunks - self.image_transform = partial( - VariableSizeImageTransform(size=args.vision_chunk_size), - max_num_chunks=args.vision_max_num_chunks, - ) - self.lm_head = ParallelLMHead( - args.vocab_size, - args.dim, - org_num_embeddings=args.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE, - quant_config=quant_config, - ) - self.logits_processor = LogitsProcessor(args.dim, args.vocab_size) - self.sampler = Sampler() - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # my_rank = get_tensor_model_parallel_rank() - state_dict = {name: weight for name, weight in weights} - # if my_rank == 0: - # with open("weight_shape_map.log", "w") as f: - # for name, weight in state_dict.items(): - # f.write(f"{name}-{tuple(weight.shape)}-{weight.dtype}\n") - - state_dict.pop('text_model.rope.freqs') - state_dict['lm_head.weight'] = state_dict.pop('text_model.output.weight') - - def load_weight(param, weight): - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, weight) - - for param_name, param in self.named_parameters(): - # print("loading", param_name) - if param_name.startswith("text_model.layers"): - layer_id = int(param_name.split(".")[2]) - if param_name.endswith("attention.qkv_proj.weight"): - # "text_model.layers.{i}.attention.qkv_proj.weight" - weight_name = f"text_model.layers.{layer_id}.attention.wqkv.weight" - weight = state_dict.pop(weight_name) - module = self.text_model.layers[layer_id].attention.qkv_proj - module.weight_loader(param, weight) - continue - elif param_name.endswith("attention.o_proj.weight"): - # "text_model.layers.{i}.attention.o_proj.weight" - weight_name = f"text_model.layers.{layer_id}.attention.wo.weight" - weight = state_dict.pop(weight_name) - module = self.text_model.layers[layer_id].attention.o_proj - module.weight_loader(param, weight) - continue - elif param_name.endswith("feed_forward.gate_up_proj.weight"): - # "text_model.layers.{i}.feed_forward.mlp.fc1_weight" - weight_name = f"text_model.layers.{layer_id}.feed_forward.mlp.fc1_weight" - weight = state_dict.pop(weight_name) - module = self.text_model.layers[layer_id].feed_forward.gate_up_proj - module.weight_loader(param, weight) - continue - elif param_name.endswith("feed_forward.down_proj.weight"): - # "text_model.layers.{i}.feed_forward.mlp.fc2_weight" - weight_name = f"text_model.layers.{layer_id}.feed_forward.mlp.fc2_weight" - weight = state_dict.pop(weight_name) - module = self.text_model.layers[layer_id].feed_forward.down_proj - module.weight_loader(param, weight) - continue - elif param_name.endswith("attention_norm.weight"): - # "text_model.layers.{i}.attention_norm.weight" - weight_name = f"text_model.layers.{layer_id}.attention.wqkv.layer_norm_weight" - weight = state_dict.pop(weight_name) - load_weight(param, weight) - continue - elif param_name.endswith("ffn_norm.weight"): - # "text_model.layers.{i}.ffn_norm.weight" - weight_name = f"text_model.layers.{layer_id}.feed_forward.mlp.layer_norm_weight" - weight = state_dict.pop(weight_name) - load_weight(param, weight) - continue - if param_name.startswith("text_model.cross_attention_layers"): - layer_id = int(param_name.split(".")[2]) - if param_name.endswith('gate_attn'): - attn_gate = state_dict.pop(param_name) - if attn_gate.dim() == 1: - attn_gate = attn_gate[0].view(1) - if attn_gate.dim() == 3: - attn_gate = attn_gate.view(1) - load_weight(param, attn_gate) - continue - if param_name.endswith('gate_ffwd'): - ffn_gate = state_dict.pop(param_name) - if ffn_gate.dim() == 1: - ffn_gate = ffn_gate[0].view(1) - if ffn_gate.dim() == 3: - ffn_gate = ffn_gate.view(1) - load_weight(param, ffn_gate) - continue - if param_name.endswith('ffn_norm.weight'): - load_weight(param, state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.feed_forward.mlp.layer_norm_weight")) - continue - if param_name.endswith('attention_norm.weight'): - load_weight(param, state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.attention.wq.layer_norm_weight")) - continue - # if param_name.endswith('attention.wk.weight') or param_name.endswith('attention.wv.weight'): - # if f"text_model.cross_attention_layers.{layer_id}.attention.wkv.weight" in state_dict: - # weight = state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.attention.wkv.weight") - # state_dict[f"text_model.cross_attention_layers.{layer_id}.attention.wkv.weight.1"] = weight - # else: - # weight = state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.attention.wkv.weight.1") - # if param_name.endswith('attention.wk.weight'): - # weight = weight.chunk(2)[0] - # else: - # weight = weight.chunk(2)[1] - # load_weight(param, weight) - # continue - if param_name.endswith('attention.qkv_proj.weight'): - q_weight = state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.attention.wq.weight") - kv_weight = state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.attention.wkv.weight") - qkv_weight = torch.cat([q_weight, kv_weight], dim=0) - module = self.text_model.cross_attention_layers[layer_id].attention.qkv_proj - module.weight_loader(param, qkv_weight) - continue - if param_name.endswith('attention.q_norm.weight'): - load_weight(param, state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.attention.inner_attention.q_norm.weight")) - continue - if param_name.endswith('attention.k_norm.weight'): - load_weight(param, state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.attention.inner_attention.k_norm.weight")) - continue - if param_name.endswith('feed_forward.gate_up_proj.weight'): - load_weight(param, state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.feed_forward.mlp.fc1_weight")) - continue - if param_name.endswith('feed_forward.down_proj.weight'): - load_weight(param, state_dict.pop(f"text_model.cross_attention_layers.{layer_id}.feed_forward.mlp.fc2_weight")) - continue - if param_name.startswith("vision_model.vision_encoder"): - if param_name == 'vision_model.vision_encoder.conv1._linear.weight': - module = self.vision_model.vision_encoder.conv1._linear - weight = state_dict.pop('vision_model.vision_encoder.conv1._linear.weight') - module.weight_loader(param, weight) - continue - if param_name.startswith("vision_model.vision_encoder.transformer.resblocks") or param_name.startswith("vision_model.vision_encoder.global_transformer.resblocks"): - layer_id = int(param_name.split(".")[4]) - if param_name.startswith('vision_model.vision_encoder.transformer.resblocks'): - prefix = 'vision_model.vision_encoder.transformer.resblocks' - transformer_block: ImageTransformerBlock = self.vision_model.vision_encoder.transformer.resblocks[layer_id] - else: - prefix = 'vision_model.vision_encoder.global_transformer.resblocks' - transformer_block = self.vision_model.vision_encoder.global_transformer.resblocks[layer_id] - if param_name.endswith("mlp.c_fc.weight"): - module = transformer_block.mlp.c_fc - weight = state_dict.pop(f"{prefix}.{layer_id}.mlp.c_fc.weight") - module.weight_loader(param, weight) - continue - if param_name.endswith("attn.qkv_proj.weight"): - module = transformer_block.attn.qkv_proj - q_weight = state_dict.pop(f"{prefix}.{layer_id}.attn.wq.weight") - k_weight = state_dict.pop(f"{prefix}.{layer_id}.attn.wk.weight") - v_weight = state_dict.pop(f"{prefix}.{layer_id}.attn.wv.weight") - # import pdb; pdb.set_trace() - qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) - module.weight_loader(param, qkv_weight) - continue - if param_name in state_dict: - loaded_weight = state_dict.pop(param_name) - load_weight(param, loaded_weight) - continue - - raise ValueError(f"Unexpected parameter {param_name}") - - if len(state_dict) > 0: - raise ValueError(f"unused keys: {state_dict.keys()}") - - def _parse_and_validate_image_input( - self, **kwargs: object) -> Optional[LlavaImageInputs]: - pixel_values = kwargs.pop("pixel_values", None) - image_embeds = kwargs.pop("image_embeds", None) - aspect_ratios = kwargs.pop("aspect_ratios", None) - - if pixel_values is None and image_embeds is None: - return None - - if pixel_values is not None and image_embeds is not None: - raise ValueError("Both pixel values and image embeds are provided.") - - if pixel_values is not None: - # tensor with the same shape will be batched together by MultiModalInputs.batch, so pixel_values here can be: - # - List[List[torch.Tensor]]: with shape (num_chunks, 3, image_res, image_res) - # - List[torch.Tensor]: with shape (num_image_in_batch, num_chunks, 3, image_res, image_res) - # - torch.Tensor: with shape (bs, num_image_in_batch, num_chunks, 3, image_res, image_res) - # the best choice is to remove MultiModalInputs.batch - pixel_values_unpacked = [] - for b in range(len(pixel_values)): - pixel_values_unpacked_b = [] - for i in range(len(pixel_values[b])): - pixel_values_unpacked_b.append(pixel_values[b][i]) - pixel_values_unpacked.append(pixel_values_unpacked_b) - - max_num_images = max([len(x) for x in pixel_values_unpacked]) - max_num_chunks = max(max([len(x) for x in y]) for y in pixel_values_unpacked) - bsz = len(pixel_values_unpacked) - out_num_chunks = [] - out_images = torch.zeros( - bsz, - max_num_images, - max_num_chunks, - 3, - self.image_res, - self.image_res - ) - out_ar = torch.ones(bsz, max_num_images, 2, dtype=torch.int64) - for b in range(len(pixel_values_unpacked)): - _num_chunks = [] - for i in range(len(pixel_values_unpacked[b])): - img = pixel_values_unpacked[b][i] - out_images[b, i, :img.shape[0]] = img - out_ar[b, i] = aspect_ratios[b][i] - _num_chunks.append(img.shape[0]) - out_num_chunks.append(_num_chunks) - - return LlamaImagePixelInputs( - type="pixel_values", - data=out_images, - num_chunks=out_num_chunks, - aspect_ratios=out_ar, - ) - - if image_embeds is not None: - raise NotImplementedError - - raise AssertionError("This line should be unreachable.") - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - **kwargs: object, - ) -> torch.Tensor: - if attn_metadata.num_prefill_tokens > 0 and attn_metadata.num_decode_tokens > 0: - raise ValueError("Chunk prefill not supported") - image_inputs = self._parse_and_validate_image_input(**kwargs) - if image_inputs is None: - cross_attention_masks = None - run_xattn_mask = (attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).cuda() - xattn_caches = None - vision_tokens = None - else: - # llama's reference implementation runs the vision model on CPU - cuda_images = image_inputs['data'].cuda() - cuda_aspect_ratios = image_inputs['aspect_ratios'].cuda() - vision_tokens = self.vision_model(cuda_images, cuda_aspect_ratios) - # import pdb; pdb.set_trace() - bsz, _, _, _, image_token_dim = tuple(vision_tokens.shape) - vision_tokens = vision_tokens.view(bsz, -1, image_token_dim) - - vision_tokens_flat = torch.zeros(sum(attn_metadata.encoder_seq_lens), image_token_dim, device=vision_tokens.device, dtype=vision_tokens.dtype) - start_pos = 0 - for seq_len, vision_token_in_batch in zip(attn_metadata.encoder_seq_lens, vision_tokens): - end_pos = start_pos + seq_len - vision_tokens_flat[start_pos:end_pos] = vision_token_in_batch[:seq_len] - start_pos = end_pos - vision_tokens = vision_tokens_flat - - run_xattn_mask = torch.ones((attn_metadata.num_prefill_tokens, 1), dtype=torch.bool, device=vision_tokens.device) - start_pos = 0 - for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens_tensor, attn_metadata.encoder_seq_lens): - if encoder_seq_len == 0: - run_xattn_mask[start_pos:start_pos+seq_len] = False - start_pos += seq_len - - # batch_masks = [] - # # TODO: get the sequence of each query without hack? 1) better attn metadata 2) better input processor to create vision mask during preprocess - # # assert isinstance(attn_metadata, PagedAttentionMetadata) - # start_pos = 0 - # for seq_len in attn_metadata.seq_lens_tensor: - # end_pos = start_pos + seq_len - # batch_masks.append(create_vision_mask(input_ids[start_pos:end_pos])) - # start_pos = end_pos - - # xattn_caches = torch.stack( - # [ - # layer.compute_xattn_kv_cache(vision_tokens_flat) - # for layer in self.text_model.cross_attention_layers - # ] - # ) - # TODO: remove this hardcode - # total_len = 512 - # padded_masks = _pad_masks( - # batch_masks, - # image_inputs['num_chunks'], - # total_len, - # self.max_num_chunks, - # ) - - # cross_attention_masks, full_text_row_masked_out_mask = ( - # self.text_model._get_xattn_mask( - # num_tokens=total_len, - # text_device="cuda", - # text_dtype=next(self.text_model.parameters()).dtype, - # vision_tokens=vision_tokens, - # cross_attention_masks=padded_masks, - # ) - # ) - - # full_text_row_masked_out_mask_plain = torch.zeros(attn_metadata.num_prefill_tokens, 1, dtype=full_text_row_masked_out_mask.dtype) - # start_pos = 0 - # for i, seq_len in enumerate(attn_metadata.seq_lens_tensor): - # end_pos = start_pos + seq_len - # full_text_row_masked_out_mask_plain[start_pos:end_pos, 0] = full_text_row_masked_out_mask[i, 0, :seq_len, 0] - # start_pos = end_pos - # full_text_row_masked_out_mask = full_text_row_masked_out_mask_plain.cuda() - # print("input_ids", input_ids, vision_tokens is None) - # if positions.numel() == 1: - # global step_name - # step_name = f"decode_{positions.item()}" - h = self.text_model.get_partially_trainable_embedding(input_ids) - # check(h, f"h_{step_name}.pt") - h = self.text_model.forward( - positions=positions, - h=h, - # xattn_mask=cross_attention_masks, - # full_text_row_masked_out_mask=full_text_row_masked_out_mask, - vision_hidden_states=vision_tokens, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - run_xattn_mask=run_xattn_mask, - ) - # if positions.numel() == 1 and positions.item() == 20: - # exit(0) - return h - -def create_vision_mask( - tokens: List[int], - vision_token: int=128256, -) -> List[List[int]]: - vision_token_locations = [ - i for i, token in enumerate(tokens) if token == vision_token - ] - if len(vision_token_locations) == 0: - return [] - - if len(vision_token_locations) == 1: - # only one image present, unmask until end of sequence - return [[vision_token_locations[0], -1]] - vision_masks = [ - [loc1, loc2] - for loc1, loc2 in zip(vision_token_locations[:-1], vision_token_locations[1:]) - ] - # last image will attend to all subsequent text - vision_masks.append([vision_token_locations[-1], len(tokens)]) - - # if there are two or more consecutive vision tokens, - # they should all attend to all subsequent - # text present - last_mask_end = vision_masks[-1][1] - for vision_mask in vision_masks[::-1]: - if vision_mask[0] == vision_mask[1] - 1: - vision_mask[1] = last_mask_end - last_mask_end = vision_mask[1] - return vision_masks - - - -def _pad_masks( - all_masks: List[List[List[int]]], - all_num_chunks: List[List[int]], - total_len: int, - max_num_chunks: int, -) -> torch.Tensor: - dtype = torch.bfloat16 - inf_value = torch.finfo(dtype).min - - bsz = len(all_masks) - max_num_media = max([len(m) for m in all_masks]) - - out_masks = torch.full( - (bsz, total_len, max_num_media, max_num_chunks), - inf_value, - dtype=dtype, - ) - - for idx, (mask, num_chunks) in enumerate(zip(all_masks, all_num_chunks)): - for mask_idx, (mask_elem, mask_num_chunks) in enumerate(zip(mask, num_chunks)): - if len(mask_elem) == 2: - mask_elem[1] = min(mask_elem[1], total_len) - if mask_elem[1] == -1: - mask_elem[1] = total_len - out_masks[ - idx, mask_elem[0] : mask_elem[1], mask_idx, :mask_num_chunks - ].fill_(0.0) - - return out_masks diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index f72cd8b9792a..3be8791f5ee5 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -23,7 +23,7 @@ GraniteConfig, InternVLChatConfig, JAISConfig, MedusaConfig, MLPSpeculatorConfig, MPTConfig, NemotronConfig, - RWConfig, UltravoxConfig, LlamaVLConfig) + RWConfig, UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file @@ -51,7 +51,6 @@ "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, "ultravox": UltravoxConfig, - "llamavl": LlamaVLConfig, # Granite can be removed from here once we have upgraded to # transformers 4.45+ "granite": GraniteConfig, diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index be6f775d7328..8381c5227584 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -14,7 +14,6 @@ from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig from vllm.transformers_utils.configs.ultravox import UltravoxConfig -from vllm.transformers_utils.configs.llamavl import LlamaVLConfig __all__ = [ "ChatGLMConfig", @@ -25,7 +24,6 @@ "JAISConfig", "MedusaConfig", "EAGLEConfig", - "LlamaVLConfig", "ExaoneConfig", "MLPSpeculatorConfig", "NemotronConfig", diff --git a/vllm/transformers_utils/configs/llamavl.py b/vllm/transformers_utils/configs/llamavl.py deleted file mode 100644 index d186ddac2e32..000000000000 --- a/vllm/transformers_utils/configs/llamavl.py +++ /dev/null @@ -1,52 +0,0 @@ -from transformers import PretrainedConfig -from typing import Optional, Any - - -class LlamaVLConfig(PretrainedConfig): - model_type = "llamavl" - - dim: int = 4096 - n_layers: int = 32 - n_heads: int = 32 - n_kv_heads: Optional[int] = None - vocab_size: int = -1 - multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 - ffn_dim_multiplier: Optional[float] = None - norm_eps: float = 1e-5 - rope_theta: float = 500000 - use_scaled_rope: bool = False - - max_batch_size: int = 32 - max_seq_len: int = 2048 - - # vision model params - vision_chunk_size: int = -1 # image resolution for image models - vision_max_num_chunks: int = 4 - vision_num_cross_attention_layers: int = -1 - - model_type: str = "llamavl" - architectures: list[str] = ["LlamaVLForCausalLM"] - - torch_dtype: str = "bfloat16" - - rope_scaling: Optional[dict[str, Any]] = None - - attribute_map = { - "num_hidden_layers": "n_layers", - "hidden_size": "dim", - "num_attention_heads": "n_heads", - "num_key_value_heads": "n_kv_heads", - } - - def __init__(self, **kwargs): - for k, v in kwargs.items(): - if hasattr(self, k): - setattr(self, k, v) - - if self.n_kv_heads is None: - self.n_kv_heads = self.n_heads - assert self.n_kv_heads <= self.n_heads - assert self.n_heads % self.n_kv_heads == 0 - assert self.dim % self.n_heads == 0 - - super().__init__(**kwargs) diff --git a/vllm/transformers_utils/image_processor.py b/vllm/transformers_utils/image_processor.py index 2bb167a3fc65..4cffac3724ba 100644 --- a/vllm/transformers_utils/image_processor.py +++ b/vllm/transformers_utils/image_processor.py @@ -41,10 +41,6 @@ def get_image_processor( from transformers.image_processing_utils import BaseImageProcessor try: - print("processor_name", processor_name) - if "Vision-Early" in processor_name and "checkpoints" in processor_name: - from .multimodal_processors.llamavl import LlamaVLImageProcessor - return LlamaVLImageProcessor(processor_name, *args, **kwargs) processor = AutoImageProcessor.from_pretrained( processor_name, *args, diff --git a/vllm/transformers_utils/multimodal_processors/llamavl.py b/vllm/transformers_utils/multimodal_processors/llamavl.py deleted file mode 100644 index 2ac9fcdcb933..000000000000 --- a/vllm/transformers_utils/multimodal_processors/llamavl.py +++ /dev/null @@ -1,366 +0,0 @@ -from transformers.image_processing_base import BatchFeature -from transformers.image_processing_utils import BaseImageProcessor - -import torch -from typing import List, Tuple -from PIL import Image -from functools import partial - -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# top-level folder for each specific model found within the models/ directory at -# the top-level of this source tree. - -import math -from functools import reduce -from typing import Any, Tuple - -import numpy as np -import torch -import torchvision.transforms as tv -from PIL import Image -from torchvision.transforms import functional as F - -IMAGE_RES = 224 - -class TorchBF16Context: - - def __enter__(self): - self.prev_dtype = torch.get_default_dtype() - if torch.cuda.is_bf16_supported(): - torch.set_default_tensor_type(torch.cuda.BFloat16Tensor) - else: - torch.set_default_tensor_type(torch.cuda.HalfTensor) - - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.prev_dtype == torch.float32: - torch.set_default_tensor_type(torch.FloatTensor) - else: - raise ValueError("Unsupported dtype") - -class VariableSizeImageTransform(object): - """ - The variable size image transform will resize the image dynamically - based on the image aspect ratio and the number of image chunks we allow. - The algorithm will not upsample low-res images to fit a certain aspect - ratio, because that leads to a significant degradation in image quality. - For example, if an input image is of size 300x800, and we want to allow - a maximum of 16 image chunks, it will find the closest aspect ratio that - is allowed within 16 image chunks, i.e., 2:5 = 2 horizontal patches and - 5 vertical patches, giving a total of 10 chunks. - The image will then be resized to products of the base size (default is - 224px because MetaCLIP takes that), so in this case it will be resized to - 2*224:5*224 = 448:1120, where we maintain the original aspect ratio and - pad with the mean value for the rest. This approach minimizes the amount - of padding required for any arbitrary resolution. - The final output will therefore be of shape (11, 3, 224, 224), where 10 - patches are coming from the resizing and chunking, and the first patch - is a downsampled version of the image that preserves aspect ratios. - """ - - def __init__(self, size: int = IMAGE_RES) -> None: - self.size = size - self.to_tensor = tv.ToTensor() - self._mean = (0.48145466, 0.4578275, 0.40821073) - self._std = (0.26862954, 0.26130258, 0.27577711) - self.normalize = tv.Normalize( - mean=self._mean, - std=self._std, - inplace=True, - ) - - @staticmethod - def _factors(n: int): - """Return all factors of a number.""" - return set( - reduce( - list.__add__, - ([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0), - ) - ) - - def _find_supported_aspect_ratios(self, num_chunks: int): - """ - This function computes all the allowed aspect ratios for a fixed - number of input chunks. - For example, with `num_chunks=5`, it will return: - { - 0.2: [(1, 5)], - 5.0: [(5, 1)], - 0.25: [(1, 4)], - 1.0: [(2, 2), (1, 1)], - 4.0: [(4, 1)], - 0.3333333333333333: [(1, 3)], - 3.0: [(3, 1)], - 0.5: [(1, 2)], - 2.0: [(2, 1)] - } - """ - asp_dict = {} - for chunk_size in range(num_chunks, 0, -1): - _factors = sorted(VariableSizeImageTransform._factors(chunk_size)) - _asp_ratios = [(x, chunk_size // x) for x in _factors] - for ratio in _asp_ratios: - k = ratio[0] / ratio[1] - if k not in asp_dict: - asp_dict[k] = [ratio] - else: - asp_dict[k].append(ratio) - return asp_dict - - def _find_closest_aspect_ratio( - self, num_chunks: int, img_width: int, img_height: int - ) -> Tuple: - """ - Given an image width, height and target number of chunks - this function will find the closest supported aspect ratio. - """ - tgt_ar = img_width / img_height - asp_dict = self._find_supported_aspect_ratios(num_chunks) - cl_d, cl_p = 1e23, None - if tgt_ar >= 1: - cl_p = min( - [k for k in asp_dict.keys() if k <= tgt_ar], - key=lambda x: abs(x - tgt_ar), - ) - v = asp_dict[cl_p] - # select width - widths = [(idx, self.size * vv[0]) for idx, vv in enumerate(v)] - tgt_idx = max(widths, key=lambda x: x[1])[0] - else: - cl_p = min( - [k for k in asp_dict.keys() if k > tgt_ar], - key=lambda x: abs(1 / x - 1 / tgt_ar), - ) - v = asp_dict[cl_p] - # select height - heights = [(idx, self.size * vv[1]) for idx, vv in enumerate(v)] - tgt_idx = max(heights, key=lambda x: x[1])[0] - out = v[tgt_idx] - return out - - def _resize( - self, image: Image.Image, target_width: int, target_height: int - ) -> Image.Image: - # Resize longer edge to given size. - w, h = image.size - scale = w / h - - if scale > 1.0: - # width > height - new_w = target_width - new_h = math.floor(new_w / scale) - else: - # height >= width - new_h = target_height - new_w = math.floor(new_h * scale) - - image = F.resize(image, (new_h, new_w)) - return image - - def _resize_max_side_to_size( - self, - image: Image.Image, - ) -> Image.Image: - # Resize longer edge to given size. - w, h = image.size - scale = w / h - - if scale > 1.0: - # width > height - new_w = max(self.size, w) - new_h = math.floor(new_w / scale) - else: - # height >= width - new_h = max(self.size, h) - new_w = math.floor(new_h * scale) - - image = F.resize(image, (new_h, new_w)) - return image - - def _pad(self, image: Image.Image, new_width: int, new_height: int) -> Image.Image: - mean_per_channel = tuple( - np.clip(np.array(image).mean(axis=(0, 1)), 0, 255).astype(np.uint8) - ) - new_im = Image.new(mode="RGB", size=(new_height, new_width), color=(0, 0, 0)) # type: ignore - new_im.paste(image) - return new_im - - def _split(self, image: torch.Tensor, ncw: int, nch: int) -> torch.Tensor: - # Split image into number of required tiles (width x height) - num_channels, height, width = image.size() - image = image.view(num_channels, nch, height // nch, ncw, width // ncw) - # Permute dimensions to reorder the axes - image = image.permute(1, 3, 0, 2, 4).contiguous() - # Reshape into the desired output shape (batch_size * 4, num_channels, width/2, height/2) - image = image.view(ncw * nch, num_channels, height // nch, width // ncw) - return image - - def _fit_image_to_canvas( - self, num_chunks: int, img_width: int, img_height: int - ) -> Any: - """ - Given an image width, height and target number of chunks this function will see if the image - can be fit into any of the canvases that can be build from arranging the tiles in a grid. - If the image can be fit onto several canvases, it will return the canvas where the shorter edge - of the image will be largest. - """ - # Initialize the optimal canvas to None. If no canvas is found where image fits, function returns None. - optimal_canvas = None - optimal_image_width_height = None - - scale = img_width / img_height - - # Gather all potential supported image resolutions and iterate through them to find best match - potential_arrangements = [ - item - for sublist in self._find_supported_aspect_ratios(num_chunks).values() - for item in sublist - ] - current_gap = 1e23 - for n_w, n_h in potential_arrangements: - # Compute the canvas size - canvas_width, canvas_height = n_w * self.size, n_h * self.size - - # Check if image can fit into the canvas without downsampling - if canvas_width >= img_width and canvas_height >= img_height: - # If we did not find a good canvas yet, we will use the current one - if optimal_canvas is None: - # Set optimal canvas and determine the actual image height and width in the canvas with aspect ratio preserving resampling - optimal_canvas = (n_w, n_h) - optimal_image_width_height = (n_w * self.size, n_h * self.size) - else: - # Find closest fit based on gap - image_width_height = (n_w * self.size, n_h * self.size) - gap = abs(img_width - image_width_height[0]) + abs( - img_height - image_width_height[1] - ) - if gap < current_gap: - # If the gap is smaller than the previous one, we will update our optimal canvas and image width height - optimal_canvas = (n_w, n_h) - optimal_image_width_height = image_width_height - current_gap = gap - return optimal_canvas - - def __call__(self, image: Image.Image, max_num_chunks: int) -> Tuple[Any, Any]: - assert max_num_chunks > 0 - assert isinstance(image, Image.Image), type(image) - - import numpy as np - w, h = image.size - # Check if the image can be fit to the canvas without downsampling - ar = self._fit_image_to_canvas( - num_chunks=max_num_chunks, img_width=w, img_height=h - ) - if ar is None: - # If we did not find a canvas, we have to find the closest aspect ratio and downsample the image - ar = self._find_closest_aspect_ratio( - num_chunks=max_num_chunks, img_width=w, img_height=h - ) - image = self._resize(image, ar[0] * self.size, ar[1] * self.size) - else: - image = self._resize_max_side_to_size(image) - - arr = np.array(image) - - image = self._pad(image, ar[1] * self.size, ar[0] * self.size) - image = self.to_tensor(image) - image = self.normalize(image) - image = self._split(image, ar[0], ar[1]) # type: ignore - return image, ar - - -def _stack_images( - images: List[List[Image.Image]], - max_num_chunks: int, - image_res: int, - max_num_images: int, -) -> Tuple[torch.Tensor, List[int]]: - """ - Takes a list of list of images and stacks them into a tensor. - This function is needed since images can be of completely - different resolutions and aspect ratios. - """ - out_images, out_num_chunks = [], [] - for imgs_sample in images: - out_images_i = torch.zeros( - max_num_images, - max_num_chunks, - 3, - image_res, - image_res, - ) - _num_chunks = [] - for j, chunks_image in enumerate(imgs_sample): - out_images_i[j, : chunks_image.shape[0]] = chunks_image - _num_chunks.append(chunks_image.shape[0]) - out_images.append(out_images_i) - out_num_chunks.append(_num_chunks) - return torch.stack(out_images), out_num_chunks - -class LlamaVLImageProcessor(BaseImageProcessor): - def __init__(self, name, *args, **kwargs): - if "11B" in name: - self.vision_chunk_size = 448 - elif "90B" in name: - self.vision_chunk_size = 560 - else: - raise ValueError(f"Unknown model name: {name}") - self.vision_max_num_chunks = 4 - self.max_num_chunks = self.vision_max_num_chunks - self.image_transform = partial( - VariableSizeImageTransform(size=self.vision_chunk_size), - max_num_chunks=self.vision_max_num_chunks, - ) - def preprocess(self, images, **kwargs) -> BatchFeature: - with TorchBF16Context(): - # assert len(images) == len( - # batch_masks - # ), "Images and masks must have the same length" - - # preprocess is called for each batch now, so add batch dimension here. - if not isinstance(images, list): - images = [images] - images = [images] - - max_num_images = max(len(x) for x in images) - bsz = len(images) - if max_num_images == 0: - data = None - else: - images_and_aspect_ratios = [ - [self.image_transform(im) for im in row] for row in images - ] - transformed_images = [ - [x[0] for x in row] for row in images_and_aspect_ratios - ] - - aspect_ratios = torch.ones(bsz, max_num_images, 2, dtype=torch.int64) - for i, row in enumerate(images_and_aspect_ratios): - if len(row) > 0: - aspect_ratios[i, : len(row)] = torch.stack( - [torch.tensor(x[1]) for x in row] - ) - assert bsz == 1, "the below code is not for batched images" - data = { - 'pixel_values': transformed_images[0], - 'aspect_ratios': aspect_ratios[0], - } - # print("transformed_images", transformed_images) - # for i, row in enumerate(transformed_images): - # for j, x in enumerate(row): - # print(i, j, x.shape) - # print("aspect_ratios", aspect_ratios) - # stacked_images, num_chunks = _stack_images( - # transformed_images, - # self.vision_max_num_chunks, - # self.vision_chunk_size, - # max_num_images, - # ) - # print("stacked_images", stacked_images.shape) - # print("num_chunks", num_chunks) - return BatchFeature(data, tensor_type=None) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index d904b035c7f5..2a2d74382e37 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -11,15 +11,14 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizers import (BaichuanTokenizer, - MistralTokenizer, - LlamaVLTokenizer) + MistralTokenizer) from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import make_async logger = init_logger(__name__) AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, - MistralTokenizer, LlamaVLTokenizer] + MistralTokenizer] def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: diff --git a/vllm/transformers_utils/tokenizers/__init__.py b/vllm/transformers_utils/tokenizers/__init__.py index 2fa055890da8..9433f2d48f6f 100644 --- a/vllm/transformers_utils/tokenizers/__init__.py +++ b/vllm/transformers_utils/tokenizers/__init__.py @@ -1,5 +1,4 @@ from vllm.transformers_utils.tokenizers.baichuan import BaichuanTokenizer from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer -from vllm.transformers_utils.tokenizers.llamavl import LlamaVLTokenizer __all__ = ["BaichuanTokenizer", "MistralTokenizer"] diff --git a/vllm/transformers_utils/tokenizers/llamavl.py b/vllm/transformers_utils/tokenizers/llamavl.py deleted file mode 100644 index 532fa269bb8c..000000000000 --- a/vllm/transformers_utils/tokenizers/llamavl.py +++ /dev/null @@ -1,221 +0,0 @@ -import os -from logging import getLogger -from pathlib import Path -from typing import ( - AbstractSet, - cast, - Collection, - Dict, - Iterator, - List, - Literal, - Optional, - Sequence, - Union, - Any, -) -from transformers.tokenization_utils import PreTrainedTokenizer - -# TODO: now use tiktoken, but I believe it should be replaced with tokenizer in huggingface -import tiktoken - -from tiktoken.load import load_tiktoken_bpe - -logger = getLogger(__name__) - - -# The tiktoken tokenizer can handle <=400k chars without -# pyo3_runtime.PanicException. -TIKTOKEN_MAX_ENCODE_CHARS = 400_000 - -# https://github.com/openai/tiktoken/issues/195 -# Here we iterate over subsequences and split if we exceed the limit -# of max consecutive non-whitespace or whitespace characters. -MAX_NO_WHITESPACES_CHARS = 25_000 - -# TODO: this class is with some hack. need toreplace with official release -class LlamaVLTokenizer: - """ - Tokenizing and encoding/decoding text using the Tiktoken tokenizer. - """ - - special_tokens: Dict[str, int] - - num_reserved_special_tokens = 256 - - pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 - - def __init__(self, model_path: str): - """ - Initializes the Tokenizer with a Tiktoken model. - - Args: - model_path (str): The path to the Tiktoken model file. - """ - assert os.path.isfile(model_path), model_path - - mergeable_ranks = load_tiktoken_bpe(model_path) - num_base_tokens = len(mergeable_ranks) - special_tokens = [ - "<|begin_of_text|>", - "<|end_of_text|>", - "<|reserved_special_token_0|>", - "<|reserved_special_token_1|>", - "<|finetune_right_pad_id|>", - "<|step_id|>", - "<|start_header_id|>", - "<|end_header_id|>", - "<|eom_id|>", # end of message - "<|eot_id|>", # end of turn - "<|python_tag|>", - "<|image|>", - ] - reserved_tokens = [ - f"<|reserved_special_token_{2 + i}|>" - for i in range(self.num_reserved_special_tokens - len(special_tokens)) - ] - special_tokens = special_tokens + reserved_tokens - - self.special_tokens = { - token: num_base_tokens + i for i, token in enumerate(special_tokens) - } - self.special_tokens["<|image|>"] = 128256 - self.model = tiktoken.Encoding( - name=Path(model_path).name, - pat_str=self.pat_str, - mergeable_ranks=mergeable_ranks, - special_tokens=self.special_tokens, - ) - - self.n_words: int = num_base_tokens + len(special_tokens) - # BOS / EOS token IDs - self.bos_id: int = self.special_tokens["<|begin_of_text|>"] - self.eos_id: int = self.special_tokens["<|end_of_text|>"] - self.eot_id: int = self.special_tokens["<|eot_id|>"] - self.eom_id: int = self.special_tokens["<|eom_id|>"] - self.python_tag_id = self.special_tokens["<|python_tag|>"] - self.pad_id: int = self.special_tokens["<|finetune_right_pad_id|>"] - self.stop_tokens = [ - self.eos_id, - self.special_tokens["<|eom_id|>"], - self.special_tokens["<|eot_id|>"], - ] - - self.bos_token_id = self.bos_id - self.eos_token_id = self.eos_id - print("need to replace tokenizer with official release") - print("warning: recheck add bos and add eos of encode function") - - # the following attributes are set to fit VLLM's design (copied from MistralTokenizer) - self.is_fast = False - self.chat_template = True - self.all_special_ids: List[Any] = [] - self.all_special_tokens: List[Any] = [] - self.all_special_tokens_extended: List[Any] = [] - - def get_added_vocab(self) -> List[str]: - return [] - - def encode( - self, - s: str, - ) -> List[int]: - """ - Encodes a string into a list of token IDs. - - Args: - s (str): The input string to be encoded. - bos (bool): Whether to prepend the beginning-of-sequence token. - eos (bool): Whether to append the end-of-sequence token. - allowed_special ("all"|set[str]): allowed special tokens in string - disallowed_special ("all"|set[str]): special tokens that raise an error when in string - - Returns: - list[int]: A list of token IDs. - - By default, setting disallowed_special=() encodes a string by ignoring - special tokens. Specifically: - - Setting `disallowed_special` to () will cause all text corresponding - to special tokens to be encoded as natural text (insteading of raising - an error). - - Setting `allowed_special` to "all" will treat all text corresponding - to special tokens to be encoded as special tokens. - """ - bos = False - eos = False - assert type(s) is str - - substrs = ( - substr - for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) - for substr in self._split_whitespaces_or_nonwhitespaces( - s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS - ) - ) - t: List[int] = [] - for substr in substrs: - t.extend( - self.model.encode( - substr, - allowed_special="all", - disallowed_special=set(), - ) - ) - if bos: - t.insert(0, self.bos_id) - if eos: - t.append(self.eos_id) - return t - - def convert_ids_to_tokens( - self, - tokens: List[int], - skip_special_tokens: Optional[bool] = True) -> List[str]: - # TODO(Patrick) - potentially allow special tokens to not be skipped - assert ( - skip_special_tokens - ), "Skipping special tokens is not supported for Mistral tokenizers." - - # assert isinstance(self.tokenizer, - # (Tekkenizer, SentencePieceTokenizer)), type( - # self.tokenizer) - - # TODO: handle skip_special_tokens - # TODO: self.model.decode returns a string, but the interface expects a list of words - return [self.model.decode(tokens)] - - def convert_tokens_to_string(self, tokens: List[str]) -> str: - return "".join(tokens) - - @staticmethod - def _split_whitespaces_or_nonwhitespaces( - s: str, max_consecutive_slice_len: int - ) -> Iterator[str]: - """ - Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` - consecutive whitespaces or consecutive non-whitespaces. - """ - current_slice_len = 0 - current_slice_is_space = s[0].isspace() if len(s) > 0 else False - slice_start = 0 - - for i in range(len(s)): - is_now_space = s[i].isspace() - - if current_slice_is_space ^ is_now_space: - current_slice_len = 1 - current_slice_is_space = is_now_space - else: - current_slice_len += 1 - if current_slice_len > max_consecutive_slice_len: - yield s[slice_start:i] - slice_start = i - current_slice_len = 1 - yield s[slice_start:] - - @classmethod - def from_pretrained(cls, model_path: str) -> "LlamaVLTokenizer": - return cls(os.path.join(model_path, "tokenizer.model")) - - def __len__(self): - return self.n_words \ No newline at end of file From 65a470bee709c34cf490646eea569d87b2b1c5ed Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 19 Sep 2024 13:50:52 -0700 Subject: [PATCH 34/75] hardcode some config to read huggingface's config.json without modifying it --- vllm/config.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index 9684cea81313..77f7ac5e3c7f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -583,6 +583,8 @@ def get_multimodal_config(self) -> "MultiModalConfig": @property def is_encoder_decoder_model(self) -> bool: """Extract the HF encoder/decoder model flag.""" + if self.hf_config.model_type == "mllama": + return True return getattr(self.hf_config, "is_encoder_decoder", False) @property @@ -1627,6 +1629,12 @@ def _get_and_verify_dtype( "of float16 by default. Please specify `dtype` if you " "want to use float16.") torch_dtype = torch.bfloat16 + if config.model_type == "mllama_text_model": + # the config is MllamaConfig.text_config + logger.info( + "Some llama vision models lack default dtype. Hardcode to bfloat16." + ) + torch_dtype = torch.bfloat16 else: # Following the common practice, we use float16 for float32 # models. From 2146716c61197f20fb2226abdb2c3a57fada9641 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 19 Sep 2024 15:37:44 -0700 Subject: [PATCH 35/75] move prompt to encoder prompt --- examples/openai_vision_api_client.py | 2 +- examples/template_llama3.2.jinja | 26 +------------------------- vllm/model_executor/models/mllama.py | 3 +++ 3 files changed, 5 insertions(+), 26 deletions(-) diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py index 09854003d501..6bb66b694b98 100644 --- a/examples/openai_vision_api_client.py +++ b/examples/openai_vision_api_client.py @@ -38,7 +38,7 @@ "content": [ { "type": "text", - "text": "What’s in this image?" + "text": "Describe image in two sentences" }, { "type": "image_url", diff --git a/examples/template_llama3.2.jinja b/examples/template_llama3.2.jinja index 66a074be5610..93049d23eaff 100644 --- a/examples/template_llama3.2.jinja +++ b/examples/template_llama3.2.jinja @@ -1,25 +1 @@ -{% for message in messages %} - {% if loop.index0 == 0 %} - {{ bos_token }} - {% endif %} - - {{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }} - - {% if message['content'] is string %} - {{ message['content'] }} - {% else %} - {% for content in message['content'] %} - {% if content['type'] == 'image' %} - {{ '<|image|>' }} - {% elif content['type'] == 'text' %} - {{ content['text'] }} - {% endif %} - {% endfor %} - {% endif %} - - {{ '<|eot_id|>' }} -{% endfor %} - -{% if add_generation_prompt %} - {{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }} -{% endif %} \ No newline at end of file +{% for message in messages %}{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' %}{{ '<|image|>' }}{% elif content['type'] == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{ '<|eot_id|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %} \ No newline at end of file diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 669bb6f626ff..fc5c82b9a1b4 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -105,6 +105,8 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): assert hf_config.vision_config.image_size % 14 == 0, "chunk size should be multiple of 14" token_per_chunk = (hf_config.vision_config.image_size // 14) ** 2 + 1 num_tokens = num_tiles * token_per_chunk + llm_inputs["prompt"] = llm_inputs["encoder_prompt"] + llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"] llm_inputs["encoder_prompt"] = "<|image|>" * num_tokens llm_inputs["encoder_prompt_token_ids"] = [LLAMA_IMAGE_TOKEN_ID] * num_tokens @@ -1328,6 +1330,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> Union[Tuple, CausalLMOutputWithPast]: + print("input_ids", input_ids) if attn_metadata.num_prefill_tokens > 0 and attn_metadata.num_decode_tokens > 0: raise ValueError("Chunk prefill not supported") image_inputs = self._parse_and_validate_image_input(**kwargs) From 062534bc64422a083d395233c1d95c5b18d36df2 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 19 Sep 2024 16:56:09 -0700 Subject: [PATCH 36/75] hardcode to match tokenizer result --- vllm/model_executor/models/mllama.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index fc5c82b9a1b4..4e934b3ef551 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -105,13 +105,17 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): assert hf_config.vision_config.image_size % 14 == 0, "chunk size should be multiple of 14" token_per_chunk = (hf_config.vision_config.image_size // 14) ** 2 + 1 num_tokens = num_tiles * token_per_chunk - llm_inputs["prompt"] = llm_inputs["encoder_prompt"] - llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"] + if llm_inputs.get("prompt") is None: + llm_inputs["prompt"] = llm_inputs["encoder_prompt"] + llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"] + if 198 in llm_inputs["prompt_token_ids"]: + index_198 = llm_inputs["prompt_token_ids"].index(198) + if index_198 > 0 and llm_inputs["prompt_token_ids"][index_198 - 1] == LLAMA_IMAGE_TOKEN_ID: + llm_inputs["prompt_token_ids"] = llm_inputs["prompt_token_ids"][:index_198] + llm_inputs["prompt_token_ids"][index_198+1:] llm_inputs["encoder_prompt"] = "<|image|>" * num_tokens llm_inputs["encoder_prompt_token_ids"] = [LLAMA_IMAGE_TOKEN_ID] * num_tokens assert "decoder_multi_modal_data" not in llm_inputs, "multi-modal data should be put in encoder message of LLaMA Vision" - return llm_inputs From 23f04b4b730da36354523eba74b5a9ff9da3eb7d Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 19 Sep 2024 17:55:31 -0700 Subject: [PATCH 37/75] update test script --- examples/openai_vision_api_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py index 6bb66b694b98..eb3ad9ffb0e9 100644 --- a/examples/openai_vision_api_client.py +++ b/examples/openai_vision_api_client.py @@ -28,7 +28,7 @@ model = models.data[0].id # Single-image input inference -image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" +image_url = "https://llava-vl.github.io/static/images/view.jpg" ## Use image url in the payload chat_completion_from_url = client.chat.completions.create( From 4ed4e6e4f97de0d585bdfd849f74333a01bd750b Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Thu, 19 Sep 2024 19:09:58 -0700 Subject: [PATCH 38/75] update test script --- examples/openai_vision_api_client.py | 1 + vllm/model_executor/models/mllama.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py index eb3ad9ffb0e9..442272c6d778 100644 --- a/examples/openai_vision_api_client.py +++ b/examples/openai_vision_api_client.py @@ -50,6 +50,7 @@ }], model=model, max_tokens=64, + temperature=0.0, ) result = chat_completion_from_url.choices[0].message.content diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 4e934b3ef551..f6d54a74cfa5 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1334,7 +1334,6 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> Union[Tuple, CausalLMOutputWithPast]: - print("input_ids", input_ids) if attn_metadata.num_prefill_tokens > 0 and attn_metadata.num_decode_tokens > 0: raise ValueError("Chunk prefill not supported") image_inputs = self._parse_and_validate_image_input(**kwargs) From c140258667cffe9a0e7c64c3720c9382ef220a79 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 20 Sep 2024 20:41:32 -0700 Subject: [PATCH 39/75] support text-only input --- examples/openai_vision_api_client.py | 27 ++++++-- vllm/model_executor/models/mllama.py | 93 ++++++++++------------------ 2 files changed, 57 insertions(+), 63 deletions(-) diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py index 442272c6d778..d48b2b3886c9 100644 --- a/examples/openai_vision_api_client.py +++ b/examples/openai_vision_api_client.py @@ -33,8 +33,7 @@ ## Use image url in the payload chat_completion_from_url = client.chat.completions.create( messages=[{ - "role": - "user", + "role": "user", "content": [ { "type": "text", @@ -47,14 +46,34 @@ }, }, ], - }], + } + ], model=model, max_tokens=64, temperature=0.0, ) result = chat_completion_from_url.choices[0].message.content -print("Chat completion output:", result) +print("Text + image output:", result) + +chat_completion_text_only = client.chat.completions.create( + messages=[{ + "role": "user", + "content": [ + { + "type": "text", + "text": "what is the recipe of mayonnaise in two sentences?" + }, + ] + } + ], + model=model, + max_tokens=64, + temperature=0.0, +) + +result = chat_completion_text_only.choices[0].message.content +print("Text-only output output:", result) print("remove me: testing done, exitting...") import sys; sys.exit(0) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index f6d54a74cfa5..fb9955a533cd 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -89,6 +89,13 @@ def recursive_sum(x): return 0 def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): + if llm_inputs.get("prompt") is None: + llm_inputs["prompt"] = llm_inputs["encoder_prompt"] + llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"] + if 198 in llm_inputs["prompt_token_ids"]: + index_198 = llm_inputs["prompt_token_ids"].index(198) + if index_198 > 0 and llm_inputs["prompt_token_ids"][index_198 - 1] == LLAMA_IMAGE_TOKEN_ID: + llm_inputs["prompt_token_ids"] = llm_inputs["prompt_token_ids"][:index_198] + llm_inputs["prompt_token_ids"][index_198+1:] multi_modal_data = llm_inputs.get("encoder_multi_modal_data") hf_config = ctx.model_config.hf_config if multi_modal_data is None or "image" not in multi_modal_data: @@ -105,13 +112,6 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): assert hf_config.vision_config.image_size % 14 == 0, "chunk size should be multiple of 14" token_per_chunk = (hf_config.vision_config.image_size // 14) ** 2 + 1 num_tokens = num_tiles * token_per_chunk - if llm_inputs.get("prompt") is None: - llm_inputs["prompt"] = llm_inputs["encoder_prompt"] - llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"] - if 198 in llm_inputs["prompt_token_ids"]: - index_198 = llm_inputs["prompt_token_ids"].index(198) - if index_198 > 0 and llm_inputs["prompt_token_ids"][index_198 - 1] == LLAMA_IMAGE_TOKEN_ID: - llm_inputs["prompt_token_ids"] = llm_inputs["prompt_token_ids"][:index_198] + llm_inputs["prompt_token_ids"][index_198+1:] llm_inputs["encoder_prompt"] = "<|image|>" * num_tokens llm_inputs["encoder_prompt_token_ids"] = [LLAMA_IMAGE_TOKEN_ID] * num_tokens @@ -937,7 +937,7 @@ def forward( hidden_states: torch.Tensor, cross_attention_states: torch.Tensor, cross_attention_mask: torch.Tensor, - full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], + full_text_row_masked_out_mask: torch.Tensor, kv_cache: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: @@ -956,8 +956,7 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) - if full_text_row_masked_out_mask is not None: - hidden_states = full_text_row_masked_out_mask[:, 0] * hidden_states # type: ignore + hidden_states = full_text_row_masked_out_mask * hidden_states hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states return hidden_states @@ -1000,21 +999,23 @@ def forward( full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + skip_cross_attention: bool, ) -> torch.Tensor: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds for idx, decoder_layer in enumerate(self.layers): if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer): - hidden_states = decoder_layer( - hidden_states=hidden_states, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - # xattn_cache=xattn_caches[xattn_layer_idx] if xattn_caches is not None else None, - kv_cache=kv_caches[idx], - attn_metadata=attn_metadata, - ) + if not skip_cross_attention: + hidden_states = decoder_layer( + hidden_states=hidden_states, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + # xattn_cache=xattn_caches[xattn_layer_idx] if xattn_caches is not None else None, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + ) elif isinstance(decoder_layer, MllamaSelfAttentionDecoderLayer): hidden_states = decoder_layer( hidden_states=hidden_states, @@ -1124,6 +1125,7 @@ def forward( full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + skip_cross_attention: bool, ) -> torch.Tensor: hidden_states = self.model( input_ids=input_ids, @@ -1133,6 +1135,7 @@ def forward( full_text_row_masked_out_mask=full_text_row_masked_out_mask, kv_caches=kv_caches, attn_metadata=attn_metadata, + skip_cross_attention=skip_cross_attention, ) return hidden_states @@ -1339,11 +1342,11 @@ def forward( image_inputs = self._parse_and_validate_image_input(**kwargs) if image_inputs is None: cross_attention_mask = None - run_xattn_mask = (attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).cuda() + full_text_row_masked_out_mask = (attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to(input_ids.device) xattn_caches = None vision_tokens = None cross_attention_states = None - full_text_row_masked_out_mask = None + skip_cross_attention = max(attn_metadata.encoder_seq_lens) > 0 else: # llama's reference implementation runs the vision model on CPU pixel_values = image_inputs['data'] @@ -1363,42 +1366,15 @@ def forward( start_pos = end_pos cross_attention_states = cross_attention_states_flat cross_attention_mask = None # TODO - full_text_row_masked_out_mask = None # TODO - - # run_xattn_mask = torch.ones((attn_metadata.num_prefill_tokens, 1), dtype=torch.bool, device=cross_attention_states.device) - # start_pos = 0 - # for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens_tensor, attn_metadata.encoder_seq_lens): - # if encoder_seq_len == 0: - # run_xattn_mask[start_pos:start_pos+seq_len] = False - # start_pos += seq_len - - # if pixel_values is not None: - # if aspect_ratio_ids is None: - # raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") - # # get vision tokens from vision model - # cross_attention_states = self.vision_model(pixel_values, aspect_ratio_ids, aspect_ratio_mask) - # cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( - # -1, cross_attention_states.shape[-2], self.hidden_size - # ) - - # cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( - # cross_attention_mask, - # past_key_values=past_key_values, - # num_vision_tokens=self.vision_model.num_patches, - # cross_attention_layers=self.language_model.model.cross_attention_layers, - # cross_attention_states=cross_attention_states, - # device=self.device, - # dtype=self.dtype, - # ) - - # if cross_attention_mask is not None and cache_position is not None: - # cross_attention_mask = cross_attention_mask[:, :, cache_position] - # full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] - - # print("input_ids", input_ids, cross_attention_states is None) - # if positions.numel() == 1: - # global step_name - # step_name = f"decode_{positions.item()}" + + full_text_row_masked_out_mask = torch.ones((attn_metadata.num_prefill_tokens, 1), dtype=torch.bool) + start_pos = 0 + for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens_tensor.cpu(), attn_metadata.encoder_seq_lens): + if encoder_seq_len == 0: + full_text_row_masked_out_mask[start_pos:start_pos+seq_len] = False + start_pos += seq_len + full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(cross_attention_states.device) + skip_cross_attention = False outputs = self.language_model( input_ids=input_ids, @@ -1408,9 +1384,8 @@ def forward( full_text_row_masked_out_mask=full_text_row_masked_out_mask, kv_caches=kv_caches, attn_metadata=attn_metadata, + skip_cross_attention=skip_cross_attention, ) - # if positions.numel() == 1 and positions.item() == 20: - # exit(0) return outputs From f662fddc130f3114404c6d38a82c36351ff4ac95 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Fri, 20 Sep 2024 23:53:06 -0700 Subject: [PATCH 40/75] fix bug in text only prompt --- vllm/model_executor/models/mllama.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index fb9955a533cd..31a01a329009 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -99,13 +99,12 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data = llm_inputs.get("encoder_multi_modal_data") hf_config = ctx.model_config.hf_config if multi_modal_data is None or "image" not in multi_modal_data: + llm_inputs["encoder_prompt"] = "" + llm_inputs["encoder_prompt_token_ids"] = [] return llm_inputs global image_processor if image_processor is None: - image_processor = MllamaImageProcessor( - ctx.model_config.model, - size={"height": hf_config.vision_config.image_size, "width": hf_config.vision_config.image_size}, - ) + image_processor = MllamaImageProcessor.from_pretrained(ctx.model_config.model) processed_image = image_processor(multi_modal_data["image"]) llm_inputs["encoder_multi_modal_data"]["image"] = processed_image num_tiles = recursive_sum(processed_image["num_tiles"]) @@ -951,6 +950,7 @@ def forward( kv_cache=kv_cache, attn_metadata=attn_metadata, ) + hidden_states = full_text_row_masked_out_mask * hidden_states hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states residual = hidden_states @@ -1346,7 +1346,7 @@ def forward( xattn_caches = None vision_tokens = None cross_attention_states = None - skip_cross_attention = max(attn_metadata.encoder_seq_lens) > 0 + skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0 else: # llama's reference implementation runs the vision model on CPU pixel_values = image_inputs['data'] From 6cf166ad198ea41956bdb077cef5c1e26ed98e49 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 21 Sep 2024 12:50:51 -0700 Subject: [PATCH 41/75] add unit test --- .../vision_language/test_mllama.py | 236 ++++++++++++++++++ 1 file changed, 236 insertions(+) create mode 100644 tests/models/encoder_decoder/vision_language/test_mllama.py diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py new file mode 100644 index 000000000000..23c712e3366c --- /dev/null +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -0,0 +1,236 @@ +from typing import List, Optional, Tuple, Type, overload + +import pytest +from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, + BatchEncoding) + +from vllm.multimodal.utils import rescale_image_size +from vllm.sequence import SampleLogprobs +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE + +from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) +from ...utils import check_logprobs_close + +_LIMIT_IMAGE_PER_PROMPT = 1 + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + "<|image|><|begin_of_text|>The meaning of the image is", + "cherry_blossom": + "<|image|><|begin_of_text|>The city is", + None: "<|begin_of_text|>The color of the sky is blue but sometimes it can also be", +}) + +models = [ + "nltpt/Llama-3.2-11B-Vision-Instruct", # TODO: Update model path to huggingface model +] + + +def vllm_to_hf_output(vllm_output: Tuple[List[int], str, + Optional[SampleLogprobs]], + model: str): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + print("output_str:", output_str) + + config = AutoConfig.from_pretrained(model) + image_token_id = config.image_token_index + + tokenizer = AutoTokenizer.from_pretrained(model) + eos_token_id = tokenizer.eos_token_id + + hf_output_ids = [ + token_id for idx, token_id in enumerate(output_ids) + if token_id != image_token_id or output_ids[idx - 1] != image_token_id + ] + + assert output_str[0] == " " + hf_output_str = output_str[1:] + if hf_output_ids[-1] == eos_token_id: + hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) + + return hf_output_ids, hf_output_str, out_logprobs + + +@overload +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: List[float], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + ... + + +@overload +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + sizes: List[Tuple[int, int]], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + ... + + +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: Optional[List[float]] = None, + sizes: Optional[List[Tuple[int, int]]] = None, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + images = [asset.pil_image for asset in image_assets] + + if size_factors is not None: + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + elif sizes is not None: + inputs_per_image = [( + [prompt for _ in sizes], + [image.resize(size) for size in sizes], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + else: + raise ValueError("You must provide either `size_factors` or `sizes`") + + _run_test(hf_runner, + vllm_runner, + inputs_per_image, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend) + + +def _run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + inputs: List[Tuple[List[str], PromptImageInput]], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test are from IMAGE_ASSETS. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding MultiModalConfig as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + # max_model_len should be greater than image_feature_size + with vllm_runner(model, + dtype=dtype, + max_num_seqs=16, + max_model_len=4096, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT + }) as vllm_model: + vllm_outputs_per_image = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs + ] + + + def process(hf_inputs: BatchEncoding): + return hf_inputs + + with hf_runner(model, + dtype=dtype, + postprocess_inputs=process, + auto_cls=AutoModelForVision2Seq) as hf_model: + hf_outputs_per_image = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, + vllm_outputs_per_image): + # TODO: Check whether using original CLIPVisionModel can improve + # consistency against HF + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, model) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, + dtype, max_tokens, num_logprobs) -> None: + run_test( + hf_runner, + vllm_runner, + image_assets, + model, + size_factors=size_factors, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + From b7124e5cca75475090c671cf528eec897f485e5f Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 21 Sep 2024 14:11:56 -0700 Subject: [PATCH 42/75] add complex tests, but cannot run single-gpu and multi-gpu at the same time --- .../vision_language/__init__.py | 0 .../vision_language/test_mllama.py | 66 ++++++++++++++----- vllm/inputs/preprocess.py | 2 +- vllm/model_executor/models/mllama.py | 3 +- 4 files changed, 54 insertions(+), 17 deletions(-) create mode 100644 tests/models/encoder_decoder/vision_language/__init__.py diff --git a/tests/models/encoder_decoder/vision_language/__init__.py b/tests/models/encoder_decoder/vision_language/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index 23c712e3366c..d1934f0b15ca 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -11,6 +11,7 @@ from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, _ImageAssets) from ...utils import check_logprobs_close +from ....utils import multi_gpu_test _LIMIT_IMAGE_PER_PROMPT = 1 @@ -19,9 +20,12 @@ "<|image|><|begin_of_text|>The meaning of the image is", "cherry_blossom": "<|image|><|begin_of_text|>The city is", - None: "<|begin_of_text|>The color of the sky is blue but sometimes it can also be", }) +text_only_prompts = [ + "The color of the sky is blue but sometimes it can also be", +] + models = [ "nltpt/Llama-3.2-11B-Vision-Instruct", # TODO: Update model path to huggingface model ] @@ -31,8 +35,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], model: str): """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - print("output_str:", output_str) + output_ids, output_str, out_logprobs = vllm_outpu config = AutoConfig.from_pretrained(model) image_token_id = config.image_token_index @@ -110,9 +113,11 @@ def run_test( ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] elif sizes is not None: inputs_per_image = [( - [prompt for _ in sizes], - [image.resize(size) for size in sizes], + [prompt if size is not None else text_only_prompts[0] for size in sizes], + [image.resize(size) if size is not None else None for size in sizes], ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + if len(sizes) == 0: + inputs_per_image.append((text_only_prompts, [None] * len(text_only_prompts))) else: raise ValueError("You must provide either `size_factors` or `sizes`") @@ -205,32 +210,63 @@ def process(hf_inputs: BatchEncoding): @pytest.mark.parametrize("model", models) @pytest.mark.parametrize( - "size_factors", + "sizes", [ - # No image + # Text only [], - # Single-scale - [1.0], - # Single-scale, batched - [1.0, 1.0, 1.0], - # Multi-scale - [0.25, 0.5, 1.0], + # Single-size + [(512, 512)], + # Single-size, batched + [(512, 512), (512, 512), (512, 512)], + # Multi-size, batched + [(512, 512), (1024, 512), (1536, 512), (2048, 512), + (512, 1024), (1024, 1024), (512, 1536), (512, 2028)], + # Multi-size, batched, including text only + [(512, 512), (1024, 512), (1536, 512), (2048, 512), + (512, 1024), (1024, 1024), (512, 1536), (512, 2028), None], + # mllama has 8 possible aspect ratios, carefully set the sizes to cover all of them ], ) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, +def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype, max_tokens, num_logprobs) -> None: run_test( hf_runner, vllm_runner, image_assets, model, - size_factors=size_factors, + sizes=sizes, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, tensor_parallel_size=1, ) + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "sizes", + [ + [(512, 512), (1024, 512), (1536, 512), (2048, 512), + (512, 1024), (1024, 1024), (512, 1536), (512, 2028), None], + ], +) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_distributed(hf_runner, vllm_runner, image_assets, model, sizes, + dtype, max_tokens, num_logprobs) -> None: + run_test( + hf_runner, + vllm_runner, + image_assets, + model, + sizes=sizes, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=2, + ) \ No newline at end of file diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 84db659b9919..03759be7856b 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -300,7 +300,7 @@ def _build_enc_dec_llm_inputs( raise ValueError("Multi-modality decoder inputs of encoder-decoder models are " "not supported yet") - # For Multi-Modal models, the start token can be the image token + # For Multi-Modal models (e.g., mllama), the text input can be <|image|><|begin_of_text|>hello world. And we should not add another <|begin_of_text|> to the beginning. decoder_prompt_ids = ( self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids, force_bos=(encoder_mm_data is None and decoder_mm_data is None))) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 31a01a329009..a04e28b64e7a 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -98,9 +98,10 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): llm_inputs["prompt_token_ids"] = llm_inputs["prompt_token_ids"][:index_198] + llm_inputs["prompt_token_ids"][index_198+1:] multi_modal_data = llm_inputs.get("encoder_multi_modal_data") hf_config = ctx.model_config.hf_config - if multi_modal_data is None or "image" not in multi_modal_data: + if multi_modal_data is None or "image" not in multi_modal_data or multi_modal_data["image"] is None: llm_inputs["encoder_prompt"] = "" llm_inputs["encoder_prompt_token_ids"] = [] + llm_inputs["encoder_multi_modal_data"] = {} return llm_inputs global image_processor if image_processor is None: From e69f1273999117ced7f1fe343c4fcfa8c26f00cb Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sat, 21 Sep 2024 16:45:19 -0700 Subject: [PATCH 43/75] seperate encoder/decoder dummy input, support max_image=1 --- vllm/inputs/registry.py | 52 ++++++++++++++++++++++++---- vllm/model_executor/models/mllama.py | 34 ++++++++++++------ vllm/worker/enc_dec_model_runner.py | 23 ++++++++---- 3 files changed, 86 insertions(+), 23 deletions(-) diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index ae6c6c05d9f7..1cc777543e27 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -111,6 +111,8 @@ class InputRegistry: def __init__(self) -> None: self._dummy_factories_by_model_type: Dict[Type[nn.Module], DummyDataFactory] = {} + self._dummy_encoder_factories_by_model_type: Dict[Type[nn.Module], + DummyDataFactory] = {} self._input_processors_by_model_type: Dict[Type[nn.Module], InputProcessor] = {} @@ -157,12 +159,33 @@ def wrapper(model_cls: N) -> N: return model_cls return wrapper + + def register_dummy_encoder_data(self, factory: DummyDataFactory): + """ + Register a dummy encoder data factory to a model class + + This is similar to :meth:`~register_dummy_data`, but for encoder input. + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._dummy_encoder_factories_by_model_type: + logger.warning( + "Model class %s already has dummy encoder data " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._dummy_encoder_factories_by_model_type[model_cls] = factory + + return model_cls + + return wrapper def dummy_data_for_profiling( self, model_config: "ModelConfig", seq_len: int, mm_registry: "MultiModalRegistry", + is_encoder_data: bool = False, ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: """ Create dummy data for profiling the memory usage of a model. @@ -180,8 +203,19 @@ def dummy_data_for_profiling( from vllm.model_executor.model_loader import get_model_architecture model_cls, _ = get_model_architecture(model_config) - dummy_factory = self._dummy_factories_by_model_type \ - .get(model_cls, self._default_dummy_data_factory) + if is_encoder_data: + if model_cls in self._dummy_encoder_factories_by_model_type: + dummy_factory = self._dummy_encoder_factories_by_model_type[model_cls] + else: + logger.warning( + "No dummy encoder data factory registered to %s. " + "Using the dummy data factory for the model instead.", + model_cls) + dummy_factory = self._dummy_factories_by_model_type \ + .get(model_cls, self._default_dummy_data_factory) + else: + dummy_factory = self._dummy_factories_by_model_type \ + .get(model_cls, self._default_dummy_data_factory) mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) seq_data, mm_data = dummy_factory( @@ -192,10 +226,16 @@ def dummy_data_for_profiling( # Having more tokens is over-conservative but otherwise fine num_tokens = seq_data.prompt_token_ids - assert len(num_tokens) >= seq_len, ( - f"Expected at least {seq_len} dummy tokens for profiling, " - f"but found {len(num_tokens)} tokens instead.") - + if len(num_tokens) < seq_len: + if is_encoder_data: + logger.warning( + "Expected at least %d dummy encoder tokens for profiling, " + "but found %d tokens instead.", + seq_len, len(num_tokens)) + else: + assert False, ( + f"Expected at least {seq_len} dummy tokens for profiling, " + f"but found {len(num_tokens)} tokens instead.") if mm_data is not None: for k, v in mm_data.items(): num_items = len(v) if isinstance(v, list) else 1 diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index a04e28b64e7a..308368326bcc 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -118,11 +118,17 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): assert "decoder_multi_modal_data" not in llm_inputs, "multi-modal data should be put in encoder message of LLaMA Vision" return llm_inputs +def get_max_mllama_image_tokens(ctx: InputContext) -> int: + hf_config = ctx.model_config.hf_config + token_per_chunk = (hf_config.vision_config.image_size // 14) ** 2 + 1 + return hf_config.vision_config.max_num_tiles * token_per_chunk -def dummy_seq_data( + +def dummy_decoder_seq_data( seq_len: int, num_images: int ): + # <|image|> * num_images + 0 * (seq_len - num_images) assert seq_len >= num_images, "seq_len should be greater than or equal to num_images" token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [LLAMA_IMAGE_TOKEN_ID]) * num_images @@ -131,22 +137,29 @@ def dummy_seq_data( return SequenceData(token_ids) +def dummy_encoder_seq_data( + ctx: InputContext, + num_images: int +): + num_tokens = get_max_mllama_image_tokens(ctx) * num_images + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [LLAMA_IMAGE_TOKEN_ID]) * num_tokens + return SequenceData(token_ids) + + def dummy_image( num_images: int, ): - width = height = 512 + width = height = 1024 image = Image.new("RGB", (width, height), color=0) return {"image": image if num_images == 1 else [image] * num_images} -def dummy_data_for_mllama(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): +def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): num_images = mm_counts["image"] - return dummy_seq_data(seq_len, num_images), dummy_image(num_images) - -def get_max_mllama_image_tokens(ctx: InputContext) -> int: - hf_config = ctx.model_config.hf_config - token_per_chunk = (hf_config.vision_config.image_size // 14) ** 2 + 1 - return hf_config.vision_config.max_num_tiles * token_per_chunk + return dummy_decoder_seq_data(seq_len, num_images), None +def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len:int, mm_counts: Mapping[str, int]): + num_images = mm_counts["image"] + return dummy_encoder_seq_data(ctx, num_images), dummy_image(num_images) # Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position def _prepare_4d_causal_attention_mask_with_cache_position( @@ -1217,7 +1230,8 @@ def prepare_inputs_for_generation( @MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_mllama) +@INPUT_REGISTRY.register_dummy_data(dummy_decoder_data_for_mllama) +@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama) @INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, config: MllamaConfig, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 2330975b54ac..bf042f9f563c 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -291,25 +291,34 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - seq_data, dummy_multi_modal_data = self.input_registry \ + decoder_seq_data, decoder_dummy_multi_modal_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, seq_len, - self.mm_registry) + self.mm_registry, + is_encoder_data=False) + encoder_seq_data, encoder_dummy_multi_modal_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, + seq_len, + self.mm_registry, + is_encoder_data=True) # Having more tokens is over-conservative but otherwise fine - assert len(seq_data.prompt_token_ids) >= seq_len, ( + assert len(decoder_seq_data.prompt_token_ids) >= seq_len, ( f"Expected at least {seq_len} dummy tokens for profiling, " - f"but got: {len(seq_data.prompt_token_ids)}") + f"but got: {len(decoder_seq_data.prompt_token_ids)}") + + assert decoder_dummy_multi_modal_data is None or encoder_dummy_multi_modal_data is None, ( + "Multi-modal data cannot be provided for both encoder and decoder") seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, - seq_data={group_id: seq_data}, + seq_data={group_id: decoder_seq_data}, sampling_params=sampling_params, block_tables=None, - encoder_seq_data=seq_data, + encoder_seq_data=encoder_seq_data, cross_block_table=None, - multi_modal_data=dummy_multi_modal_data, + multi_modal_data=decoder_dummy_multi_modal_data or encoder_dummy_multi_modal_data, ) seqs.append(seq) From e0e297cc47170e9d1ab9fba2602a7b5fc67c8417 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 22 Sep 2024 00:42:06 -0700 Subject: [PATCH 44/75] add mllamaconfig to override some params, simplying the model code (WIP) --- .../vision_language/test_mllama.py | 8 +- vllm/config.py | 10 +- vllm/model_executor/models/mllama.py | 439 +----------------- vllm/transformers_utils/config.py | 12 +- vllm/transformers_utils/configs/mllama.py | 26 ++ 5 files changed, 63 insertions(+), 432 deletions(-) create mode 100644 vllm/transformers_utils/configs/mllama.py diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index d1934f0b15ca..a7138a2529ac 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -35,7 +35,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], model: str): """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_outpu + output_ids, output_str, out_logprobs = vllm_output config = AutoConfig.from_pretrained(model) image_token_id = config.image_token_index @@ -181,6 +181,10 @@ def _run_test( def process(hf_inputs: BatchEncoding): return hf_inputs + from transformers import AutoConfig + from transformers.models.mllama import MllamaConfig as MllamaConfigHf + # use transformer's MllamaConfig for hf_runner and vllm's MllamaConfig for vllm_runner + AutoConfig.register("mllama", MllamaConfigHf, exist_ok=True) with hf_runner(model, dtype=dtype, postprocess_inputs=process, @@ -193,6 +197,8 @@ def process(hf_inputs: BatchEncoding): for prompts, images in inputs ] + from vllm.transformers_utils.configs.mllama import MllamaConfig + AutoConfig.register("mllama", MllamaConfig, exist_ok=True) for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, vllm_outputs_per_image): # TODO: Check whether using original CLIPVisionModel can improve diff --git a/vllm/config.py b/vllm/config.py index 77f7ac5e3c7f..492b0c243554 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -583,9 +583,7 @@ def get_multimodal_config(self) -> "MultiModalConfig": @property def is_encoder_decoder_model(self) -> bool: """Extract the HF encoder/decoder model flag.""" - if self.hf_config.model_type == "mllama": - return True - return getattr(self.hf_config, "is_encoder_decoder", False) + return getattr(self.hf_config, "is_encoder_decoder", False) or ((hasattr(self.hf_config, "text_config") and getattr(self.hf_config.text_config, "is_encoder_decoder", False))) @property def is_embedding_model(self) -> bool: @@ -1629,12 +1627,6 @@ def _get_and_verify_dtype( "of float16 by default. Please specify `dtype` if you " "want to use float16.") torch_dtype = torch.bfloat16 - if config.model_type == "mllama_text_model": - # the config is MllamaConfig.text_config - logger.info( - "Some llama vision models lack default dtype. Hardcode to bfloat16." - ) - torch_dtype = torch.bfloat16 else: # Following the common practice, we use float16 for float32 # models. diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 308368326bcc..42c73e3f8bd3 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -36,6 +36,7 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -48,8 +49,7 @@ from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal -from .llama import LlamaAttention, LlamaMLP -# from vllm.model_executor.layers.layernorm import RMSNorm +from .llama import LlamaAttention, LlamaMLP, LlamaDecoderLayer from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, @@ -62,8 +62,6 @@ logger = init_logger(__name__) -MP_SCALE = 8 -IMAGE_RES = 224 LLAMA_IMAGE_TOKEN_ID = 128256 class MllamaImagePixelInputs(TypedDict): @@ -480,7 +478,6 @@ def forward( self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - output_attentions: bool = None, ): # Self Attention residual = hidden_state @@ -512,86 +509,36 @@ def __init__(self, config: MllamaVisionConfig, num_layers=32, is_gated=False): super().__init__() self.config = config self.layers = nn.ModuleList([MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers)]) - self.gradient_checkpointing = False self.config = config def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: - r""" - Args: - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. - This is useful if you want more control over how to convert `input_ids` indices into associated vectors - than the model's internal embedding lookup matrix. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None for encoder_layer in self.layers: if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - encoder_layer.__call__, - hidden_states, - attention_mask, - output_attentions, - ) - else: - hidden_states = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) - - # SDPA never returns attn weights, so the kwarg isn't used at all - # TODO: fix this - # if output_attentions: - # all_attentions = all_attentions + (layer_outputs[1],) + hidden_states = encoder_layer( + hidden_states, + attention_mask, + ) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) - + return hidden_states, encoder_states class MllamaVisionModel(nn.Module): - config_class = MllamaVisionConfig - base_model_prefix = "vision_encoder" - _no_split_modules = ["MllamaVisionSdpaAttention"] - _supports_sdpa = True - def __init__(self, config: MllamaVisionConfig): super().__init__() self.image_size = config.image_size @@ -820,108 +767,6 @@ def forward( out, _ = self.o_proj(output) return out -# Copied from transformers.models.llama.modeling_llama.rotate_half -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): - """Applies Rotary Position Embedding to the query and key tensors. - - Args: - q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. - cos (`torch.Tensor`): The cosine part of the rotary embedding. - sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`, *optional*): - Deprecated and unused. - unsqueeze_dim (`int`, *optional*, defaults to 1): - The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and - sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note - that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and - k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes - cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have - the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. - Returns: - `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. - """ - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) - - -# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LlamaDecoder->MllamaSelfAttentionDecoder, Llama->MllamaText, LLAMA->MLLAMA_TEXT -class MllamaSelfAttentionDecoderLayer(nn.Module): - def __init__(self, config: MllamaTextConfig, layer_idx: int, cache_config: Optional[CacheConfig] = None): - super().__init__() - self.hidden_size = config.hidden_size - - self.self_attn = LlamaAttention( - config=config, - hidden_size=config.hidden_size, - num_heads=config.num_attention_heads, - num_kv_heads=config.num_key_value_heads, - rope_theta=config.rope_theta, - rope_scaling=config.rope_scaling, - max_position_embeddings=config.max_position_embeddings, - quant_config=None, - bias=False, - cache_config=cache_config) - - self.mlp = LlamaMLP(config.hidden_size, config.intermediate_size, hidden_act=config.hidden_activation) - self.input_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - # Ignore copy - self.layer_idx = layer_idx - - def forward( - self, - hidden_states: torch.Tensor, - positions: Optional[torch.LongTensor], - kv_cache: List[torch.Tensor], - attn_metadata: AttentionMetadata, - ) -> torch.FloatTensor: - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - hidden_states = self.self_attn( - hidden_states=hidden_states, - positions=positions, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) - hidden_states = residual + hidden_states - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - - return hidden_states - class MllamaCrossAttentionDecoderLayer(torch.nn.Module): """Cross-attention transformer block with tanh-gated attention and feedforward.""" @@ -934,15 +779,15 @@ def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: layer_idx=layer_idx, ) - self.input_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) self.mlp = LlamaMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, - hidden_act=config.hidden_activation, + hidden_act=config.hidden_act, ) - self.post_attention_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) def forward( @@ -979,7 +824,7 @@ class MllamaTextModel(nn.Module): base_model_prefix = "model" _no_split_modules = ["MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"] - def __init__(self, config: MllamaTextConfig, cache_config:Optional[CacheConfig]): + def __init__(self, config: MllamaTextConfig, cache_config:Optional[CacheConfig], quant_config: Optional[QuantizationConfig]): super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size @@ -991,12 +836,11 @@ def __init__(self, config: MllamaTextConfig, cache_config:Optional[CacheConfig]) if layer_idx in self.cross_attention_layers: layers.append(MllamaCrossAttentionDecoderLayer(config, layer_idx)) else: - layers.append(MllamaSelfAttentionDecoderLayer(config, layer_idx, cache_config=cache_config)) + layers.append(LlamaDecoderLayer(config, cache_config=cache_config, quant_config=quant_config)) self.layers = nn.ModuleList(layers) - self.norm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # self.rotary_emb = MllamaRotaryEmbedding(config=config) - self.gradient_checkpointing = False def get_input_embeddings(self): return self.embed_tokens @@ -1030,86 +874,19 @@ def forward( kv_cache=kv_caches[idx], attn_metadata=attn_metadata, ) - elif isinstance(decoder_layer, MllamaSelfAttentionDecoderLayer): - hidden_states = decoder_layer( - hidden_states=hidden_states, + elif isinstance(decoder_layer, LlamaDecoderLayer): + hidden_states, residual = decoder_layer( positions=positions, + hidden_states=hidden_states, kv_cache=kv_caches[idx], attn_metadata=attn_metadata, + residual=None, ) + hidden_states = hidden_states + residual else: raise ValueError(f"Unknown decoder layer type {type(decoder_layer)}") hidden_states = self.norm(hidden_states) return hidden_states - - - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - # TODO: we have only SDPA currently and there's a bug when attn-bias is passed. Need to add eager attn and return the line - # self.config._attn_implementation == "sdpa" and - if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - if using_static_cache: - target_length = past_key_values.get_max_length() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - min_dtype=min_dtype, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type == "cuda" - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask class MllamaForCausalLM(nn.Module): @@ -1120,7 +897,7 @@ class MllamaForCausalLM(nn.Module): def __init__(self, config: MllamaTextConfig, cache_config:Optional[CacheConfig], quant_config: Optional[QuantizationConfig]): super().__init__() self.vocab_size = config.vocab_size - self.model = MllamaTextModel(config, cache_config) + self.model = MllamaTextModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, @@ -1153,80 +930,6 @@ def forward( ) return hidden_states - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - use_cache=True, - num_logits_to_keep=None, - **kwargs, - ): - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - if past_key_values is not None: - if inputs_embeds is not None: # Exception 1 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. - position_ids = position_ids.clone(memory_format=torch.contiguous_format) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape - device = model_inputs["inputs_embeds"].device - else: - batch_size, sequence_length = model_inputs["input_ids"].shape - device = model_inputs["input_ids"].device - - dtype = self.lm_head.weight.dtype - min_dtype = torch.finfo(dtype).min - - attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_length(), - dtype=dtype, - device=device, - min_dtype=min_dtype, - cache_position=cache_position, - batch_size=batch_size, - ) - - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep - - model_inputs.update( - { - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - } - ) - return model_inputs - @MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens) @@ -1341,7 +1044,6 @@ def _parse_and_validate_image_input( raise AssertionError("This line should be unreachable.") - def forward( self, @@ -1404,109 +1106,6 @@ def forward( return outputs - def prepare_inputs_for_generation( - self, - input_ids=None, - inputs_embeds=None, - attention_mask=None, - position_ids=None, - pixel_values=None, - aspect_ratio_ids=None, - aspect_ratio_mask=None, - cross_attention_mask=None, - past_key_values=None, - use_cache=False, - cache_position=None, - num_logits_to_keep=None, - **kwargs, - ): - # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens - # Exception 1: when passing input_embeds, input_ids may be missing entries - # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here - if past_key_values is not None: - if inputs_embeds is not None: # Exception 1 - input_ids = input_ids[:, -cache_position.shape[0] :] - elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) - input_ids = input_ids[:, cache_position] - - # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. - position_ids = position_ids.clone(memory_format=torch.contiguous_format) - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and cache_position[0] == 0: - model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} - else: - # The clone here is for the same reason as for `position_ids`. - model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} - - if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: - if model_inputs["inputs_embeds"] is not None: - batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape - device = model_inputs["inputs_embeds"].device - else: - batch_size, sequence_length = model_inputs["input_ids"].shape - device = model_inputs["input_ids"].device - - dtype = self.get_output_embeddings().weight.dtype - min_dtype = torch.finfo(dtype).min - - attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=past_key_values.get_max_length(), - dtype=dtype, - device=device, - min_dtype=min_dtype, - cache_position=cache_position, - batch_size=batch_size, - ) - - if num_logits_to_keep is not None: - model_inputs["num_logits_to_keep"] = num_logits_to_keep - - model_inputs.update( - { - "position_ids": position_ids, - "cache_position": cache_position, - "past_key_values": past_key_values, - "use_cache": use_cache, - "attention_mask": attention_mask, - "cross_attention_mask": cross_attention_mask, - } - ) - - # If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios - # to compute image hidden states, otherwise they are cached within each cross attn layer - if (input_ids == self.config.image_token_index).any(): - model_inputs["pixel_values"] = pixel_values - model_inputs["aspect_ratio_ids"] = aspect_ratio_ids - model_inputs["aspect_ratio_mask"] = aspect_ratio_mask - - return model_inputs - - def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): - cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None) - model_kwargs = super()._update_model_kwargs_for_generation( - outputs=outputs, - model_kwargs=model_kwargs, - is_encoder_decoder=is_encoder_decoder, - **kwargs, - ) - - # add cross-attn mask for new token - if cross_attention_mask_prev is not None: - model_kwargs["cross_attention_mask"] = torch.cat( - [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1 - ) - return model_kwargs def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 3be8791f5ee5..0d111998036a 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -21,7 +21,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, EAGLEConfig, ExaoneConfig, GraniteConfig, InternVLChatConfig, JAISConfig, - MedusaConfig, MLPSpeculatorConfig, + MedusaConfig, MllamaConfig, MLPSpeculatorConfig, MPTConfig, NemotronConfig, RWConfig, UltravoxConfig) @@ -37,6 +37,10 @@ logger = init_logger(__name__) +_CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = { + "mllama": MllamaConfig +} + _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { "chatglm": ChatGLMConfig, "dbrx": DbrxConfig, @@ -54,11 +58,15 @@ # Granite can be removed from here once we have upgraded to # transformers 4.45+ "granite": GraniteConfig, + **_CONFIG_REGISTRY_OVERRIDE_HF } for name, cls in _CONFIG_REGISTRY.items(): with contextlib.suppress(ValueError): - AutoConfig.register(name, cls) + if name in _CONFIG_REGISTRY_OVERRIDE_HF: + AutoConfig.register(name, cls, exist_ok=True) + else: + AutoConfig.register(name, cls) class ConfigFormat(str, enum.Enum): diff --git a/vllm/transformers_utils/configs/mllama.py b/vllm/transformers_utils/configs/mllama.py new file mode 100644 index 000000000000..93b1e2b66f8b --- /dev/null +++ b/vllm/transformers_utils/configs/mllama.py @@ -0,0 +1,26 @@ +from transformers.models.mllama.configuration_mllama import MllamaTextConfig as MllamaTextConfigHf, MllamaConfig as MllamaConfigHf + +class MllamaTextConfig(MllamaTextConfigHf): + ''' + Use this class to override is_encoder_decoder: + - transformers regards mllama as is_encoder_decoder=False + - vllm needs is_encoder_decoder=True to enable cross-attention + ''' + def __init__( + self, + **kwargs, + ): + super().__init__(**kwargs) + self.hidden_act = self.hidden_activation + self.is_encoder_decoder = True + + +class MllamaConfig(MllamaConfigHf): + def __init__( + self, + text_config=None, + **kwargs, + ): + if isinstance(text_config, dict): + text_config = MllamaTextConfig(**text_config) + super().__init__(text_config=text_config, **kwargs) From f6732cf4c87d4d38dcf8be6fcca2123b170ffec8 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 22 Sep 2024 00:42:57 -0700 Subject: [PATCH 45/75] upd --- vllm/transformers_utils/configs/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 8381c5227584..bda59811a438 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -10,6 +10,7 @@ from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.medusa import MedusaConfig +from vllm.transformers_utils.configs.mllama import MllamaConfig from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig @@ -25,10 +26,12 @@ "MedusaConfig", "EAGLEConfig", "ExaoneConfig", + "MllamaConfig", "MLPSpeculatorConfig", "NemotronConfig", "UltravoxConfig", # Granite can be removed from here once we have upgraded to # transformers 4.45+ "GraniteConfig", + "" ] From 228b66bcc5f7a960c8e002aaf479d13d6bc711a8 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 22 Sep 2024 14:04:40 -0700 Subject: [PATCH 46/75] code cleanup --- vllm/model_executor/models/mllama.py | 148 ++------------------------- 1 file changed, 9 insertions(+), 139 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 42c73e3f8bd3..7ab5888de7d8 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -17,28 +17,21 @@ import math from PIL import Image from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, - TypedDict, Union, Callable, Dict, Any, Set) + TypedDict, Union,) import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from transformers.activations import ACT2FN -from transformers.cache_utils import Cache, StaticCache -from transformers.modeling_attn_mask_utils import AttentionMaskConverter -from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast from transformers.models.mllama.configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig from transformers.models.mllama.image_processing_mllama import MllamaImageProcessor from vllm.attention import Attention, AttentionMetadata, AttentionType -from vllm.attention.ops.paged_attn import PagedAttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs -from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -49,13 +42,12 @@ from vllm.logger import init_logger from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal -from .llama import LlamaAttention, LlamaMLP, LlamaDecoderLayer -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, - QKVParallelLinear, +from .llama import LlamaMLP, LlamaDecoderLayer +from .clip import CLIPMLP +from vllm.model_executor.layers.linear import (QKVParallelLinear, RowParallelLinear, ColumnParallelLinear) -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_tensor_model_parallel_world_size import vllm.distributed.parallel_state as ps from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData @@ -159,111 +151,6 @@ def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len:int, mm_counts: Map num_images = mm_counts["image"] return dummy_encoder_seq_data(ctx, num_images), dummy_image(num_images) -# Copied from transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_with_cache_position -def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - min_dtype: float, - cache_position: torch.Tensor, - batch_size: int, -): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to plcae the 4D attention mask on. - min_dtype (`float`): - The minimum value representable with the dtype `dtype`. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -def _prepare_cross_attention_mask( - cross_attention_mask: torch.Tensor, - past_key_values: Cache, - num_vision_tokens: int, - cross_attention_states: torch.Tensor, - cross_attention_layers: List[int], - device: str, - dtype: str, -) -> Tuple[torch.Tensor, torch.Tensor]: - if cross_attention_mask is None: - # should we raise error or prepare a full attn mask with all ones? - return None, None - else: - # reshape so it can be used by attn module - batch_size, text_total_length, *_ = cross_attention_mask.shape - cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3) - cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1) - cross_attention_mask = cross_attention_mask.unsqueeze(1) - - # invert the mask - inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype) - cross_attention_mask = inverted_cross_attn_mask.masked_fill( - inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min - ) - - # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's - # last dimension contains negative infinity values, otherwise it's 1 - negative_inf_value = torch.finfo(dtype).min - full_text_row_masked_out_mask = ( - (cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None] - ) - cross_attention_mask *= full_text_row_masked_out_mask - - # In case we receive a new image but already have previous cross-attention key/values in cache, - # then we need to extend the attention-mask and add previous images' lengths - if ( - past_key_values is not None - and cross_attention_states is not None - and past_key_values.get_seq_length(cross_attention_layers[0]) != 0 - ): - # make all zeros mask for cross-attn-mask from previuos cached hidden_states, all zeros right? - # i.e. extend current cross-attn-mask on image-seq-length dimension to account for past_seen_tokens - past_cross_attn_kv_length = past_key_values.get_seq_length(cross_attention_layers[0]) - past_cross_attn_mask = torch.zeros( - (*cross_attention_mask.shape[:-1], past_cross_attn_kv_length), dtype=dtype, device=device - ) - # concatenate both on image-seq-length dimension - cross_attention_mask = torch.cat([past_cross_attn_mask, cross_attention_mask], dim=-1) - - return cross_attention_mask, full_text_row_masked_out_mask - def _prepare_aspect_ratio_attention_mask( aspect_ratio_mask: torch.Tensor, @@ -330,7 +217,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x - class MllamaPrecomputedAspectRatioEmbedding(nn.Module): def __init__(self, config: MllamaVisionConfig, is_gated: bool = True): super().__init__() @@ -391,22 +277,6 @@ def forward(self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> return hidden_state -# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->MllamaVision -class MllamaVisionMLP(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.activation_fn = ACT2FN[config.hidden_act] - self.fc1 = ColumnParallelLinear(config.hidden_size, config.intermediate_size, bias=True) - self.fc2 = RowParallelLinear(config.intermediate_size, config.hidden_size, bias=True) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states, _ = self.fc1(hidden_states) - hidden_states = self.activation_fn(hidden_states) - hidden_states, _ = self.fc2(hidden_states) - return hidden_states - - class MllamaVisionSdpaAttention(nn.Module): def __init__(self, config: MllamaVisionConfig): super().__init__() @@ -464,7 +334,7 @@ def __init__(self, config, is_gated: bool = False): self.intermediate_size = config.intermediate_size self.self_attn = MllamaVisionSdpaAttention(config) - self.mlp = MllamaVisionMLP(config) + self.mlp = CLIPMLP(config) self.input_layernorm = nn.LayerNorm(self.hidden_size) self.post_attention_layernorm = nn.LayerNorm(self.hidden_size) @@ -726,8 +596,8 @@ def __init__( bias=False, input_is_parallel=True, ) - - self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + # vllm.model_executor.layers.layernorm.RMSNorm will cause precision issue self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.scaling = self.head_dim**-0.5 From f30319c19a049b8de2a27223b866657e2e782ae6 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 22 Sep 2024 14:39:25 -0700 Subject: [PATCH 47/75] remove image processing from input processor --- vllm/model_executor/models/mllama.py | 56 ++++++++++++++-------------- 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 7ab5888de7d8..a2859f16d852 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -26,7 +26,7 @@ from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast from transformers.models.mllama.configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig -from transformers.models.mllama.image_processing_mllama import MllamaImageProcessor +from transformers.models.mllama.image_processing_mllama import get_optimal_tiled_canvas from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs @@ -67,45 +67,53 @@ class MllamaImagePixelInputs(TypedDict): # TODO: support LlamaImageEmbeddingInputs -image_processor = None - -def recursive_sum(x): - if isinstance(x, torch.Tensor): - return x.sum() - if isinstance(x, (list, tuple)): - return sum(recursive_sum(v) for v in x) - if isinstance(x, (int, float)): - return x - return 0 - def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): + # move prompt to encoder_prompt if llm_inputs.get("prompt") is None: llm_inputs["prompt"] = llm_inputs["encoder_prompt"] llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"] + # TODO: remove this hack if 198 in llm_inputs["prompt_token_ids"]: index_198 = llm_inputs["prompt_token_ids"].index(198) if index_198 > 0 and llm_inputs["prompt_token_ids"][index_198 - 1] == LLAMA_IMAGE_TOKEN_ID: llm_inputs["prompt_token_ids"] = llm_inputs["prompt_token_ids"][:index_198] + llm_inputs["prompt_token_ids"][index_198+1:] + + # process multi-modal data + assert "decoder_multi_modal_data" not in llm_inputs, "multi-modal data should be put in encoder message of LLaMA Vision models" multi_modal_data = llm_inputs.get("encoder_multi_modal_data") - hf_config = ctx.model_config.hf_config + if multi_modal_data is None or "image" not in multi_modal_data or multi_modal_data["image"] is None: + # text-only llm_inputs["encoder_prompt"] = "" llm_inputs["encoder_prompt_token_ids"] = [] llm_inputs["encoder_multi_modal_data"] = {} return llm_inputs - global image_processor - if image_processor is None: - image_processor = MllamaImageProcessor.from_pretrained(ctx.model_config.model) - processed_image = image_processor(multi_modal_data["image"]) - llm_inputs["encoder_multi_modal_data"]["image"] = processed_image - num_tiles = recursive_sum(processed_image["num_tiles"]) + + # get num_tiles + if isinstance(multi_modal_data['image'], Image.Image): + multi_modal_data['image'] = [multi_modal_data['image']] + hf_config = ctx.model_config.hf_config + num_tiles = 0 + for image in multi_modal_data["image"]: + width, height = image.size + tile_size = hf_config.vision_config.image_size + canvas_height, canvas_width = get_optimal_tiled_canvas( + image_height=height, + image_width=width, + max_image_tiles=hf_config.vision_config.max_num_tiles, + tile_size=tile_size, + ) + num_tiles_height = canvas_height // tile_size + num_tiles_width = canvas_width // tile_size + num_tiles += num_tiles_height * num_tiles_width + + # set encoder prompt based on num_tiles assert hf_config.vision_config.image_size % 14 == 0, "chunk size should be multiple of 14" token_per_chunk = (hf_config.vision_config.image_size // 14) ** 2 + 1 num_tokens = num_tiles * token_per_chunk llm_inputs["encoder_prompt"] = "<|image|>" * num_tokens llm_inputs["encoder_prompt_token_ids"] = [LLAMA_IMAGE_TOKEN_ID] * num_tokens - assert "decoder_multi_modal_data" not in llm_inputs, "multi-modal data should be put in encoder message of LLaMA Vision" return llm_inputs def get_max_mllama_image_tokens(ctx: InputContext) -> int: @@ -616,7 +624,6 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: - """Input shape: Batch x Time x Channel""" qkv_dec, _ = self.qkv_proj(hidden_states) q, _, _ = qkv_dec.split([self.q_local_size, self.kv_local_size, self.kv_local_size], dim=-1) @@ -710,13 +717,6 @@ def __init__(self, config: MllamaTextConfig, cache_config:Optional[CacheConfig], self.layers = nn.ModuleList(layers) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # self.rotary_emb = MllamaRotaryEmbedding(config=config) - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value def forward( self, From 471e79f519ad9f7c64b3855b71c0d7f37ac971b0 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 22 Sep 2024 15:46:48 -0700 Subject: [PATCH 48/75] fix precision issue of RMSNorm --- vllm/model_executor/models/mllama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index a2859f16d852..4c46d6b317d2 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -604,7 +604,7 @@ def __init__( bias=False, input_is_parallel=True, ) - self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) # vllm.model_executor.layers.layernorm.RMSNorm will cause precision issue self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.scaling = self.head_dim**-0.5 From 2a0cb7e2b014559cde1ab861f09438685125d85c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 22 Sep 2024 15:57:49 -0700 Subject: [PATCH 49/75] only keep usefull vision encoder layer --- vllm/model_executor/models/mllama.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 4c46d6b317d2..bb8b20d748e9 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -383,35 +383,31 @@ class MllamaVisionEncoder(nn.Module): config: MllamaConfig """ - def __init__(self, config: MllamaVisionConfig, num_layers=32, is_gated=False): + def __init__(self, config: MllamaVisionConfig, num_layers=32, is_gated=False, output_hidden_states=None): super().__init__() self.config = config self.layers = nn.ModuleList([MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers)]) - self.config = config + self.output_hidden_states = output_hidden_states or [] def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - encoder_states = () if output_hidden_states else None + encoder_states = () - for encoder_layer in self.layers: - if output_hidden_states: + for i, encoder_layer in enumerate(self.layers): + if i in self.output_hidden_states: encoder_states = encoder_states + (hidden_states,) hidden_states = encoder_layer( hidden_states, attention_mask, ) - if output_hidden_states: + if len(self.layers) - 1 in self.output_hidden_states: encoder_states = encoder_states + (hidden_states,) return hidden_states, encoder_states @@ -448,7 +444,7 @@ def __init__(self, config: MllamaVisionConfig): self.layernorm_post = nn.LayerNorm(self.hidden_size) # encoders - self.transformer = MllamaVisionEncoder(config, config.num_hidden_layers, is_gated=False) + self.transformer = MllamaVisionEncoder(config, config.num_hidden_layers, is_gated=False, output_hidden_states=config.intermediate_layers_indices) self.global_transformer = MllamaVisionEncoder(config, config.num_global_layers, is_gated=True) def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: @@ -508,14 +504,8 @@ def forward( output = self.transformer( hidden_state, attention_mask=attention_mask, - output_hidden_states=True, ) - hidden_state, all_intermediate_hidden_states = output[0], output[1] - intermediate_hidden_states = [ - hidden_state - for idx, hidden_state in enumerate(all_intermediate_hidden_states) - if idx in self.intermediate_layers_indices - ] + hidden_state, intermediate_hidden_states = output[0], output[1] intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1) # apply global encoder From a596997552feb1e9ea511c534f24feceb63ab76c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 22 Sep 2024 17:32:57 -0700 Subject: [PATCH 50/75] format code --- examples/offline_inference_vision_language.py | 1 - examples/openai_vision_api_client.py | 17 +- .../vision_language/test_mllama.py | 50 +- tests/models/test_llamavl.py | 67 --- vllm/config.py | 4 +- vllm/engine/llm_engine.py | 3 +- vllm/inputs/preprocess.py | 18 +- vllm/inputs/registry.py | 14 +- vllm/model_executor/models/__init__.py | 3 +- vllm/model_executor/models/mllama.py | 536 +++++++++++------- vllm/multimodal/base.py | 4 +- vllm/multimodal/image.py | 4 +- vllm/sequence.py | 11 +- vllm/transformers_utils/config.py | 9 +- vllm/transformers_utils/configs/mllama.py | 8 +- vllm/worker/enc_dec_model_runner.py | 27 +- 16 files changed, 439 insertions(+), 337 deletions(-) delete mode 100644 tests/models/test_llamavl.py diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 789bed04b6fb..a2da4cae6bc1 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -12,7 +12,6 @@ from vllm.assets.video import VideoAsset from vllm.utils import FlexibleArgumentParser - # Input image and question image = ImageAsset("cherry_blossom").pil_image.convert("RGB") question = "What is the content of this image?" diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py index d48b2b3886c9..475a8a9dc1bf 100644 --- a/examples/openai_vision_api_client.py +++ b/examples/openai_vision_api_client.py @@ -33,7 +33,8 @@ ## Use image url in the payload chat_completion_from_url = client.chat.completions.create( messages=[{ - "role": "user", + "role": + "user", "content": [ { "type": "text", @@ -46,8 +47,7 @@ }, }, ], - } - ], + }], model=model, max_tokens=64, temperature=0.0, @@ -58,15 +58,15 @@ chat_completion_text_only = client.chat.completions.create( messages=[{ - "role": "user", + "role": + "user", "content": [ { "type": "text", "text": "what is the recipe of mayonnaise in two sentences?" }, ] - } - ], + }], model=model, max_tokens=64, temperature=0.0, @@ -75,8 +75,9 @@ result = chat_completion_text_only.choices[0].message.content print("Text-only output output:", result) -print("remove me: testing done, exitting...") -import sys; sys.exit(0) +print("remove me: testing done, exiting...") +exit(0) + ## Use base64 encoded image in the payload def encode_image_base64_from_url(image_url: str) -> str: diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index a7138a2529ac..44c657bc317d 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -6,12 +6,11 @@ from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, _ImageAssets) -from ...utils import check_logprobs_close from ....utils import multi_gpu_test +from ...utils import check_logprobs_close _LIMIT_IMAGE_PER_PROMPT = 1 @@ -27,8 +26,9 @@ ] models = [ - "nltpt/Llama-3.2-11B-Vision-Instruct", # TODO: Update model path to huggingface model + "nltpt/Llama-3.2-11B-Vision-Instruct", ] +# TODO: Update model path to huggingface model def vllm_to_hf_output(vllm_output: Tuple[List[int], str, @@ -113,11 +113,18 @@ def run_test( ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] elif sizes is not None: inputs_per_image = [( - [prompt if size is not None else text_only_prompts[0] for size in sizes], - [image.resize(size) if size is not None else None for size in sizes], + [ + prompt if size is not None else text_only_prompts[0] + for size in sizes + ], + [ + image.resize(size) if size is not None else None + for size in sizes + ], ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] if len(sizes) == 0: - inputs_per_image.append((text_only_prompts, [None] * len(text_only_prompts))) + inputs_per_image.append( + (text_only_prompts, [None] * len(text_only_prompts))) else: raise ValueError("You must provide either `size_factors` or `sizes`") @@ -153,7 +160,6 @@ def _run_test( Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ - # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it @@ -177,13 +183,14 @@ def _run_test( for prompts, images in inputs ] - def process(hf_inputs: BatchEncoding): return hf_inputs from transformers import AutoConfig from transformers.models.mllama import MllamaConfig as MllamaConfigHf - # use transformer's MllamaConfig for hf_runner and vllm's MllamaConfig for vllm_runner + + # use transformer's MllamaConfig for hf_runner + # and vllm's MllamaConfig for vllm_runner AutoConfig.register("mllama", MllamaConfigHf, exist_ok=True) with hf_runner(model, dtype=dtype, @@ -201,8 +208,6 @@ def process(hf_inputs: BatchEncoding): AutoConfig.register("mllama", MllamaConfig, exist_ok=True) for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, vllm_outputs_per_image): - # TODO: Check whether using original CLIPVisionModel can improve - # consistency against HF check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=[ @@ -225,19 +230,20 @@ def process(hf_inputs: BatchEncoding): # Single-size, batched [(512, 512), (512, 512), (512, 512)], # Multi-size, batched - [(512, 512), (1024, 512), (1536, 512), (2048, 512), - (512, 1024), (1024, 1024), (512, 1536), (512, 2028)], + [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024), + (1024, 1024), (512, 1536), (512, 2028)], # Multi-size, batched, including text only - [(512, 512), (1024, 512), (1536, 512), (2048, 512), - (512, 1024), (1024, 1024), (512, 1536), (512, 2028), None], - # mllama has 8 possible aspect ratios, carefully set the sizes to cover all of them + [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024), + (1024, 1024), (512, 1536), (512, 2028), None], + # mllama has 8 possible aspect ratios, carefully set the sizes + # to cover all of them ], ) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, image_assets, model, sizes, - dtype, max_tokens, num_logprobs) -> None: +def test_models(hf_runner, vllm_runner, image_assets, model, sizes, dtype, + max_tokens, num_logprobs) -> None: run_test( hf_runner, vllm_runner, @@ -256,15 +262,15 @@ def test_models(hf_runner, vllm_runner, image_assets, model, sizes, @pytest.mark.parametrize( "sizes", [ - [(512, 512), (1024, 512), (1536, 512), (2048, 512), - (512, 1024), (1024, 1024), (512, 1536), (512, 2028), None], + [(512, 512), (1024, 512), (1536, 512), (2048, 512), (512, 1024), + (1024, 1024), (512, 1536), (512, 2028), None], ], ) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models_distributed(hf_runner, vllm_runner, image_assets, model, sizes, - dtype, max_tokens, num_logprobs) -> None: + dtype, max_tokens, num_logprobs) -> None: run_test( hf_runner, vllm_runner, @@ -275,4 +281,4 @@ def test_models_distributed(hf_runner, vllm_runner, image_assets, model, sizes, max_tokens=max_tokens, num_logprobs=num_logprobs, tensor_parallel_size=2, - ) \ No newline at end of file + ) diff --git a/tests/models/test_llamavl.py b/tests/models/test_llamavl.py deleted file mode 100644 index fcf6c5c48d8a..000000000000 --- a/tests/models/test_llamavl.py +++ /dev/null @@ -1,67 +0,0 @@ -from transformers import AutoTokenizer - -from vllm import LLM, SamplingParams -from vllm.assets.image import ImageAsset -from vllm.utils import FlexibleArgumentParser - -from functools import partial -from PIL import Image as PIL_Image - - -if __name__ == "__main__": - model_size_map = { - "llama-3.2-11b": "11B", - "llama-3.2-90b": "90B", - } - parser = FlexibleArgumentParser( - description='Demo on using vLLM for offline inference with ' - 'vision language models') - parser.add_argument('--model-type', - '-m', - type=str, - default="llama-3.2-11b", - choices=model_size_map.keys(), - help='Huggingface "model_type".') - - args = parser.parse_args() - - size = model_size_map[args.model_type] - # checkpoint_dir = "/data/zhang-chen/llama/checkpoints" # update checkpoint path here - model_id = "/data/zhang-chen/Llama-3.2-11B-Vision-Early" - llm = LLM(model=model_id, - enforce_eager=True, - limit_mm_per_prompt={"image": 2}, - max_num_seqs=16, - tensor_parallel_size=1, - # load_format="dummy" - ) - - resource_dir = "/home/eecs/zhang-chen/venv/vllm-multimodal/lib/python3.10/site-packages/llama_models/scripts/resources/" - # Input image and question - with open(f"{resource_dir}/dog.jpg", "rb") as f: - image = PIL_Image.open(f).convert("RGB") - with open(f"{resource_dir}/pasta.jpeg", "rb") as f: - image2 = PIL_Image.open(f).convert("RGB") - - inputs = [ - { - "encoder_prompt":{ - "prompt": "", - "multi_modal_data": { - "image": [image] - } - }, - "decoder_prompt": "<|image|><|begin_of_text|>If I had to write a haiku for this one", - }, - { - "encoder_prompt":{ - "prompt": "", - }, - "decoder_prompt": "The color of the sky is blue but sometimes it can also be", - }, - ] - outputs = llm.generate(inputs, SamplingParams(temperature=0, top_p=0.9, max_tokens=512)) - for o in outputs: - generated_text = o.outputs[0].text - print(generated_text) - print("==================================") diff --git a/vllm/config.py b/vllm/config.py index 0a7c8de48ada..df8c8419f4cc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -555,7 +555,9 @@ def get_multimodal_config(self) -> "MultiModalConfig": @property def is_encoder_decoder_model(self) -> bool: """Extract the HF encoder/decoder model flag.""" - return getattr(self.hf_config, "is_encoder_decoder", False) or ((hasattr(self.hf_config, "text_config") and getattr(self.hf_config.text_config, "is_encoder_decoder", False))) + return getattr(self.hf_config, "is_encoder_decoder", False) or ( + (hasattr(self.hf_config, "text_config") and getattr( + self.hf_config.text_config, "is_encoder_decoder", False))) @property def is_embedding_model(self) -> bool: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 678a9417e923..ebaefe8cde06 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1698,7 +1698,8 @@ def is_embedding_model(self): def _validate_model_inputs(self, inputs: Union[LLMInputs, EncoderDecoderLLMInputs]): if self.model_config.is_multimodal_model: - # For encoder-decoder multimodal models, the max_prompt_len restricts the decoder prompt length + # For encoder-decoder multimodal models, the max_prompt_len + # restricts the decoder prompt length prompt_ids = inputs.get("prompt_token_ids") elif self.is_encoder_decoder_model(): prompt_ids = inputs.get("encoder_prompt_token_ids") diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 8c9624303860..22f65ed5a324 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -159,7 +159,7 @@ def _prepare_decoder_input_ids_for_generation( decoder_input_ids = self._get_default_enc_dec_decoder_prompt() if force_bos and (len(decoder_input_ids) == 0 - or decoder_input_ids[0] != decoder_start_token_id): + or decoder_input_ids[0] != decoder_start_token_id): decoder_input_ids = [decoder_start_token_id] + decoder_input_ids return decoder_input_ids @@ -297,12 +297,16 @@ def _build_enc_dec_llm_inputs( decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps if decoder_mm_data is not None: - raise ValueError("Multi-modality decoder inputs of encoder-decoder models are " - "not supported yet") + raise ValueError( + "Multi-modality decoder inputs of encoder-decoder models are " + "not supported yet") - # For Multi-Modal models (e.g., mllama), the text input can be <|image|><|begin_of_text|>hello world. And we should not add another <|begin_of_text|> to the beginning. - decoder_prompt_ids = ( - self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids, force_bos=(encoder_mm_data is None and decoder_mm_data is None))) + # For Multi-Modal models (e.g., mllama), the text input can be + # <|image|><|begin_of_text|>hello world. And we should not add + # another <|begin_of_text|> to the beginning. + decoder_prompt_ids = (self._prepare_decoder_input_ids_for_generation( + decoder_prompt_ids, + force_bos=(encoder_mm_data is None and decoder_mm_data is None))) return EncoderDecoderLLMInputs( prompt_token_ids=decoder_prompt_ids, @@ -310,7 +314,7 @@ def _build_enc_dec_llm_inputs( multi_modal_data=decoder_mm_data, encoder_prompt_token_ids=encoder_prompt_ids, encoder_prompt=encoder_prompt, - encoder_multi_modal_data = encoder_mm_data, + encoder_multi_modal_data=encoder_mm_data, ) def _process_encoder_decoder_prompt( diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index 33e1205278a1..c64a65b89fd3 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -106,8 +106,8 @@ class InputRegistry: def __init__(self) -> None: self._dummy_factories_by_model_type: Dict[Type[nn.Module], DummyDataFactory] = {} - self._dummy_encoder_factories_by_model_type: Dict[Type[nn.Module], - DummyDataFactory] = {} + self._dummy_encoder_factories_by_model_type: Dict[ + Type[nn.Module], DummyDataFactory] = {} self._input_processors_by_model_type: Dict[Type[nn.Module], InputProcessor] = {} @@ -153,7 +153,7 @@ def wrapper(model_cls: N) -> N: return model_cls return wrapper - + def register_dummy_encoder_data(self, factory: DummyDataFactory): """ Register a dummy encoder data factory to a model class @@ -199,7 +199,8 @@ def dummy_data_for_profiling( model_cls, _ = get_model_architecture(model_config) if is_encoder_data: if model_cls in self._dummy_encoder_factories_by_model_type: - dummy_factory = self._dummy_encoder_factories_by_model_type[model_cls] + dummy_factory = self._dummy_encoder_factories_by_model_type[ + model_cls] else: logger.warning( "No dummy encoder data factory registered to %s. " @@ -224,10 +225,9 @@ def dummy_data_for_profiling( if is_encoder_data: logger.warning( "Expected at least %d dummy encoder tokens for profiling, " - "but found %d tokens instead.", - seq_len, len(num_tokens)) + "but found %d tokens instead.", seq_len, len(num_tokens)) else: - assert False, ( + raise AssertionError( f"Expected at least {seq_len} dummy tokens for profiling, " f"but found {len(num_tokens)} tokens instead.") if mm_data is not None: diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 3ae9e1f3b725..3a6fa9e26ff4 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -101,7 +101,8 @@ "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), "UltravoxModel": ("ultravox", "UltravoxModel"), - "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), + "MllamaForConditionalGeneration": ("mllama", + "MllamaForConditionalGeneration"), } _CONDITIONAL_GENERATION_MODELS = { "BartModel": ("bart", "BartForConditionalGeneration"), diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index bb8b20d748e9..00d809d79a26 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -13,60 +13,65 @@ # See the License for the specific language governing permissions and # limitations under the License. """PyTorch Mllama model.""" -from array import array import math -from PIL import Image +from array import array from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, - TypedDict, Union,) + TypedDict, Union) import torch import torch.nn.functional as F import torch.utils.checkpoint +from PIL import Image from torch import nn +from transformers.modeling_outputs import (BaseModelOutput, + CausalLMOutputWithPast) +from transformers.models.mllama.configuration_mllama import ( + MllamaConfig, MllamaTextConfig, MllamaVisionConfig) +from transformers.models.mllama.image_processing_mllama import ( + get_optimal_tiled_canvas) -from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast -from transformers.models.mllama.configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig -from transformers.models.mllama.image_processing_mllama import get_optimal_tiled_canvas +import vllm.distributed.parallel_state as ps from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.config import CacheConfig, MultiModalConfig +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.logger import init_logger -from vllm.sequence import IntermediateTensors -from .interfaces import SupportsMultiModal -from .llama import LlamaMLP, LlamaDecoderLayer -from .clip import CLIPMLP -from vllm.model_executor.layers.linear import (QKVParallelLinear, - RowParallelLinear, - ColumnParallelLinear) -from vllm.distributed import get_tensor_model_parallel_world_size - -import vllm.distributed.parallel_state as ps from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from .clip import CLIPMLP +from .interfaces import SupportsMultiModal +from .llama import LlamaDecoderLayer, LlamaMLP logger = init_logger(__name__) LLAMA_IMAGE_TOKEN_ID = 128256 + class MllamaImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor - """Shape: `(batch_size, max_num_image, max_num_chunk, num_channels, height, width)`""" + """Shape: """ + """(batch_size, max_num_image, max_num_chunk, num_channel, height, width)""" aspect_ratio_ids: torch.Tensor """Shape: `(batch_size, max_num_image)`""" aspect_ratio_mask: torch.Tensor """Shape: `(batch_size, max_num_image, max_num_tiles)`""" + # TODO: support LlamaImageEmbeddingInputs + def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): # move prompt to encoder_prompt if llm_inputs.get("prompt") is None: @@ -75,14 +80,19 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): # TODO: remove this hack if 198 in llm_inputs["prompt_token_ids"]: index_198 = llm_inputs["prompt_token_ids"].index(198) - if index_198 > 0 and llm_inputs["prompt_token_ids"][index_198 - 1] == LLAMA_IMAGE_TOKEN_ID: - llm_inputs["prompt_token_ids"] = llm_inputs["prompt_token_ids"][:index_198] + llm_inputs["prompt_token_ids"][index_198+1:] - + if index_198 > 0 and llm_inputs["prompt_token_ids"][ + index_198 - 1] == LLAMA_IMAGE_TOKEN_ID: + llm_inputs["prompt_token_ids"] = llm_inputs[ + "prompt_token_ids"][:index_198] + llm_inputs[ + "prompt_token_ids"][index_198 + 1:] + # process multi-modal data - assert "decoder_multi_modal_data" not in llm_inputs, "multi-modal data should be put in encoder message of LLaMA Vision models" + assert "decoder_multi_modal_data" not in llm_inputs, \ + "multi-modal data should be put in encoder message of mllama" multi_modal_data = llm_inputs.get("encoder_multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data or multi_modal_data["image"] is None: + if multi_modal_data is None or "image" not in multi_modal_data \ + or multi_modal_data["image"] is None: # text-only llm_inputs["encoder_prompt"] = "" llm_inputs["encoder_prompt_token_ids"] = [] @@ -108,54 +118,54 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): num_tiles += num_tiles_height * num_tiles_width # set encoder prompt based on num_tiles - assert hf_config.vision_config.image_size % 14 == 0, "chunk size should be multiple of 14" - token_per_chunk = (hf_config.vision_config.image_size // 14) ** 2 + 1 + assert hf_config.vision_config.image_size % 14 == 0, \ + "chunk size should be multiple of 14" + token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 num_tokens = num_tiles * token_per_chunk llm_inputs["encoder_prompt"] = "<|image|>" * num_tokens - llm_inputs["encoder_prompt_token_ids"] = [LLAMA_IMAGE_TOKEN_ID] * num_tokens + llm_inputs["encoder_prompt_token_ids"] = [LLAMA_IMAGE_TOKEN_ID + ] * num_tokens return llm_inputs + def get_max_mllama_image_tokens(ctx: InputContext) -> int: hf_config = ctx.model_config.hf_config - token_per_chunk = (hf_config.vision_config.image_size // 14) ** 2 + 1 - return hf_config.vision_config.max_num_tiles * token_per_chunk + token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 + return hf_config.vision_config.max_num_tiles * token_per_chunk -def dummy_decoder_seq_data( - seq_len: int, - num_images: int -): +def dummy_decoder_seq_data(seq_len: int, num_images: int): # <|image|> * num_images + 0 * (seq_len - num_images) - assert seq_len >= num_images, "seq_len should be greater than or equal to num_images" + assert seq_len >= num_images, \ + "seq_len should be greater than or equal to num_images" token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [LLAMA_IMAGE_TOKEN_ID]) * num_images - token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, - [0]) * (seq_len - num_images) + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - num_images) return SequenceData(token_ids) -def dummy_encoder_seq_data( - ctx: InputContext, - num_images: int -): +def dummy_encoder_seq_data(ctx: InputContext, num_images: int): num_tokens = get_max_mllama_image_tokens(ctx) * num_images - token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [LLAMA_IMAGE_TOKEN_ID]) * num_tokens + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [LLAMA_IMAGE_TOKEN_ID]) * num_tokens return SequenceData(token_ids) -def dummy_image( - num_images: int, -): +def dummy_image(num_images: int, ): width = height = 1024 image = Image.new("RGB", (width, height), color=0) return {"image": image if num_images == 1 else [image] * num_images} -def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): + +def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): num_images = mm_counts["image"] return dummy_decoder_seq_data(seq_len, num_images), None -def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len:int, mm_counts: Mapping[str, int]): + +def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): num_images = mm_counts["image"] return dummy_encoder_seq_data(ctx, num_images), dummy_image(num_images) @@ -168,7 +178,8 @@ def _prepare_aspect_ratio_attention_mask( ) -> torch.Tensor: # Expand aspect ratio mask to target_length batch_size, max_num_tiles = aspect_ratio_mask.shape - attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype) + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, + 1).to(dtype) attention_mask = attention_mask.repeat(1, 1, target_length, 1) # Mask padding patches @@ -179,9 +190,11 @@ def _prepare_aspect_ratio_attention_mask( attention_mask = 1 - attention_mask # Reshape to 2D and create 4D attention mask - # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length) - attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1) - attention_mask = attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min + # (batch_size, 1, max_num_tiles*target_length, max_num_tiles*target_length) + attention_mask = attention_mask.reshape(batch_size, + max_num_tiles * target_length, 1) + attention_mask = attention_mask @ attention_mask.transpose( + -1, -2) * torch.finfo(dtype).min attention_mask = attention_mask.unsqueeze(1) return attention_mask @@ -226,6 +239,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MllamaPrecomputedAspectRatioEmbedding(nn.Module): + def __init__(self, config: MllamaVisionConfig, is_gated: bool = True): super().__init__() self.max_num_tiles = config.max_num_tiles @@ -233,13 +247,16 @@ def __init__(self, config: MllamaVisionConfig, is_gated: bool = True): self.max_aspect_ratio_id = config.max_aspect_ratio_id self.is_gated = is_gated - self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1, self.max_num_tiles * self.hidden_size) + self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.hidden_size) if is_gated: self.gate = nn.Parameter(torch.zeros(1)) - def forward(self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_state: torch.Tensor, + aspect_ratio_ids: torch.Tensor) -> torch.Tensor: embeddings = self.embedding(aspect_ratio_ids) - embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, + self.hidden_size) if self.is_gated: embeddings = embeddings * self.gate.tanh() @@ -249,11 +266,12 @@ def forward(self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> class MllamaPrecomputedPositionEmbedding(nn.Module): + def __init__(self, config: MllamaVisionConfig): super().__init__() self.max_num_tiles = config.max_num_tiles self.max_aspect_ratio_id = config.max_aspect_ratio_id - self.num_patches = (config.image_size // config.patch_size) ** 2 + 1 + self.num_patches = (config.image_size // config.patch_size)**2 + 1 self.hidden_size = config.hidden_size self.scale = config.hidden_size**-0.5 @@ -265,27 +283,30 @@ def __init__(self, config: MllamaVisionConfig): # tile position embedding self.tile_embedding = nn.Embedding( - self.max_aspect_ratio_id + 1, self.max_num_tiles * self.num_patches * self.hidden_size - ) + self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.num_patches * self.hidden_size) - def forward(self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + def forward(self, hidden_state: torch.Tensor, + aspect_ratio_ids: torch.Tensor) -> torch.Tensor: # position embeddings gated_position_embedding = (1 - self.gate.tanh()) * self.embedding - hidden_state = hidden_state + gated_position_embedding.view(1, 1, self.num_patches, self.hidden_size) + hidden_state = hidden_state + gated_position_embedding.view( + 1, 1, self.num_patches, self.hidden_size) # precomputed tile position embeddings tile_position_embedding = self.tile_embedding(aspect_ratio_ids) batch_size = hidden_state.shape[0] tile_position_embedding = tile_position_embedding.reshape( - batch_size, self.max_num_tiles, self.num_patches, self.hidden_size - ) - gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size) + gated_tile_position_embedding = self.gate.tanh( + ) * tile_position_embedding hidden_state = hidden_state + gated_tile_position_embedding return hidden_state class MllamaVisionSdpaAttention(nn.Module): + def __init__(self, config: MllamaVisionConfig): super().__init__() @@ -317,22 +338,29 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_state) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = q.view(q.shape[0], q.shape[1], self.num_local_heads, self.head_dim).transpose(1, 2) - k = k.view(k.shape[0], k.shape[1], self.num_local_heads, self.head_dim).transpose(1, 2) - v = v.view(v.shape[0], v.shape[1], self.num_local_heads, self.head_dim).transpose(1, 2) + q = q.view(q.shape[0], q.shape[1], self.num_local_heads, + self.head_dim).transpose(1, 2) + k = k.view(k.shape[0], k.shape[1], self.num_local_heads, + self.head_dim).transpose(1, 2) + v = v.view(v.shape[0], v.shape[1], self.num_local_heads, + self.head_dim).transpose(1, 2) # TODO: remove padding in image encoder - attn_output = F.scaled_dot_product_attention( - q, k, v, attn_mask=attention_mask, dropout_p=0.0 - ) + attn_output = F.scaled_dot_product_attention(q, + k, + v, + attn_mask=attention_mask, + dropout_p=0.0) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(attn_output.shape[0], attn_output.shape[1], -1) + attn_output = attn_output.reshape(attn_output.shape[0], + attn_output.shape[1], -1) output, _ = self.o_proj(attn_output) return output class MllamaVisionEncoderLayer(nn.Module): + def __init__(self, config, is_gated: bool = False): super().__init__() @@ -360,7 +388,8 @@ def forward( # Self Attention residual = hidden_state hidden_state = self.input_layernorm(hidden_state) - hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask) + hidden_state = self.self_attn(hidden_state, + attention_mask=attention_mask) gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() hidden_state = residual + gate_attn * hidden_state @@ -376,43 +405,49 @@ def forward( class MllamaVisionEncoder(nn.Module): """ - Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a - [`MllamaEncoderLayer`]. + Transformer encoder consisting of `config.num_hidden_layers` self attention + layers. Each layer is a [`MllamaEncoderLayer`]. Args: config: MllamaConfig """ - def __init__(self, config: MllamaVisionConfig, num_layers=32, is_gated=False, output_hidden_states=None): + def __init__(self, + config: MllamaVisionConfig, + num_layers=32, + is_gated=False, + output_hidden_states=None): super().__init__() self.config = config - self.layers = nn.ModuleList([MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers)]) + self.layers = nn.ModuleList([ + MllamaVisionEncoderLayer(config, is_gated) + for _ in range(num_layers) + ]) self.output_hidden_states = output_hidden_states or [] def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - encoder_states = () for i, encoder_layer in enumerate(self.layers): if i in self.output_hidden_states: - encoder_states = encoder_states + (hidden_states,) + encoder_states = encoder_states + (hidden_states, ) hidden_states = encoder_layer( hidden_states, attention_mask, ) if len(self.layers) - 1 in self.output_hidden_states: - encoder_states = encoder_states + (hidden_states,) + encoder_states = encoder_states + (hidden_states, ) return hidden_states, encoder_states + class MllamaVisionModel(nn.Module): + def __init__(self, config: MllamaVisionConfig): super().__init__() self.image_size = config.image_size @@ -422,7 +457,7 @@ def __init__(self, config: MllamaVisionConfig): self.in_channels = config.in_channels self.intermediate_layers_indices = config.intermediate_layers_indices - self.num_patches = (self.image_size // self.patch_size) ** 2 + 1 + self.num_patches = (self.image_size // self.patch_size)**2 + 1 self.scale = config.hidden_size**-0.5 self.patch_embedding = ColumnParallelConv2dPatch( @@ -432,53 +467,75 @@ def __init__(self, config: MllamaVisionConfig): stride=self.patch_size, bias=False, ) - - self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size)) - self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(config) - self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) - self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) + self.class_embedding = nn.Parameter(self.scale * + torch.randn(self.hidden_size)) + self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( + config) + + self.pre_tile_positional_embedding = \ + MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) + self.post_tile_positional_embedding = \ + MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) # layer norms self.layernorm_pre = nn.LayerNorm(self.hidden_size) self.layernorm_post = nn.LayerNorm(self.hidden_size) # encoders - self.transformer = MllamaVisionEncoder(config, config.num_hidden_layers, is_gated=False, output_hidden_states=config.intermediate_layers_indices) - self.global_transformer = MllamaVisionEncoder(config, config.num_global_layers, is_gated=True) - - def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor: + self.transformer = MllamaVisionEncoder( + config, + config.num_hidden_layers, + is_gated=False, + output_hidden_states=config.intermediate_layers_indices) + self.global_transformer = MllamaVisionEncoder(config, + config.num_global_layers, + is_gated=True) + + def apply_class_embedding(self, + hidden_state: torch.Tensor) -> torch.Tensor: batch_size, _, hidden_size = hidden_state.shape - class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size) + class_embedding = self.class_embedding.expand(batch_size, 1, + hidden_size) hidden_state = torch.cat([class_embedding, hidden_state], dim=1) return hidden_state - def forward( - self, pixel_values: torch.Tensor, aspect_ratio_ids: torch.Tensor, attention_mask: torch.Tensor - ) -> torch.Tensor: - batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape + def forward(self, pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + attention_mask: torch.Tensor) -> torch.Tensor: + batch_size, num_concurrent_media, num_tiles, num_channels, \ + height, width = pixel_values.shape - pixel_values = pixel_values.reshape(batch_size * num_concurrent_media * num_tiles, num_channels, height, width) - aspect_ratio_ids = aspect_ratio_ids.reshape(batch_size * num_concurrent_media, -1) + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, + height, width) + aspect_ratio_ids = aspect_ratio_ids.reshape( + batch_size * num_concurrent_media, -1) # patch embedding - patch_embeds = self.patch_embedding(pixel_values.to(self.layernorm_pre.weight.dtype)) + patch_embeds = self.patch_embedding( + pixel_values.to(self.layernorm_pre.weight.dtype)) hidden_state = patch_embeds hidden_state = ps.get_tp_group().all_gather(hidden_state) # tile embeddings _, num_patches, dim = hidden_state.shape - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, -1, dim) - hidden_state = self.pre_tile_positional_embedding(hidden_state, aspect_ratio_ids) + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, -1, dim) + hidden_state = self.pre_tile_positional_embedding( + hidden_state, aspect_ratio_ids) # apply cls token - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media * num_tiles, num_patches, dim) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim) hidden_state = self.apply_class_embedding(hidden_state) num_patches += 1 # apply position embeddings - hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, num_tiles, num_patches, dim) - hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids) + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, num_patches, dim) + hidden_state = self.gated_positional_embedding(hidden_state, + aspect_ratio_ids) # apply encoder hidden_state = self.layernorm_pre(hidden_state) @@ -486,13 +543,16 @@ def forward( # Compute the number of tokens to pad num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 # Compute padding tuple for pad function - padding = (0, 0, 0, num_padding_patches) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + padding = ( + 0, 0, 0, num_padding_patches + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) # Pad the tensor hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) slice_index = -num_padding_patches if num_padding_patches > 0 else None if attention_mask is not None: - attention_mask = attention_mask.reshape(batch_size * num_concurrent_media, -1) + attention_mask = attention_mask.reshape( + batch_size * num_concurrent_media, -1) attention_mask = _prepare_aspect_ratio_attention_mask( aspect_ratio_mask=attention_mask, num_patches=self.num_patches, @@ -500,44 +560,52 @@ def forward( dtype=self.layernorm_pre.weight.dtype, ) - hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, + dim) output = self.transformer( hidden_state, attention_mask=attention_mask, ) hidden_state, intermediate_hidden_states = output[0], output[1] - intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1) + intermediate_hidden_states = torch.stack(intermediate_hidden_states, + dim=-1) # apply global encoder hidden_state = self.layernorm_post(hidden_state) + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim) + hidden_state = self.post_tile_positional_embedding( + hidden_state, aspect_ratio_ids) hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim - ) - hidden_state = self.post_tile_positional_embedding(hidden_state, aspect_ratio_ids) - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, num_tiles * (num_patches + num_padding_patches), dim - ) - hidden_state = self.global_transformer(hidden_state, attention_mask=attention_mask)[0] - hidden_state = hidden_state.reshape( - batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, dim - ) + batch_size * num_concurrent_media, + num_tiles * (num_patches + num_padding_patches), dim) + hidden_state = self.global_transformer( + hidden_state, attention_mask=attention_mask)[0] + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim) hidden_state = hidden_state[:, :, :slice_index] # adding intermediate layer outputs - hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, num_tiles, num_patches, dim) + hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, + num_tiles, num_patches, dim) intermediate_hidden_states = intermediate_hidden_states.reshape( - batch_size * num_concurrent_media, num_tiles, num_patches + num_padding_patches, -1 - ) - intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index] + batch_size * num_concurrent_media, num_tiles, + num_patches + num_padding_patches, -1) + intermediate_hidden_states = intermediate_hidden_states[:, :, : + slice_index] intermediate_hidden_states = intermediate_hidden_states.reshape( - batch_size, num_concurrent_media, num_tiles, num_patches, -1 - ) - hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1) + batch_size, num_concurrent_media, num_tiles, num_patches, -1) + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], + dim=-1) return hidden_state -# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MllamaText class MllamaTextRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): """ MllamaTextRMSNorm is equivalent to T5LayerNorm @@ -550,7 +618,8 @@ def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) def extra_repr(self): @@ -571,7 +640,8 @@ def __init__( self.num_heads = self.config.num_attention_heads self.num_local_heads = self.num_heads // self.model_parallel_size self.num_key_value_heads = self.config.num_key_value_heads - self.num_local_key_value_heads = self.num_key_value_heads // self.model_parallel_size + self.num_local_key_value_heads = \ + self.num_key_value_heads // self.model_parallel_size self.dropout = config.dropout self.hidden_size = config.hidden_size self.head_dim = config.hidden_size // self.num_heads @@ -580,7 +650,7 @@ def __init__( self.q_local_size = self.num_local_heads * self.head_dim self.kv_local_size = self.num_local_key_value_heads * self.head_dim - # TODO(heheda12345): change to Q/KV seperate linear after #7448 is merged + # TODO: change to Q/KV separate linear after #7448 is merged self.qkv_proj = QKVParallelLinear( self.hidden_size, self.head_dim, @@ -594,8 +664,9 @@ def __init__( bias=False, input_is_parallel=True, ) + # vllm.model_executor.layers.layernorm.RMSNorm has precision issue, + # use huggingface's instead self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) - # vllm.model_executor.layers.layernorm.RMSNorm will cause precision issue self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.scaling = self.head_dim**-0.5 @@ -615,28 +686,36 @@ def forward( attn_metadata: AttentionMetadata, ) -> torch.Tensor: qkv_dec, _ = self.qkv_proj(hidden_states) - q, _, _ = qkv_dec.split([self.q_local_size, self.kv_local_size, self.kv_local_size], - dim=-1) + q, _, _ = qkv_dec.split( + [self.q_local_size, self.kv_local_size, self.kv_local_size], + dim=-1) if cross_attention_states is None: k = None v = None else: qkv_enc, _ = self.qkv_proj(cross_attention_states) - _, k, v = qkv_enc.split([self.q_local_size, self.kv_local_size, self.kv_local_size], - dim=-1) + _, k, v = qkv_enc.split( + [self.q_local_size, self.kv_local_size, self.kv_local_size], + dim=-1) k = k.view(-1, self.num_local_key_value_heads, self.head_dim) v = v.view(-1, self.num_local_key_value_heads, self.head_dim) k = self.k_norm(k) q = q.view(-1, self.num_local_heads, self.head_dim) q = self.q_norm(q) - output = self.attn(q, k, v, kv_cache, attn_metadata, attn_type=AttentionType.ENCODER_DECODER) + output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.ENCODER_DECODER) out, _ = self.o_proj(output) return out class MllamaCrossAttentionDecoderLayer(torch.nn.Module): - """Cross-attention transformer block with tanh-gated attention and feedforward.""" + """Cross-attention transformer block with tanh-gated attention + and feedforward.""" def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: super().__init__() @@ -646,7 +725,8 @@ def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: layer_idx=layer_idx, ) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) self.mlp = LlamaMLP( @@ -654,7 +734,8 @@ def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, ) - self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) def forward( @@ -677,33 +758,45 @@ def forward( attn_metadata=attn_metadata, ) hidden_states = full_text_row_masked_out_mask * hidden_states - hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states + hidden_states = residual + self.cross_attn_attn_gate.tanh( + ) * hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = full_text_row_masked_out_mask * hidden_states - hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states + hidden_states = residual + self.cross_attn_mlp_gate.tanh( + ) * hidden_states return hidden_states + class MllamaTextModel(nn.Module): config_class = MllamaTextConfig base_model_prefix = "model" - _no_split_modules = ["MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"] + _no_split_modules = [ + "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer" + ] - def __init__(self, config: MllamaTextConfig, cache_config:Optional[CacheConfig], quant_config: Optional[QuantizationConfig]): + def __init__(self, config: MllamaTextConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig]): super().__init__() self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, config.hidden_size) + self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, + config.hidden_size) self.cross_attention_layers = config.cross_attention_layers layers = [] for layer_idx in range(config.num_hidden_layers): if layer_idx in self.cross_attention_layers: - layers.append(MllamaCrossAttentionDecoderLayer(config, layer_idx)) + layers.append( + MllamaCrossAttentionDecoderLayer(config, layer_idx)) else: - layers.append(LlamaDecoderLayer(config, cache_config=cache_config, quant_config=quant_config)) + layers.append( + LlamaDecoderLayer(config, + cache_config=cache_config, + quant_config=quant_config)) self.layers = nn.ModuleList(layers) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -714,7 +807,8 @@ def forward( positions: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor], - full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, + torch.Tensor]], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, skip_cross_attention: bool, @@ -729,8 +823,8 @@ def forward( hidden_states=hidden_states, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - # xattn_cache=xattn_caches[xattn_layer_idx] if xattn_caches is not None else None, + full_text_row_masked_out_mask= + full_text_row_masked_out_mask, kv_cache=kv_caches[idx], attn_metadata=attn_metadata, ) @@ -744,7 +838,8 @@ def forward( ) hidden_states = hidden_states + residual else: - raise ValueError(f"Unknown decoder layer type {type(decoder_layer)}") + raise ValueError( + f"Unknown decoder layer type {type(decoder_layer)}") hidden_states = self.norm(hidden_states) return hidden_states @@ -752,9 +847,13 @@ def forward( class MllamaForCausalLM(nn.Module): config_class = MllamaTextConfig base_model_prefix = "language_model" - _no_split_modules = ["MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer"] + _no_split_modules = [ + "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer" + ] - def __init__(self, config: MllamaTextConfig, cache_config:Optional[CacheConfig], quant_config: Optional[QuantizationConfig]): + def __init__(self, config: MllamaTextConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig]): super().__init__() self.vocab_size = config.vocab_size self.model = MllamaTextModel(config, cache_config, quant_config) @@ -766,14 +865,14 @@ def __init__(self, config: MllamaTextConfig, cache_config:Optional[CacheConfig], quant_config=quant_config, ) - def forward( self, input_ids: torch.LongTensor, positions: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor], - full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, + torch.Tensor]], kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, skip_cross_attention: bool, @@ -797,7 +896,9 @@ def forward( @INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama) @INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): - def __init__(self, config: MllamaConfig, + + def __init__(self, + config: MllamaConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): @@ -806,12 +907,11 @@ def __init__(self, config: MllamaConfig, self.hidden_size = config.text_config.hidden_size self.max_num_tiles = config.vision_config.max_num_tiles self.vision_output_dim = config.vision_config.vision_output_dim - self.pad_token_id = config.pad_token_id if config.pad_token_id is not None else -1 + self.pad_token_id = \ + config.pad_token_id if config.pad_token_id is not None else -1 self.image_size = config.vision_config.image_size - self.vision_model = MllamaVisionModel( - config.vision_config, - ) + self.vision_model = MllamaVisionModel(config.vision_config, ) self.language_model = MllamaForCausalLM( config.text_config, cache_config=cache_config, @@ -822,19 +922,19 @@ def __init__(self, config: MllamaConfig, config.text_config.hidden_size, bias=True, ) - self.logits_processor = LogitsProcessor(config.output_hidden_states, config.text_config.vocab_size) + self.logits_processor = LogitsProcessor(config.output_hidden_states, + config.text_config.vocab_size) self.sampler = Sampler() - def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.language_model.lm_head, hidden_states, - sampling_metadata) + logits = self.logits_processor(self.language_model.lm_head, + hidden_states, sampling_metadata) return logits - + def sample( self, logits: torch.Tensor, @@ -842,23 +942,39 @@ def sample( ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - - def _parse_and_validate_image_input( - self, **kwargs: object): - # tensor with the same shape will be batched together by MultiModalInputs.batch, so pixel_values here can be: - # - List[List[torch.Tensor]]: with shape (num_tiles, 3, image_res, image_res) - # - List[torch.Tensor]: with shape (num_image_in_batch, num_tiles, 3, image_res, image_res) - # - torch.Tensor: with shape (bs, num_image_in_batch, num_tiles, 3, image_res, image_res) - pixel_values: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = kwargs.pop("pixel_values", None) - image_embeds: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = kwargs.pop("image_embeds", None) - aspect_ratio_ids: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = kwargs.pop("aspect_ratio_ids", None) - aspect_ratio_mask: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = kwargs.pop("aspect_ratio_mask", None) + + def _parse_and_validate_image_input(self, **kwargs: object): + # tensor with the same shape will be batched together by + # MultiModalInputs.batch, so pixel_values here can be: + # - List[List[torch.Tensor]]: + # with shape (num_tiles, 3, image_res, image_res) + # - List[torch.Tensor]: + # with shape (num_image, num_tiles, 3, image_res, image_res) + # - torch.Tensor: + # with shape (bs, num_image, num_tiles, 3, image_res, image_res) + pixel_values: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "pixel_values", None) + image_embeds: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "image_embeds", None) + aspect_ratio_ids: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "aspect_ratio_ids", None) + aspect_ratio_mask: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "aspect_ratio_mask", None) if pixel_values is None and image_embeds is None: return None - + if pixel_values is not None and image_embeds is not None: - raise ValueError("Both pixel values and image embeds are provided.") + raise ValueError( + "Both pixel values and image embeds are provided.") if pixel_values is not None: assert aspect_ratio_ids is not None @@ -866,7 +982,8 @@ def _parse_and_validate_image_input( max_num_images = max([len(x[0]) for x in pixel_values]) if max_num_images == 0: raise ValueError("No images provided.") - max_num_tiles = max(max([len(x) for x in y[0]]) for y in pixel_values) + max_num_tiles = max( + max([len(x) for x in y[0]]) for y in pixel_values) device = self.multi_modal_projector.weight.device bsz = len(pixel_values) out_num_tiles = [] @@ -880,8 +997,15 @@ def _parse_and_validate_image_input( dtype=torch.float32, device=device, ) - out_ar_ids = torch.ones(bsz, max_num_images, dtype=torch.int64, device=device) - out_ar_mask = torch.zeros(bsz, max_num_images, max_num_tiles, dtype=torch.int64, device=device) + out_ar_ids = torch.ones(bsz, + max_num_images, + dtype=torch.int64, + device=device) + out_ar_mask = torch.zeros(bsz, + max_num_images, + max_num_tiles, + dtype=torch.int64, + device=device) for b in range(len(pixel_values)): _num_tiles = [] for i in range(len(pixel_values[b][0])): @@ -904,24 +1028,23 @@ def _parse_and_validate_image_input( raise AssertionError("This line should be unreachable.") - def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ) -> Union[Tuple, CausalLMOutputWithPast]: - if attn_metadata.num_prefill_tokens > 0 and attn_metadata.num_decode_tokens > 0: + if attn_metadata.num_prefill_tokens > 0 and \ + attn_metadata.num_decode_tokens > 0: raise ValueError("Chunk prefill not supported") image_inputs = self._parse_and_validate_image_input(**kwargs) if image_inputs is None: cross_attention_mask = None - full_text_row_masked_out_mask = (attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to(input_ids.device) - xattn_caches = None - vision_tokens = None + full_text_row_masked_out_mask = ( + attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to( + input_ids.device) cross_attention_states = None skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0 else: @@ -929,28 +1052,43 @@ def forward( pixel_values = image_inputs['data'] aspect_ratio_ids = image_inputs['aspect_ratio_ids'] aspect_ratio_mask = image_inputs['aspect_ratio_mask'] - cross_attention_states = self.vision_model(pixel_values, aspect_ratio_ids, aspect_ratio_mask) - cross_attention_states = self.multi_modal_projector(cross_attention_states) + cross_attention_states = self.vision_model(pixel_values, + aspect_ratio_ids, + aspect_ratio_mask) + cross_attention_states = self.multi_modal_projector( + cross_attention_states) bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape) - cross_attention_states = cross_attention_states.view(bsz, -1, image_token_dim) - - cross_attention_states_flat = torch.zeros(sum(attn_metadata.encoder_seq_lens), image_token_dim, device=cross_attention_states.device, dtype=cross_attention_states.dtype) + cross_attention_states = cross_attention_states.view( + bsz, -1, image_token_dim) + + cross_attention_states_flat = torch.zeros( + sum(attn_metadata.encoder_seq_lens), + image_token_dim, + device=cross_attention_states.device, + dtype=cross_attention_states.dtype) start_pos = 0 - for seq_len, vision_token_in_batch in zip(attn_metadata.encoder_seq_lens, cross_attention_states): + for seq_len, vision_token_in_batch in zip( + attn_metadata.encoder_seq_lens, cross_attention_states): end_pos = start_pos + seq_len - cross_attention_states_flat[start_pos:end_pos] = vision_token_in_batch[:seq_len] + cross_attention_states_flat[ + start_pos:end_pos] = vision_token_in_batch[:seq_len] start_pos = end_pos cross_attention_states = cross_attention_states_flat - cross_attention_mask = None # TODO + cross_attention_mask = None # TODO - full_text_row_masked_out_mask = torch.ones((attn_metadata.num_prefill_tokens, 1), dtype=torch.bool) + full_text_row_masked_out_mask = torch.ones( + (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool) start_pos = 0 - for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens_tensor.cpu(), attn_metadata.encoder_seq_lens): + for seq_len, encoder_seq_len in zip( + attn_metadata.seq_lens_tensor.cpu(), + attn_metadata.encoder_seq_lens): if encoder_seq_len == 0: - full_text_row_masked_out_mask[start_pos:start_pos+seq_len] = False + full_text_row_masked_out_mask[start_pos:start_pos + + seq_len] = False start_pos += seq_len - full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(cross_attention_states.device) + full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( + cross_attention_states.device) skip_cross_attention = False outputs = self.language_model( @@ -966,7 +1104,6 @@ def forward( return outputs - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -980,7 +1117,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): updated_params = set() for name, loaded_weight in weights: if 'patch_embedding.weight' in name: - name = name.replace('patch_embedding.weight', 'patch_embedding._linear.weight') + name = name.replace('patch_embedding.weight', + 'patch_embedding._linear.weight') loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 499380292b79..e8589525a558 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -52,10 +52,10 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: """ if isinstance(nested_tensors, torch.Tensor): return nested_tensors - + if isinstance(nested_tensors, np.ndarray): return torch.from_numpy(nested_tensors) - + if isinstance(nested_tensors, (int, float)): return torch.tensor(nested_tensors) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 5b10e199c562..9969336c61d0 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -2,13 +2,13 @@ import torch from PIL import Image +from transformers.image_processing_base import BatchFeature from vllm.config import ModelConfig from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.transformers_utils.image_processor import get_image_processor from vllm.utils import is_list_of -from transformers.image_processing_base import BatchFeature from .base import MultiModalData, MultiModalInputs, MultiModalPlugin @@ -34,7 +34,7 @@ def _default_input_mapper( data: MultiModalData[object], ) -> MultiModalInputs: model_config = ctx.model_config - + # Processed by input processor if isinstance(data, BatchFeature): return MultiModalInputs(data.data) diff --git a/vllm/sequence.py b/vllm/sequence.py index 59795a354350..8d486395fe5d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -13,6 +13,7 @@ import msgspec import torch +from vllm.inputs import EncoderDecoderLLMInputs, LLMInputs from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams @@ -21,7 +22,6 @@ from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics if TYPE_CHECKING: - from vllm.inputs import LLMInputs from vllm.multimodal.base import MultiModalDataDict VLLM_TOKEN_ID_ARRAY_TYPE = "l" @@ -471,8 +471,13 @@ def prompt_token_ids(self) -> List[int]: def multi_modal_data(self) -> "MultiModalDataDict": if self.inputs.get("multi_modal_data") and self.inputs.get( "encoder_multi_modal_data"): - raise ValueError("Multi-modal data in both encoder and decoder is not supported yet.") - return self.inputs.get("multi_modal_data") or self.inputs.get("encoder_multi_modal_data") or {} + raise ValueError( + "Multi-modal data in both encoder and decoder is not supported." + ) + inputs = self.inputs + return self.inputs.get("multi_modal_data") or (cast( + EncoderDecoderLLMInputs, + inputs).get("encoder_multi_modal_data")) or {} @property def lora_int_id(self) -> int: diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index b9556fc01b53..3871c0cb8b81 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -21,10 +21,11 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, EAGLEConfig, ExaoneConfig, GraniteConfig, InternVLChatConfig, - JAISConfig, MedusaConfig, MllamaConfig, - MLPSpeculatorConfig, MPTConfig, - NemotronConfig, RWConfig, - SolarConfig, UltravoxConfig) + JAISConfig, MedusaConfig, + MllamaConfig, MLPSpeculatorConfig, + MPTConfig, NemotronConfig, + RWConfig, SolarConfig, + UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file diff --git a/vllm/transformers_utils/configs/mllama.py b/vllm/transformers_utils/configs/mllama.py index 93b1e2b66f8b..8c79855c7f45 100644 --- a/vllm/transformers_utils/configs/mllama.py +++ b/vllm/transformers_utils/configs/mllama.py @@ -1,4 +1,8 @@ -from transformers.models.mllama.configuration_mllama import MllamaTextConfig as MllamaTextConfigHf, MllamaConfig as MllamaConfigHf +from transformers.models.mllama.configuration_mllama import ( + MllamaConfig as MllamaConfigHf) +from transformers.models.mllama.configuration_mllama import ( + MllamaTextConfig as MllamaTextConfigHf) + class MllamaTextConfig(MllamaTextConfigHf): ''' @@ -6,6 +10,7 @@ class MllamaTextConfig(MllamaTextConfigHf): - transformers regards mllama as is_encoder_decoder=False - vllm needs is_encoder_decoder=True to enable cross-attention ''' + def __init__( self, **kwargs, @@ -16,6 +21,7 @@ def __init__( class MllamaConfig(MllamaConfigHf): + def __init__( self, text_config=None, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 7c7f36aff6cf..14be7d1bd2b3 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -18,7 +18,8 @@ from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry, MultiModalInputs +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, + MultiModalRegistry) from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceGroupMetadata) @@ -293,8 +294,7 @@ def profile_run(self) -> None: max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( self.model_config) if max_mm_tokens > 0: - logger.warning( - "profile run for multi-modal models") + logger.warning("profile run for multi-modal models") batch_size = 0 for group_id in range(max_num_seqs): @@ -302,13 +302,15 @@ def profile_run(self) -> None: (group_id < max_num_batched_tokens % max_num_seqs)) batch_size += seq_len - decoder_seq_data, decoder_dummy_multi_modal_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, + decoder_seq_data, decoder_dummy_multi_modal_data \ + = self.input_registry.dummy_data_for_profiling( + self.model_config, seq_len, self.mm_registry, is_encoder_data=False) - encoder_seq_data, encoder_dummy_multi_modal_data = self.input_registry \ - .dummy_data_for_profiling(self.model_config, + encoder_seq_data, encoder_dummy_multi_modal_data \ + = self.input_registry.dummy_data_for_profiling( + self.model_config, seq_len, self.mm_registry, is_encoder_data=True) @@ -317,9 +319,11 @@ def profile_run(self) -> None: assert len(decoder_seq_data.prompt_token_ids) >= seq_len, ( f"Expected at least {seq_len} dummy tokens for profiling, " f"but got: {len(decoder_seq_data.prompt_token_ids)}") - - assert decoder_dummy_multi_modal_data is None or encoder_dummy_multi_modal_data is None, ( - "Multi-modal data cannot be provided for both encoder and decoder") + + assert decoder_dummy_multi_modal_data is None or \ + encoder_dummy_multi_modal_data is None, ( + "Multi-modal data can't be provided in both encoder and decoder" + ) seq = SequenceGroupMetadata( request_id=str(group_id), @@ -329,7 +333,8 @@ def profile_run(self) -> None: block_tables=None, encoder_seq_data=encoder_seq_data, cross_block_table=None, - multi_modal_data=decoder_dummy_multi_modal_data or encoder_dummy_multi_modal_data, + multi_modal_data=decoder_dummy_multi_modal_data + or encoder_dummy_multi_modal_data, ) seqs.append(seq) From 70b6bb38981e59ba8288afb332fe12e28e9b6cc9 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 22 Sep 2024 17:49:47 -0700 Subject: [PATCH 51/75] try formater again --- vllm/transformers_utils/configs/mllama.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/transformers_utils/configs/mllama.py b/vllm/transformers_utils/configs/mllama.py index 8c79855c7f45..99370defe597 100644 --- a/vllm/transformers_utils/configs/mllama.py +++ b/vllm/transformers_utils/configs/mllama.py @@ -1,7 +1,9 @@ -from transformers.models.mllama.configuration_mllama import ( - MllamaConfig as MllamaConfigHf) -from transformers.models.mllama.configuration_mllama import ( - MllamaTextConfig as MllamaTextConfigHf) +from transformers.models.mllama.configuration_mllama import (MllamaConfig as + MllamaConfigHf) +from transformers.models.mllama.configuration_mllama import (MllamaTextConfig + as + MllamaTextConfigHf + ) class MllamaTextConfig(MllamaTextConfigHf): From 31000d0d189dd9a0d8e77d25c965b44152905560 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 22 Sep 2024 17:51:29 -0700 Subject: [PATCH 52/75] try formater again --- vllm/transformers_utils/configs/mllama.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/vllm/transformers_utils/configs/mllama.py b/vllm/transformers_utils/configs/mllama.py index 99370defe597..8c79855c7f45 100644 --- a/vllm/transformers_utils/configs/mllama.py +++ b/vllm/transformers_utils/configs/mllama.py @@ -1,9 +1,7 @@ -from transformers.models.mllama.configuration_mllama import (MllamaConfig as - MllamaConfigHf) -from transformers.models.mllama.configuration_mllama import (MllamaTextConfig - as - MllamaTextConfigHf - ) +from transformers.models.mllama.configuration_mllama import ( + MllamaConfig as MllamaConfigHf) +from transformers.models.mllama.configuration_mllama import ( + MllamaTextConfig as MllamaTextConfigHf) class MllamaTextConfig(MllamaTextConfigHf): From 5be8a6556d84e081559f33b6244c47adf3276d5b Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 22 Sep 2024 17:59:39 -0700 Subject: [PATCH 53/75] try formater again again again --- vllm/transformers_utils/configs/mllama.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/transformers_utils/configs/mllama.py b/vllm/transformers_utils/configs/mllama.py index 8c79855c7f45..11a34e0c658f 100644 --- a/vllm/transformers_utils/configs/mllama.py +++ b/vllm/transformers_utils/configs/mllama.py @@ -1,10 +1,7 @@ -from transformers.models.mllama.configuration_mllama import ( - MllamaConfig as MllamaConfigHf) -from transformers.models.mllama.configuration_mllama import ( - MllamaTextConfig as MllamaTextConfigHf) +from transformers.models.mllama import configuration_mllama as mllama_hf_config -class MllamaTextConfig(MllamaTextConfigHf): +class MllamaTextConfig(mllama_hf_config.MllamaTextConfig): ''' Use this class to override is_encoder_decoder: - transformers regards mllama as is_encoder_decoder=False @@ -20,7 +17,7 @@ def __init__( self.is_encoder_decoder = True -class MllamaConfig(MllamaConfigHf): +class MllamaConfig(mllama_hf_config.MllamaConfig): def __init__( self, From 8505a8f528d27f2fe82d064f7aa18adf3c405bc4 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 22 Sep 2024 18:06:51 -0700 Subject: [PATCH 54/75] try formater again again again again --- vllm/model_executor/models/mllama.py | 30 +++++++++++++++------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 00d809d79a26..e8c7b3edafc3 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -21,12 +21,11 @@ import torch import torch.nn.functional as F import torch.utils.checkpoint +import transformers.models.mllama.configuration_mllama as config_mllama from PIL import Image from torch import nn from transformers.modeling_outputs import (BaseModelOutput, CausalLMOutputWithPast) -from transformers.models.mllama.configuration_mllama import ( - MllamaConfig, MllamaTextConfig, MllamaVisionConfig) from transformers.models.mllama.image_processing_mllama import ( get_optimal_tiled_canvas) @@ -240,7 +239,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MllamaPrecomputedAspectRatioEmbedding(nn.Module): - def __init__(self, config: MllamaVisionConfig, is_gated: bool = True): + def __init__(self, + config: config_mllama.MllamaVisionConfig, + is_gated: bool = True): super().__init__() self.max_num_tiles = config.max_num_tiles self.hidden_size = config.hidden_size @@ -267,7 +268,7 @@ def forward(self, hidden_state: torch.Tensor, class MllamaPrecomputedPositionEmbedding(nn.Module): - def __init__(self, config: MllamaVisionConfig): + def __init__(self, config: config_mllama.MllamaVisionConfig): super().__init__() self.max_num_tiles = config.max_num_tiles self.max_aspect_ratio_id = config.max_aspect_ratio_id @@ -307,7 +308,7 @@ def forward(self, hidden_state: torch.Tensor, class MllamaVisionSdpaAttention(nn.Module): - def __init__(self, config: MllamaVisionConfig): + def __init__(self, config: config_mllama.MllamaVisionConfig): super().__init__() model_parallel_size = get_tensor_model_parallel_world_size() @@ -413,7 +414,7 @@ class MllamaVisionEncoder(nn.Module): """ def __init__(self, - config: MllamaVisionConfig, + config: config_mllama.MllamaVisionConfig, num_layers=32, is_gated=False, output_hidden_states=None): @@ -448,7 +449,7 @@ def forward( class MllamaVisionModel(nn.Module): - def __init__(self, config: MllamaVisionConfig): + def __init__(self, config: config_mllama.MllamaVisionConfig): super().__init__() self.image_size = config.image_size self.patch_size = config.patch_size @@ -631,7 +632,7 @@ class MllamaTextCrossAttention(nn.Module): def __init__( self, - config: Optional[MllamaTextConfig] = None, + config: Optional[config_mllama.MllamaTextConfig] = None, layer_idx: Optional[int] = None, ): super().__init__() @@ -717,7 +718,8 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module): """Cross-attention transformer block with tanh-gated attention and feedforward.""" - def __init__(self, config: MllamaTextConfig, layer_idx: int) -> None: + def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int) \ + -> None: super().__init__() self.layer_idx = layer_idx self.cross_attn = MllamaTextCrossAttention( @@ -771,13 +773,13 @@ def forward( class MllamaTextModel(nn.Module): - config_class = MllamaTextConfig + config_class = config_mllama.MllamaTextConfig base_model_prefix = "model" _no_split_modules = [ "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer" ] - def __init__(self, config: MllamaTextConfig, + def __init__(self, config: config_mllama.MllamaTextConfig, cache_config: Optional[CacheConfig], quant_config: Optional[QuantizationConfig]): super().__init__() @@ -845,13 +847,13 @@ def forward( class MllamaForCausalLM(nn.Module): - config_class = MllamaTextConfig + config_class = config_mllama.MllamaTextConfig base_model_prefix = "language_model" _no_split_modules = [ "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer" ] - def __init__(self, config: MllamaTextConfig, + def __init__(self, config: config_mllama.MllamaTextConfig, cache_config: Optional[CacheConfig], quant_config: Optional[QuantizationConfig]): super().__init__() @@ -898,7 +900,7 @@ def forward( class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): def __init__(self, - config: MllamaConfig, + config: config_mllama.MllamaConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): From a32c3ab05215c67e33c93deab1a84397ddbdfb50 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Sun, 22 Sep 2024 21:43:10 -0700 Subject: [PATCH 55/75] update example --- examples/offline_inference_vision_language.py | 22 ++++++++++--- examples/openai_vision_api_client.py | 31 +++---------------- 2 files changed, 22 insertions(+), 31 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index a2da4cae6bc1..1a4be093074b 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -12,10 +12,6 @@ from vllm.assets.video import VideoAsset from vllm.utils import FlexibleArgumentParser -# Input image and question -image = ImageAsset("cherry_blossom").pil_image.convert("RGB") -question = "What is the content of this image?" - # LLaVA-1.5 def run_llava(question, modality): @@ -232,6 +228,23 @@ def run_qwen2_vl(question, modality): return llm, prompt, stop_token_ids +# LLama +def run_mllama(question, modality): + assert modality == "image" + + model_name = "/data/zhang-chen/Llama-3.2-11B-Vision-Instruct" + + llm = LLM( + model=model_name, + max_num_seqs=16, + enforce_eager=True, + ) + + prompt = f"<|image|><|begin_of_text|>{question}" + stop_token_ids = None + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -246,6 +259,7 @@ def run_qwen2_vl(question, modality): "internvl_chat": run_internvl, "qwen_vl": run_qwen_vl, "qwen2_vl": run_qwen2_vl, + "mllama": run_mllama, } diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py index 475a8a9dc1bf..71ae03e4d148 100644 --- a/examples/openai_vision_api_client.py +++ b/examples/openai_vision_api_client.py @@ -28,7 +28,7 @@ model = models.data[0].id # Single-image input inference -image_url = "https://llava-vl.github.io/static/images/view.jpg" +image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" ## Use image url in the payload chat_completion_from_url = client.chat.completions.create( @@ -38,7 +38,7 @@ "content": [ { "type": "text", - "text": "Describe image in two sentences" + "text": "What's in this image?" }, { "type": "image_url", @@ -50,33 +50,10 @@ }], model=model, max_tokens=64, - temperature=0.0, ) result = chat_completion_from_url.choices[0].message.content -print("Text + image output:", result) - -chat_completion_text_only = client.chat.completions.create( - messages=[{ - "role": - "user", - "content": [ - { - "type": "text", - "text": "what is the recipe of mayonnaise in two sentences?" - }, - ] - }], - model=model, - max_tokens=64, - temperature=0.0, -) - -result = chat_completion_text_only.choices[0].message.content -print("Text-only output output:", result) - -print("remove me: testing done, exiting...") -exit(0) +print("Chat completion output:", result) ## Use base64 encoded image in the payload @@ -98,7 +75,7 @@ def encode_image_base64_from_url(image_url: str) -> str: "content": [ { "type": "text", - "text": "What’s in this image?" + "text": "What's in this image?" }, { "type": "image_url", From 10d17367f481fd369c621bad688c066b9a012991 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 23 Sep 2024 00:58:16 -0700 Subject: [PATCH 56/75] fix bug in openai api -> chat template --- vllm/entrypoints/chat_utils.py | 32 ++++++++++++++++++++++++++++ vllm/model_executor/models/mllama.py | 8 ------- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index a4d0f7c44437..291239f43358 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -360,6 +360,7 @@ def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], _ImageParser = partial(cast, ChatCompletionContentPartImageParam) _AudioParser = partial(cast, ChatCompletionContentPartAudioParam) _RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) +MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'} def _parse_chat_message_content_parts( @@ -370,6 +371,37 @@ def _parse_chat_message_content_parts( texts: List[str] = [] mm_parser = mm_tracker.create_parser() + keep_multimodal_content = \ + mm_tracker._model_config.hf_config.model_type in \ + MODEL_KEEP_MULTI_MODAL_CONTENT + + if keep_multimodal_content: + is_image = False + for part in parts: + part_type = part["type"] + if part_type == "text": + text = _TextParser(part)["text"] + texts.append(text) + elif part_type == "image_url": + image_url = _ImageParser(part)["image_url"] + + if image_url.get("detail", "auto") != "auto": + logger.warning( + "'image_url.detail' is currently not supported and " + "will be ignored.") + + mm_parser.parse_image(image_url["url"]) + is_image = True + else: + raise NotImplementedError(f"Unknown part type: {part_type}") + + text_prompt = "\n".join(texts) + role_content = [{'type': 'text', 'text': text_prompt}] + + if is_image: + role_content = [{'type': 'image'}] + role_content + return [ConversationMessage(role=role, + content=role_content)] # type: ignore for part in parts: part_type = part["type"] diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index e8c7b3edafc3..adf6c67e634f 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -76,14 +76,6 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): if llm_inputs.get("prompt") is None: llm_inputs["prompt"] = llm_inputs["encoder_prompt"] llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"] - # TODO: remove this hack - if 198 in llm_inputs["prompt_token_ids"]: - index_198 = llm_inputs["prompt_token_ids"].index(198) - if index_198 > 0 and llm_inputs["prompt_token_ids"][ - index_198 - 1] == LLAMA_IMAGE_TOKEN_ID: - llm_inputs["prompt_token_ids"] = llm_inputs[ - "prompt_token_ids"][:index_198] + llm_inputs[ - "prompt_token_ids"][index_198 + 1:] # process multi-modal data assert "decoder_multi_modal_data" not in llm_inputs, \ From 0aa61b0f3efcde098368fac7bbf9a75fc99f707c Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 23 Sep 2024 13:51:32 -0700 Subject: [PATCH 57/75] change model based on new hf --- vllm/model_executor/models/mllama.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index adf6c67e634f..9c22d5fd0e23 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -365,8 +365,8 @@ def __init__(self, config, is_gated: bool = False): self.self_attn = MllamaVisionSdpaAttention(config) self.mlp = CLIPMLP(config) - self.input_layernorm = nn.LayerNorm(self.hidden_size) - self.post_attention_layernorm = nn.LayerNorm(self.hidden_size) + self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) + self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) # there used to be an if else here, no code path if is_gated: @@ -447,14 +447,14 @@ def __init__(self, config: config_mllama.MllamaVisionConfig): self.patch_size = config.patch_size self.max_num_tiles = config.max_num_tiles self.hidden_size = config.hidden_size - self.in_channels = config.in_channels + self.in_channels = config.num_channels self.intermediate_layers_indices = config.intermediate_layers_indices self.num_patches = (self.image_size // self.patch_size)**2 + 1 self.scale = config.hidden_size**-0.5 self.patch_embedding = ColumnParallelConv2dPatch( - in_channels=config.in_channels, + in_channels=config.num_channels, out_channels=self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size, @@ -495,7 +495,7 @@ def apply_class_embedding(self, def forward(self, pixel_values: torch.Tensor, aspect_ratio_ids: torch.Tensor, - attention_mask: torch.Tensor) -> torch.Tensor: + aspect_ratio_mask: torch.Tensor) -> torch.Tensor: batch_size, num_concurrent_media, num_tiles, num_channels, \ height, width = pixel_values.shape @@ -543,15 +543,14 @@ def forward(self, pixel_values: torch.Tensor, hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) slice_index = -num_padding_patches if num_padding_patches > 0 else None - if attention_mask is not None: - attention_mask = attention_mask.reshape( - batch_size * num_concurrent_media, -1) - attention_mask = _prepare_aspect_ratio_attention_mask( - aspect_ratio_mask=attention_mask, - num_patches=self.num_patches, - target_length=hidden_state.shape[2], - dtype=self.layernorm_pre.weight.dtype, - ) + attention_mask = aspect_ratio_mask.reshape( + batch_size * num_concurrent_media, -1) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.layernorm_pre.weight.dtype, + ) hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim) @@ -787,6 +786,7 @@ def __init__(self, config: config_mllama.MllamaTextConfig, layers.append( MllamaCrossAttentionDecoderLayer(config, layer_idx)) else: + # TODO: force LlamaDecoderLayer to config.attention_bias=False layers.append( LlamaDecoderLayer(config, cache_config=cache_config, From b993988acfdaa1d2484d701ebef4b4575e285cc2 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 23 Sep 2024 14:34:01 -0700 Subject: [PATCH 58/75] make formater happy --- vllm/model_executor/models/mllama.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 9c22d5fd0e23..6dac117cdc6c 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -365,8 +365,10 @@ def __init__(self, config, is_gated: bool = False): self.self_attn = MllamaVisionSdpaAttention(config) self.mlp = CLIPMLP(config) - self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) - self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) + self.input_layernorm = nn.LayerNorm(self.hidden_size, + eps=config.norm_eps) + self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, + eps=config.norm_eps) # there used to be an if else here, no code path if is_gated: From 9065770cf61a7337786386d5267bd30130031435 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 23 Sep 2024 14:49:19 -0700 Subject: [PATCH 59/75] update model name in example --- examples/offline_inference_vision_language.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 1a4be093074b..de0947115bbf 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -232,7 +232,7 @@ def run_qwen2_vl(question, modality): def run_mllama(question, modality): assert modality == "image" - model_name = "/data/zhang-chen/Llama-3.2-11B-Vision-Instruct" + model_name = "nltpt/Llama-3.2-11B-Vision-Instruct" llm = LLM( model=model_name, From bc34aa44bf49ae26ef417f9ccbc2d8eacd2fdbef Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 23 Sep 2024 14:52:53 -0700 Subject: [PATCH 60/75] remove mllama chat template, use HF's instead --- examples/template_llama3.2.jinja | 1 - 1 file changed, 1 deletion(-) delete mode 100644 examples/template_llama3.2.jinja diff --git a/examples/template_llama3.2.jinja b/examples/template_llama3.2.jinja deleted file mode 100644 index 93049d23eaff..000000000000 --- a/examples/template_llama3.2.jinja +++ /dev/null @@ -1 +0,0 @@ -{% for message in messages %}{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }}{% if message['content'] is string %}{{ message['content'] }}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' %}{{ '<|image|>' }}{% elif content['type'] == 'text' %}{{ content['text'] }}{% endif %}{% endfor %}{% endif %}{{ '<|eot_id|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %} \ No newline at end of file From a25e38308d0d9f8a9fadba2d5a396dc91e491d1d Mon Sep 17 00:00:00 2001 From: Chang Su Date: Mon, 23 Sep 2024 14:18:36 -0700 Subject: [PATCH 61/75] [Bugfix] Include encoder_prompt_tokens in num_prompt_tokensin UsageInfo --- vllm/entrypoints/openai/serving_chat.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1ee4b3ce17cf..35b2fe29778f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -300,6 +300,8 @@ async def chat_completion_stream_generator( async for res in result_generator: if res.prompt_token_ids is not None: num_prompt_tokens = len(res.prompt_token_ids) + if res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(res.encoder_prompt_token_ids) # We need to do it here, because if there are exceptions in # the result_generator, it needs to be sent as the FIRST From 1eefdc7eb726296dcfd68bb33f57ddb5bf1a7868 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 24 Sep 2024 12:17:24 -0700 Subject: [PATCH 62/75] update config based on HF update --- vllm/transformers_utils/configs/mllama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/transformers_utils/configs/mllama.py b/vllm/transformers_utils/configs/mllama.py index 11a34e0c658f..49e766d7fa1f 100644 --- a/vllm/transformers_utils/configs/mllama.py +++ b/vllm/transformers_utils/configs/mllama.py @@ -13,7 +13,6 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.hidden_act = self.hidden_activation self.is_encoder_decoder = True From ccebf14085ad02c0319b4aa33659236aa4c310e8 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 24 Sep 2024 22:12:19 -0700 Subject: [PATCH 63/75] Merge branch 'main' of github.com:vllm-project/vllm --- .gitignore | 4 +- CMakeLists.txt | 10 + Dockerfile | 5 +- Dockerfile.cpu | 4 +- Dockerfile.neuron | 23 +- Dockerfile.openvino | 5 +- Dockerfile.ppc64le | 12 +- Dockerfile.rocm | 61 +- Dockerfile.tpu | 17 +- Dockerfile.xpu | 21 +- benchmarks/benchmark_latency.py | 8 +- benchmarks/benchmark_prioritization.py | 295 ++++ benchmarks/benchmark_throughput.py | 24 +- benchmarks/kernels/benchmark_machete.py | 74 +- benchmarks/kernels/requirements.txt | 1 + cmake/cpu_extension.cmake | 1 - csrc/custom_all_reduce.cuh | 139 +- csrc/custom_all_reduce_test.cu | 21 +- csrc/cutlass_extensions/torch_utils.hpp | 8 +- csrc/moe/marlin_kernels/marlin_moe_kernel.h | 1425 ++++++++++++++++ .../marlin_kernels/marlin_moe_kernel_ku4b8.cu | 29 + .../marlin_kernels/marlin_moe_kernel_ku4b8.h | 20 + .../marlin_moe_kernel_ku8b128.cu | 29 + .../marlin_moe_kernel_ku8b128.h | 18 + csrc/moe/marlin_moe_ops.cu | 1453 +---------------- csrc/ops.h | 2 + csrc/permute_cols.cu | 88 + csrc/quantization/machete/generate.py | 173 +- .../machete/machete_mm_kernel.cuh | 3 +- .../machete/machete_mm_launcher.cuh | 2 +- .../machete/machete_prepack_launcher.cuh | 2 +- csrc/torch_bindings.cpp | 3 + .../dev/multimodal/multimodal_index.rst | 2 +- .../dev/offline_inference/llm_inputs.rst | 2 +- .../getting_started/amd-installation.rst | 65 +- .../getting_started/cpu-installation.rst | 2 +- .../getting_started/xpu-installation.rst | 6 +- docs/source/models/vlm.rst | 2 +- docs/source/quantization/bnb.rst | 2 +- examples/lora_with_quantization_inference.py | 26 +- examples/offline_inference_chat.py | 27 + examples/offline_inference_vision_language.py | 14 + ...e_inference_vision_language_multi_image.py | 13 + pyproject.toml | 7 +- requirements-build.txt | 3 +- requirements-test.txt | 2 +- requirements-xpu.txt | 4 +- setup.py | 65 +- tests/conftest.py | 18 +- tests/core/test_chunked_prefill_scheduler.py | 225 ++- tests/core/test_scheduler.py | 363 ++-- tests/distributed/test_pipeline_parallel.py | 8 + tests/engine/test_arg_utils.py | 21 + tests/entrypoints/llm/test_generate.py | 35 + tests/entrypoints/openai/test_accuracy.py | 6 +- tests/kernels/test_machete_gemm.py | 3 + tests/kernels/test_permute_cols.py | 15 + tests/lora/test_punica_sizes.py | 5 + tests/lora/test_punica_variation.py | 5 + .../vision_language/test_phi3v.py | 186 ++- .../decoder_only/vision_language/test_qwen.py | 29 +- tests/models/utils.py | 35 + tests/mq_llm_engine/test_error_handling.py | 27 +- tests/mq_llm_engine/utils.py | 2 +- tests/multimodal/test_processor_kwargs.py | 339 ++++ tests/quantization/test_bitsandbytes.py | 2 +- tests/samplers/test_beam_search.py | 6 +- .../test_typical_acceptance_sampler.py | 17 +- tests/spec_decode/e2e/conftest.py | 139 +- .../spec_decode/e2e/test_eagle_correctness.py | 58 + tests/spec_decode/e2e/test_logprobs.py | 95 +- .../e2e/test_medusa_correctness.py | 59 + tests/spec_decode/e2e/test_mlp_correctness.py | 57 +- .../spec_decode/e2e/test_ngram_correctness.py | 59 + tests/test_embedded_commit.py | 7 +- vllm/__init__.py | 8 +- vllm/_custom_ops.py | 19 +- vllm/config.py | 45 +- vllm/core/scheduler.py | 77 + .../device_communicators/shm_broadcast.py | 7 +- vllm/engine/arg_utils.py | 14 + vllm/engine/async_llm_engine.py | 24 +- vllm/engine/llm_engine.py | 71 +- vllm/engine/multiprocessing/__init__.py | 11 +- vllm/engine/multiprocessing/client.py | 71 +- vllm/engine/multiprocessing/engine.py | 88 +- vllm/engine/output_processor/multi_step.py | 9 +- vllm/engine/protocol.py | 8 +- vllm/entrypoints/llm.py | 294 +++- vllm/envs.py | 5 + vllm/inputs/__init__.py | 6 +- vllm/inputs/data.py | 26 +- vllm/inputs/parse.py | 22 +- vllm/inputs/preprocess.py | 86 +- vllm/inputs/registry.py | 60 +- vllm/lora/ops/bgmv_expand.py | 2 +- vllm/lora/ops/bgmv_expand_slice.py | 2 +- vllm/lora/ops/sgmv_expand.py | 16 +- vllm/lora/ops/sgmv_expand_slice.py | 18 +- vllm/lora/ops/sgmv_shrink.py | 16 +- vllm/lora/punica.py | 38 +- .../layers/quantization/awq_marlin.py | 9 +- .../layers/quantization/bitsandbytes.py | 8 +- .../schemes/compressed_tensors_wNa16.py | 114 +- .../layers/quantization/gptq_marlin.py | 133 +- .../quantization/kernels/MPLinearKernel.py | 83 + .../layers/quantization/kernels/__init__.py | 72 + .../layers/quantization/kernels/machete.py | 118 ++ .../layers/quantization/kernels/marlin.py | 132 ++ .../layers/quantization/utils/__init__.py | 3 + .../layers/quantization/utils/layer_utils.py | 37 + .../quantization/utils/machete_utils.py | 30 + .../layers/quantization/utils/marlin_utils.py | 29 +- .../layers/quantization/utils/quant_utils.py | 43 + vllm/model_executor/layers/sampler.py | 11 +- .../layers/typical_acceptance_sampler.py | 28 +- vllm/model_executor/model_loader/loader.py | 68 +- vllm/model_executor/models/fuyu.py | 2 +- vllm/model_executor/models/paligemma.py | 3 +- vllm/model_executor/models/persimmon.py | 12 +- vllm/model_executor/models/phi3v.py | 31 +- vllm/model_executor/models/qwen2.py | 22 +- vllm/model_executor/models/qwen2_vl.py | 29 +- vllm/model_executor/models/ultravox.py | 30 +- vllm/model_executor/parameter.py | 58 + vllm/multimodal/base.py | 19 +- vllm/multimodal/image.py | 10 +- vllm/multimodal/registry.py | 9 + vllm/multimodal/video.py | 9 +- vllm/outputs.py | 96 +- vllm/sampling_params.py | 5 + vllm/sequence.py | 34 +- vllm/spec_decode/batch_expansion.py | 10 +- vllm/spec_decode/spec_decode_worker.py | 62 +- vllm/spec_decode/util.py | 45 +- vllm/transformers_utils/detokenizer.py | 16 +- vllm/transformers_utils/image_processor.py | 64 - vllm/transformers_utils/processor.py | 65 +- vllm/utils.py | 57 + vllm/version.py | 12 +- vllm/worker/cpu_model_runner.py | 304 ++-- vllm/worker/model_runner_base.py | 10 +- 142 files changed, 6115 insertions(+), 2903 deletions(-) create mode 100644 benchmarks/benchmark_prioritization.py create mode 100644 benchmarks/kernels/requirements.txt create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel.h create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu create mode 100644 csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h create mode 100644 csrc/permute_cols.cu create mode 100644 tests/kernels/test_permute_cols.py create mode 100644 tests/multimodal/test_processor_kwargs.py create mode 100644 vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py create mode 100644 vllm/model_executor/layers/quantization/kernels/__init__.py create mode 100644 vllm/model_executor/layers/quantization/kernels/machete.py create mode 100644 vllm/model_executor/layers/quantization/kernels/marlin.py create mode 100644 vllm/model_executor/layers/quantization/utils/layer_utils.py create mode 100644 vllm/model_executor/layers/quantization/utils/machete_utils.py delete mode 100644 vllm/transformers_utils/image_processor.py diff --git a/.gitignore b/.gitignore index bc7236ea1869..abeaf0a82e30 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ -# vllm commit id, generated by setup.py -vllm/commit_id.py +# version file generated by setuptools-scm +/vllm/_version.py # vllm-flash-attn built from source vllm/vllm_flash_attn/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 03937e4e0658..b2fa72d4775c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -192,6 +192,10 @@ set(VLLM_EXT_SRC if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") + + # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. + set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use") + FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git @@ -219,6 +223,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/fp8/fp8_marlin.cu" "csrc/custom_all_reduce.cu" + "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") @@ -311,6 +316,11 @@ set(VLLM_MOE_EXT_SRC if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC + "csrc/moe/marlin_kernels/marlin_moe_kernel.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" "csrc/moe/marlin_moe_ops.cu") endif() diff --git a/Dockerfile b/Dockerfile index 30e27620574a..ec803764a128 100644 --- a/Dockerfile +++ b/Dockerfile @@ -79,15 +79,13 @@ ENV MAX_JOBS=${max_jobs} ARG nvcc_threads=8 ENV NVCC_THREADS=$nvcc_threads -ARG buildkite_commit -ENV BUILDKITE_COMMIT=${buildkite_commit} - ARG USE_SCCACHE ARG SCCACHE_BUCKET_NAME=vllm-build-sccache ARG SCCACHE_REGION_NAME=us-west-2 ARG SCCACHE_S3_NO_CREDENTIALS=0 # if USE_SCCACHE is set, use sccache to speed up compilation RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=.git,target=.git \ if [ "$USE_SCCACHE" = "1" ]; then \ echo "Installing sccache..." \ && curl -L -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz \ @@ -107,6 +105,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/ccache \ --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=.git,target=.git \ if [ "$USE_SCCACHE" != "1" ]; then \ python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \ fi diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 4d7289366296..a9d97a3e0bde 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -62,8 +62,10 @@ ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512} RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=cache,target=/root/.cache/ccache \ + --mount=type=bind,source=.git,target=.git \ VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \ - pip install dist/*.whl + pip install dist/*.whl && \ + rm -rf dist WORKDIR /workspace/ diff --git a/Dockerfile.neuron b/Dockerfile.neuron index 647ed99a41e7..adae6db87ba8 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -6,9 +6,12 @@ FROM $BASE_IMAGE RUN echo "Base image is $BASE_IMAGE" # Install some basic utilities -RUN apt-get update \ - && apt-get install python3 python3-pip -y \ - && apt-get install -y ffmpeg libsm6 libxext6 libgl1 +RUN apt-get update && \ + apt-get install -y \ + git \ + python3 \ + python3-pip \ + ffmpeg libsm6 libxext6 libgl1 ### Mount Point ### # When launching the container, mount the code directory to /app @@ -22,17 +25,17 @@ RUN python3 -m pip install sentencepiece transformers==4.36.2 -U RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U RUN python3 -m pip install --pre neuronx-cc==2.15.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U -COPY ./vllm /app/vllm/vllm -COPY ./setup.py /app/vllm/setup.py -COPY ./requirements-common.txt /app/vllm/requirements-common.txt -COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt +COPY . /app/vllm RUN cd /app/vllm \ - && python3 -m pip install -U -r requirements-neuron.txt + && python3 -m pip install -U \ + cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + -r requirements-neuron.txt ENV VLLM_TARGET_DEVICE neuron -RUN cd /app/vllm \ - && pip install -e . \ +RUN --mount=type=bind,source=.git,target=.git \ + cd /app/vllm \ + && pip install --no-build-isolation -v -e . \ && cd .. CMD ["/bin/bash"] diff --git a/Dockerfile.openvino b/Dockerfile.openvino index 96b9593a2bfa..95714a3d1718 100644 --- a/Dockerfile.openvino +++ b/Dockerfile.openvino @@ -4,8 +4,9 @@ FROM ubuntu:22.04 AS dev RUN apt-get update -y && \ - apt-get install -y python3-pip git && \ - apt-get install -y ffmpeg libsm6 libxext6 libgl1 + apt-get install -y \ + git python3-pip \ + ffmpeg libsm6 libxext6 libgl1 WORKDIR /workspace # copy requirements diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le index 3313162bf28e..1f374b01b9bc 100644 --- a/Dockerfile.ppc64le +++ b/Dockerfile.ppc64le @@ -16,9 +16,15 @@ COPY ./ /workspace/vllm WORKDIR /workspace/vllm # These packages will be in rocketce eventually -RUN pip install -v cmake xformers torch==2.3.1 uvloop==0.20.0 -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing - -RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \ + cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + torch==2.3.1 \ + -r requirements-cpu.txt \ + xformers uvloop==0.20.0 + +RUN --mount=type=bind,source=.git,target=.git \ + VLLM_TARGET_DEVICE=cpu python3 setup.py install WORKDIR /workspace/ diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 33423fde4ff9..9aa3a974e704 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -1,5 +1,5 @@ -# Default ROCm 6.1 base image -ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging" +# Default ROCm 6.2 base image +ARG BASE_IMAGE="rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0" # Default ROCm ARCHes to build vLLM for. ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100" @@ -7,18 +7,12 @@ ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100" # Whether to install CK-based flash-attention # If 0, will not install flash-attention ARG BUILD_FA="1" -# If `TRY_FA_WHEEL=1`, we will try installing flash-attention from `FA_WHEEL_URL` -# If this succeeds, we use the downloaded wheel and skip building flash-attention. -# Otherwise, ROCm flash-attention from `FA_BRANCH` will be built for the -# architectures specified in `FA_GFX_ARCHS` -ARG TRY_FA_WHEEL="1" -ARG FA_WHEEL_URL="https://github.com/ROCm/flash-attention/releases/download/v2.5.9post1-cktile-vllm/flash_attn-2.5.9.post1-cp39-cp39-linux_x86_64.whl" ARG FA_GFX_ARCHS="gfx90a;gfx942" -ARG FA_BRANCH="23a2b1c2" +ARG FA_BRANCH="3cea2fb" # Whether to build triton on rocm ARG BUILD_TRITON="1" -ARG TRITON_BRANCH="e0fc12c" +ARG TRITON_BRANCH="e192dba" ### Base image build stage FROM $BASE_IMAGE AS base @@ -50,14 +44,17 @@ RUN python3 -m pip install --upgrade pip # Remove sccache so it doesn't interfere with ccache # TODO: implement sccache support across components RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)" -# Install torch == 2.5.0 on ROCm -RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ - *"rocm-6.1"*) \ + +# Install torch == 2.6.0 on ROCm +RUN --mount=type=cache,target=/root/.cache/pip \ + case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ + *"rocm-6.2"*) \ python3 -m pip uninstall -y torch torchvision \ - && python3 -m pip install --no-cache-dir --pre \ - torch==2.5.0.dev20240726 \ - torchvision==0.20.0.dev20240726 \ - --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \ + && python3 -m pip install --pre \ + torch==2.6.0.dev20240918 \ + setuptools-scm>=8 \ + torchvision==0.20.0.dev20240918 \ + --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2;; \ *) ;; esac ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer @@ -79,25 +76,18 @@ RUN cd /opt/rocm/share/amd_smi \ ### Flash-Attention wheel build stage FROM base AS build_fa ARG BUILD_FA -ARG TRY_FA_WHEEL -ARG FA_WHEEL_URL ARG FA_GFX_ARCHS ARG FA_BRANCH # Build ROCm flash-attention wheel if `BUILD_FA = 1` RUN --mount=type=cache,target=${CCACHE_DIR} \ if [ "$BUILD_FA" = "1" ]; then \ - if [ "${TRY_FA_WHEEL}" = "1" ] && python3 -m pip install "${FA_WHEEL_URL}"; then \ - # If a suitable wheel exists, we download it instead of building FA - mkdir -p /install && wget -N "${FA_WHEEL_URL}" -P /install; \ - else \ - mkdir -p libs \ - && cd libs \ - && git clone https://github.com/ROCm/flash-attention.git \ - && cd flash-attention \ - && git checkout "${FA_BRANCH}" \ - && git submodule update --init \ - && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \ - fi; \ + mkdir -p libs \ + && cd libs \ + && git clone https://github.com/ROCm/flash-attention.git \ + && cd flash-attention \ + && git checkout "${FA_BRANCH}" \ + && git submodule update --init \ + && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \ # Create an empty directory otherwise as later build stages expect one else mkdir -p /install; \ fi @@ -112,6 +102,7 @@ RUN --mount=type=cache,target=${CCACHE_DIR} \ if [ "$BUILD_TRITON" = "1" ]; then \ mkdir -p libs \ && cd libs \ + && python3 -m pip install ninja cmake wheel pybind11 \ && git clone https://github.com/OpenAI/triton.git \ && cd triton \ && git checkout "${TRITON_BRANCH}" \ @@ -138,15 +129,9 @@ ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 ENV TOKENIZERS_PARALLELISM=false RUN --mount=type=cache,target=${CCACHE_DIR} \ + --mount=type=bind,source=.git,target=.git \ --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -Ur requirements-rocm.txt \ - && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ - *"rocm-6.1"*) \ - # Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM - wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P /opt/rocm/lib \ - # Prevent interference if torch bundles its own HIP runtime - && rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \ - *) ;; esac \ && python3 setup.py clean --all \ && python3 setup.py develop diff --git a/Dockerfile.tpu b/Dockerfile.tpu index 04cd4d79f404..d8f1a42c4517 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -5,16 +5,25 @@ FROM $BASE_IMAGE WORKDIR /workspace # Install some basic utilities -RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 libgl1 +RUN apt-get update && apt-get install -y \ + git \ + ffmpeg libsm6 libxext6 libgl1 # Install the TPU and Pallas dependencies. -RUN python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html -RUN python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html # Build vLLM. COPY . /workspace/vllm ENV VLLM_TARGET_DEVICE="tpu" -RUN cd /workspace/vllm && python3 -m pip install -r requirements-tpu.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=.git,target=.git \ + cd /workspace/vllm && \ + python3 -m pip install \ + cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + -r requirements-tpu.txt RUN cd /workspace/vllm && python3 setup.py develop CMD ["/bin/bash"] diff --git a/Dockerfile.xpu b/Dockerfile.xpu index 50bbd8f7dad8..8471edd16e4b 100644 --- a/Dockerfile.xpu +++ b/Dockerfile.xpu @@ -7,23 +7,20 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \ chmod 644 /usr/share/keyrings/intel-graphics.gpg -RUN apt-get update -y \ -&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip ffmpeg libsm6 libxext6 libgl1 - -RUN git clone https://github.com/intel/pti-gpu && \ - cd pti-gpu/sdk && \ - mkdir build && \ - cd build && \ - cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE=../cmake/toolchains/icpx_toolchain.cmake -DBUILD_TESTING=OFF .. && \ - make -j && \ - cmake --install . --config Release --prefix "/usr/local" +RUN apt-get update -y && \ + apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip ffmpeg libsm6 libxext6 libgl1 COPY ./ /workspace/vllm WORKDIR /workspace/vllm -RUN pip install -v -r requirements-xpu.txt +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -v --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ \ + cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + -r requirements-xpu.txt -RUN VLLM_TARGET_DEVICE=xpu python3 setup.py install +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=.git,target=.git \ + VLLM_TARGET_DEVICE=xpu python3 setup.py install CMD ["/bin/bash"] diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index eadf994cacd3..a39d1cf842f0 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -11,7 +11,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs -from vllm.inputs import PromptType +from vllm.inputs import PromptInputs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -61,7 +61,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_prompts: List[PromptType] = [{ + dummy_inputs: List[PromptInputs] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] @@ -74,13 +74,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(dummy_prompts, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(dummy_prompts, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py new file mode 100644 index 000000000000..0ba29fabca59 --- /dev/null +++ b/benchmarks/benchmark_prioritization.py @@ -0,0 +1,295 @@ +"""Benchmark offline prioritization.""" +import argparse +import json +import random +import time +from typing import List, Optional, Tuple + +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + + +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int], +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data["conversations"][0]["value"], + data["conversations"][1]["value"]) for data in dataset] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + + #Select a equi-probable random priority + priority = 0 if random.random() < 0.5 else 1 + + filtered_dataset.append((prompt, prompt_len, output_len, priority)) + + return filtered_dataset + + +def run_vllm( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: str, + quantization: Optional[str], + tensor_parallel_size: int, + seed: int, + n: int, + use_beam_search: bool, + trust_remote_code: bool, + dtype: str, + max_model_len: Optional[int], + enforce_eager: bool, + kv_cache_dtype: str, + quantization_param_path: Optional[str], + device: str, + enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, + gpu_memory_utilization: float = 0.9, + download_dir: Optional[str] = None, +) -> float: + from vllm import LLM, SamplingParams + llm = LLM( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + disable_log_stats=False, + ) + + # Add the requests to the engine. + prompts = [] + sampling_params = [] + priority = [] + for prompt, _, output_len, _priority in requests: + prompts.append(prompt) + priority.append(_priority) + sampling_params.append( + SamplingParams( + n=n, + temperature=0.0 if use_beam_search else 1.0, + top_p=1.0, + use_beam_search=use_beam_search, + ignore_eos=True, + max_tokens=output_len, + )) + + start = time.perf_counter() + llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True) + end = time.perf_counter() + return end - start + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + if args.dataset is None: + # Synthesize a prompt with the given input length. + prompt = "hi" * (args.input_len - 1) + requests = [(prompt, args.input_len, args.output_len) + for _ in range(args.num_prompts)] + else: + requests = sample_requests(args.dataset, args.num_prompts, tokenizer, + args.output_len) + + if args.backend == "vllm": + elapsed_time = run_vllm( + requests, args.model, args.tokenizer, args.quantization, + args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, + args.trust_remote_code, args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, args.enable_chunked_prefill, + args.max_num_batched_tokens, args.gpu_memory_utilization, + args.download_dir) + else: + raise ValueError(f"Unknown backend: {args.backend}") + total_num_tokens = sum(prompt_len + output_len + for _, prompt_len, output_len, priority in requests) + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} tokens/s") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark the throughput.") + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii"], + default="vllm") + parser.add_argument("--dataset", + type=str, + default=None, + help="Path to the dataset.") + parser.add_argument("--input-len", + type=int, + default=None, + help="Input prompt length for each request") + parser.add_argument("--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--model", type=str, default="facebook/opt-125m") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument('--quantization', + '-q', + choices=[*QUANTIZATION_METHODS, None], + default=None) + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--num-prompts", + type=int, + default=200, + help="Number of prompts to process.") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument('--trust-remote-code', + action='store_true', + help='trust remote code from huggingface') + parser.add_argument( + '--max-model-len', + type=int, + default=None, + help='Maximum length of a sequence (including prompt and output). ' + 'If None, will be derived from the model.') + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], + help='data type for model weights and activations. ' + 'The "auto" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models.') + parser.add_argument('--gpu-memory-utilization', + type=float, + default=0.9, + help='the fraction of GPU memory to be used for ' + 'the model executor, which can range from 0 to 1.' + 'If unspecified, will use the default value of 0.9.') + parser.add_argument("--enforce-eager", + action="store_true", + help="enforce eager execution") + parser.add_argument( + '--kv-cache-dtype', + type=str, + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + default="auto", + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + parser.add_argument( + '--quantization-param-path', + type=str, + default=None, + help='Path to the JSON file containing the KV cache scaling factors. ' + 'This should generally be supplied, when KV cache dtype is FP8. ' + 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' + 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' + 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' + 'instead supported for common inference criteria.') + parser.add_argument( + "--device", + type=str, + default="cuda", + choices=["cuda", "cpu"], + help='device type for vLLM execution, supporting CUDA and CPU.') + parser.add_argument( + "--enable-prefix-caching", + action='store_true', + help="enable automatic prefix caching for vLLM backend.") + parser.add_argument("--enable-chunked-prefill", + action='store_true', + help="enable chunked prefill for vLLM backend.") + parser.add_argument('--max-num-batched-tokens', + type=int, + default=None, + help='maximum number of batched tokens per ' + 'iteration') + parser.add_argument('--download-dir', + type=str, + default=None, + help='directory to download and load the weights, ' + 'default to the default cache dir of huggingface') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + + args = parser.parse_args() + if args.tokenizer is None: + args.tokenizer = args.model + if args.dataset is None: + assert args.input_len is not None + assert args.output_len is not None + else: + assert args.input_len is None + + main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index e1a5d4ee28ea..68b401d5bbbb 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -90,6 +90,7 @@ def run_vllm( download_dir: Optional[str] = None, load_format: str = EngineArgs.load_format, disable_async_output_proc: bool = False, + use_new_beam_search_impl: bool = False, ) -> float: from vllm import LLM, SamplingParams llm = LLM( @@ -132,9 +133,23 @@ def run_vllm( max_tokens=output_len, )) - start = time.perf_counter() - llm.generate(prompts, sampling_params, use_tqdm=True) - end = time.perf_counter() + if not use_new_beam_search_impl: + start = time.perf_counter() + llm.generate(prompts, sampling_params, use_tqdm=True) + end = time.perf_counter() + else: + assert use_beam_search + prompts = [prompt for prompt, _, _ in requests] + # output_len should be the same for all requests. + output_len = requests[0][2] + for prompt, input_len, _output_len in requests: + assert _output_len == output_len + start = time.perf_counter() + llm.beam_search(prompts, + beam_width=n, + max_tokens=output_len, + ignore_eos=True) + end = time.perf_counter() return end - start @@ -336,7 +351,7 @@ def main(args: argparse.Namespace): run_args.append(args.disable_frontend_multiprocessing) elapsed_time = uvloop.run(run_vllm_async(*run_args)) else: - elapsed_time = run_vllm(*run_args) + elapsed_time = run_vllm(*run_args, args.use_new_beam_search_impl) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -396,6 +411,7 @@ def main(args: argparse.Namespace): default=1, help="Number of generated sequences per prompt.") parser.add_argument("--use-beam-search", action="store_true") + parser.add_argument("--use-new-beam-search-impl", action="store_true") parser.add_argument("--num-prompts", type=int, default=1000, diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index ca45cba6f816..b70c4b94c97a 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -4,8 +4,10 @@ import math import pickle as pkl import time -from typing import Callable, Iterable, List, Tuple +from itertools import product +from typing import Callable, Iterable, List, Optional, Tuple +import pandas as pd import torch import torch.utils.benchmark as TBenchmark from torch.utils.benchmark import Measurement as TMeasurement @@ -84,6 +86,10 @@ def loop_over_weights( fn(a, w_ref, w_q, w_s) +_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None +_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None + + def bench(atype: torch.dtype, wtype: ScalarType, group_size: int, @@ -94,6 +100,8 @@ def bench(atype: torch.dtype, sub_label: str, benchmark_marlinv1: bool = True, sweep_schedules: bool = True) -> Iterable[TMeasurement]: + global _SWEEP_SCHEDULES_RESULTS + a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k) sub_label += f", L={len(weights)}" @@ -163,6 +171,11 @@ def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor: best_schedule = None schedules = ops.machete_supported_schedules(wtype) for schedule in reversed(schedules): + schedule_M = int(schedule.split("_")[0].split("x")[1]) + + # Prune known bad schedules + if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4: + continue def run(a, _, w_q, w_s, schedule=schedule): ops.machete_gemm(a, @@ -175,6 +188,20 @@ def run(a, _, w_q, w_s, schedule=schedule): res = bench_fn(label, sub_label, "machete_best", lambda: loop_over_weights(a, weights_machete, run)) + results_row = { + "M": m, + "K": k, + "N": n, + "group_size": group_size, + "schedule": schedule, + "median": res.median, + } + if _SWEEP_SCHEDULES_RESULTS is None: + _SWEEP_SCHEDULES_RESULTS = pd.DataFrame( + columns=results_row.keys()) + _SWEEP_SCHEDULES_RESULTS.\ + loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row + print(f" {res.median:5.5} ", schedule) if not best or res.median < best.median: best = res @@ -235,18 +262,22 @@ def run_square_bench(args): dim_sizes = list( range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + data = run(args.dtype, args.sweep_schedules, MKNs) make_output(data, MKNs, f"square_bench-{args.dtype}") def run_range_bench(args): - dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) - n = len(dim_sizes) - Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes - Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes - Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes - MKNs = list(zip(Ms, Ks, Ns)) + m_start, k_start, n_start = [int(x) for x in args.dim_start.split(",")] + m_end, k_end, n_end = [int(x) for x in args.dim_end.split(",")] + m_increment, k_increment, n_increment = \ + [int(x) for x in args.dim_increment.split(",")] + Ms = list(range(m_start, m_end + 1, m_increment)) + Ks = list(range(k_start, k_end + 1, k_increment)) + Ns = list(range(n_start, n_end + 1, n_increment)) + MKNs = list(product(Ms, Ks, Ns)) + data = run(args.dtype, args.sweep_schedules, MKNs) make_output(data, MKNs, f"range_bench-{args.dtype}") @@ -333,6 +364,9 @@ def to_torch_dtype(dt): action="store_true", help="Run a sweep over all supported schedules", ) + parser.add_argument("--sweep-csv-out", + help="CSV to store sweep results", + default="sch_sweep_results.csv") subparsers = parser.add_subparsers(dest="cmd", required=True) square_parser = subparsers.add_parser("square_bench") @@ -342,12 +376,21 @@ def to_torch_dtype(dt): square_parser.set_defaults(func=run_square_bench) range_parser = subparsers.add_parser("range_bench") - range_parser.add_argument("--dim-start", type=int, required=True) - range_parser.add_argument("--dim-end", type=int, required=True) - range_parser.add_argument("--dim-increment", type=int, required=True) - range_parser.add_argument("--m-constant", type=int, default=None) - range_parser.add_argument("--n-constant", type=int, default=None) - range_parser.add_argument("--k-constant", type=int, default=None) + range_parser.add_argument( + "--dim-start", + type=str, + required=True, + help="Start value for M,K,N as common separated list") + range_parser.add_argument( + "--dim-end", + type=str, + required=True, + help="End value (inclusive) for M,K,N as common separated list") + range_parser.add_argument( + "--dim-increment", + type=str, + required=True, + help="Increment value for M,K,N as common separated list") range_parser.set_defaults(func=run_range_bench) model_parser = subparsers.add_parser("model_bench") @@ -369,4 +412,9 @@ def to_torch_dtype(dt): model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() + + _SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out args.func(args) + + if _SWEEP_SCHEDULES_RESULTS is not None: + _SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV) diff --git a/benchmarks/kernels/requirements.txt b/benchmarks/kernels/requirements.txt new file mode 100644 index 000000000000..1411a4a0b5ab --- /dev/null +++ b/benchmarks/kernels/requirements.txt @@ -0,0 +1 @@ +pandas \ No newline at end of file diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 8470e9ea9ebd..3c474bd58d04 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -120,4 +120,3 @@ define_gpu_extension_target( ) message(STATUS "Enabling C extension.") -add_dependencies(default _C) diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 1ed49b8aa9ca..a2f7e4330000 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -23,17 +24,23 @@ namespace vllm { -constexpr int kMaxBlocks = 64; -// note: we don't want to use atomics for signals because peer atomics are no -// supported on PCIe links +constexpr int kMaxBlocks = 36; +// Counter may overflow, but it's fine since unsigned int overflow is +// well-defined behavior. +using FlagType = uint32_t; struct Signal { - alignas(128) uint32_t start[kMaxBlocks][8]; - alignas(128) uint32_t end[kMaxBlocks][8]; + alignas(128) FlagType self_counter[kMaxBlocks][8]; + // Two sets of peer counters are needed for two syncs. The reason is that + // it's possible for peer GPU block to arrive at the second sync point while + // the current GPU block haven't passed the first sync point. Thus, peer GPU + // may write counter+1 while current GPU is busy waiting for counter. We use + // alternating counter array to avoid this possibility. + alignas(128) FlagType peer_counter[2][kMaxBlocks][8]; }; struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; -struct __align__(16) RankSignals { volatile Signal* signals[8]; }; +struct __align__(16) RankSignals { Signal* signals[8]; }; // like std::array, but aligned template @@ -123,47 +130,71 @@ DINLINE O downcast(array_t val) { } } -// This function is meant to be used as the first synchronization in the all -// reduce kernel. Thus, it doesn't need to make any visibility guarantees for -// prior memory accesses. Note: volatile writes will not be reordered against -// other volatile writes. -template -DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, - int rank) { - if (threadIdx.x < ngpus) { - // reset flag for next time - self_sg->end[blockIdx.x][threadIdx.x] = 0; - // simultaneously write to the corresponding flag of all ranks. - // Latency = 1 p2p write - sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; - // wait until we got true from all ranks - while (!self_sg->start[blockIdx.x][threadIdx.x]); - } - __syncthreads(); +static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), + "l"(flag_addr)); +#else + asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag), + "l"(flag_addr)); +#endif +} + +static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) { + FlagType flag; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile("ld.acquire.sys.global.u32 %0, [%1];" + : "=r"(flag) + : "l"(flag_addr)); +#else + asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;" + : "=r"(flag) + : "l"(flag_addr)); +#endif + return flag; } -// This function is meant to be used as the second or the final synchronization -// barrier in the all reduce kernel. If it's the final synchronization barrier, -// we don't need to make any visibility guarantees for prior memory accesses. -template -DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, - int rank) { - __syncthreads(); - // eliminate the case that prior writes are not visible after signals become - // visible. Note that I did not managed to make this happen through a lot of - // testing. Might be the case that hardware provides stronger guarantee than - // the memory model. - if constexpr (!final_sync) __threadfence_system(); +static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) { + asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +} + +static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) { + FlagType flag; + asm volatile("ld.volatile.global.u32 %0, [%1];" + : "=r"(flag) + : "l"(flag_addr)); + return flag; +} + +// is_start: whether this is the very first synchronization barrier. +// need_fence: whether a memory fence is needed. If true, a release-acquire +// semantic is used to enforce memory access order before and after this +// barrier. +template +DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, + int rank) { + if constexpr (!is_start) __syncthreads(); + static_assert( + !(is_start && need_fence)); // Start barrier shouldn't need fence. if (threadIdx.x < ngpus) { - // reset flag for next time - self_sg->start[blockIdx.x][threadIdx.x] = 0; - // simultaneously write to the corresponding flag of all ranks. - // Latency = 1 p2p write - sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; - // wait until we got true from all ranks - while (!self_sg->end[blockIdx.x][threadIdx.x]); + // Increment the counter. Technically we only need one counter, but we use + // multiple per block to eliminate the need to share the counter via smem. + auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1; + // Write the expected counter value to peer and wait for correct value from + // peer. + auto peer_counter_ptr = + &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank]; + auto self_counter_ptr = + &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x]; + if constexpr (need_fence) { + st_flag_release(peer_counter_ptr, val); + while (ld_flag_acquire(self_counter_ptr) != val); + } else { + st_flag_volatile(peer_counter_ptr, val); + while (ld_flag_volatile(self_counter_ptr) != val); + } } - if constexpr (!final_sync) __syncthreads(); + if constexpr (is_start || need_fence) __syncthreads(); } template @@ -178,33 +209,31 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) { template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_1stage(RankData* _dp, RankSignals sg, - volatile Signal* self_sg, T* __restrict__ result, - int rank, int size) { + cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg, + T* __restrict__ result, int rank, int size) { using P = typename packed_t::P; using A = typename packed_t::A; // note: we don't reorder the address so the accumulation order is the same // for all ranks, ensuring bitwise identical results auto dp = *_dp; - start_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); // do the actual reduction for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); } - end_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); } template -DINLINE P* get_tmp_buf(volatile Signal* sg) { +DINLINE P* get_tmp_buf(Signal* sg) { return (P*)(((Signal*)sg) + 1); } template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_2stage(RankData* _dp, RankSignals sg, - volatile Signal* self_sg, T* __restrict__ result, - int rank, int size) { + cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg, + T* __restrict__ result, int rank, int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; using P = typename packed_t::P; @@ -222,12 +251,12 @@ __global__ void __launch_bounds__(512, 1) tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; - start_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); // stage 1: reduce scatter for (int idx = start + tid; idx < end; idx += stride) { tmp_out[idx - start] = packed_reduce(ptrs, idx); } - end_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); // stage 2: allgather. Note: it's important to match the tid between // the two stages, because visibility across devices is only guaranteed @@ -437,6 +466,8 @@ class CustomAllreduce { #define KL(ngpus, name) \ name<<>>(ptrs, sg_, self_sg_, output, \ rank_, size); + // TODO(hanzhi713): Threshold is different for A100 and H100. + // Add per device threshold. #define REDUCE_CASE(ngpus) \ case ngpus: { \ if (world_size_ == 2) { \ diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index f7868233076c..376687e91cfd 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -1,15 +1,15 @@ /** * This is a standalone test for custom allreduce. * To compile, make sure you have MPI and NCCL installed in your system. - * export MPI_HOME=XXX + * export MPI_HOME=xxx * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o - * custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi + * custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi * * Warning: this C++ test is not designed to be very readable and was used * during the rapid prototyping process. * * To run: - * mpirun -np 8 ./custom_all_reduce_test + * mpirun --allow-run-as-root -np 8 ./custom_all_reduce_test */ #include #include @@ -44,7 +44,14 @@ } while (0) __global__ void dummy_kernel() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms +#else + for (int i = 0; i < 100; i++) { + long long int start = clock64(); + while (clock64() - start < 150000000); // approximately 98.4ms on P40 + } +#endif } template @@ -302,15 +309,19 @@ int main(int argc, char** argv) { bool performance_test = true; cudaProfilerStart(); - // for (int threads : {256, 512}) { + // Uncomment to scan through different block size configs. + // for (int threads : {256, 512, 1024}) { // for (int block_limit = 16; block_limit < 112; block_limit += 4) { - // run(myRank, nRanks, comm, threads, block_limit, 4096 * 1024); + // run(myRank, nRanks, comm, threads, block_limit, 1024 * 1024, + // performance_test); // } // } + // Scan through different sizes to test performance. for (int sz = 512; sz <= (8 << 20); sz *= 2) { run(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test); } cudaProfilerStop(); + MPICHECK(MPI_Finalize()); return EXIT_SUCCESS; } diff --git a/csrc/cutlass_extensions/torch_utils.hpp b/csrc/cutlass_extensions/torch_utils.hpp index 1618a340ce10..2c78572521ee 100644 --- a/csrc/cutlass_extensions/torch_utils.hpp +++ b/csrc/cutlass_extensions/torch_utils.hpp @@ -68,7 +68,13 @@ static inline auto make_cute_layout(torch::Tensor const& tensor, name, ".stride(", idx, ") to be ", StrideEle::value); return StrideEle{}; } else { - return tensor.stride(idx); + if (tensor.size(idx) == 1) { + // use 0 stride for dim with size 1, this is easier for + // cute/cutlass to optimize (helps the TMA code flatten dims) + return StrideEle{0}; + } else { + return tensor.stride(idx); + } } } else { // Extra strides are assumed to be 0 or 1 diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel.h b/csrc/moe/marlin_kernels/marlin_moe_kernel.h new file mode 100644 index 000000000000..0bd3017226c9 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel.h @@ -0,0 +1,1425 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include + +#include "core/scalar_type.hpp" + +namespace marlin_moe { + +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, + FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline FragB dequant(int q); + +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +template <> +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 +// Reference: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Given 2 floats multiply by 2 scales (halves) +__device__ inline void scale_float(float* c, FragS& s) { + __half* s_ptr = reinterpret_cast<__half*>(&s); + c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +} + +// Same as above, but for act_order (each K is multiplied individually) +__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, + FragS& frag_s_3, FragS& frag_s_4, int i) { + __half2 s_val_1_2; + s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; + + __half2 s_val_3_4; + s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__device__ inline void MarlinMoESingle( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block // current m block to start kernel computation from +) { + static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + constexpr int pack_factor = 32 / w_type.size_bits(); + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + sorted_ids += 16 * thread_m_blocks; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + constexpr int sorted_sh_stride = threads; + constexpr int sorted_gl_stride = threads; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + int shs_size; + if constexpr (has_act_order) + shs_size = sh_max_num_groups * s_sh_stride + threads; + else + shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_s = sh_g_idx + (stages * g_idx_stage); + int* sh_sorted = (int*)(sh_s + shs_size); + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_sh_wr_delta * i + a_sh_wr; + int row = a_idx / a_gl_rd_delta_o; + if (row >= prob_m) { + a_sh_wr_pred[i] = false; + } else { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; + int row = a_idx / a_gl_stride; + int sorted_row = + replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; + int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; + if (sorted_row < tot_m * (replicate_input ? 1 : topk) && + new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], + a_sh_wr_pred[i]); + } + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // TODO we are currently hitting illegal memory accesses when fetching + // sorted_ids to shared data: fix this + auto fetch_sorted_ids_to_shared = [&]() { + const int mpt = ceildiv(prob_m, threads); + for (int i = 0; i < mpt; i++) { + if ((i * sorted_gl_stride) + threadIdx.x < prob_m) { + sh_sorted[(i * sorted_sh_stride) + threadIdx.x] = + sorted_ids[(i * sorted_gl_stride) + threadIdx.x]; + } + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant_0, b_quant_1; + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k % 2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + FragB frag_b0 = dequant(b_quant_0); + FragB frag_b1 = dequant(b_quant_1); + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int sorted_row = sorted_ids[c_idx / c_gl_stride]; + int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], + sorted_row < tot_m * topk && + (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk))); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + __half2float(reinterpret_cast<__half*>(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int row = sorted_ids[c_idx / c_gl_stride]; + if (row < tot_m * topk) { + int new_idx = row * c_gl_stride + c_idx % c_gl_stride; + C[new_idx] = c; + } + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4) { + res = __hmul2(res, s[0]); + } + + ((half2*)sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + int row = sorted_ids[c_gl_wr / c_gl_stride]; + if (row < tot_m * topk) { + int off = row * c_gl_stride + c_gl_wr % c_gl_stride; + if (!apply_weights) { + C[off] = sh[c_sh_rd]; + } else { + __half* ctrg = reinterpret_cast<__half*>(&C[off]); + __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); + for (int j = 0; j < 8; ++j) { + ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); + } + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + // TODO re-enable after fixing this function + // fetch_sorted_ids_to_shared(); + // __syncthreads(); + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } else { + // For 4-bit per-column scales, we only fetch them here in the + // final step before write-out + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } + start_pipes(); + } + } + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par, // maximum parallelism + int cfg_max_m_blocks // upper bound on m blocks +) { + int m_block_ctr = current_m_block; + + const int* sorted_ids_expert = + sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; + int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; + if (tot_its == 0) { + return; + } + int tot_m_blocks = ceildiv(tot_its, 16); + int pad = 16 * tot_m_blocks - tot_its; + + if (m_block_ctr >= tot_m_blocks) { + return; + } + + int max_block = tot_m_blocks - m_block_ctr; + prob_m = tot_its - 16 * m_block_ctr; + + int par = 1; + if (max_block > cfg_max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * cfg_max_m_blocks) * par; + m_block_ctr += cfg_max_m_blocks * (par - 1); + max_block = cfg_max_m_blocks; + } + + if (max_block == 1) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 2) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 3) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } +} + +#else + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par, // maximum parallelism + int cfg_max_m_blocks // upper bound on m blocks + +) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +#endif + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +const int USER_THREADS = + 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory +// const int SHARED_MEM = +// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ + GROUP_BLOCKS, NUM_THREADS) \ + else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + cfg_max_m_blocks); \ + } + +#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu new file mode 100644 index 000000000000..cbafd9ffe747 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu @@ -0,0 +1,29 @@ +#include "marlin_moe_kernel_ku4b8.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4b8( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, + int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, + int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, + bool replicate_input, bool apply_weights, int m_block, int max_par, + int cfg_max_m_blocks) { + if (false) { + } + GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) + else { + return false; + } + return true; +} + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h new file mode 100644 index 000000000000..9eacb42c115f --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h @@ -0,0 +1,20 @@ +#pragma once + +#include "marlin_moe_kernel.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4b8( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, + int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, + int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, + bool replicate_input, bool apply_weights, int m_block, int max_par, + int cfg_max_m_blocks); + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu new file mode 100644 index 000000000000..c46712474f71 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu @@ -0,0 +1,29 @@ +#include "marlin_moe_kernel_ku8b128.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku8b128( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, + int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, + int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, + bool replicate_input, bool apply_weights, int m_block, int max_par, + int cfg_max_m_blocks) { + if (false) { + } + GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) + else { + return false; + } + return true; +} + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h new file mode 100644 index 000000000000..7cd9acafb3b8 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h @@ -0,0 +1,18 @@ +#pragma once + +#include "marlin_moe_kernel.h" + +namespace marlin_moe { + +bool call_marlin_moe_kernel_ku8b128( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr, + int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts, + int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks, + bool replicate_input, bool apply_weights, int m_block, int max_par, + int cfg_max_m_blocks); + +} diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 293a6fad72c2..dfe043741401 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -26,6 +26,8 @@ #include #include "core/scalar_type.hpp" +#include "marlin_kernels/marlin_moe_kernel_ku4b8.h" +#include "marlin_kernels/marlin_moe_kernel_ku8b128.h" template inline std::string str(T x) { @@ -34,230 +36,8 @@ inline std::string str(T x) { namespace marlin_moe { -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -using I4 = Vec; - -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales - -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -template -__device__ inline FragB dequant(int q); - -// Efficiently dequantize 4bit values packed in an int32 value into a full -// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, -// with some small changes: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -template <> -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); - int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 -// Reference: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -template <> -__device__ inline FragB dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -// Given 2 floats multiply by 2 scales (halves) -__device__ inline void scale_float(float* c, FragS& s) { - __half* s_ptr = reinterpret_cast<__half*>(&s); - c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); -} - -// Same as above, but for act_order (each K is multiplied individually) -__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, - FragS& frag_s_3, FragS& frag_s_4, int i) { - __half2 s_val_1_2; - s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; - - __half2 s_val_3_4; - s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; - - frag_b[0] = __hmul2(frag_b[0], s_val_1_2); - frag_b[1] = __hmul2(frag_b[1], s_val_3_4); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - // For a given "a" of size [M,K] performs a permutation of the K columns based // on the given "perm" indices. __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, @@ -335,1106 +115,6 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, __syncthreads(); } -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__device__ inline void MarlinMoESingle( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block // current m block to start kernel computation from -) { - static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); - constexpr int pack_factor = 32 / w_type.size_bits(); - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - - if constexpr (!has_act_order && group_blocks != -1) { - if (group_blocks >= thread_k_blocks) { - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts - // in the middle of group. - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - } - } - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - sorted_ids += 16 * thread_m_blocks; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = - !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks - : 1; - constexpr int s_sh_stage = s_tb_groups * s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - constexpr int sorted_sh_stride = threads; - constexpr int sorted_gl_stride = threads; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; - int slice_k_start = tb_k * slice_row; - int slice_k_finish = slice_k_start + tb_k * slice_iters; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd; - if constexpr (!has_act_order) { - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - } - int s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - int s_sh_rd; - if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - int shs_size; - if constexpr (has_act_order) - shs_size = sh_max_num_groups * s_sh_stride + threads; - else - shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_s = sh_g_idx + (stages * g_idx_stage); - int* sh_sorted = (int*)(sh_s + shs_size); - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_sh_wr_delta * i + a_sh_wr; - int row = a_idx / a_gl_rd_delta_o; - if (row >= prob_m) { - a_sh_wr_pred[i] = false; - } else { - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - } - } - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; - int row = a_idx / a_gl_stride; - int sorted_row = - replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; - int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; - if (sorted_row < tot_m * (replicate_input ? 1 : topk) && - new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { - cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], - a_sh_wr_pred[i]); - } - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - B_ptr[i] += b_gl_rd_delta_o; - } - - if constexpr (has_act_order) { - // Fetch g_idx thread-block portion - int full_pipe = a_off; - int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; - if (cur_k < prob_k && cur_k < slice_k_finish) { - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - - int4 const* cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); - - if (threadIdx.x < g_idx_stage) { - cp_async4_pred(&sh_g_idx_stage[threadIdx.x], - &cur_g_idx_stage_ptr[threadIdx.x]); - } - } - } else { - if constexpr (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // TODO we are currently hitting illegal memory accesses when fetching - // sorted_ids to shared data: fix this - auto fetch_sorted_ids_to_shared = [&]() { - const int mpt = ceildiv(prob_m, threads); - for (int i = 0; i < mpt; i++) { - if ((i * sorted_gl_stride) + threadIdx.x < prob_m) { - sh_sorted[(i * sorted_sh_stride) + threadIdx.x] = - sorted_ids[(i * sorted_gl_stride) + threadIdx.x]; - } - } - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - if constexpr (!has_act_order) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - int group_id_1 = sh_g_idx_int_ptr[0]; - int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - - is_same_group[pipe] = group_id_1 == group_id_2; - same_group_id[pipe] = group_id_1; - }; - - auto fetch_scales_to_registers = [&](int k, int full_pipe) { - int pipe = full_pipe % stages; - - if constexpr (!has_act_order) { - // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; - - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } - } - - return; - } - - // Act-order case - - // Determine K of the "current" thread-block - int cur_k = slice_k_start + tb_k * full_pipe; - if (cur_k >= prob_k || cur_k >= slice_k_finish) { - return; - } - - // Reset (to current thread-block) since we read g_idx portion from the - // shared memory - cur_k = 0; - - // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); - - // Determine "position" inside the thread-block (based on warp and - // thread-id) - int warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; - - int th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - - int s_col_shift = - /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + - (th_id / 4) * act_s_col_stride; - - if (is_same_group[pipe]) { - if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + - s_col_shift]; - } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); - } - - for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); - } - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread - - #pragma unroll - for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; - - int group_id = sh_g_idx_int_ptr[actual_k]; - int rel_group_id = group_id - sh_first_group_id; - - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - sh_s[rel_group_id * s_sh_stride + s_col_shift]; - } - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { - b_quant_0 = frag_b_quant[k % 2][0][j]; - b_quant_1 = b_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - } - - FragB frag_b0 = dequant(b_quant_0); - FragB frag_b1 = dequant(b_quant_1); - - // Apply scale to frag_b0 - if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); - } else { - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - } - - // Apply scale to frag_b1 - if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); - - } else { - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int sorted_row = sorted_ids[c_idx / c_gl_stride]; - int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], - sorted_row < tot_m * topk && - (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk))); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half*>(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int row = sorted_ids[c_idx / c_gl_stride]; - if (row < tot_m * topk) { - int new_idx = row * c_gl_stride + c_idx % c_gl_stride; - C[new_idx] = c; - } - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - - // For per-column quantization we finally apply the scale here (only for - // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4) { - res = __hmul2(res, s[0]); - } - - ((half2*)sh)[idx] = res; - }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - int row = sorted_ids[c_gl_wr / c_gl_stride]; - if (row < tot_m * topk) { - int off = row * c_gl_stride + c_gl_wr % c_gl_stride; - if (!apply_weights) { - C[off] = sh[c_sh_rd]; - } else { - __half* ctrg = reinterpret_cast<__half*>(&C[off]); - __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); - for (int j = 0; j < 8; ++j) { - ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); - } - } - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - // TODO re-enable after fixing this function - // fetch_sorted_ids_to_shared(); - // __syncthreads(); - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - if (has_act_order && i == 0) { - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); - } - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - fetch_scales_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - if constexpr (has_act_order) { - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); - } - } - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } else { - // For 4-bit per-column scales, we only fetch them here in the - // final step before write-out - if (last) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } - } - } - - thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - } - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - if constexpr (has_act_order) { - slice_k_start = tb_k * slice_row; - slice_k_finish = slice_k_start + tb_k * slice_iters; - slice_k_start_shared_fetch = slice_k_start; - slice_n_offset = act_s_col_tb_stride * slice_col; - - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } - start_pipes(); - } - } - } -} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par, // maximum parallelism - int cfg_max_m_blocks // upper bound on m blocks -) { - int m_block_ctr = current_m_block; - - const int* sorted_ids_expert = - sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; - int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; - if (tot_its == 0) { - return; - } - int tot_m_blocks = ceildiv(tot_its, 16); - int pad = 16 * tot_m_blocks - tot_its; - - if (m_block_ctr >= tot_m_blocks) { - return; - } - - int max_block = tot_m_blocks - m_block_ctr; - prob_m = tot_its - 16 * m_block_ctr; - - int par = 1; - if (max_block > cfg_max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); - if (par > max_par) par = max_par; - prob_m = (16 * cfg_max_m_blocks) * par; - m_block_ctr += cfg_max_m_blocks * (par - 1); - max_block = cfg_max_m_blocks; - } - - if (max_block == 1) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 2) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 3) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } -} - #else __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, @@ -1454,81 +134,8 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids, return; } -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par, // maximum parallelism - int cfg_max_m_blocks // upper bound on m blocks - -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - #endif -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory -// const int SHARED_MEM = -// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ - GROUP_BLOCKS, NUM_THREADS) \ - else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ - num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ - <<>>( \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - exec_cfg.max_m_blocks); \ - } - typedef struct { int thread_k; int thread_n; @@ -1703,25 +310,27 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, return exec_config_t{0, {-1, -1, -1}}; } -#define CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ - \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) - -void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, - const void* sorted_ids, const void* topk_weights, - const void* topk_ids, const void* s, const void* g_idx, - const void* perm, void* a_tmp, void* expert_offsets, - int prob_m, int prob_n, int prob_k, void* workspace, - vllm::ScalarType const& q_type, bool has_act_order, - bool is_k_full, int num_groups, int group_size, - int num_experts, int topk, int moe_block_size, int dev, - cudaStream_t stream, int thread_k, int thread_n, - int sms, int max_par, bool replicate_input, - bool apply_weights) { +#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ + else if (KERNEL_FUNCTION(q_type, thread_n_blocks, thread_k_blocks, \ + has_act_order, group_blocks, num_threads, blocks, \ + max_shared_mem, stream, A_ptr, B_ptr, C_ptr, \ + sorted_ids_ptr, topk_weights_ptr, s_ptr, g_idx_ptr, \ + expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, \ + locks, replicate_input, apply_weights, m_block, \ + max_par, exec_cfg.max_m_blocks)) { \ + } + +void marlin_mm_moe(const void* A, const void* B, void* C, + const void* sorted_ids, const void* topk_weights, + const void* topk_ids, const void* s, const void* g_idx, + const void* perm, void* a_tmp, void* expert_offsets, + int prob_m, int prob_n, int prob_k, void* workspace, + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, int num_groups, int group_size, + int num_experts, int topk, int moe_block_size, int dev, + cudaStream_t stream, int thread_k, int thread_n, int sms, + int max_par, bool replicate_input, bool apply_weights) { TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); @@ -1845,26 +454,16 @@ void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, int tot_m_blocks = ceildiv(tot_m, 16); for (int m_block = 0; m_block < tot_m_blocks; m_block += 4 * exec_cfg.max_m_blocks) { - // make it max possible value - int thread_m_blocks = exec_cfg.max_m_blocks; - if (false) { } - CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) - CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) - CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) - CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) - CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) - CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) - CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) - CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) else { TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + str(prob_n) + ", " + str(prob_k) + "]" + ", has_act_order = " + str(has_act_order) + ", num_groups = " + str(num_groups) + ", group_size = " + str(group_size) + - ", thread_m_blocks = " + str(thread_m_blocks) + ", thread_n_blocks = " + str(thread_n_blocks) + ", thread_k_blocks = " + str(thread_k_blocks)); } @@ -1943,7 +542,7 @@ torch::Tensor marlin_gemm_moe( } } - marlin_moe::marlin_mm_moe_f16i4( + marlin_moe::marlin_mm_moe( a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), diff --git a/csrc/ops.h b/csrc/ops.h index 15e9ebe87408..7ad0abd46c82 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -113,6 +113,8 @@ torch::Tensor prepack_B(torch::Tensor const& B, }; // namespace machete +torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); + torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor& b_meta, torch::Tensor& b_scales, diff --git a/csrc/permute_cols.cu b/csrc/permute_cols.cu new file mode 100644 index 000000000000..f51fa73298cc --- /dev/null +++ b/csrc/permute_cols.cu @@ -0,0 +1,88 @@ +#include + +#include +#include + +#include + +static constexpr int default_threads = 256; +static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +// Currently only supports 16bit types (since we permute half types) +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = std::max(finish_row - start_row, 0); + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int offset = row * row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); + half* out_half = reinterpret_cast(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +// More efficient version of A[..., perm] +// taken from gptq_marlin.cu +torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + auto dev = A.get_device(); + auto stream = at::cuda::getCurrentCUDAStream(dev); + + TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16, + "Currently only 16bit types are supported"); + TORCH_CHECK(A.is_contiguous(), "A must be contiguous"); + TORCH_CHECK(A.size(-1) % 8 == 0, + "A columns must be a multiple of 8 (128bits)"); + auto A_2d = A.view({-1, A.size(-1)}); + + torch::Tensor D = torch::empty_like(A); + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + int block_rows = div_ceil(A_2d.size(0), sms); + permute_cols_kernel<<>>( + reinterpret_cast(A_2d.const_data_ptr()), + perm.const_data_ptr(), reinterpret_cast(D.mutable_data_ptr()), + A_2d.size(0), A_2d.size(1), block_rows); + return D; +} \ No newline at end of file diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 09a98a5dd1fd..8ed81ea727aa 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -157,7 +157,7 @@ TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative -@dataclass +@dataclass(frozen=True) class ScheduleConfig: tile_shape_mn: Tuple[int, int] cluster_shape_mnk: Tuple[int, int, int] @@ -328,56 +328,137 @@ def generate(): # about how this works SCRIPT_DIR = os.path.dirname(__file__) - schedules = [ - ScheduleConfig( - tile_shape_mn=tile_shape_mn, - cluster_shape_mnk=cluster_shape_mnk, - kernel_schedule=kernel_schedule, - epilogue_schedule=epilogue_schedule, - tile_scheduler=tile_scheduler, - ) for tile_shape_mn, cluster_shape_mnk in ( - ((128, 16), (1, 1, 1)), - ((128, 32), (1, 1, 1)), - ((128, 64), (1, 1, 1)), - ((128, 128), (1, 1, 1)), - ) for kernel_schedule in (TmaMI, ) for epilogue_schedule in (TmaCoop, ) - for tile_scheduler in (TileSchedulerType.StreamK, ) - ] + schedule_common_params = dict( + kernel_schedule=TmaMI, + epilogue_schedule=TmaCoop, + tile_scheduler=TileSchedulerType.StreamK, + ) # For now we use the same heuristic for all types + # Heuristic is currently tuned for H100s default_heuristic = [ - ("M > 64", - ScheduleConfig( - tile_shape_mn=(128, 128), - cluster_shape_mnk=(1, 1, 1), - kernel_schedule=TmaMI, - epilogue_schedule=TmaCoop, - tile_scheduler=TileSchedulerType.StreamK, - )), - ("M > 32", - ScheduleConfig( - tile_shape_mn=(128, 64), - cluster_shape_mnk=(1, 1, 1), - kernel_schedule=TmaMI, - epilogue_schedule=TmaCoop, - tile_scheduler=TileSchedulerType.StreamK, - )), - ("M > 16", - ScheduleConfig( - tile_shape_mn=(128, 32), - cluster_shape_mnk=(1, 1, 1), - kernel_schedule=TmaMI, - epilogue_schedule=TmaCoop, - tile_scheduler=TileSchedulerType.StreamK, - )), - (None, - ScheduleConfig(tile_shape_mn=(128, 16), - cluster_shape_mnk=(1, 1, 1), - kernel_schedule=TmaMI, - epilogue_schedule=TmaCoop, - tile_scheduler=TileSchedulerType.StreamK)) + #### M = 257+ + ( + "M > 256 && K <= 16384 && N <= 4096", + ScheduleConfig( + tile_shape_mn=(128, 128), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 256", + ScheduleConfig( + tile_shape_mn=(128, 256), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 129-256 + ( + "M > 128 && K <= 4096 && N <= 4096", + ScheduleConfig( + tile_shape_mn=(128, 64), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 128 && K <= 8192 && N <= 8192", + ScheduleConfig( + tile_shape_mn=(128, 128), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 128", + ScheduleConfig( + tile_shape_mn=(128, 256), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 65-128 + ( + "M > 64 && K <= 4069 && N <= 4069", + ScheduleConfig( + tile_shape_mn=(128, 32), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 64 && K <= 4069 && N <= 8192", + ScheduleConfig( + tile_shape_mn=(128, 64), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 64 && K >= 8192 && N >= 12288", + ScheduleConfig( + tile_shape_mn=(256, 128), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 64", + ScheduleConfig( + tile_shape_mn=(128, 128), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 33-64 + ( + "M > 32 && K <= 6144 && N <= 6144", + ScheduleConfig( + tile_shape_mn=(128, 16), + cluster_shape_mnk=(1, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 32 && K >= 16384 && N >= 12288", + ScheduleConfig( + tile_shape_mn=(256, 64), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 32", + ScheduleConfig( + tile_shape_mn=(128, 64), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 17-32 + ( + "M > 16 && K <= 12288 && N <= 8192", + ScheduleConfig( + tile_shape_mn=(128, 32), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + ( + "M > 16", + ScheduleConfig( + tile_shape_mn=(256, 32), + cluster_shape_mnk=(2, 1, 1), + **schedule_common_params # type: ignore + )), + #### M = 1-16 + ( + "N >= 26624", + ScheduleConfig( + tile_shape_mn=(256, 16), + cluster_shape_mnk=(1, 1, 1), + **schedule_common_params # type: ignore + )), + ( + None, + ScheduleConfig( + tile_shape_mn=(128, 16), + cluster_shape_mnk=(1, 1, 1), + **schedule_common_params # type: ignore + )), ] + schedules = list(set([x[1] for x in default_heuristic])) + impl_configs = [] GPTQ_kernel_type_configs = list( diff --git a/csrc/quantization/machete/machete_mm_kernel.cuh b/csrc/quantization/machete/machete_mm_kernel.cuh index 046e6e5a5365..4d41b8d29148 100644 --- a/csrc/quantization/machete/machete_mm_kernel.cuh +++ b/csrc/quantization/machete/machete_mm_kernel.cuh @@ -152,7 +152,8 @@ struct MacheteKernelTemplate { int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A); - int const group_size = maybe_group_size.value_or(K); + int const group_size = + maybe_group_size == -1 ? K : maybe_group_size.value_or(K); int const scale_k = (K + group_size - 1) / group_size; TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K); diff --git a/csrc/quantization/machete/machete_mm_launcher.cuh b/csrc/quantization/machete/machete_mm_launcher.cuh index e2604d4bed3e..60a4ed60535b 100644 --- a/csrc/quantization/machete/machete_mm_launcher.cuh +++ b/csrc/quantization/machete/machete_mm_launcher.cuh @@ -71,7 +71,7 @@ torch::Tensor run_impl(PyTorchArguments args) { auto arguments = MacheteKernel::create_arguments( stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr, layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0), - args.group_size.value_or(K)); + args.group_size); TORCH_CHECK(MacheteKernel::can_implement(arguments), "Machete kernel cannot be run with these arguments"); diff --git a/csrc/quantization/machete/machete_prepack_launcher.cuh b/csrc/quantization/machete/machete_prepack_launcher.cuh index 686dd68bd52b..df78312997fb 100644 --- a/csrc/quantization/machete/machete_prepack_launcher.cuh +++ b/csrc/quantization/machete/machete_prepack_launcher.cuh @@ -53,7 +53,7 @@ torch::Tensor prepack_impl(torch::Tensor const B) { // clang-format on // Allocate output - torch::Tensor D = torch::empty_like(B); + torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous); prepack_B(stream, B_ptr, layout_Bt, static_cast(D.mutable_data_ptr())); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 045203c3de8a..4b374af5ae24 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -192,6 +192,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "-> Tensor"); ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B); + ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); + ops.impl("permute_cols", torch::kCUDA, &permute_cols); + // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def( "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index e112b43aade5..241b2ccd0991 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -8,7 +8,7 @@ Multi-Modality vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models ` -via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`. +via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`. Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities by following :ref:`this guide `. diff --git a/docs/source/dev/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst index 0d47281db485..9adf82d43f3e 100644 --- a/docs/source/dev/offline_inference/llm_inputs.rst +++ b/docs/source/dev/offline_inference/llm_inputs.rst @@ -1,7 +1,7 @@ LLM Inputs ========== -.. autodata:: vllm.inputs.PromptType +.. autodata:: vllm.inputs.PromptInputs .. autoclass:: vllm.inputs.TextPrompt :show-inheritance: diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index d169fe676dc9..4ed0bfe70071 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -3,15 +3,17 @@ Installation with ROCm ====================== -vLLM supports AMD GPUs with ROCm 6.1. +vLLM supports AMD GPUs with ROCm 6.2. Requirements ------------ * OS: Linux -* Python: 3.8 -- 3.11 +* Python: 3.9 -- 3.12 * GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100) -* ROCm 6.1 +* ROCm 6.2 + +Note: PyTorch 2.5+/ROCm6.2 dropped the support for python 3.8. Installation options: @@ -27,7 +29,7 @@ You can build and install vLLM from source. First, build a docker image from `Dockerfile.rocm `_ and launch a docker container from the image. -`Dockerfile.rocm `_ uses ROCm 6.1 by default, but also supports ROCm 5.7 and 6.0 in older vLLM branches. +`Dockerfile.rocm `_ uses ROCm 6.2 by default, but also supports ROCm 5.7, 6.0 and 6.1 in older vLLM branches. It provides flexibility to customize the build of docker image using the following arguments: * `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. @@ -39,13 +41,13 @@ It provides flexibility to customize the build of docker image using the followi Their values can be passed in when running ``docker build`` with ``--build-arg`` options. -To build vllm on ROCm 6.1 for MI200 and MI300 series, you can use the default: +To build vllm on ROCm 6.2 for MI200 and MI300 series, you can use the default: .. code-block:: console $ DOCKER_BUILDKIT=1 docker build -f Dockerfile.rocm -t vllm-rocm . -To build vllm on ROCm 6.1 for Radeon RX7900 series (gfx1100), you should specify ``BUILD_FA`` as below: +To build vllm on ROCm 6.2 for Radeon RX7900 series (gfx1100), you should specify ``BUILD_FA`` as below: .. code-block:: console @@ -79,9 +81,8 @@ Option 2: Build from source - `ROCm `_ - `PyTorch `_ -- `hipBLAS `_ -For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging`, `rocm/pytorch-nightly`. +For installing PyTorch, you can start from a fresh docker image, e.g, `rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0`, `rocm/pytorch-nightly`. Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTorch installation guide in PyTorch `Getting Started `_ @@ -90,26 +91,45 @@ Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTor Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from `ROCm/triton `_ + .. code-block:: console + + $ python3 -m pip install ninja cmake wheel pybind11 + $ pip uninstall -y triton + $ git clone https://github.com/OpenAI/triton.git + $ cd triton + $ git checkout e192dba + $ cd python + $ pip3 install . + $ cd ../.. + +.. note:: + - If you see HTTP issue related to downloading packages during building triton, please try again as the HTTP error is intermittent. + + 2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm `_ + Install ROCm's flash attention (v2.5.9.post1) following the instructions from `ROCm/flash-attention `_ Alternatively, wheels intended for vLLM use can be accessed under the releases. -.. note:: - - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) +For example, for ROCm 6.2, suppose your gfx arch is `gfx90a`. +Note to get your gfx architecture, run `rocminfo |grep gfx`. -3. Build vLLM. - -.. code-block:: console + .. code-block:: console - $ cd vllm - $ pip install -U -r requirements-rocm.txt - $ python setup.py develop # This may take 5-10 minutes. Currently, `pip install .` does not work for ROCm installation + $ git clone https://github.com/ROCm/flash-attention.git + $ cd flash-attention + $ git checkout 3cea2fb + $ git submodule update --init + $ GPU_ARCHS="gfx90a" python3 setup.py install + $ cd .. +.. note:: + - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) -.. tip:: +3. Build vLLM. - For example, vLLM v0.5.3 on ROCM 6.1 can be built with the following steps: + For example, vLLM on ROCM 6.2 can be built with the following steps: .. code-block:: console @@ -117,7 +137,7 @@ Alternatively, wheels intended for vLLM use can be accessed under the releases. $ # Install PyTorch $ pip uninstall torch -y - $ pip install --no-cache-dir --pre torch==2.5.0.dev20240726 --index-url https://download.pytorch.org/whl/nightly/rocm6.1 + $ pip install --no-cache-dir --pre torch==2.6.0.dev20240918 --index-url https://download.pytorch.org/whl/nightly/rocm6.2 $ # Build & install AMD SMI $ pip install /opt/rocm/share/amd_smi @@ -127,15 +147,14 @@ Alternatively, wheels intended for vLLM use can be accessed under the releases. $ pip install "numpy<2" $ pip install -r requirements-rocm.txt - $ # Apply the patch to ROCM 6.1 (requires root permission) - $ wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P /opt/rocm/lib - $ rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* - $ # Build vLLM for MI210/MI250/MI300. $ export PYTORCH_ROCM_ARCH="gfx90a;gfx942" $ python3 setup.py develop + This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation + + .. tip:: - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index 816e0a29ef28..c8947beb3494 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -56,7 +56,7 @@ Build from source .. code-block:: console $ pip install --upgrade pip - $ pip install wheel packaging ninja "setuptools>=49.4.0" numpy + $ pip install cmake>=3.26 wheel packaging ninja "setuptools-scm>=8" numpy $ pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu - Third, build and install oneDNN library from source: diff --git a/docs/source/getting_started/xpu-installation.rst b/docs/source/getting_started/xpu-installation.rst index a0118e20c49d..151ebb5f1811 100644 --- a/docs/source/getting_started/xpu-installation.rst +++ b/docs/source/getting_started/xpu-installation.rst @@ -17,8 +17,8 @@ Requirements ------------ * OS: Linux -* Supported Hardware: Intel Data Center GPU (Intel ARC GPU WIP) -* OneAPI requirements: oneAPI 2024.1 +* Supported Hardware: Intel Data Center GPU, Intel ARC GPU +* OneAPI requirements: oneAPI 2024.2 .. _xpu_backend_quick_start_dockerfile: @@ -40,7 +40,7 @@ Quick start using Dockerfile Build from source ----------------- -- First, install required driver and intel OneAPI 2024.1 or later. +- First, install required driver and intel OneAPI 2024.2 or later. - Second, install Python packages for vLLM XPU backend building: diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index ca5b125369c8..08db89166504 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model. -To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`: +To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`: * ``prompt``: The prompt should follow the format that is documented on HuggingFace. * ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. diff --git a/docs/source/quantization/bnb.rst b/docs/source/quantization/bnb.rst index aefb54a8acb6..682938cc63d4 100644 --- a/docs/source/quantization/bnb.rst +++ b/docs/source/quantization/bnb.rst @@ -11,7 +11,7 @@ Below are the steps to utilize BitsAndBytes with vLLM. .. code-block:: console - $ pip install bitsandbytes>=0.42.0 + $ pip install bitsandbytes>=0.44.0 vLLM reads the model's config file and supports both in-flight quantization and pre-quantized checkpoint. diff --git a/examples/lora_with_quantization_inference.py b/examples/lora_with_quantization_inference.py index 3b2347c1115e..0c454ea50f66 100644 --- a/examples/lora_with_quantization_inference.py +++ b/examples/lora_with_quantization_inference.py @@ -79,23 +79,17 @@ def initialize_engine(model: str, quantization: str, # It quantizes the model when loading, with some config info from the # LoRA adapter repo. So need to set the parameter of load_format and # qlora_adapter_name_or_path as below. - engine_args = EngineArgs( - model=model, - quantization=quantization, - qlora_adapter_name_or_path=lora_repo, - load_format="bitsandbytes", - enable_lora=True, - max_lora_rank=64, - # set it only in GPUs of limited memory - enforce_eager=True) + engine_args = EngineArgs(model=model, + quantization=quantization, + qlora_adapter_name_or_path=lora_repo, + load_format="bitsandbytes", + enable_lora=True, + max_lora_rank=64) else: - engine_args = EngineArgs( - model=model, - quantization=quantization, - enable_lora=True, - max_loras=4, - # set it only in GPUs of limited memory - enforce_eager=True) + engine_args = EngineArgs(model=model, + quantization=quantization, + enable_lora=True, + max_loras=4) return LLMEngine.from_engine_args(engine_args) diff --git a/examples/offline_inference_chat.py b/examples/offline_inference_chat.py index c2020724c72f..8814f4d7bef0 100644 --- a/examples/offline_inference_chat.py +++ b/examples/offline_inference_chat.py @@ -39,6 +39,33 @@ def print_outputs(outputs): use_tqdm=False) print_outputs(outputs) +# You can run batch inference with llm.chat API +conversation = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "Hello" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "Write an essay about the importance of higher education.", + }, +] +conversations = [conversation for _ in range(10)] + +# We turn on tqdm progress bar to verify it's indeed running batch inference +outputs = llm.chat(messages=conversations, + sampling_params=sampling_params, + use_tqdm=True) +print_outputs(outputs) + # A chat template can be optionally supplied. # If not, the model will use its default chat template. diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index de0947115bbf..d4cb5e285a87 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -83,10 +83,24 @@ def run_phi3v(question, modality): # In this example, we override max_num_seqs to 5 while # keeping the original context length of 128k. + + # num_crops is an override kwarg to the multimodal image processor; + # For some models, e.g., Phi-3.5-vision-instruct, it is recommended + # to use 16 for single frame scenarios, and 4 for multi-frame. + # + # Generally speaking, a larger value for num_crops results in more + # tokens per image instance, because it may scale the image more in + # the image preprocessing. Some references in the model docs and the + # formula for image tokens after the preprocessing + # transform can be found below. + # + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194 llm = LLM( model="microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True, max_num_seqs=5, + mm_processor_kwargs={"num_crops": 16}, ) stop_token_ids = None return llm, prompt, stop_token_ids diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 92ab4f42baa8..8c5f1a7b7af0 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -67,11 +67,24 @@ def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData: def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: + # num_crops is an override kwarg to the multimodal image processor; + # For some models, e.g., Phi-3.5-vision-instruct, it is recommended + # to use 16 for single frame scenarios, and 4 for multi-frame. + # + # Generally speaking, a larger value for num_crops results in more + # tokens per image instance, because it may scale the image more in + # the image preprocessing. Some references in the model docs and the + # formula for image tokens after the preprocessing + # transform can be found below. + # + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct#loading-the-model-locally + # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/main/processing_phi3_v.py#L194 llm = LLM( model="microsoft/Phi-3.5-vision-instruct", trust_remote_code=True, max_model_len=4096, limit_mm_per_prompt={"image": len(image_urls)}, + mm_processor_kwargs={"num_crops": 4}, ) placeholders = "\n".join(f"<|image_{i}|>" for i, _ in enumerate(image_urls, start=1)) diff --git a/pyproject.toml b/pyproject.toml index 14f0934499c4..c9057b061aad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,8 @@ requires = [ "cmake>=3.26", "ninja", "packaging", - "setuptools >= 49.4.0", + "setuptools>=61", + "setuptools-scm>=8.0", "torch == 2.4.0", "wheel", "jinja2", @@ -19,6 +20,10 @@ exclude = [ "examples/fp8/quantizer/quantize.py" ] +[tool.ruff.lint.per-file-ignores] +"vllm/version.py" = ["F401"] +"vllm/_version.py" = ["ALL"] + [tool.ruff.lint] select = [ # pycodestyle diff --git a/requirements-build.txt b/requirements-build.txt index 3f08f5d67b6d..6144a56da8c4 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -2,7 +2,8 @@ cmake>=3.26 ninja packaging -setuptools>=49.4.0 +setuptools>=61 +setuptools-scm>=8 torch==2.4.0 wheel jinja2 diff --git a/requirements-test.txt b/requirements-test.txt index 10d463de27be..9c6fadb88865 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -30,5 +30,5 @@ datamodel_code_generator # required for minicpm3 test aiohttp # quantization -bitsandbytes==0.42.0 +bitsandbytes>=0.44.0 buildkite-test-collector==0.1.8 diff --git a/requirements-xpu.txt b/requirements-xpu.txt index f07211b48b68..9b21845e084d 100644 --- a/requirements-xpu.txt +++ b/requirements-xpu.txt @@ -3,10 +3,10 @@ setuptools < 70.0.0 # IPEX's torch have some dependency. to be removed. +ray >= 2.9 +# Following pkgs retrieved from https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ torch == 2.3.1+cxx11.abi intel-extension-for-pytorch == 2.3.110+xpu oneccl_bind_pt == 2.3.100+xpu triton-xpu == 3.0.0b2 - ---extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ diff --git a/setup.py b/setup.py index 60e31af0a8d3..8ef759f5245f 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,6 @@ import re import subprocess import sys -import warnings from pathlib import Path from shutil import which from typing import Dict, List @@ -14,6 +13,7 @@ from packaging.version import Version, parse from setuptools import Extension, find_packages, setup from setuptools.command.build_ext import build_ext +from setuptools_scm import get_version from torch.utils.cpp_extension import CUDA_HOME @@ -28,34 +28,6 @@ def load_module_from_path(module_name, path): ROOT_DIR = os.path.dirname(__file__) logger = logging.getLogger(__name__) - -def embed_commit_hash(): - try: - if "BUILDKITE_COMMIT" in os.environ: - # ci build - commit_id = os.environ["BUILDKITE_COMMIT"] - else: - commit_id = subprocess.check_output(["git", "rev-parse", "HEAD"], - encoding="utf-8").strip() - - commit_contents = f'__commit__ = "{commit_id}"\n' - - version_file = os.path.join(ROOT_DIR, "vllm", "commit_id.py") - with open(version_file, "w", encoding="utf-8") as f: - f.write(commit_contents) - - except subprocess.CalledProcessError as e: - warnings.warn(f"Failed to get commit hash:\n{e}", - RuntimeWarning, - stacklevel=2) - except Exception as e: - warnings.warn(f"Failed to embed commit hash:\n{e}", - RuntimeWarning, - stacklevel=2) - - -embed_commit_hash() - # cannot import envs directly because it depends on vllm, # which is not installed yet envs = load_module_from_path('envs', os.path.join(ROOT_DIR, 'vllm', 'envs.py')) @@ -381,52 +353,43 @@ def get_path(*filepath) -> str: return os.path.join(ROOT_DIR, *filepath) -def find_version(filepath: str) -> str: - """Extract version information from the given filepath. - - Adapted from https://github.com/ray-project/ray/blob/0b190ee1160eeca9796bc091e07eaebf4c85b511/python/setup.py - """ - with open(filepath) as fp: - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - fp.read(), re.M) - if version_match: - return version_match.group(1) - raise RuntimeError("Unable to find version string.") - - def get_vllm_version() -> str: - version = find_version(get_path("vllm", "version.py")) + version = get_version( + write_to="vllm/_version.py", # TODO: move this to pyproject.toml + ) + + sep = "+" if "+" not in version else "." # dev versions might contain + if _no_device(): if envs.VLLM_TARGET_DEVICE == "empty": - version += "+empty" + version += f"{sep}empty" elif _is_cuda(): cuda_version = str(get_nvcc_cuda_version()) if cuda_version != MAIN_CUDA_VERSION: cuda_version_str = cuda_version.replace(".", "")[:3] # skip this for source tarball, required for pypi if "sdist" not in sys.argv: - version += f"+cu{cuda_version_str}" + version += f"{sep}cu{cuda_version_str}" elif _is_hip(): # Get the HIP version hipcc_version = get_hipcc_rocm_version() if hipcc_version != MAIN_CUDA_VERSION: rocm_version_str = hipcc_version.replace(".", "")[:3] - version += f"+rocm{rocm_version_str}" + version += f"{sep}rocm{rocm_version_str}" elif _is_neuron(): # Get the Neuron version neuron_version = str(get_neuronxcc_version()) if neuron_version != MAIN_CUDA_VERSION: neuron_version_str = neuron_version.replace(".", "")[:3] - version += f"+neuron{neuron_version_str}" + version += f"{sep}neuron{neuron_version_str}" elif _is_openvino(): - version += "+openvino" + version += f"{sep}openvino" elif _is_tpu(): - version += "+tpu" + version += f"{sep}tpu" elif _is_cpu(): - version += "+cpu" + version += f"{sep}cpu" elif _is_xpu(): - version += "+xpu" + version += f"{sep}xpu" else: raise RuntimeError("Unknown runtime environment") diff --git a/tests/conftest.py b/tests/conftest.py index c2616bcf7091..dcd9afdae3c1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -675,8 +675,6 @@ def generate_w_logprobs( videos: Optional[PromptVideoInput] = None, ) -> Union[List[TokensTextLogprobs], List[TokensTextLogprobsPromptLogprobs]]: - assert sampling_params.logprobs is not None - if images is not None: assert len(prompts) == len(images) @@ -754,7 +752,7 @@ def generate_greedy_logprobs( temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs, - prompt_logprobs=(num_prompt_logprobs), + prompt_logprobs=num_prompt_logprobs, stop_token_ids=stop_token_ids) return self.generate_w_logprobs(prompts, @@ -798,6 +796,20 @@ def generate_beam_search( outputs = self.generate(prompts, beam_search_params) return outputs + def generate_beam_search_new( + self, + prompts: Union[List[str], List[List[int]]], + beam_width: int, + max_tokens: int, + ) -> List[Tuple[List[List[int]], List[str]]]: + outputs = self.model.beam_search(prompts, beam_width, max_tokens) + returned_outputs = [] + for output in outputs: + token_ids = [x.tokens for x in output.sequences] + texts = [x.text for x in output.sequences] + returned_outputs.append((token_ids, texts)) + return returned_outputs + def encode(self, prompts: List[str]) -> List[List[float]]: req_outputs = self.model.encode(prompts) outputs = [] diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 2f6ea632a5d9..9dddd751c785 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -27,16 +27,19 @@ def schedule_and_update_computed_tokens(scheduler): return metas, out -def test_simple(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_simple(use_v2_block_manager: bool): """Verify basic scheduling works.""" block_size = 4 num_seq_group = 4 max_model_len = 16 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - num_seq_group, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + num_seq_group, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -45,7 +48,9 @@ def test_simple(): # Add seq groups to scheduler. for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=block_size, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -69,30 +74,36 @@ def test_simple(): assert len(seq_group_meta) == num_seq_group -def test_chunk(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_chunk(use_v2_block_manager: bool): """Verify prefills are chunked properly.""" block_size = 4 max_seqs = 60 max_model_len = 80 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 32 + cache_config.num_gpu_blocks = 32 scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(2): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) # Verify the second request is chunked. seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + print() assert set(get_sequence_groups(out)) == set(running) assert seq_group_meta[0].token_chunk_size == 60 # Verify it is chunked. @@ -113,24 +124,29 @@ def test_chunk(): assert out.num_batched_tokens == 57 -def test_complex(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_complex(use_v2_block_manager: bool): block_size = 4 max_seqs = 60 max_model_len = 80 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 64 + cache_config.num_gpu_blocks = 64 scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(2): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -151,7 +167,9 @@ def test_complex(): # Add 2 more requests. for i in range(2, 4): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -176,16 +194,19 @@ def test_complex(): assert running[2].is_prefill() -def test_maximal_decoding(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_maximal_decoding(use_v2_block_manager: bool): """Verify decoding requests are prioritized.""" block_size = 4 max_seqs = 2 max_model_len = 8 max_num_batched_tokens = 2 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -194,7 +215,9 @@ def test_maximal_decoding(): # Add seq groups to scheduler. for i in range(2): - _, seq_group = create_dummy_prompt(str(i), prompt_length=2) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=2, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -211,7 +234,9 @@ def test_maximal_decoding(): append_new_token(running[0], 1) # Create one more seq_group. - _, seq_group = create_dummy_prompt("3", prompt_length=2) + _, seq_group = create_dummy_prompt("3", + prompt_length=2, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -263,23 +288,28 @@ def test_maximal_decoding(): assert out.num_batched_tokens == 2 -def test_prompt_limit(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prompt_limit(use_v2_block_manager: bool): """Verify max_num_batched_tokens < max_model_len is possible.""" block_size = 4 max_seqs = 32 max_model_len = 64 max_num_batched_tokens = 32 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] - _, seq_group = create_dummy_prompt("1", prompt_length=48) + _, seq_group = create_dummy_prompt("1", + prompt_length=48, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -293,7 +323,8 @@ def test_prompt_limit(): assert out.num_batched_tokens == 32 -def test_prompt_limit_exceed(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prompt_limit_exceed(use_v2_block_manager: bool): block_size = 4 max_seqs = 64 max_model_len = 32 @@ -303,12 +334,13 @@ def test_prompt_limit_exceed(): max_model_len, enable_chunked_prefill=True) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] - - _, seq_group = create_dummy_prompt("2", prompt_length=48) + _, seq_group = create_dummy_prompt("2", + prompt_length=48, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) assert seq_group.is_prefill() @@ -317,22 +349,28 @@ def test_prompt_limit_exceed(): assert out.ignored_seq_groups[0] == seq_group -def test_swap(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_swap(use_v2_block_manager: bool): """Verify swapping works with chunked prefill requests""" block_size = 4 max_seqs = 30 max_model_len = 200 max_num_batched_tokens = 30 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt("1", + prompt_length=60, + best_of=2, + block_size=block_size) scheduler.add_seq_group(seq_group) _, out = schedule_and_update_computed_tokens(scheduler) # The request is chunked. @@ -369,21 +407,27 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert out.blocks_to_swap_out == [] -def test_running_prefill_prioritized_over_swap(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_running_prefill_prioritized_over_swap(use_v2_block_manager: bool): block_size = 4 max_seqs = 30 max_model_len = 200 max_num_batched_tokens = 30 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 32 + cache_config.num_gpu_blocks = 32 scheduler = Scheduler(scheduler_config, cache_config, None) - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt("1", + prompt_length=60, + best_of=2, + block_size=block_size) scheduler.add_seq_group(seq_group) _, out = schedule_and_update_computed_tokens(scheduler) # The request is chunked. @@ -413,7 +457,9 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): scheduler.block_manager.can_swap_in = MagicMock() scheduler.block_manager.can_swap_in.return_value = AllocStatus.LATER - _, seq_group2 = create_dummy_prompt("2", prompt_length=60) + _, seq_group2 = create_dummy_prompt("2", + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group2) _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 1 @@ -455,22 +501,27 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert out.blocks_to_swap_out == [] -def test_chunked_prefill_preempt(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_chunked_prefill_preempt(use_v2_block_manager: bool): """Verify preempt works with chunked prefill requests""" block_size = 4 max_seqs = 30 max_model_len = 200 max_num_batched_tokens = 30 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) - _, seq_group = create_dummy_prompt("1", prompt_length=60) + _, seq_group = create_dummy_prompt("1", + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) _, out = schedule_and_update_computed_tokens(scheduler) # The request is chunked. @@ -517,22 +568,27 @@ def cannot_append_second_group2(seq_group, num_lookahead_slots): assert out.num_batched_tokens == max_num_batched_tokens -def test_chunked_prefill_max_seqs(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_chunked_prefill_max_seqs(use_v2_block_manager: bool): block_size = 4 max_seqs = 2 max_model_len = 80 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = 128 + cache_config.num_gpu_blocks = 128 scheduler = Scheduler(scheduler_config, cache_config, None) running: List[SequenceGroup] = [] - _, seq_group = create_dummy_prompt("1", prompt_length=65) + _, seq_group = create_dummy_prompt("1", + prompt_length=65, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) # The first prefill is chunked. @@ -542,7 +598,9 @@ def test_chunked_prefill_max_seqs(): # Add new requests. for i in range(4): - _, seq_group = create_dummy_prompt(str(i), prompt_length=65) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=65, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -564,16 +622,19 @@ def test_chunked_prefill_max_seqs(): assert not running[1].is_prefill() -def test_perfix_caching(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_perfix_caching(use_v2_block_manager: bool): """Verify allocating full blocks when prefix caching is enabled.""" block_size = 4 max_seqs = 10 max_model_len = 80 max_num_batched_tokens = 64 - scheduler_config = SchedulerConfig(max_num_batched_tokens, - max_seqs, - max_model_len, - enable_chunked_prefill=True) + scheduler_config = SchedulerConfig( + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 11168d2423b0..88c6c3bb28e4 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -3,7 +3,8 @@ from typing import List, Set, Tuple from unittest.mock import MagicMock -import pytest # noqa +import pytest +from torch import Use # noqa from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus @@ -16,9 +17,11 @@ schedule_and_update_computed_tokens) -def test_scheduler_add_seq_group(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_add_seq_group(use_v2_block_manager: bool): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 1) + scheduler_config = SchedulerConfig( + 100, 64, 1, use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -27,14 +30,18 @@ def test_scheduler_add_seq_group(): # Add seq group to scheduler. num_seq_group = 4 for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), block_size) + _, seq_group = create_dummy_prompt(str(i), + block_size, + block_size=block_size) scheduler.add_seq_group(seq_group) assert scheduler.get_num_unfinished_seq_groups() == i + 1 -def test_scheduler_abort_seq_group(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_abort_seq_group(use_v2_block_manager: bool): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 1) + scheduler_config = SchedulerConfig( + 100, 64, 1, use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 4 cache_config.num_gpu_blocks = 4 @@ -54,11 +61,16 @@ def test_scheduler_abort_seq_group(): assert scheduler.get_num_unfinished_seq_groups() == 0 -def test_scheduler_schedule_simple(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_schedule_simple(use_v2_block_manager: bool): block_size = 4 num_seq_group = 4 max_model_len = 16 - scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len) + scheduler_config = SchedulerConfig( + 64, + num_seq_group, + max_model_len, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -67,7 +79,9 @@ def test_scheduler_schedule_simple(): # Add seq groups to scheduler. for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=block_size, + block_size=block_size) scheduler.add_seq_group(seq_group) running.append(seq_group) @@ -91,20 +105,24 @@ def test_scheduler_schedule_simple(): append_new_token(out, 1) -def test_scheduler_prefill_prioritized(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_prefill_prioritized(use_v2_block_manager: bool): """Verify running batched tokens are not applied to prefill requests.""" block_size = 4 max_model_len = 30 max_batched_num_tokens = 30 - scheduler_config = SchedulerConfig(max_batched_num_tokens, 2, - max_model_len) + scheduler_config = SchedulerConfig( + max_batched_num_tokens, + 2, + max_model_len, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 2 - cache_config.num_gpu_blocks = 2 + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 scheduler = Scheduler(scheduler_config, cache_config, None) # Add seq groups to scheduler. - _, seq_group_a = create_dummy_prompt("1", 1) + _, seq_group_a = create_dummy_prompt("1", 1, block_size=block_size) scheduler.add_seq_group(seq_group_a) # Schedule seq groups prompts. @@ -112,7 +130,7 @@ def test_scheduler_prefill_prioritized(): assert get_sequence_groups(out) == [seq_group_a] # Add a new prefill request B. - _, seq_group_b = create_dummy_prompt("2", 30) + _, seq_group_b = create_dummy_prompt("2", 30, block_size=block_size) scheduler.add_seq_group(seq_group_b) # Verify prefill requests are prioritized. Since max_batched_num_tokens @@ -121,18 +139,24 @@ def test_scheduler_prefill_prioritized(): assert get_sequence_groups(out) == [seq_group_b] -def test_scheduler_schedule_preempt_abort(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_schedule_preempt_abort(use_v2_block_manager: bool): block_size = 4 max_model_len = 16 - scheduler_config = SchedulerConfig(64, 2, max_model_len) + scheduler_config = SchedulerConfig( + 64, 2, max_model_len, use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 2 cache_config.num_gpu_blocks = 2 scheduler = Scheduler(scheduler_config, cache_config, None) # Add seq groups to scheduler. - seq_a, seq_group_a = create_dummy_prompt("1", block_size) - seq_b, seq_group_b = create_dummy_prompt("2", block_size) + seq_a, seq_group_a = create_dummy_prompt("1", + block_size, + block_size=block_size) + seq_b, seq_group_b = create_dummy_prompt("2", + block_size, + block_size=block_size) scheduler.add_seq_group(seq_group_a) scheduler.add_seq_group(seq_group_b) @@ -170,12 +194,17 @@ def test_scheduler_schedule_preempt_abort(): assert scheduler.get_num_unfinished_seq_groups() == 1 -def test_scheduler_max_seqs(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_max_seqs(use_v2_block_manager: bool): block_size = 4 num_seq_group = 4 max_seq_group = 2 max_model_len = 16 - scheduler_config = SchedulerConfig(64, max_seq_group, max_model_len) + scheduler_config = SchedulerConfig( + 64, + max_seq_group, + max_model_len, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -184,7 +213,9 @@ def test_scheduler_max_seqs(): all_seq_groups: List[SequenceGroup] = [] # Add seq groups to scheduler. for i in range(num_seq_group): - _, seq_group = create_dummy_prompt(str(i), prompt_length=block_size) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=block_size, + block_size=block_size) all_seq_groups.append(seq_group) # Append 1 seq group @@ -211,9 +242,15 @@ def test_scheduler_max_seqs(): assert set(get_sequence_groups(out)) == set([all_seq_groups[1]]) -def test_scheduler_delay_factor(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_scheduler_delay_factor(use_v2_block_manager: bool): block_size = 4 - scheduler_config = SchedulerConfig(100, 64, 16, delay_factor=0.5) + scheduler_config = SchedulerConfig( + 100, + 64, + 16, + delay_factor=0.5, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 8 cache_config.num_gpu_blocks = 8 @@ -221,7 +258,8 @@ def test_scheduler_delay_factor(): # schedule first prompt seq_group_meta, seq_group = create_dummy_prompt("0", - prompt_length=block_size) + prompt_length=block_size, + block_size=block_size) scheduler.add_seq_group(seq_group) seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert out.num_prefill_groups > 0 @@ -231,7 +269,8 @@ def test_scheduler_delay_factor(): # wait for a second before scheduling next prompt time.sleep(1) seq_group_meta, seq_group = create_dummy_prompt("1", - prompt_length=block_size) + prompt_length=block_size, + block_size=block_size) scheduler.add_seq_group(seq_group) # second prompt should *not* be scheduled @@ -248,11 +287,20 @@ def test_scheduler_delay_factor(): append_new_token(out, 1) -def test_swapped_out_prioritized(): - scheduler = initialize_scheduler(max_num_seqs=6) +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_swapped_out_prioritized(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(max_num_seqs=6, + block_size=block_size, + use_v2_block_manager=use_v2_block_manager, + num_cpu_blocks=64, + num_gpu_blocks=64) # best_of=2 * 3 == 6 sequences. for i in range(3): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + best_of=2, + block_size=block_size) scheduler.add_seq_group(seq_group) seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) # prefill scheduled now. @@ -276,7 +324,10 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): append_new_token(out, 1) # Add 1 more task. Swap should be prioritized over prefill. - _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + best_of=2, + block_size=block_size) scheduler.add_seq_group(seq_group) seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) append_new_token(out, 1) @@ -287,17 +338,26 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert out.blocks_to_swap_out == [] -def initialize_scheduler(*, - max_num_seqs=1000, - max_token_budget=1000, - max_model_len=1000, - lora_config=None): - block_size = 4 - scheduler_config = SchedulerConfig(max_token_budget, max_num_seqs, - max_model_len) +def initialize_scheduler( + *, + max_num_seqs=1000, + max_token_budget=1000, + max_model_len=1000, + lora_config=None, + use_v2_block_manager=False, + block_size=4, + num_cpu_blocks=8, + num_gpu_blocks=8, +): + block_size = block_size + scheduler_config = SchedulerConfig( + max_token_budget, + max_num_seqs, + max_model_len, + use_v2_block_manager=use_v2_block_manager) cache_config = CacheConfig(block_size, 1.0, 1, "auto") - cache_config.num_cpu_blocks = 8 - cache_config.num_gpu_blocks = 8 + cache_config.num_cpu_blocks = num_cpu_blocks + cache_config.num_gpu_blocks = num_gpu_blocks scheduler = Scheduler(scheduler_config, cache_config, lora_config) return scheduler @@ -319,12 +379,18 @@ def add_token_budget(budget: SchedulingBudget, budget.add_num_seqs(mock_seq_group.request_id, num_curr_seqs) -def test_prefill_schedule_max_prompt_len(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prefill_schedule_max_prompt_len(use_v2_block_manager: bool): """ Test prompt longer than max_prompt_len is aborted. """ - scheduler = initialize_scheduler(max_model_len=30) - _, seq_group = create_dummy_prompt("0", prompt_length=60) + block_size = 4 + scheduler = initialize_scheduler(max_model_len=30, + use_v2_block_manager=use_v2_block_manager, + block_size=block_size) + _, seq_group = create_dummy_prompt("0", + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) budget = create_token_budget() output = scheduler._schedule_prefills(budget, None) @@ -336,14 +402,21 @@ def test_prefill_schedule_max_prompt_len(): assert len(remaining_waiting) == 0 -def test_prefill_schedule_token_budget(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prefill_schedule_token_budget(use_v2_block_manager: bool): """ Test token budget respected. """ - scheduler = initialize_scheduler() + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=64, + num_gpu_blocks=64) budget = create_token_budget(token_budget=0) for i in range(2): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) # 0 token budget == nothing is scheduled. @@ -366,10 +439,15 @@ def test_prefill_schedule_token_budget(): assert len(remaining_waiting) == 1 # Test when current_batched_tokens respected. - scheduler = initialize_scheduler() + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=16, + num_gpu_blocks=16) budget = create_token_budget(token_budget=60) add_token_budget(budget, 30, 0) - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) # Cannot schedule a prompt that doesn't fit the budget. scheduler.add_seq_group(seq_group) output = scheduler._schedule_prefills(budget, None) @@ -389,14 +467,21 @@ def test_prefill_schedule_token_budget(): assert len(remaining_waiting) == 0 -def test_prefill_schedule_max_seqs(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prefill_schedule_max_seqs(use_v2_block_manager: bool): """ Test max seq respected. """ - scheduler = initialize_scheduler() + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=64, + num_gpu_blocks=64) budget = create_token_budget(max_num_seqs=2) for i in range(3): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) output = scheduler._schedule_prefills(budget, None) remaining_waiting = scheduler.waiting @@ -410,7 +495,9 @@ def test_prefill_schedule_max_seqs(): scheduler.waiting = deque() budget = create_token_budget(max_num_seqs=2) add_token_budget(budget, 0, 2) - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) output = scheduler._schedule_prefills(budget, None) remaining_waiting = scheduler.waiting @@ -421,17 +508,24 @@ def test_prefill_schedule_max_seqs(): assert len(remaining_waiting) == 1 -def test_prefill_schedule_max_lora(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prefill_schedule_max_lora(use_v2_block_manager: bool): """ Test max lora is respected and prioritized. """ + block_size = 4 lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) - scheduler = initialize_scheduler(lora_config=lora_config) + scheduler = initialize_scheduler(lora_config=lora_config, + use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=64, + num_gpu_blocks=64) budget = create_token_budget(token_budget=120) curr_loras: Set[int] = set() for i in range(2): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, + block_size=block_size, lora_request=LoRARequest( lora_name=str(i), lora_int_id=i + 1, @@ -443,7 +537,9 @@ def test_prefill_schedule_max_lora(): # If a request is not scheduled because it hits max lora, it is # prioritized. Verify that. for i in range(2, 4): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) # Schedule 2 requests (0 and 2) output = scheduler._schedule_prefills(budget, curr_loras) @@ -467,14 +563,21 @@ def test_prefill_schedule_max_lora(): assert budget.num_batched_tokens == 60 -def test_prefill_schedule_no_block_manager_capacity(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_prefill_schedule_no_block_manager_capacity(use_v2_block_manager): """ Test sequence cannot be scheduled due to block manager has no capacity. """ - scheduler = initialize_scheduler() + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_gpu_blocks=128, + num_cpu_blocks=128) budget = create_token_budget() for i in range(3): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) scheduler.block_manager.can_allocate = MagicMock() scheduler.block_manager.can_allocate.return_value = AllocStatus.LATER @@ -489,7 +592,9 @@ def test_prefill_schedule_no_block_manager_capacity(): scheduler = initialize_scheduler() budget = create_token_budget() for i in range(3): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler.add_seq_group(seq_group) scheduler.block_manager.can_allocate = MagicMock() scheduler.block_manager.can_allocate.return_value = AllocStatus.NEVER @@ -502,14 +607,21 @@ def test_prefill_schedule_no_block_manager_capacity(): assert len(remaining_waiting) == 0 -def test_decode_schedule_preempted(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_decode_schedule_preempted(use_v2_block_manager: bool): """ Test decodes cannot be scheduled and preempted. """ - scheduler = initialize_scheduler() + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=64, + num_gpu_blocks=64) curr_loras = None for i in range(3): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=block_size) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._add_seq_group_to_running(seq_group) @@ -541,15 +653,23 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert output.blocks_to_copy == [] -def test_decode_swap_beam_search(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_decode_swap_beam_search(use_v2_block_manager: bool): """ Test best_of > 1 swap out blocks """ - scheduler = initialize_scheduler() + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_gpu_blocks=64, + num_cpu_blocks=64) curr_loras = None budget = create_token_budget() for i in range(3): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + best_of=2, + block_size=block_size) scheduler._allocate_and_set_running(seq_group) scheduler._add_seq_group_to_running(seq_group) append_new_token_seq_group(60, seq_group, 1) @@ -589,12 +709,20 @@ def cannot_append_second_group(seq_group, num_lookahead_slots): assert output.blocks_to_copy == [] -def test_schedule_decode_blocks_to_copy_update(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_decode_blocks_to_copy_update(use_v2_block_manager: bool): """ Verify blocks_to_copy is updated. """ - scheduler = initialize_scheduler() - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=4, + num_cpu_blocks=16, + num_gpu_blocks=16) + _, seq_group = create_dummy_prompt("1", + prompt_length=60, + best_of=2, + block_size=block_size) curr_loras = None scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) @@ -619,13 +747,19 @@ def test_schedule_decode_blocks_to_copy_update(): assert output.blocks_to_copy == [(2, 3)] -def test_schedule_swapped_simple(): - scheduler = initialize_scheduler() +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_swapped_simple(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size) curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt("1", + prompt_length=4, + best_of=2, + block_size=block_size) scheduler._allocate_and_set_running(seq_group) - append_new_token_seq_group(60, seq_group, 1) + append_new_token_seq_group(4, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) scheduler._add_seq_group_to_swapped(seq_group) @@ -644,12 +778,17 @@ def test_schedule_swapped_simple(): assert blocks_to_swap_out == blocks_to_swap_in_reverse -def test_schedule_swapped_max_token_budget(): - scheduler = initialize_scheduler() +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_swapped_max_token_budget(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=32, + num_gpu_blocks=32) curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] - for _ in range(2): - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) @@ -676,12 +815,19 @@ def test_schedule_swapped_max_token_budget(): assert len(output.prefill_seq_groups) == 0 -def test_schedule_swapped_max_seqs(): - scheduler = initialize_scheduler() +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_swapped_max_seqs(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=64, + num_gpu_blocks=64) curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] for i in range(4): - _, seq_group = create_dummy_prompt(str(i), prompt_length=60) + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + block_size=4) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) @@ -706,14 +852,21 @@ def test_schedule_swapped_max_seqs(): assert len(output.prefill_seq_groups) == 0 -def test_schedule_swapped_max_loras(): +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_swapped_max_loras(use_v2_block_manager: bool): + block_size = 4 lora_config = LoRAConfig(max_lora_rank=8, max_loras=1) - scheduler = initialize_scheduler(lora_config=lora_config) + scheduler = initialize_scheduler(lora_config=lora_config, + use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=32, + num_gpu_blocks=32) curr_loras: Set[int] = set() blocks_to_swap_out: List[Tuple[int, int]] = [] for i in range(2): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, + block_size=block_size, lora_request=LoRARequest( lora_name=str(i), lora_int_id=i + 1, @@ -734,12 +887,20 @@ def test_schedule_swapped_max_loras(): assert len(curr_loras) == 1 -def test_schedule_swapped_cannot_swap_in(): - scheduler = initialize_scheduler() +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_swapped_cannot_swap_in(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=32, + num_gpu_blocks=32) curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] - for _ in range(2): - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + best_of=2, + block_size=block_size) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) @@ -759,12 +920,20 @@ def test_schedule_swapped_cannot_swap_in(): assert len(output.prefill_seq_groups) == 0 -def test_infeasible_swap(): - scheduler = initialize_scheduler() +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_infeasible_swap(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=32, + num_gpu_blocks=32) curr_loras = None blocks_to_swap_out: List[Tuple[int, int]] = [] - for _ in range(2): - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + prompt_length=60, + best_of=2, + block_size=block_size) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) scheduler._swap_out(seq_group, blocks_to_swap_out) @@ -785,10 +954,18 @@ def test_infeasible_swap(): assert len(output.prefill_seq_groups) == 0 -def test_schedule_swapped_blocks_to_copy(): - scheduler = initialize_scheduler() +@pytest.mark.parametrize('use_v2_block_manager', [True, False]) +def test_schedule_swapped_blocks_to_copy(use_v2_block_manager: bool): + block_size = 4 + scheduler = initialize_scheduler(use_v2_block_manager=use_v2_block_manager, + block_size=block_size, + num_cpu_blocks=32, + num_gpu_blocks=32) curr_loras = None - _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) + _, seq_group = create_dummy_prompt("1", + prompt_length=60, + best_of=2, + block_size=block_size) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) blocks_to_swap_out: List[Tuple[int, int]] = [] diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 02288dc9dac9..280a8abdd13a 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -8,6 +8,8 @@ import os import pytest +from packaging import version +from transformers import __version__ as transformers_version from vllm.logger import init_logger @@ -37,6 +39,7 @@ (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"), (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"), (1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"), + (1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp") ], ) @fork_new_process_for_each_test @@ -46,6 +49,11 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, pytest.skip("Skipping multi-node pipeline parallel test for " "multiprocessing distributed backend") + # Skip tests that require transformers>=4.45.0 + if "Qwen2-VL" in MODEL_NAME and version.parse( + transformers_version) < version.parse("4.45.0.dev0"): + pytest.skip("This test requires transformers>=4.45.0") + pp_args = [ # use half precision for speed and memory savings in CI environment "--dtype", diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 8dd200b35d0f..360ac1bfbad9 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -40,3 +40,24 @@ def test_limit_mm_per_prompt_parser(arg, expected): def test_bad_nullable_kvs(arg): with pytest.raises(ArgumentTypeError): nullable_kvs(arg) + + +@pytest.mark.parametrize(("arg", "expected"), [ + (None, None), + ("{}", {}), + ('{"num_crops": 4}', { + "num_crops": 4 + }), + ('{"foo": {"bar": "baz"}}', { + "foo": { + "bar": "baz" + } + }), +]) +def test_mm_processor_kwargs_prompt_parser(arg, expected): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + if arg is None: + args = parser.parse_args([]) + else: + args = parser.parse_args(["--mm-processor-kwargs", arg]) + assert args.mm_processor_kwargs == expected diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index ef34bebbb0f8..cd989225e248 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -162,6 +162,41 @@ def test_chat(): assert len(outputs) == 1 +def test_multi_chat(): + + llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct") + + prompt1 = "Explain the concept of entropy." + prompt2 = "Explain what among us is." + + conversation1 = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt1 + }, + ] + + conversation2 = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": prompt2 + }, + ] + + messages = [conversation1, conversation2] + + outputs = llm.chat(messages) + assert len(outputs) == 2 + + @pytest.mark.parametrize("image_urls", [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) def test_chat_multi_image(image_urls: List[str]): diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index 2ad8460023c2..63beaaba29a8 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -19,7 +19,11 @@ RTOL = 0.03 EXPECTED_VALUE = 0.58 DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"] -MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]] +MORE_ARGS_LIST = [ + ["--enable-chunked-prefill"], # Chunked + ["--num-scheduler-steps", "8"], # MS + ["--num-scheduler-steps", "8", "--multi-step-stream-outputs"] # MS+Stream +] @pytest.mark.parametrize("more_args", MORE_ARGS_LIST) diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py index 0a9088222307..0dfa79e9af8e 100644 --- a/tests/kernels/test_machete_gemm.py +++ b/tests/kernels/test_machete_gemm.py @@ -31,6 +31,8 @@ (257, 4224, 4160), (257, 4096, 4096), (64, 4096, 4096), + (1024, 4096, 8192), + (1024, 8192, 4096), ] ACT_TYPES = [torch.float16, torch.bfloat16] @@ -139,6 +141,7 @@ def test_machete_all_schedules(shape, atype: torch.dtype, output_ref = torch.matmul(a, w_ref) for schedule in ops.machete_supported_schedules(wtype): + print(f"Testing schedule {schedule}") output = ops.machete_gemm( a, b_q=w_q_machete, diff --git a/tests/kernels/test_permute_cols.py b/tests/kernels/test_permute_cols.py new file mode 100644 index 000000000000..14ad7a22cf7c --- /dev/null +++ b/tests/kernels/test_permute_cols.py @@ -0,0 +1,15 @@ +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm._custom_ops import permute_cols + + +@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)]) +@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16]) +def test_permute_cols(shape, dtype): + x = torch.randn(shape, dtype=dtype).cuda() + perm = torch.randperm(x.shape[1]).to(torch.int).cuda() + opcheck(torch.ops._C.permute_cols, (x, perm)) + y = permute_cols(x, perm) + torch.testing.assert_close(y, x[:, perm]) \ No newline at end of file diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 314d6215cbd9..41c37a4813c6 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -169,6 +169,7 @@ def test_punica_sgmv( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -183,6 +184,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, scaling, ) else: @@ -195,6 +197,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, add_inputs=True, ) ref_torch_groupgemm( @@ -347,6 +350,7 @@ def test_punica_expand_nslices( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -364,6 +368,7 @@ def test_punica_expand_nslices( lora_indices_tensor, batches, max_seq_length, + token_nums, slice_offset, hidden_size, add_inputs=True, diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 28a395af19e6..185da6399a06 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -84,6 +84,7 @@ def test_punica_sgmv( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -98,6 +99,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, scaling, ) else: @@ -110,6 +112,7 @@ def test_punica_sgmv( lora_indices_tensor, batches, max_seq_length, + token_nums, add_inputs=True, ) ref_torch_groupgemm( @@ -262,6 +265,7 @@ def test_punica_expand_nslices( device, ) max_seq_length = seq_len_tensor.max() + token_nums = seq_len_tensor.sum().item() if isinstance(max_seq_length, tuple): max_seq_length = max_seq_length[0].item() else: @@ -279,6 +283,7 @@ def test_punica_expand_nslices( lora_indices_tensor, batches, max_seq_length, + token_nums, slice_offset, hidden_size, add_inputs=True, diff --git a/tests/models/decoder_only/vision_language/test_phi3v.py b/tests/models/decoder_only/vision_language/test_phi3v.py index e248151c40a6..eba0a1a1bce4 100644 --- a/tests/models/decoder_only/vision_language/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/test_phi3v.py @@ -1,16 +1,21 @@ import os import re -from typing import List, Optional, Tuple, Type +from typing import Callable, List, Optional, Tuple, Type import pytest -from transformers import AutoTokenizer +import torch +from transformers import AutoImageProcessor, AutoTokenizer +from vllm.inputs import InputContext, LLMInputs +from vllm.model_executor.models.phi3v import _IMAGE_TOKEN_ID +from vllm.multimodal import MultiModalRegistry from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs from vllm.utils import is_cpu, is_hip -from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner -from ...utils import check_logprobs_close +from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) +from ...utils import build_model_context, check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -71,7 +76,7 @@ def run_test( All the image fixtures for the test are from IMAGE_ASSETS. For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects + For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. @@ -230,3 +235,174 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, mm_limit=2, tensor_parallel_size=1, ) + + +### Fast tests for correctness in processor_kwarg override handling + + +# Wrap lazy imports to avoid initializing CUDA during test collection +@pytest.fixture() +def input_processor_for_phi3v(): + from vllm.model_executor.models.phi3v import input_processor_for_phi3v + return input_processor_for_phi3v + + +@pytest.fixture() +def dummy_data_for_phi3v(): + from vllm.model_executor.models.phi3v import dummy_data_for_phi3v + return dummy_data_for_phi3v + + +@pytest.fixture() +def get_max_phi3v_image_tokens(): + from vllm.model_executor.models.phi3v import get_max_phi3v_image_tokens + return get_max_phi3v_image_tokens + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops", [4, 16, None]) +def test_input_mapper_override(model: str, image_assets: _ImageAssets, + num_crops: Optional[int]): + """Ensure that the [default] input mapper handles num_crops properly.""" + # We pass the processor kwargs here since for this model, we fall back to + # the default mapper; this will fall back to the HF mapper and forward + # mm_processor_kwargs to it. + mm_processor_kwargs = { + "num_crops": num_crops + } if num_crops is not None else {} + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=mm_processor_kwargs, + ) + + hf_processor = AutoImageProcessor.from_pretrained(model, + trust_remote_code=True, + **mm_processor_kwargs) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + + image = image_assets[0].pil_image + hf_result = hf_processor.preprocess( + image, + return_tensors="pt", + ) + + vllm_result = mm_registry.map_input( + ctx.model_config, + {"image": image}, + ) + + assert torch.all(hf_result["image_sizes"] == vllm_result["image_sizes"]) + assert torch.all( + hf_result["num_img_tokens"] == vllm_result["num_img_tokens"]) + + # For pixel values, the second axis should be the num_crops + 1 + # for the rescaled original image. The default value in VLLM falls + # back to the HF config, which is why we compare to the processor num_crops + assert torch.all(hf_result["pixel_values"] == vllm_result["pixel_values"]) + assert vllm_result["pixel_values"].shape[1] == hf_processor.num_crops + 1 + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops,expected_max_tokens", [ + (4, 781), + (16, 2653), +]) +def test_max_tokens_override(get_max_phi3v_image_tokens: Callable, model: str, + num_crops: int, expected_max_tokens: int): + """Ensure get_max_phi3v_image_tokens handles num_crops properly.""" + # NOTE: mm_processor_kwargs on the context in this test is unused, since + # this is testing the mapper directly. In practice, the processor kwargs + # are wrapped in a closure when calling the max tokens func. We explicitly + # do NOT use the mm_processor_kwargs in the model context here to ensure + # that the max image tokens implementation is referencing a mix of the + # kwargs to the function and the original mm_processor_kwargs in case + # values are somehow updated and end up in a bad state. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=None, + ) + + actual_max_tokens = get_max_phi3v_image_tokens( + InputContext(ctx.model_config), + num_crops=num_crops, + ) + + assert expected_max_tokens == actual_max_tokens + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops,toks_per_img,num_imgs", [ + (4, 781, 1), + (4, 781, 2), + (16, 2653, 1), + (16, 2653, 2), +]) +def test_dummy_data_override(dummy_data_for_phi3v: Callable, model: str, + num_crops: int, toks_per_img: int, num_imgs: int): + """Ensure dummy_data_for_phi3v handles num_crops properly.""" + # Same as the previous test - don't initialize mm_processor_kwargs + # in this test and assume that the kwargs will be correctly expanded by + # the partial when calling the dummy data func. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=None, + ) + + sequence_data, _, = dummy_data_for_phi3v( + ctx=ctx, + seq_len=8192, # Should be bigger than num_imgs * toks_per_img + mm_counts={"image": num_imgs}, + num_crops=num_crops, + ) + # Ensure we have the right number of placeholders per num_crops size + img_tok_count = sequence_data.get_token_ids().count(_IMAGE_TOKEN_ID) + assert img_tok_count == toks_per_img * num_imgs + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_crops,expected_toks_per_img,num_imgs", [ + (4, 757, 1), + (4, 757, 2), + (16, 1921, 1), + (16, 1921, 2), +]) +def test_input_processor_override(input_processor_for_phi3v: Callable, + image_assets: _ImageAssets, model: str, + num_crops: int, expected_toks_per_img: int, + num_imgs: int): + """Ensure input_processor_for_phi3v handles num_crops properly.""" + # Same as the previous test - don't initialize mm_processor_kwargs + # in this test and assume that the kwargs will be correctly expanded by + # the partial when calling the custom input processor. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + ) + tokenizer = AutoTokenizer.from_pretrained(model) + # Build the image str / prompt based on the number of images we pass + img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) + prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" + images = [image_assets[0].pil_image] * num_imgs + + llm_inputs = LLMInputs(prompt_token_ids=tokenizer.encode(prompt), + prompt=prompt, + multi_modal_data={"image": images}) + + proc_llm_inputs = input_processor_for_phi3v( + ctx=ctx, + llm_inputs=llm_inputs, + num_crops=num_crops, + ) + + # Ensure we have the right number of placeholders per num_crops size + img_tok_count = proc_llm_inputs["prompt_token_ids"].count(_IMAGE_TOKEN_ID) + assert img_tok_count == expected_toks_per_img * num_imgs diff --git a/tests/models/decoder_only/vision_language/test_qwen.py b/tests/models/decoder_only/vision_language/test_qwen.py index e4f79092b760..638fb68b8f87 100644 --- a/tests/models/decoder_only/vision_language/test_qwen.py +++ b/tests/models/decoder_only/vision_language/test_qwen.py @@ -5,14 +5,13 @@ import torch from PIL.Image import Image -from vllm.config import ModelConfig from vllm.inputs import InputContext, LLMInputs from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer, rescale_image_size from ....conftest import (IMAGE_ASSETS, HfRunner, ImageAsset, PromptImageInput, VllmRunner, _ImageAssets) -from ...utils import check_logprobs_close +from ...utils import build_model_context, check_logprobs_close text_only_models = [ "Qwen/Qwen-7B-Chat" # Has no visual component @@ -42,32 +41,6 @@ IMG_SIZE = 448 -def build_model_context(model_name: str, - tokenizer_name: Optional[str] = None, - trust_remote_code: bool = False): - """Creates an InputContext for a given model. - - Args: - model_name: Name of the model being considered. - tokenizer_name: Name of the tokenizer being considered. - trust_remote_code: Whether or not to allow loading remote code. - - Returns: - InputContext for the model being considered. - """ - if tokenizer_name is None: - tokenizer_name = model_name - model_config = ModelConfig( - model_name, - tokenizer_name, - tokenizer_mode="auto", - trust_remote_code=trust_remote_code, - dtype="float32", - seed=0, - ) - return InputContext(model_config) - - @pytest.fixture() def input_mapper_for_qwen(): # Lazy import to avoid initializing CUDA during test collection diff --git a/tests/models/utils.py b/tests/models/utils.py index 8e31a1d6eefe..eb6254f18182 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,6 +1,8 @@ import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union +from vllm.config import ModelConfig +from vllm.inputs import InputContext from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs TokensText = Tuple[List[int], str] @@ -240,3 +242,36 @@ def check_logprobs_close( warnings.simplefilter("always") warnings.warn(fail_msg, stacklevel=2) + + +def build_model_context(model_name: str, + tokenizer_name: Optional[str] = None, + trust_remote_code: bool = False, + mm_processor_kwargs: Optional[Dict] = None, + limit_mm_per_prompt: Optional[Dict] = None): + """Creates an InputContext for a given model. + + Args: + model_name: Name of the model being considered. + tokenizer_name: Name of the tokenizer being considered. + trust_remote_code: Whether or not to allow loading remote code. + mm_processor_kwargs: optional processor kwargs for to be leveraged + in the input processor, mapper, dummy data creation, etc. + limit_mm_per_prompt: Multimodal limits. + + Returns: + InputContext for the model being considered. + """ + if tokenizer_name is None: + tokenizer_name = model_name + model_config = ModelConfig( + model_name, + tokenizer_name, + tokenizer_mode="auto", + trust_remote_code=trust_remote_code, + dtype="float32", + seed=0, + mm_processor_kwargs=mm_processor_kwargs, + limit_mm_per_prompt=limit_mm_per_prompt, + ) + return InputContext(model_config) diff --git a/tests/mq_llm_engine/test_error_handling.py b/tests/mq_llm_engine/test_error_handling.py index 7c466c92d529..76b2f494d5b2 100644 --- a/tests/mq_llm_engine/test_error_handling.py +++ b/tests/mq_llm_engine/test_error_handling.py @@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket): # Throws an error in first forward pass. with pytest.raises(RAISED_ERROR): - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket): # Engine is errored, should get ENGINE_DEAD_ERROR. with pytest.raises(MQEngineDeadError): - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket): # Generate call should throw ENGINE_DEAD_ERROR with pytest.raises(MQEngineDeadError): - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id=uuid.uuid4()): pass @@ -153,27 +153,20 @@ async def test_failed_abort(tmp_socket): await client.check_health() # Trigger an abort on the client side. - async def bad_abort_after_2s(): - await asyncio.sleep(2.0) - await client.abort(request_id="foo") + # This request ID does not exist, and will cause the engine to error + await client.abort(request_id="foo") - # Trigger an abort in 2s from now. - abort_task = asyncio.create_task(bad_abort_after_2s()) - - # Exception in abort() will happen during this generation. - # This will kill the engine and should return ENGINE_DEAD_ERROR + # Future generation requests will now fail # with reference to the original KeyError("foo") with pytest.raises(MQEngineDeadError) as execinfo: async for _ in client.generate( - prompt="Hello my name is", - sampling_params=SamplingParams(max_tokens=2000), + inputs="Hello my name is", + sampling_params=SamplingParams(max_tokens=10), request_id=uuid.uuid4()): pass assert "KeyError" in repr(execinfo.value) assert client.errored - await abort_task - # This should raise the original error. with pytest.raises(RAISED_ERROR): await client.check_health() @@ -190,7 +183,7 @@ async def test_bad_request(tmp_socket): # Invalid request should fail, but not crash the server. with pytest.raises(ValueError): - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-1", lora_request=LoRARequest( @@ -199,7 +192,7 @@ async def test_bad_request(tmp_socket): pass # This request should be okay. - async for _ in client.generate(prompt="Hello my name is", + async for _ in client.generate(inputs="Hello my name is", sampling_params=SamplingParams(), request_id="abcd-2"): pass diff --git a/tests/mq_llm_engine/utils.py b/tests/mq_llm_engine/utils.py index 3ffa126070ca..e27fd7792341 100644 --- a/tests/mq_llm_engine/utils.py +++ b/tests/mq_llm_engine/utils.py @@ -20,7 +20,7 @@ async def generate( count = 0 async for out in client.generate( request_id=request_id, - prompt="Hello my name is Robert and", + inputs="Hello my name is Robert and", sampling_params=SamplingParams(max_tokens=num_tokens, temperature=0)): diff --git a/tests/multimodal/test_processor_kwargs.py b/tests/multimodal/test_processor_kwargs.py new file mode 100644 index 000000000000..5529ccd4fa57 --- /dev/null +++ b/tests/multimodal/test_processor_kwargs.py @@ -0,0 +1,339 @@ +from array import array +from typing import Mapping +from unittest.mock import patch + +import pytest +import torch + +from vllm.inputs import InputContext, LLMInputs +from vllm.inputs.registry import InputRegistry +from vllm.multimodal import MultiModalRegistry +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData + +from ..models.utils import build_model_context + +# Used for fast tests where the model doesn't matter +DUMMY_MODEL_ID = "facebook/opt-125m" +# Used for tests that need a multimodal model +MULTIMODAL_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" + +# For mm_processor_kwargs - we test overrides by defining mocks for each place +# it is used, and ensuring that we can pass processor kwargs an override value +# to receive the intended result for things like sequence length etc. +DEFAULT_NUM_CROPS = 4 +NUM_CROPS_OVERRIDE = 16 + + +# Mocks for all of the places that we use the mm_processor_kwargs +# to override values in different callables +@pytest.fixture +def use_processor_mock(): + """Patches the internal model input processor with an override callable.""" + + def custom_processor(ctx: InputContext, + llm_inputs: LLMInputs, + *, + num_crops=DEFAULT_NUM_CROPS): + # For testing purposes, we don't worry about the llm inputs / return + # type validation, and just return the value of the kwarg that we + # clobber. + return num_crops + + with patch("vllm.inputs.registry.InputRegistry._get_model_input_processor", + return_value=custom_processor): + yield + + +@pytest.fixture +def use_dummy_data_mock(): + """Patches the internal model input processor with an override callable.""" + + def custom_dummy_data_factory(self, + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + *, + num_crops=DEFAULT_NUM_CROPS): + seq_data = SequenceData( + array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * num_crops)) + return seq_data, None + + with patch( + "vllm.inputs.registry.InputRegistry._default_dummy_data_factory", + custom_dummy_data_factory): + yield + + +# Lazy import to avoid CUDA reinitialization error +def mm_model_cls(): + from vllm.model_executor.models.phi3v import Phi3VForCausalLM + + return Phi3VForCausalLM + + +# lambda whose signature matches max token calcs extra & mapper + extra kwargs +get_num_crops = lambda ctx, *, num_crops=DEFAULT_NUM_CROPS: num_crops +custom_mapper = lambda ctx, data, *, num_crops=DEFAULT_NUM_CROPS: { + "num_pixels": torch.zeros(size=(1, num_crops + 1, 3, 336, 336)) +} + + +### Test for default processor logic & mm_processor_kwargs wrapping +def test_default_processor_is_a_noop(): + """Ensure that by default, there is no processor override.""" + dummy_registry = InputRegistry() + ctx = build_model_context(DUMMY_MODEL_ID) + processor = dummy_registry.create_input_processor(ctx.model_config) + proc_inputs = LLMInputs(prompt_token_ids=[], prompt="") + proc_outputs = processor(inputs=proc_inputs) + assert proc_inputs is proc_outputs + + +@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) +def test_processor_default_kwargs(use_processor_mock, num_crops): + """Ensure input processors can use processor kwargs.""" + dummy_registry = InputRegistry() + # If we have a value for num_crops, pass the override value and make + # sure we get that value as a return-value from out mock processor, + # otherwise fall back to the default value + mm_processor_kwargs = None if num_crops is None else { + "num_crops": num_crops + } + expected_num_crops = DEFAULT_NUM_CROPS if num_crops is None else num_crops + ctx = build_model_context(DUMMY_MODEL_ID, + mm_processor_kwargs=mm_processor_kwargs) + processor = dummy_registry.create_input_processor(ctx.model_config) + + num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) + assert num_crops_val == expected_num_crops + + +@pytest.mark.parametrize( + "mm_processor_kwargs", + [ + # Not part of the signature + { + "does_not_exist": 100 + }, + # Part of the signature, not keyword only + { + "ctx": "something bad" + } + ]) +def test_processor_with_sad_kwarg_overrides(use_processor_mock, + mm_processor_kwargs): + """Ensure that input processors filter out invalid mm_processor_kwargs""" + dummy_registry = InputRegistry() + ctx = build_model_context(DUMMY_MODEL_ID, + mm_processor_kwargs=mm_processor_kwargs) + + processor = dummy_registry.create_input_processor(ctx.model_config) + num_crops_val = processor(LLMInputs(prompt_token_ids=[], prompt="")) + assert num_crops_val == DEFAULT_NUM_CROPS + + +### Test overrides for the dummy data +@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) +def test_dummy_data_kwarg_overrides(use_dummy_data_mock, num_crops): + """Ensure dummy data factories can use processor kwargs.""" + mm_processor_kwargs = None if num_crops is None else { + "num_crops": num_crops + } + expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops + dummy_registry = InputRegistry() + ctx = build_model_context(DUMMY_MODEL_ID, + mm_processor_kwargs=mm_processor_kwargs) + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + + # NOTE: seq_len is thrown away here since this will leverage the + # default dummy data factory that we have patched in, whose seq + # len is solely dependent on the value of the mm_processor_kwargs. + seq_data, _ = dummy_registry.dummy_data_for_profiling( + ctx.model_config, seq_len=-1, mm_registry=mm_registry) + assert len(seq_data.prompt_token_ids) == expected_seq_count + + +@pytest.mark.parametrize( + "mm_processor_kwargs", + [ + # Not part of the signature + { + "does_not_exist": 100 + }, + # Part of the signature, not keyword only + { + "ctx": "something bad" + } + ]) +def test_dummy_data_with_sad_kwarg_overrides(use_dummy_data_mock, + mm_processor_kwargs): + """Ensure the dummy data factory filters out invalid mm_processor_kwargs""" + dummy_registry = InputRegistry() + ctx = build_model_context(DUMMY_MODEL_ID, + mm_processor_kwargs=mm_processor_kwargs) + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + + # NOTE: seq_len is thrown away here since this will leverage the + # default dummy data factory that we have patched in, whose seq + # len is solely dependent on the value of the mm_processor_kwargs. + seq_data, _ = dummy_registry.dummy_data_for_profiling( + ctx.model_config, seq_len=-1, mm_registry=mm_registry) + assert len(seq_data.prompt_token_ids) == DEFAULT_NUM_CROPS + + +### Test overrides for the max token count per multimodal instance +@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) +def test_max_tokens_kwarg_overrides(num_crops): + """Ensure max token calcs can use processor kwargs.""" + mm_processor_kwargs = None if num_crops is None else { + "num_crops": num_crops + } + expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops + + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + mm_processor_kwargs=mm_processor_kwargs, + limit_mm_per_prompt={"image": 1}) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + # Patch the image registry for phi3v with our lambda that is compatible + # with overrides, then ensure that calling the method correctly echos + # our num_crops value back from the mm_processor_kwargs. + with patch.object( + mm_registry._get_plugin("image"), + "_max_mm_tokens", + {mm_model_cls(): get_num_crops}, + ): + max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( + ctx.model_config) + + assert expected_seq_count == max_multimodal_tokens + + +@pytest.mark.parametrize( + "mm_processor_kwargs", + [ + # Not part of the signature + { + "does_not_exist": 100 + }, + # Part of the signature, not keyword only + { + "ctx": "something bad" + } + ]) +def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs): + """Ensure that max token calcs filters out invalid mm_processor_kwargs""" + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + mm_processor_kwargs=mm_processor_kwargs, + limit_mm_per_prompt={"image": 1}) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + + # Similar before, but since these kwargs get filtered, + # we always get our default value back. + with patch.object( + mm_registry._get_plugin("image"), + "_max_mm_tokens", + {mm_model_cls(): get_num_crops}, + ): + max_multimodal_tokens = mm_registry.get_max_multimodal_tokens( + ctx.model_config) + + assert max_multimodal_tokens == DEFAULT_NUM_CROPS + + +### Test overrides for the mapper +@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE]) +def test_default_mapper_with_processer_kwargs(image_assets, num_crops): + """Ensure that the mapper processor kwargs can fall back to HF models.""" + # NOTE - we don't validate bad inputs for the default mapper, because it's + # through the automodel interface in transformers, so we can't easily + # inspect what kwargs are or are not allowed. + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + mm_processor_kwargs={"num_crops": num_crops}, + limit_mm_per_prompt={"image": 1}) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + + image = image_assets[0].pil_image + mm_inputs = {"image": image} + + mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) + # Phi3v pixel vals should have shape: [batch, num_crops+1, 3, 336, 336] + assert mapped_inputs["pixel_values"].shape[1] == num_crops + 1 + + +@pytest.mark.parametrize("num_crops", [None, NUM_CROPS_OVERRIDE]) +def test_custom_mapper_kwarg_overrides(image_assets, num_crops): + """Ensure custom mappers can use processor kwargs.""" + mm_processor_kwargs = None if num_crops is None else { + "num_crops": num_crops + } + expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + mm_processor_kwargs=mm_processor_kwargs, + limit_mm_per_prompt={"image": 1}) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + # Patch the image registry for phi3v with our lambda that is compatible + # with overrides, then ensure that calling the method correctly echos + # our num_crops value back from the mm_processor_kwargs. + image = image_assets[0].pil_image + mm_inputs = {"image": image} + + with patch.object( + mm_registry._get_plugin("image"), + "_default_input_mapper", + {mm_model_cls(): custom_mapper}, + ): + mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) + + assert mapped_inputs["pixel_values"].shape[1] == expected_seq_count + 1 + + +@pytest.mark.parametrize( + "mm_processor_kwargs", + [ + # Not part of the signature + { + "does_not_exist": 100 + }, + # Part of the signature, not keyword only + { + "ctx": "something bad" + } + ]) +def test_custom_mapper_with_sad_kwarg_overrides(image_assets, + mm_processor_kwargs): + """Ensure that custom mappers filters out invalid mm_processor_kwargs""" + ctx = build_model_context(MULTIMODAL_MODEL_ID, + trust_remote_code=True, + mm_processor_kwargs=mm_processor_kwargs, + limit_mm_per_prompt={"image": 1}) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + # Patch the image registry for phi3v with our lambda that is compatible + # with overrides, then ensure that calling the method correctly echos + # our num_crops value back from the mm_processor_kwargs. + image = image_assets[0].pil_image + mm_inputs = {"image": image} + + with patch.object( + mm_registry._get_plugin("image"), + "_default_input_mapper", + {mm_model_cls(): custom_mapper}, + ): + mapped_inputs = mm_registry.map_input(ctx.model_config, mm_inputs) + + assert mapped_inputs["pixel_values"].shape[1] == DEFAULT_NUM_CROPS + 1 diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index 36167cf95f58..ac2ebc622ba6 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -107,7 +107,7 @@ def validate_generated_texts(hf_runner, quantization='bitsandbytes', load_format='bitsandbytes', tensor_parallel_size=vllm_tp_size, - enforce_eager=True, + enforce_eager=False, gpu_memory_utilization=0.8) as llm: vllm_outputs = llm.generate_greedy(prompts, 8) vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index 98a02dec895d..a9bedc2956fd 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -9,7 +9,7 @@ # 1. Increase max_tokens to 256. # 2. Increase beam_width to 8. # 3. Use the model "huggyllama/llama-7b". -MAX_TOKENS = [128] +MAX_TOKENS = [64] BEAM_WIDTHS = [4] MODELS = ["TinyLlama/TinyLlama-1.1B-Chat-v1.0"] @@ -33,8 +33,8 @@ def test_beam_search_single_input( max_tokens) with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_beam_search(example_prompts, - beam_width, max_tokens) + vllm_outputs = vllm_model.generate_beam_search_new( + example_prompts, beam_width, max_tokens) for i in range(len(example_prompts)): hf_output_ids, hf_output_texts = hf_outputs[i] diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index 1eba98cefd04..4ddad66dce1f 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -365,7 +365,7 @@ def test_accept_tokens_partially(seed: int, device: str): # Next only keep the first 2 draft tokens same as the zero temperature # tokens. For the remaining 3 choose some other tokens. In the # response we will expect the first 2 tokens to be the same as the - # draft tokens and the rest as -1 + # draft tokens and the recovered token and rest as -1 draft_token_ids_to_replace = get_draft_token_ids( batch_size, k, vocab_size, zero_temperature_token_ids) draft_token_ids = torch.cat( @@ -378,6 +378,8 @@ def test_accept_tokens_partially(seed: int, device: str): assert output_token_ids.shape[0] == batch_size assert output_token_ids.shape[1] == (k + 1) assert torch.all(output_token_ids[:, :2] == draft_token_ids[:, :2]) + assert torch.all( + output_token_ids[:, 2] == target_with_bonus_probs.argmax(-1)[:, 2]) assert torch.all(output_token_ids[:, -3:] == -1) @@ -443,14 +445,14 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str): @pytest.mark.parametrize("seed", list(range(10))) @pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_replacement_token_ids(seed: int, device: str): +def test_get_recovered_token_ids(seed: int, device: str): """ Test the TypicalAcceptanceSampler's method for generating replacement token IDs. - This test verifies that the `_replacement_token_ids` method of the + This test verifies that the `_get_recovered_token_ids` method of the TypicalAcceptanceSampler correctly identifies the token IDs to be used - as replacements based on the target probability distribution. + as recovered token IDs based on the target probability distribution. Specifically, it ensures that the method correctly identifies the tokens with the highest probability for each sequence in the batch. """ @@ -462,10 +464,7 @@ def test_replacement_token_ids(seed: int, device: str): typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) typical_acceptance_sampler.init_gpu_tensors(device=device) target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - expected_replacement_tokens = -torch.ones( - (batch_size, k), dtype=torch.long) - expected_replacement_tokens[:, 0] = torch.argmax(target_probs[:, 0, :], - dim=1) + expected_replacement_tokens = torch.argmax(target_probs, dim=-1) actual_replacement_tokens = ( - typical_acceptance_sampler._replacement_token_ids(target_probs)) + typical_acceptance_sampler._get_recovered_token_ids(target_probs)) assert torch.all(expected_replacement_tokens == actual_replacement_tokens) diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 3d93f4a23b68..b450ef97c89d 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -1,13 +1,16 @@ from itertools import cycle -from typing import List, Optional, Tuple +from typing import List, Optional, Sequence, Tuple, Union import pytest from vllm import LLM, SamplingParams from vllm.model_executor.utils import set_random_seed +from vllm.sequence import PromptLogprobs, SampleLogprobs from ...conftest import cleanup -from ...models.utils import check_logprobs_close, check_outputs_equal +from ...models.utils import (TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs, + check_logprobs_close, check_outputs_equal) from ...utils import RemoteOpenAIServer PROMPTS = [ @@ -81,45 +84,77 @@ def get_output_from_llm_generator( return tokens, token_ids, acceptance_rate -def run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size: int, - max_output_len: int, - seed: Optional[int] = 0, - temperature: float = 0.0, - logprobs: int = 1): - org_args = { - **common_llm_kwargs, - **per_test_common_llm_kwargs, - **baseline_llm_kwargs, - } - - sd_args = { - **common_llm_kwargs, - **per_test_common_llm_kwargs, - **test_llm_kwargs, - } - - prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))] - - sampling_params = SamplingParams(temperature=temperature, - max_tokens=max_output_len, - seed=seed, - logprobs=logprobs) - - with vllm_runner(**org_args) as vllm_model: - org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) - - with vllm_runner(**sd_args) as vllm_model: - sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) - - check_logprobs_close(outputs_0_lst=org_outputs, - outputs_1_lst=sd_outputs, - name_0="org", - name_1="sd") +def check_logprobs_correctness( + spec_outputs: Sequence[Union[TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs]], + baseline_outputs: Sequence[Union[TokensTextLogprobs, + TokensTextLogprobsPromptLogprobs]], + disable_logprobs: bool = False, +): + """Compare sampled and prompt logprobs between baseline and spec decoding + """ + if not disable_logprobs: + return check_logprobs_close( + outputs_0_lst=baseline_outputs, + outputs_1_lst=spec_outputs, + name_0="org", + name_1="sd", + ) + + # Check correctness when disable_logprobs == True + for spec_output, baseline_output in zip(spec_outputs, baseline_outputs): + # Check generated token logprobs. + spec_logprobs = spec_output[2] + baseline_logprobs = baseline_output[2] + _check_logprobs_when_output_disabled(spec_logprobs, + baseline_logprobs, + is_prompt_logprobs=False) + + # Check prompt logprobs too, if they exist + if len(baseline_output) == 4: + assert len(spec_output) == 4 + spec_prompt_logprobs = spec_output[3] + baseline_prompt_logprobs = baseline_output[3] + _check_logprobs_when_output_disabled(spec_prompt_logprobs, + baseline_prompt_logprobs, + is_prompt_logprobs=True) + + +def _check_logprobs_when_output_disabled( + spec_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs], + baseline_logprobs: Union[Optional[PromptLogprobs], SampleLogprobs], + is_prompt_logprobs: bool = False, +): + # Prompt logprobs are optional + if is_prompt_logprobs and baseline_logprobs is None: + assert spec_logprobs is None + return + + assert spec_logprobs is not None + assert baseline_logprobs is not None + assert len(spec_logprobs) == len(baseline_logprobs) + + # For each generated position of the sequence. + for pos, (spec_pos_logprobs, baseline_pos_logprobs) in enumerate( + zip(spec_logprobs, baseline_logprobs)): + + # First prompt logprob is expected to be None + if is_prompt_logprobs and baseline_pos_logprobs is None: + assert spec_pos_logprobs is None + assert pos == 0 + continue + + assert spec_pos_logprobs is not None + assert baseline_pos_logprobs is not None + + # When disabled, the 1 logprob is returned with dummy values for the + # score and rank, but the token id should match the baseline model + assert len(spec_pos_logprobs) == 1 + (spec_pos_logprob_token_id, + spec_pos_logprob) = next(iter(spec_pos_logprobs.items())) + assert spec_pos_logprob.rank == -1 + assert spec_pos_logprob.logprob == 0.0 + assert spec_pos_logprob_token_id in baseline_pos_logprobs def run_equality_correctness_test( @@ -135,7 +170,10 @@ def run_equality_correctness_test( disable_seed: bool = False, ignore_eos: bool = True, ensure_all_accepted: bool = False, - expected_acceptance_rate: Optional[float] = None): + expected_acceptance_rate: Optional[float] = None, + logprobs: Optional[int] = None, + prompt_logprobs: Optional[int] = None, + disable_logprobs: bool = False): org_args = { **common_llm_kwargs, @@ -157,10 +195,12 @@ def run_equality_correctness_test( sampling_params = SamplingParams(temperature=temperature, max_tokens=max_output_len, seed=seed, - ignore_eos=ignore_eos) + ignore_eos=ignore_eos, + logprobs=logprobs, + prompt_logprobs=prompt_logprobs) with vllm_runner(**org_args) as vllm_model: - org_outputs = vllm_model.generate(prompts, sampling_params) + org_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) with vllm_runner(**sd_args) as vllm_model: if ensure_all_accepted or expected_acceptance_rate is not None: @@ -169,7 +209,7 @@ def run_equality_correctness_test( 'prometheus'] stat_logger.local_interval = -100 - sd_outputs = vllm_model.generate(prompts, sampling_params) + sd_outputs = vllm_model.generate_w_logprobs(prompts, sampling_params) if ensure_all_accepted or expected_acceptance_rate is not None: acceptance_rate = (stat_logger.metrics. @@ -185,11 +225,16 @@ def run_equality_correctness_test( if expected_acceptance_rate is not None: assert acceptance_rate >= expected_acceptance_rate - 1e-2 - check_outputs_equal(outputs_0_lst=org_outputs, - outputs_1_lst=sd_outputs, + # Only pass token entries, not the logprobs + check_outputs_equal(outputs_0_lst=[out[0:2] for out in org_outputs], + outputs_1_lst=[out[0:2] for out in sd_outputs], name_0="org", name_1="sd") + # Check logprobs if requested + if logprobs is not None or prompt_logprobs is not None: + check_logprobs_correctness(sd_outputs, org_outputs, disable_logprobs) + def run_equality_correctness_test_tp(model, common_llm_kwargs, diff --git a/tests/spec_decode/e2e/test_eagle_correctness.py b/tests/spec_decode/e2e/test_eagle_correctness.py index f2af2c2bedb1..d7ca8815ec25 100644 --- a/tests/spec_decode/e2e/test_eagle_correctness.py +++ b/tests/spec_decode/e2e/test_eagle_correctness.py @@ -80,6 +80,64 @@ def test_eagle_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, batch_size, output_len, seed) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 128, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_eagle_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_logprobs.py b/tests/spec_decode/e2e/test_logprobs.py index 03c1733f104f..b7d54991e053 100644 --- a/tests/spec_decode/e2e/test_logprobs.py +++ b/tests/spec_decode/e2e/test_logprobs.py @@ -4,7 +4,7 @@ from vllm import SamplingParams -from .conftest import run_logprob_correctness_test +from .conftest import run_equality_correctness_test @pytest.mark.parametrize( @@ -25,6 +25,10 @@ "speculative_model": "JackFram/llama-160m", "num_speculative_tokens": 3, "disable_logprobs_during_spec_decoding": False, + }, { + "speculative_model": "JackFram/llama-160m", + "num_speculative_tokens": 3, + "disable_logprobs_during_spec_decoding": True, }]) @pytest.mark.parametrize("batch_size", [8]) @pytest.mark.parametrize( @@ -41,16 +45,19 @@ def test_logprobs_equality(vllm_runner, common_llm_kwargs, seed: int, logprobs: int): """Verify output logprobs are equal with and without speculative decoding. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) @pytest.mark.parametrize( @@ -91,16 +98,18 @@ def test_logprobs_different_k(vllm_runner, common_llm_kwargs, output_len: int, seed: int, logprobs: int): """Veriy logprob greedy equality with different speculation lens. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) @pytest.mark.parametrize( @@ -143,16 +152,18 @@ def test_logprobs_when_skip_speculation(vllm_runner, common_llm_kwargs, seed: int, logprobs: int): """Verify logprobs greedy equality when some sequences skip speculation. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) @pytest.mark.parametrize( @@ -267,13 +278,15 @@ def test_logprobs_disabled(vllm_runner, common_llm_kwargs, """Check the behavior when logprobs are disabled. Token choices should match with the base model. """ - run_logprob_correctness_test(vllm_runner, - common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, - test_llm_kwargs, - batch_size, - output_len, - seed, - temperature=0.0, - logprobs=logprobs) + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + output_len, + seed, + temperature=0.0, + logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index 7cefe99d026c..8c90e147df23 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -87,6 +87,65 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, temperature=0.0) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 8, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, + seed: int, logprobs: int): + """Verify greedy equality with different batch size.""" + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 2d0d6fb923ad..7f3180befaff 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -16,7 +16,7 @@ * Test greedy equality under various number of speculative tokens. With those tests, we can say at least, MLPSpeculator would not break the -correctess for the target model outputs. +correctness for the target model outputs. """ from unittest.mock import patch @@ -88,6 +88,61 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, temperature=0.0) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": SPEC_MODEL, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": SPEC_MODEL, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [8]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + """Verify greedy equality with different batch size.""" + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 89301f24e115..850114eb7f5a 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -76,6 +76,65 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, temperature=0.0) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Print spec metrics. + "disable_log_stats": False, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [ + { + "model_name": "JackFram/llama-68m", + }, +]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [ + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "disable_logprobs_during_spec_decoding": False, + }, + { + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + "disable_logprobs_during_spec_decoding": True, + }, +]) +@pytest.mark.parametrize("output_len", [ + 8, +]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("logprobs", [1, 6]) +def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, output_len: int, seed: int, + logprobs: int): + """Verify greedy equality on a tiny model with different batch size.""" + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0, + logprobs=logprobs, + prompt_logprobs=logprobs, + disable_logprobs=test_llm_kwargs[ + 'disable_logprobs_during_spec_decoding']) + + @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/test_embedded_commit.py b/tests/test_embedded_commit.py index 17b01651e39a..ffeacf34b7ba 100644 --- a/tests/test_embedded_commit.py +++ b/tests/test_embedded_commit.py @@ -2,6 +2,7 @@ def test_embedded_commit_defined(): - assert vllm.__commit__ != "COMMIT_HASH_PLACEHOLDER" - # 7 characters is the length of a short commit hash - assert len(vllm.__commit__) >= 7 + assert hasattr(vllm, "__version__") + assert hasattr(vllm, "__version_tuple__") + assert vllm.__version__ != "dev" + assert vllm.__version_tuple__ != (0, 0, "dev") diff --git a/vllm/__init__.py b/vllm/__init__.py index 59af68fb493e..90363b3e49b7 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,21 +5,21 @@ from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptType, TextPrompt, TokensPrompt +from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import (CompletionOutput, EmbeddingOutput, EmbeddingRequestOutput, RequestOutput) from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams -from .version import __commit__, __version__ +from .version import __version__, __version_tuple__ __all__ = [ - "__commit__", "__version__", + "__version_tuple__", "LLM", "ModelRegistry", - "PromptType", + "PromptInputs", "TextPrompt", "TokensPrompt", "SamplingParams", diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 678700055c99..a71bafc974ad 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -438,7 +438,8 @@ def machete_gemm_fake( @torch.library.register_fake("_C::machete_prepack_B") def machete_prepack_B_fake(b_q_weight: torch.Tensor, b_type: ScalarType) -> torch.Tensor: - return torch.empty_like(b_q_weight) + return torch.empty_like(b_q_weight, + memory_format=torch.contiguous_format) @torch.library.register_fake("_C::causal_conv1d_fwd") def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, @@ -625,6 +626,22 @@ def machete_prepack_B(b_q_weight: torch.Tensor, return torch.ops._C.machete_prepack_B(b_q_weight, b_type) +# TODO: has to be a better way to do this +try: + torch.ops._C.permute_cols # noqa B018 + + @torch.library.register_fake("_C::permute_cols") + def _permute_cols_fake(a: torch.Tensor, + perm: torch.Tensor) -> torch.Tensor: + return torch.empty_like(a) +except Exception: + pass + + +def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: + return torch.ops._C.permute_cols(a, perm) + + # fp8 def scaled_fp8_quant( input: torch.Tensor, diff --git a/vllm/config.py b/vllm/config.py index df8c8419f4cc..108badf150c8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -51,6 +51,7 @@ "Qwen2ForCausalLM", "Qwen2MoeForCausalLM", "QWenLMHeadModel", + "Qwen2VLForConditionalGeneration", ] @@ -122,6 +123,8 @@ class ModelConfig: can not be gathered from the vllm arguments. config_format: The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'. + mm_processor_kwargs: Arguments to be forwarded to the model's processor + for multi-modal data, e.g., image processor. """ def __init__(self, @@ -150,7 +153,8 @@ def __init__(self, limit_mm_per_prompt: Optional[Mapping[str, int]] = None, use_async_output_proc: bool = True, override_neuron_config: Optional[Dict[str, Any]] = None, - config_format: ConfigFormat = ConfigFormat.AUTO) -> None: + config_format: ConfigFormat = ConfigFormat.AUTO, + mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -184,6 +188,7 @@ def __init__(self, self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.use_async_output_proc = use_async_output_proc + self.mm_processor_kwargs = mm_processor_kwargs # Set enforce_eager to False if the value is unset. if self.enforce_eager is None: @@ -217,6 +222,7 @@ def __init__(self, self._verify_embedding_mode() self._verify_quantization() self._verify_cuda_graph() + self._verify_bnb_config() def _init_multimodal_config( self, limit_mm_per_prompt: Optional[Mapping[str, int]] @@ -332,6 +338,28 @@ def _verify_cuda_graph(self) -> None: self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_model_len) + def _verify_bnb_config(self) -> None: + """ + The current version of bitsandbytes (0.44.0) with 8-bit models does not + yet support CUDA graph. + """ + is_bitsandbytes = self.quantization == "bitsandbytes" + has_quantization_config = (getattr(self.hf_config, + "quantization_config", None) + is not None) + is_8bit = (self.hf_config.quantization_config.get( + "load_in_8bit", False) if has_quantization_config else False) + if all([ + is_bitsandbytes, + has_quantization_config, + is_8bit, + not self.enforce_eager, + ]): + logger.warning( + "CUDA graph is not supported on BitAndBytes 8bit yet, " + "fallback to the eager mode.") + self.enforce_eager = True + def verify_async_output_proc(self, parallel_config, speculative_config, device_config) -> None: if not self.use_async_output_proc: @@ -396,13 +424,6 @@ def verify_with_parallel_config( "Pipeline parallelism is only supported for the following " f" architectures: {_PP_SUPPORTED_MODELS}.") - # Remove the constraint after the bitsandbytes issue is fixed: - # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308 - if self.quantization == "bitsandbytes" and self.enforce_eager is False: - logger.warning("CUDA graph is not supported on BitAndBytes yet, " - "fallback to the eager mode.") - self.enforce_eager = True - if pipeline_parallel_size > 1 and self.use_async_output_proc: logger.warning("Async output processor is not supported with " "pipeline parallelism currently. Disabling it.") @@ -942,7 +963,7 @@ class SchedulerConfig: workers instead of an entire data. It should be enabled only when SPMD worker architecture is enabled. I.e., VLLM_USE_RAY_SPMD_WORKER=1 - + policy: The scheduling policy to use. "fcfs" (default) or "priority". """ def __init__(self, @@ -957,7 +978,9 @@ def __init__(self, is_multimodal_model: bool = False, preemption_mode: Optional[str] = None, num_scheduler_steps: int = 1, - send_delta_data: bool = False) -> None: + multi_step_stream_outputs: bool = False, + send_delta_data: bool = False, + policy: str = "fcfs") -> None: if max_num_batched_tokens is None: if enable_chunked_prefill: # It is the values that have the best balance between ITL @@ -997,7 +1020,9 @@ def __init__(self, self.embedding_mode = embedding_mode self.preemption_mode = preemption_mode self.num_scheduler_steps = num_scheduler_steps + self.multi_step_stream_outputs = multi_step_stream_outputs self.send_delta_data = send_delta_data + self.policy = policy self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index c3fa95f57b73..b707d87c3af8 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -766,6 +766,79 @@ def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: else: return prompt_limit + def _get_priority(self, + seq_group: SequenceGroup) -> Tuple[Optional[int], float]: + """ Get the priority of the sequence group. + Highest preference to user-defined priority, followed by arrival time. + Args: + seq_group: The sequence group input. + Returns: + The priority of the sequence group. + """ + return seq_group.priority, seq_group.arrival_time + + def _schedule_priority_preemption( + self, + budget: SchedulingBudget, + ) -> int: + """Sorts waiting and running queue. Also, force preempt requests + from the running queue if their priority is lower. + Priority-based preemption is used with the priority policy. + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are scheduled. + Returns: + A count of priority-based preemptions. + """ + + waiting_queue = self.waiting + + running_queue = deque(sorted(self.running, key=self._get_priority)) + + blocks_to_swap_out: List[Tuple[int, int]] = [] + force_preemption_count = 0 + + if waiting_queue: + seq_group = waiting_queue.popleft() + num_new_seqs = seq_group.get_max_num_running_seqs() + num_new_tokens = self._get_num_new_tokens(seq_group, + SequenceStatus.WAITING, + False, budget) + + #Only preempt if priority inversion exists + while running_queue and self._get_priority( + running_queue[-1]) > self._get_priority(seq_group): + #Only preempt if waiting sequence cannot be allocated + can_allocate = self.block_manager.can_allocate(seq_group) + if (num_new_tokens and can_allocate == AllocStatus.OK + and budget.can_schedule(num_new_tokens=num_new_tokens, + num_new_seqs=num_new_seqs)): + break + + #Adjust budget to remove the victim sequence group + vseq_group = running_queue.pop() + num_running_tokens = self._get_num_new_tokens( + vseq_group, SequenceStatus.RUNNING, False, budget) + budget.subtract_num_batched_tokens(vseq_group.request_id, + num_running_tokens) + num_running_seqs = vseq_group.get_max_num_running_seqs() + budget.subtract_num_seqs(vseq_group.request_id, + num_running_seqs) + + #Preempt out the victim sequence group + self._preempt(vseq_group, blocks_to_swap_out, + PreemptionMode.RECOMPUTE) + waiting_queue.appendleft(vseq_group) + force_preemption_count += 1 + #Put the sequence back into the waiting queue + waiting_queue.appendleft(seq_group) + + waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) + + self.waiting = waiting_queue + self.running = running_queue + return force_preemption_count + def _schedule_prefills( self, budget: SchedulingBudget, @@ -917,6 +990,10 @@ def _schedule_default(self) -> SchedulerOutputs: curr_loras, enable_chunking=False) + if len(prefills.seq_groups + ) == 0 and self.scheduler_config.policy == "priority": + self._schedule_priority_preemption(budget) + # Don't schedule decodes if prefills are scheduled. # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running # only contains decode requests, not chunked prefills. diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index b507cd2e1cdd..7d526b25ed19 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -9,11 +9,12 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup +from zmq import IPV6 # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore import vllm.envs as envs from vllm.logger import init_logger -from vllm.utils import get_ip, get_open_port +from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL @@ -214,6 +215,8 @@ def __init__( self.remote_socket = context.socket(XPUB) self.remote_socket.setsockopt(XPUB_VERBOSE, True) remote_subscribe_port = get_open_port() + if is_valid_ipv6_address(connect_ip): + self.remote_socket.setsockopt(IPV6, 1) socket_addr = f"tcp://*:{remote_subscribe_port}" self.remote_socket.bind(socket_addr) @@ -274,6 +277,8 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.remote_socket = context.socket(SUB) self.remote_socket.setsockopt_string(SUBSCRIBE, "") + if is_valid_ipv6_address(handle.connect_ip): + self.remote_socket.setsockopt(IPV6, 1) socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}" logger.debug("Connecting to %s", socket_addr) self.remote_socket.connect(socket_addr) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4139eca9c183..0d4559e37742 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -145,6 +145,7 @@ class EngineArgs: max_cpu_loras: Optional[int] = None device: str = 'auto' num_scheduler_steps: int = 1 + multi_step_stream_outputs: bool = False ray_workers_use_nsight: bool = False num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 @@ -175,6 +176,7 @@ class EngineArgs: collect_detailed_traces: Optional[str] = None disable_async_output_proc: bool = False override_neuron_config: Optional[Dict[str, Any]] = None + mm_processor_kwargs: Optional[Dict[str, Any]] = None def __post_init__(self): if self.tokenizer is None: @@ -513,6 +515,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'e.g.: `image=16,video=2` allows a maximum of 16 ' 'images and 2 videos per prompt. Defaults to 1 for ' 'each modality.')) + parser.add_argument( + '--mm-processor-kwargs', + default=None, + type=json.loads, + help=('Overrides for the multimodal input mapping/processing,' + 'e.g., image processor. For example: {"num_crops": 4}.')) # LoRA related configs parser.add_argument('--enable-lora', @@ -588,6 +596,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help=('Maximum number of forward steps per ' 'scheduler call.')) + parser.add_argument( + '--multi-step-stream-outputs', + action='store_true', + help='If True, then multi-step will stream outputs for every step') parser.add_argument( '--scheduler-delay-factor', type=float, @@ -822,6 +834,7 @@ def create_model_config(self) -> ModelConfig: use_async_output_proc=not self.disable_async_output_proc, override_neuron_config=self.override_neuron_config, config_format=self.config_format, + mm_processor_kwargs=self.mm_processor_kwargs, ) def create_load_config(self) -> LoadConfig: @@ -991,6 +1004,7 @@ def create_engine_config(self) -> EngineConfig: is_multimodal_model=model_config.is_multimodal_model, preemption_mode=self.preemption_mode, num_scheduler_steps=self.num_scheduler_steps, + multi_step_stream_outputs=self.multi_step_stream_outputs, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), ) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f108751056ab..34e7e05341f0 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -17,7 +17,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptType +from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput @@ -405,7 +405,7 @@ async def stop_remote_worker_execution_loop_async(self) -> None: async def add_request_async( self, request_id: str, - prompt: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -420,7 +420,7 @@ async def add_request_async( arrival_time = time.time() preprocessed_inputs = await self.input_preprocessor.preprocess_async( - prompt, + inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -777,7 +777,7 @@ async def run_engine_loop(engine_ref: ReferenceType): async def add_request( self, request_id: str, - prompt: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, @@ -797,7 +797,7 @@ async def add_request( stream = self._request_tracker.add_request( request_id, verbose=self.log_requests, - prompt=prompt, + inputs=inputs, params=params, arrival_time=arrival_time or time.time(), lora_request=lora_request, @@ -808,7 +808,7 @@ async def add_request( async def generate( self, - prompt: PromptType, + inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -822,7 +822,8 @@ async def generate( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -880,7 +881,7 @@ async def generate( """ async for output in await self.add_request( request_id, - prompt, + inputs, sampling_params, lora_request=lora_request, trace_headers=trace_headers, @@ -890,7 +891,7 @@ async def generate( async def encode( self, - prompt: PromptType, + inputs: PromptInputs, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -903,7 +904,8 @@ async def encode( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -957,7 +959,7 @@ async def encode( """ async for output in await self.add_request( request_id, - prompt, + inputs, pooling_params, lora_request=lora_request, trace_headers=trace_headers, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ebaefe8cde06..768ac69c3692 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -29,7 +29,7 @@ from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, - InputRegistry, LLMInputs, PromptType) + InputRegistry, LLMInputs, PromptInputs) from vllm.inputs.preprocess import InputPreprocessor from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -95,7 +95,7 @@ class OutputData(NamedTuple): class SchedulerContext: - def __init__(self): + def __init__(self, multi_step_stream_outputs: bool = False): self.output_queue: Deque[OutputData] = deque() self.request_outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] @@ -103,6 +103,8 @@ def __init__(self): List[SequenceGroupMetadata]] = None self.scheduler_outputs: Optional[SchedulerOutputs] = None + self.multi_step_stream_outputs: bool = multi_step_stream_outputs + def append_output(self, outputs: List[SamplerOutput], seq_group_metadata_list: List[SequenceGroupMetadata], scheduler_outputs: SchedulerOutputs, is_async: bool, @@ -219,6 +221,7 @@ def __init__( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, + use_cached_outputs: bool = False, ) -> None: logger.info( "Initializing an LLM engine (v%s) with config: " @@ -234,8 +237,9 @@ def __init__( "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " - "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s)", + "num_scheduler_steps=%d, multi_step_stream_outputs=%s, " + "enable_prefix_caching=%s, use_async_output_proc=%s, " + "use_cached_outputs=%s, mm_processor_kwargs=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -266,8 +270,11 @@ def __init__( model_config.served_model_name, scheduler_config.use_v2_block_manager, scheduler_config.num_scheduler_steps, + scheduler_config.multi_step_stream_outputs, cache_config.enable_prefix_caching, model_config.use_async_output_proc, + use_cached_outputs, + model_config.mm_processor_kwargs, ) # TODO(woosuk): Print more configs in debug mode. from vllm.plugins import load_general_plugins @@ -286,6 +293,7 @@ def __init__( self.observability_config = observability_config or ObservabilityConfig( ) self.log_stats = log_stats + self.use_cached_outputs = use_cached_outputs if not self.model_config.skip_tokenizer_init: self.tokenizer = self._init_tokenizer() @@ -378,7 +386,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: ] self.scheduler_contexts = [ - SchedulerContext() + SchedulerContext(multi_step_stream_outputs=self.scheduler_config. + multi_step_stream_outputs) for _ in range(self.parallel_config.pipeline_parallel_size) ] @@ -622,6 +631,7 @@ def _add_processed_request( lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, ) -> None: self._validate_model_inputs(processed_inputs) # Create the sequences. @@ -652,7 +662,8 @@ def _add_processed_request( lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + priority=priority) elif isinstance(params, PoolingParams): seq_group = self._create_sequence_group_with_pooling( request_id, @@ -661,7 +672,8 @@ def _add_processed_request( arrival_time=arrival_time, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + priority=priority) else: raise ValueError( "Either SamplingParams or PoolingParams must be provided.") @@ -680,12 +692,13 @@ def stop_remote_worker_execution_loop(self) -> None: def add_request( self, request_id: str, - prompt: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: """Add a request to the engine's request pool. @@ -695,7 +708,8 @@ def add_request( Args: request_id: The unique ID of the request. - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. params: Parameters for sampling or pooling. :class:`~vllm.SamplingParams` for text generation. @@ -703,6 +717,8 @@ def add_request( arrival_time: The arrival time of the request. If None, we use the current monotonic time. trace_headers: OpenTelemetry trace headers. + priority: The priority of the request. + Only applicable with priority scheduling. Details: - Set arrival_time to the current time if it is None. @@ -731,11 +747,16 @@ def add_request( if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") + + if priority > 0 and not self.scheduler_config.policy == "priority": + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + if arrival_time is None: arrival_time = time.time() preprocessed_inputs = self.input_preprocessor.preprocess( - prompt, + inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -750,6 +771,7 @@ def add_request( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, + priority=priority, ) def _create_sequence_group_with_sampling( @@ -762,6 +784,7 @@ def _create_sequence_group_with_sampling( trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, encoder_seq: Optional[Sequence] = None, + priority: int = 0, ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -788,7 +811,8 @@ def _create_sequence_group_with_sampling( lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + priority=priority) return seq_group @@ -801,6 +825,7 @@ def _create_sequence_group_with_pooling( lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], encoder_seq: Optional[Sequence] = None, + priority: int = 0, ) -> SequenceGroup: """Creates a SequenceGroup with PoolingParams.""" # Defensive copy of PoolingParams, which are used by the pooler @@ -813,7 +838,8 @@ def _create_sequence_group_with_pooling( lora_request=lora_request, pooling_params=pooling_params, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + priority=priority) return seq_group def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: @@ -997,7 +1023,8 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - request_output = RequestOutputFactory.create(seq_group) + request_output = RequestOutputFactory.create( + seq_group, use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) @@ -1018,8 +1045,8 @@ def _process_model_outputs(self, for scheduler in self.scheduler: scheduler.free_finished_seq_groups() - # For multi-step, do not create outputs each iteration - if not is_last_step: + # For multi-step without streaming, don't create outputs each iteration + if not is_last_step and not ctx.multi_step_stream_outputs: # Immediately process request outputs here (if callback is given) if (finished_now and self.process_request_outputs_callback is not None): @@ -1036,17 +1063,27 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - request_output = RequestOutputFactory.create(seq_group) + request_output = RequestOutputFactory.create( + seq_group, use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) + # For multi-step with streaming, create outputs each iteration + if not is_last_step and ctx.multi_step_stream_outputs: + # Immediately process request outputs here (if callback is given) + if self.process_request_outputs_callback is not None: + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + for seq_group in scheduler_outputs.ignored_seq_groups: params = seq_group.sampling_params if params is not None and params.output_kind == ( RequestOutputKind.DELTA) and not seq_group.is_finished(): continue - request_output = RequestOutputFactory.create(seq_group) + request_output = RequestOutputFactory.create( + seq_group, use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 09aa279f1e22..165e6cc2146c 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -3,7 +3,7 @@ from typing import List, Mapping, Optional, Union from vllm import PoolingParams -from vllm.inputs import PromptType +from vllm.inputs import PromptInputs from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.prompt_adapter.request import PromptAdapterRequest @@ -23,7 +23,7 @@ class MQEngineDeadError(RuntimeError): @dataclass class RPCProcessRequest: - prompt: PromptType + inputs: PromptInputs params: Union[SamplingParams, PoolingParams] request_id: str lora_request: Optional[LoRARequest] = None @@ -43,10 +43,6 @@ class RPCAbortRequest: request_id: str -class RPCHealthRequest: - pass - - class RPCStartupRequest(Enum): IS_SERVER_READY = 1 @@ -56,8 +52,7 @@ class RPCStartupResponse: tracing_enabled: bool -RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCHealthRequest, - RPCStartupRequest] +RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest] REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 71099115ea12..7e397cf408fb 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -20,12 +20,11 @@ IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, RPC_REQUEST_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCError, RPCHealthRequest, - RPCProcessRequest, RPCStartupRequest, - RPCStartupResponse) + RPCError, RPCProcessRequest, + RPCStartupRequest, RPCStartupResponse) # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT -from vllm.inputs import PromptType +from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -95,9 +94,9 @@ def __init__(self, ipc_path: str, engine_config: EngineConfig): self.output_socket: Socket = self.context.socket(zmq.constants.PULL) self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") - # IPC path for ack of check_health requests. - self.health_socket: Socket = self.context.socket(zmq.constants.PULL) - self.health_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + # IPC path for acking heartbeats. + self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) + self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" @@ -124,34 +123,28 @@ def get_data_socket(self) -> Iterator[Socket]: finally: socket.close(linger=0) - async def run_check_health_loop(self, timeout: int): - """Background loop that continually probes the RPCServer for health. - - The loop sends CHECK_HEALTH requests to the INPUT_SOCKET, which - the MQLLMEngine server is blocking on. - - The Server replies on the HEALTH_SOCKET (rather than on the - OUTPUT_SOCKET such that the messages are not intermingled with - output streaming). + async def run_heartbeat_loop(self, timeout: int): + """Background loop that continually listens to the RPCServer for + heartbeats. """ - try: while True: - if await self.health_socket.poll(timeout=timeout) == 0: - # Wakeup every N seconds and do a health probe. - await self._send_one_way_rpc_request( - RPCHealthRequest(), self.input_socket) - - # Wait for ack from the health socket. - await self._await_ack(error_message="Health check failed.", - socket=self.health_socket) + if await self.heartbeat_socket.poll(timeout=timeout) == 0: + # No heartbeat was received. Set error and exit the loop + self._set_errored( + TimeoutError("No heartbeat received " + "from MQLLMEngine")) + logger.debug("Shutting down MQLLMEngineClient check " + "health loop due to timeout") + break + else: - # Server sent a health status message unprompted. + # Heartbeat received- check the message await self._check_success( - error_message="Health check failed.", - socket=self.health_socket) + error_message="Heartbeat failed.", + socket=self.heartbeat_socket) - logger.debug("Health probe successful.") + logger.debug("Heartbeat successful.") except asyncio.CancelledError: logger.debug("Shutting down MQLLMEngineClient check health loop.") @@ -234,7 +227,7 @@ async def setup(self): # Start health_loop. self.health_loop = asyncio.create_task( - self.run_check_health_loop(timeout=VLLM_RPC_TIMEOUT)) + self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) def close(self): """Destroy the ZeroMQ Context.""" @@ -375,7 +368,7 @@ def dead_error(self) -> BaseException: def generate( self, - prompt: PromptType, + inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -389,7 +382,8 @@ def generate( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. @@ -398,13 +392,13 @@ def generate( prompt_adapter_request: Prompt Adapter request to use for generation, if any. """ - return self._process_request(prompt, sampling_params, request_id, + return self._process_request(inputs, sampling_params, request_id, lora_request, trace_headers, prompt_adapter_request) def encode( self, - prompt: PromptType, + inputs: PromptInputs, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, @@ -417,7 +411,8 @@ def encode( from the LLMEngine to the caller. Args: - prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. @@ -428,12 +423,12 @@ def encode( The output `EmbeddingRequestOutput` objects from the LLMEngine for the request. """ - return self._process_request(prompt, pooling_params, request_id, + return self._process_request(inputs, pooling_params, request_id, lora_request, trace_headers) async def _process_request( self, - prompt: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], request_id: str, lora_request: Optional[LoRARequest] = None, @@ -466,7 +461,7 @@ async def _process_request( request_bytes = pickle.dumps( RPCProcessRequest( - prompt=prompt, + inputs=inputs, params=params, request_id=request_id, lora_request=lora_request, diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 788c1573ae25..b1dd9915cbbf 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -1,5 +1,7 @@ import pickle import signal +import threading +import time from contextlib import contextmanager from typing import Iterator, List, Optional, Union @@ -15,10 +17,10 @@ IPC_HEALTH_EXT, IPC_INPUT_EXT, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, VLLM_RPC_SUCCESS_STR, RPCAbortRequest, - RPCError, RPCHealthRequest, - RPCProcessRequest, RPCStartupRequest, - RPCStartupResponse) + RPCError, RPCProcessRequest, + RPCStartupRequest, RPCStartupResponse) # yapf: enable +from vllm.envs import VLLM_RPC_TIMEOUT from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext @@ -66,7 +68,14 @@ def __init__(self, *args, log_requests: bool = True, **kwargs) -> None: - self.engine = LLMEngine(*args, **kwargs) + # For MQLLMEngine, we can use cached outputs, since each new request + # output is immediately pickled and send over the socket, which frees + # the python object to be reused again. + use_cached_outputs = True + + self.engine = LLMEngine(*args, + **kwargs, + use_cached_outputs=use_cached_outputs) self.log_requests = log_requests self.use_async_sockets = use_async_sockets @@ -84,9 +93,9 @@ def __init__(self, self.output_socket = self.ctx.socket(zmq.constants.PUSH) self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") - # Send health status back to client. - self.health_socket = self.ctx.socket(zmq.constants.PUSH) - self.health_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + # Send heartbeats back to client. + self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) + self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") # IPC path for the data socket. self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" @@ -94,6 +103,20 @@ def __init__(self, # Error state. self._errored_with: Optional[BaseException] = None + # Heartbeat thread + self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop, + daemon=True) + self._heartbeat_stop_event = threading.Event() + # The heartbeat needs to be faster than what the client will wait for + # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds + self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0 + + self._last_alive_time = time.time() + # The heartbeats can tolerate a long period of the engine chugging + # away at a generation request. + # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds + self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0 + @property def dead_error(self) -> BaseException: if self._errored_with is not None: @@ -124,6 +147,8 @@ def start(self): try: logger.debug("Starting Startup Loop.") self.run_startup_loop() + logger.debug("Starting heartbeat thread") + self.heartbeat_thread.start() logger.debug("Starting Engine Loop.") self.run_engine_loop() except Exception as e: @@ -137,6 +162,7 @@ def start(self): def cleanup(self): """Cleanup zeromq state on shutdown.""" # Closes all sockets and destroys context. + self._heartbeat_stop_event.set() self.ctx.destroy(linger=0) del self.engine @@ -175,9 +201,11 @@ def run_engine_loop(self): """Core busy loop of the LLMEngine.""" while True: + self._alive() if not self.engine.has_unfinished_requests(): # Poll until there is work to do. while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + self._alive() self.engine.do_log_stats() logger.debug("Waiting for new requests in engine loop.") @@ -193,7 +221,6 @@ def run_engine_loop(self): def engine_step(self) -> List[RequestOutput]: """Engine step wrapper with error handling.""" - try: return self.engine.step() except SystemExit: @@ -222,10 +249,9 @@ def handle_new_input(self): self._handle_process_request(request) elif isinstance(request, RPCAbortRequest): self._handle_abort_request(request) - elif isinstance(request, RPCHealthRequest): - self._handle_health_request() else: - raise ValueError("Unknown RPCRequest Type: {request}") + raise ValueError("Unknown RPCRequest Type: " + f"{type(request)}") except Exception as e: self._set_errored(e) @@ -245,7 +271,7 @@ def _handle_process_request(self, request: RPCProcessRequest): try: self.engine.add_request( request_id=request_id, - prompt=request.prompt, + inputs=request.inputs, params=request.params, lora_request=request.lora_request, trace_headers=request.trace_headers, @@ -272,13 +298,32 @@ def _handle_abort_request(self, request: RPCAbortRequest): if self.log_requests: logger.info("Aborted request %s.", request.request_id) - def _handle_health_request(self): + def _heartbeat_loop(self): + while not self._heartbeat_stop_event.wait( + timeout=self.heartbeat_interval_seconds): + # Loops until the stop event is set + self._heartbeat() + + logger.debug("Exiting MQLLMEngine heartbeat thread") + + def _heartbeat(self): + # Send unhealthy if engine has already errored if self._errored_with is not None: self._send_unhealthy(self._errored_with) - # Raises error if unhealthy. - self.engine.check_health() - self._send_healthy() + # Check for life of the main loop + elif time.time() - self._last_alive_time > self.last_alive_threshold: + self._send_unhealthy(RuntimeError("Engine loop has died")) + + else: + # Otherwise- check health of the engine + # self.engine.check_health() raises on unhealthy + try: + self.engine.check_health() + self._send_healthy() + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): """Send List of RequestOutput to RPCClient.""" @@ -288,12 +333,14 @@ def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): def _send_healthy(self): """Send HEALTHY message to RPCClient.""" - self.health_socket.send_multipart(HEALTHY_RESPONSE, copy=False) + if not self.heartbeat_socket.closed: + self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False) def _send_unhealthy(self, error: BaseException): """Send UNHEALTHY message to RPCClient.""" - error_bytes = pickle.dumps(error) - self.health_socket.send_multipart((error_bytes, ), copy=False) + if not self.heartbeat_socket.closed: + error_bytes = pickle.dumps(error) + self.heartbeat_socket.send_multipart((error_bytes, ), copy=False) def _async_socket_engine_callback(self, request_outputs: REQUEST_OUTPUTS_T): @@ -306,6 +353,9 @@ def _set_errored(self, e: BaseException): if self._errored_with is None: self._errored_with = e + def _alive(self): + self._last_alive_time = time.time() + def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, ipc_path: str): diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index c73db765fc3b..31c2bbc8e712 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -9,8 +9,8 @@ from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.sampling_params import SamplingParams -from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, - SequenceOutput, SequenceStatus) +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Sequence, SequenceGroup, + SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import Counter @@ -110,10 +110,11 @@ def process_outputs(self, # we can take the first sample. samples = [output.samples[0] for output in outputs] - # -1 means the output token is not valid (eg. due to spec decode + # entries in sample tokens may be invalid (eg. due to spec decode # rejecting tokens). valid_samples = [ - sample for sample in samples if sample.output_token != -1 + sample for sample in samples + if sample.output_token != VLLM_INVALID_TOKEN_ID ] assert valid_samples diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index d0bbeb357b50..70444faa670a 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -3,7 +3,7 @@ from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs -from vllm.inputs.data import PromptType +from vllm.inputs.data import PromptInputs from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput @@ -35,19 +35,19 @@ def dead_error(self) -> BaseException: def generate( self, - prompt: PromptType, + inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[RequestOutput, None]: - """Generate outputs for a request.""" + """Generates outputs for a request""" ... def encode( self, - prompt: PromptType, + inputs: PromptInputs, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c7548ca4bcfb..77ae7b088398 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,6 +1,8 @@ +import itertools from contextlib import contextmanager -from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast, - overload) +from dataclasses import dataclass +from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, + Union, cast, overload) from tqdm import tqdm @@ -10,7 +12,7 @@ apply_hf_chat_template, apply_mistral_chat_template, parse_chat_messages) -from vllm.inputs import PromptType, TextPrompt, TokensPrompt +from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -30,6 +32,37 @@ logger = init_logger(__name__) +@dataclass +class BeamSearchSequence: + """A sequence for beam search. + It keeps track of the tokens and the log probability of the sequence. + The text field is optional and will only be filled when the sequence is + about to be returned to the user. + """ + # The tokens includes the prompt. + tokens: List[int] + cum_logprob: float = 0.0 + text: Optional[str] = None + + +@dataclass +class BeamSearchOutput: + """The output of beam search. + It contains the list of the best beam search sequences. + The length of the list is equal to the beam width. + """ + sequences: List[BeamSearchSequence] + + +class BeamSearchInstance: + + def __init__(self, prompt_tokens: List[int]): + self.beams: List[BeamSearchSequence] = [ + BeamSearchSequence(tokens=prompt_tokens) + ] + self.completed: List[BeamSearchSequence] = [] + + class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -134,6 +167,7 @@ def __init__( max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, disable_async_output_proc: bool = False, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> None: ''' @@ -174,6 +208,7 @@ def __init__( max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, disable_async_output_proc=disable_async_output_proc, + mm_processor_kwargs=mm_processor_kwargs, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args( @@ -258,8 +293,8 @@ def generate( @overload def generate( self, - prompts: Union[PromptType, Sequence[PromptType]], - /, + inputs: Union[PromptInputs, Sequence[PromptInputs]], + /, # We may enable `inputs` keyword after removing the old API *, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -276,7 +311,7 @@ def generate( ) def generate( self, - prompts: Union[Union[PromptType, Sequence[PromptType]], + prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], Optional[Union[str, List[str]]]] = None, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -285,7 +320,8 @@ def generate( lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, guided_options_request: Optional[Union[LLMGuidedOptions, - GuidedDecodingRequest]] = None + GuidedDecodingRequest]] = None, + priority: Optional[List[int]] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -294,9 +330,7 @@ def generate( into a single list and pass it to this method. Args: - prompts: The prompts to the LLM. You may pass a sequence of prompts - for batch inference. See :class:`~vllm.inputs.PromptType` - for more details about the format of each prompts. + inputs: A list of inputs to generate completions for. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. @@ -306,6 +340,8 @@ def generate( lora_request: LoRA request to use for generation, if any. prompt_adapter_request: Prompt Adapter request to use for generation, if any. + priority: The priority of the requests, if any. + Only applicable when priority scheduling policy is enabled. Returns: A list of ``RequestOutput`` objects containing the @@ -322,13 +358,12 @@ def generate( "models (XForCausalLM, XForConditionalGeneration).") if prompt_token_ids is not None: - parsed_prompts = self._convert_v1_inputs( + inputs = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], - prompts) + inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) if isinstance(guided_options_request, dict): if len(guided_options_request) > 1: @@ -343,18 +378,119 @@ def generate( sampling_params = SamplingParams() self._validate_and_add_requests( - prompts=parsed_prompts, + inputs=inputs, params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, - guided_options=guided_options_request) + guided_options=guided_options_request, + priority=priority) outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) + def beam_search( + self, + prompts: List[Union[str, List[int]]], + beam_width: int, + max_tokens: int, + ignore_eos: bool = False, + ) -> List[BeamSearchOutput]: + """ + Generate sequences using beam search. + + Args: + prompts: A list of prompts. Each prompt can be a string or a list + of token IDs. + beam_width: The number of beams to keep at each step. + max_tokens: The max number of tokens to generate for each prompt. + + TODO: how does beam search work together with length penalty, frequency + penalty, and stopping criteria, etc.? + """ + + tokenizer = self.get_tokenizer() + # generate 2 * beam_width candidates at each step + # following the huggingface transformers implementation + # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa + beam_search_params = SamplingParams(logprobs=2 * beam_width, + max_tokens=1, + temperature=0.0) + instances: List[BeamSearchInstance] = [] + + for prompt in prompts: + prompt_tokens = prompt if isinstance( + prompt, list) else tokenizer.encode(prompt) + instances.append(BeamSearchInstance(prompt_tokens)) + + for _ in range(max_tokens): + all_beams: List[BeamSearchSequence] = list( + sum((instance.beams for instance in instances), [])) + pos = [0] + list( + itertools.accumulate( + len(instance.beams) for instance in instances)) + instance_start_and_end: List[Tuple[int, int]] = list( + zip(pos[:-1], pos[1:])) + + if len(all_beams) == 0: + break + + prompts_batch = [ + TokensPrompt(prompt_token_ids=beam.tokens) + for beam in all_beams + ] + + # only runs for one step + # we don't need to use tqdm here + output = self.generate(prompts_batch, + sampling_params=beam_search_params, + use_tqdm=False) + + for (start, end), instance in zip(instance_start_and_end, + instances): + instance_new_beams = [] + for i in range(start, end): + current_beam = all_beams[i] + result = output[i] + + if result.outputs[0].logprobs is not None: + # if `result.outputs[0].logprobs` is None, it means + # the sequence is completed because of the max-model-len + # or abortion. we don't need to add it to the new beams. + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob) + + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + instance.completed.append(new_beam) + else: + instance_new_beams.append(new_beam) + sorted_beams = sorted(instance_new_beams, + key=lambda x: x.cum_logprob, + reverse=True) + instance.beams = sorted_beams[:beam_width] + + outputs = [] + for instance in instances: + instance.completed.extend(instance.beams) + sorted_completed = sorted(instance.completed, + key=lambda x: x.cum_logprob, + reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + beam.text = tokenizer.decode(beam.tokens) + outputs.append(BeamSearchOutput(sequences=best_beams)) + + return outputs + def chat( self, - messages: List[ChatCompletionMessageParam], + messages: Union[List[ChatCompletionMessageParam], + List[List[ChatCompletionMessageParam]]], sampling_params: Optional[Union[SamplingParams, List[SamplingParams]]] = None, use_tqdm: bool = True, @@ -374,8 +510,9 @@ def chat( to the OpenAI API. Args: - messages: A single conversation represented as a list of messages. - Each message is a dictionary with 'role' and 'content' keys. + messages: A list of conversations or a single conversation. + - Each conversation is represented as a list of messages. + - Each message is a dictionary with 'role' and 'content' keys. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. When it @@ -392,42 +529,56 @@ def chat( A list of ``RequestOutput`` objects containing the generated responses in the same order as the input messages. """ + list_of_messages: List[List[ChatCompletionMessageParam]] - tokenizer = self.get_tokenizer() - model_config = self.llm_engine.get_model_config() - - conversation, mm_data = parse_chat_messages(messages, model_config, - tokenizer) - - prompt_data: Union[str, List[int]] - if isinstance(tokenizer, MistralTokenizer): - prompt_data = apply_mistral_chat_template( - tokenizer, - messages=messages, - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - tools=tools, - ) + # Handle multi and single conversations + if is_list_of(messages, list): + # messages is List[List[...]] + list_of_messages = messages else: - prompt_data = apply_hf_chat_template( - tokenizer, - conversation=conversation, - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - tools=tools, - ) + # messages is List[...] + list_of_messages = [messages] + + prompts: List[Union[TokensPrompt, TextPrompt]] = [] + + for msgs in list_of_messages: + tokenizer = self.get_tokenizer() + model_config = self.llm_engine.get_model_config() + + conversation, mm_data = parse_chat_messages( + msgs, model_config, tokenizer) + + prompt_data: Union[str, List[int]] + if isinstance(tokenizer, MistralTokenizer): + prompt_data = apply_mistral_chat_template( + tokenizer, + messages=msgs, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + tools=tools, + ) + else: + prompt_data = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + tools=tools, + ) + + prompt: Union[TokensPrompt, TextPrompt] + if is_list_of(prompt_data, int): + prompt = TokensPrompt(prompt_token_ids=prompt_data) + else: + prompt = TextPrompt(prompt=prompt_data) - prompt: PromptType - if is_list_of(prompt_data, int): - prompt = TokensPrompt(prompt_token_ids=prompt_data) - else: - prompt = TextPrompt(prompt=prompt_data) + if mm_data is not None: + prompt["multi_modal_data"] = mm_data - if mm_data is not None: - prompt["multi_modal_data"] = mm_data + prompts.append(prompt) return self.generate( - prompt, + prompts, sampling_params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, @@ -497,8 +648,8 @@ def encode( @overload def encode( self, - prompts: Union[PromptType, Sequence[PromptType]], - /, + inputs: Union[PromptInputs, Sequence[PromptInputs]], + /, # We may enable `inputs` keyword after removing the old API *, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -515,7 +666,7 @@ def encode( ) def encode( self, - prompts: Union[Union[PromptType, Sequence[PromptType]], + prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], Optional[Union[str, List[str]]]] = None, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -531,9 +682,9 @@ def encode( into a single list and pass it to this method. Args: - prompts: The prompts to the LLM. You may pass a sequence of prompts - for batch inference. See :class:`~vllm.inputs.PromptType` - for more details about the format of each prompts. + inputs: The inputs to the LLM. You may pass a sequence of inputs for + batch inference. See :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. @@ -556,20 +707,19 @@ def encode( ) if prompt_token_ids is not None: - parsed_prompts = self._convert_v1_inputs( + inputs = self._convert_v1_inputs( prompts=cast(Optional[Union[str, List[str]]], prompts), prompt_token_ids=prompt_token_ids, ) else: - parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], - prompts) + inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) if pooling_params is None: # Use default pooling params. pooling_params = PoolingParams() self._validate_and_add_requests( - prompts=parsed_prompts, + inputs=inputs, params=pooling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -613,9 +763,9 @@ def _convert_v1_inputs( raise ValueError("Either prompts or prompt_token_ids must be " "provided.") - parsed_prompts: List[PromptType] = [] + inputs: List[PromptInputs] = [] for i in range(num_requests): - item: PromptType + item: PromptInputs if prompts is not None: item = TextPrompt(prompt=prompts[i]) @@ -624,24 +774,25 @@ def _convert_v1_inputs( else: raise AssertionError - parsed_prompts.append(item) + inputs.append(item) - return parsed_prompts + return inputs def _validate_and_add_requests( self, - prompts: Union[PromptType, Sequence[PromptType]], + inputs: Union[PromptInputs, Sequence[PromptInputs]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], prompt_adapter_request: Optional[PromptAdapterRequest], guided_options: Optional[GuidedDecodingRequest] = None, + priority: Optional[List[int]] = None, ) -> None: - if isinstance(prompts, (str, dict)): + if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. - prompts = [prompts] + inputs = [inputs] - num_requests = len(prompts) + num_requests = len(inputs) if isinstance(params, list) and len(params) != num_requests: raise ValueError("The lengths of prompts and params " "must be the same.") @@ -658,29 +809,32 @@ def _validate_and_add_requests( sp.output_kind = RequestOutputKind.FINAL_ONLY # Add requests to the engine. - for i, prompt in enumerate(prompts): + for i, request_inputs in enumerate(inputs): self._add_request( - prompt, + request_inputs, params[i] if isinstance(params, Sequence) else params, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, prompt_adapter_request=prompt_adapter_request, + priority=priority[i] if priority else 0, ) def _add_request( self, - prompt: PromptType, + inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request( request_id, - prompt, + inputs, params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, + priority=priority, ) def _add_guided_processor( diff --git a/vllm/envs.py b/vllm/envs.py index 43c7aa8af85b..705d858e71a6 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -62,6 +62,7 @@ VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False + VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: bool = False def get_default_cache_root(): @@ -195,6 +196,10 @@ def get_default_config_root(): lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")), + # If set, allowing the use of deprecated beam search implementation + "VLLM_ALLOW_DEPRECATED_BEAM_SEARCH": + lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BEAM_SEARCH", "0") == "1", + # Internal flag to enable Dynamo graph capture "VLLM_TEST_DYNAMO_GRAPH_CAPTURE": lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index ba1bef1ab3ec..0b08e9691f91 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,5 +1,5 @@ from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptType, SingletonPrompt, TextPrompt, + LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, TokensPrompt, build_explicit_enc_dec_prompt, to_enc_dec_tuple_list, zip_enc_dec_prompts) from .registry import InputContext, InputRegistry @@ -16,8 +16,8 @@ __all__ = [ "TextPrompt", "TokensPrompt", - "PromptType", - "SingletonPrompt", + "PromptInputs", + "SingletonPromptInputs", "ExplicitEncoderDecoderPrompt", "LLMInputs", "EncoderDecoderLLMInputs", diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index b8c8646c4084..a71e9a7b5db6 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -33,7 +33,7 @@ class TokensPrompt(TypedDict): """ -SingletonPrompt = Union[str, TextPrompt, TokensPrompt] +SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt] """ Set of possible schemas for a single LLM input: @@ -46,7 +46,7 @@ class TokensPrompt(TypedDict): the user desires to express both the encoder & decoder prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt` -A prompt of type :class:`SingletonPromptType` may be employed +A prompt of type :class:`SingletonPromptInputs` may be employed as (1) input to a decoder-only model, (2) input to the encoder of an encoder/decoder model, in the scenario where the decoder-prompt is not specified explicitly, or @@ -55,12 +55,12 @@ class TokensPrompt(TypedDict): """ _T1_co = TypeVar("_T1_co", - bound=SingletonPrompt, - default=SingletonPrompt, + bound=SingletonPromptInputs, + default=SingletonPromptInputs, covariant=True) _T2_co = TypeVar("_T2_co", - bound=SingletonPrompt, - default=SingletonPrompt, + bound=SingletonPromptInputs, + default=SingletonPromptInputs, covariant=True) @@ -72,7 +72,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): The encoder and decoder prompts, respectively, may formatted according to any of the - :class:`SingletonPromptType` schemas, and are not + :class:`SingletonPromptInputs` schemas, and are not required to have the same schema. Only the encoder prompt may have multi-modal data. @@ -81,7 +81,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): be used as an input to a decoder-only model, and that the `encoder_prompt` and `decoder_prompt` fields of this data structure themselves must be - :class:`SingletonPromptType` instances. + :class:`SingletonPromptInputs` instances. """ encoder_prompt: _T1_co @@ -89,7 +89,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): decoder_prompt: Optional[_T2_co] -PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] +PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt] """ Set of possible schemas for an LLM input, including both decoder-only and encoder/decoder input types: @@ -146,8 +146,12 @@ class EncoderDecoderLLMInputs(LLMInputs): """ -_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) -_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) +_T1 = TypeVar("_T1", + bound=SingletonPromptInputs, + default=SingletonPromptInputs) +_T2 = TypeVar("_T2", + bound=SingletonPromptInputs, + default=SingletonPromptInputs) def build_explicit_enc_dec_prompt( diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py index e5fa1e418427..ac9d355c64c8 100644 --- a/vllm/inputs/parse.py +++ b/vllm/inputs/parse.py @@ -5,7 +5,7 @@ from vllm.utils import is_list_of from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, - LLMInputs, PromptType, SingletonPrompt, TextPrompt, + LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, TokensPrompt) @@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict): def parse_singleton_prompt( - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, ) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: - if isinstance(prompt, str): - return ParsedStrPrompt(type="str", content=prompt) - elif isinstance(prompt, dict): - if "prompt_token_ids" in prompt: + if isinstance(inputs, str): + return ParsedStrPrompt(type="str", content=inputs) + elif isinstance(inputs, dict): + if "prompt_token_ids" in inputs: return ParsedTokensPrompt(type="tokens", - content=prompt) # type: ignore - elif "prompt" in prompt: - return ParsedTextPrompt(type="text", content=prompt) + content=inputs) # type: ignore + elif "prompt" in inputs: + return ParsedTextPrompt(type="text", content=inputs) raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") def is_explicit_encoder_decoder_prompt( - prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: - return isinstance(prompt, dict) and "encoder_prompt" in prompt + inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]: + return isinstance(inputs, dict) and "encoder_prompt" in inputs def is_valid_encoder_decoder_llm_inputs( diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index 22f65ed5a324..bee3d1ed75cb 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -9,8 +9,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType, - SingletonPrompt) +from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, + SingletonPromptInputs) from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt if TYPE_CHECKING: @@ -207,7 +207,7 @@ async def _tokenize_prompt_async( def _extract_prompt_components( self, - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: @@ -217,7 +217,7 @@ def _extract_prompt_components( Arguments: * request_id - * prompt: single encoder or decoder input prompt + * inputs: single encoder or decoder input prompt * lora_request: this is only valid for decoder prompts Returns: @@ -227,24 +227,24 @@ def _extract_prompt_components( * multi_modal_data ''' - parsed = parse_singleton_prompt(prompt) + parsed = parse_singleton_prompt(inputs) if parsed["type"] == "str": - prompt_text = parsed["content"] + prompt = parsed["content"] prompt_token_ids = self._tokenize_prompt( - prompt_text, + prompt, request_id=request_id, lora_request=lora_request, ) multi_modal_data = None elif parsed["type"] == "tokens": - prompt_text = None + prompt = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") elif parsed["type"] == "text": - prompt_text = parsed["content"]["prompt"] + prompt = parsed["content"]["prompt"] prompt_token_ids = self._tokenize_prompt( - prompt_text, + prompt, request_id=request_id, lora_request=lora_request, ) @@ -252,33 +252,33 @@ def _extract_prompt_components( else: assert_never(parsed) - return prompt_text, prompt_token_ids, multi_modal_data + return prompt, prompt_token_ids, multi_modal_data async def _extract_prompt_components_async( self, - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> PromptComponents: """Async version of :meth:`_extract_prompt_components`.""" - parsed = parse_singleton_prompt(prompt) + parsed = parse_singleton_prompt(inputs) if parsed["type"] == "str": - prompt_text = parsed["content"] + prompt = parsed["content"] prompt_token_ids = await self._tokenize_prompt_async( - prompt_text, + prompt, request_id=request_id, lora_request=lora_request, ) multi_modal_data = None elif parsed["type"] == "tokens": - prompt_text = None + prompt = None prompt_token_ids = parsed["content"]["prompt_token_ids"] multi_modal_data = parsed["content"].get("multi_modal_data") elif parsed["type"] == "text": - prompt_text = parsed["content"]["prompt"] + prompt = parsed["content"]["prompt"] prompt_token_ids = await self._tokenize_prompt_async( - prompt_text, + prompt, request_id=request_id, lora_request=lora_request, ) @@ -286,7 +286,7 @@ async def _extract_prompt_components_async( else: assert_never(parsed) - return prompt_text, prompt_token_ids, multi_modal_data + return prompt, prompt_token_ids, multi_modal_data def _build_enc_dec_llm_inputs( self, @@ -319,7 +319,7 @@ def _build_enc_dec_llm_inputs( def _process_encoder_decoder_prompt( self, - prompt: PromptType, + inputs: PromptInputs, request_id: str, ) -> EncoderDecoderLLMInputs: ''' @@ -347,7 +347,7 @@ def _process_encoder_decoder_prompt( Arguments: - * prompt: an input prompt + * inputs: an input prompt * request_id Returns: @@ -358,13 +358,13 @@ def _process_encoder_decoder_prompt( encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents - if is_explicit_encoder_decoder_prompt(prompt): + if is_explicit_encoder_decoder_prompt(inputs): encoder_comps = self._extract_prompt_components( - prompt["encoder_prompt"], + inputs["encoder_prompt"], request_id=request_id, ) - if (decoder_input := prompt["decoder_prompt"]) is None: + if (decoder_input := inputs["decoder_prompt"]) is None: decoder_comps = None, None, None else: decoder_comps = self._extract_prompt_components( @@ -373,7 +373,7 @@ def _process_encoder_decoder_prompt( ) else: encoder_comps = self._extract_prompt_components( - prompt, + inputs, request_id=request_id, ) @@ -383,20 +383,20 @@ def _process_encoder_decoder_prompt( async def _process_encoder_decoder_prompt_async( self, - prompt: PromptType, + inputs: PromptInputs, request_id: str, ) -> EncoderDecoderLLMInputs: """Async version of :meth:`_process_encoder_decoder_prompt`.""" encoder_comps: PromptComponents decoder_comps: DecoderPromptComponents - if is_explicit_encoder_decoder_prompt(prompt): + if is_explicit_encoder_decoder_prompt(inputs): encoder_task = self._extract_prompt_components_async( - prompt["encoder_prompt"], + inputs["encoder_prompt"], request_id=request_id, ) - if (decoder_input := prompt["decoder_prompt"]) is None: + if (decoder_input := inputs["decoder_prompt"]) is None: encoder_comps = await encoder_task decoder_comps = None, None, None else: @@ -409,7 +409,7 @@ async def _process_encoder_decoder_prompt_async( encoder_task, decoder_task) else: encoder_comps = await self._extract_prompt_components_async( - prompt, + inputs, request_id=request_id, ) @@ -433,7 +433,7 @@ def _build_decoder_only_llm_inputs( def _process_decoder_only_prompt( self, - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -444,7 +444,7 @@ def _process_decoder_only_prompt( Arguments: - * prompt: input prompt + * inputs: input prompt * request_id * lora_request * prompt_adapter_request @@ -455,7 +455,7 @@ def _process_decoder_only_prompt( ''' prompt_comps = self._extract_prompt_components( - prompt, + inputs, request_id=request_id, lora_request=lora_request, ) @@ -467,14 +467,14 @@ def _process_decoder_only_prompt( async def _process_decoder_only_prompt_async( self, - prompt: SingletonPrompt, + inputs: SingletonPromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: """Async version of :meth:`_process_decoder_only_prompt`.""" prompt_comps = await self._extract_prompt_components_async( - prompt, + inputs, request_id=request_id, lora_request=lora_request, ) @@ -486,7 +486,7 @@ async def _process_decoder_only_prompt_async( def preprocess( self, - prompt: PromptType, + inputs: PromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -496,17 +496,17 @@ def preprocess( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return self._process_encoder_decoder_prompt( - prompt, + inputs, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(prompt): + if is_explicit_encoder_decoder_prompt(inputs): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return self._process_decoder_only_prompt( - prompt, + inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, @@ -514,7 +514,7 @@ def preprocess( async def preprocess_async( self, - prompt: PromptType, + inputs: PromptInputs, request_id: str, lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, @@ -524,17 +524,17 @@ async def preprocess_async( # Encoder-decoder model requires special mapping of # input prompts to encoder & decoder return await self._process_encoder_decoder_prompt_async( - prompt, + inputs, request_id=request_id, ) - if is_explicit_encoder_decoder_prompt(prompt): + if is_explicit_encoder_decoder_prompt(inputs): raise ValueError("Cannot pass encoder-decoder prompt " "to decoder-only models") # Decoder-only operation return await self._process_decoder_only_prompt_async( - prompt, + inputs, request_id=request_id, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index c64a65b89fd3..159d958ebf67 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -9,6 +9,7 @@ from typing_extensions import TypeVar from vllm.logger import init_logger +from vllm.utils import get_allowed_kwarg_only_overrides from .data import LLMInputs @@ -68,12 +69,17 @@ def __call__( ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int], + **mm_processor_kwargs: Any, ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: """ Create dummy data to be inputted into the model. Note: :data:`InputProcessor` is not applied to the dummy data. + + The :code:`mm_processor_kwargs` are overrides provided at + initialization time to values in the config whose values + may affect the number of tokens per instance. """ ... @@ -154,6 +160,10 @@ def wrapper(model_cls: N) -> N: return wrapper + def _get_dummy_data_factory(self, model_cls: Type[nn.Module]): + return self._dummy_factories_by_model_type \ + .get(model_cls, self._default_dummy_data_factory) + def register_dummy_encoder_data(self, factory: DummyDataFactory): """ Register a dummy encoder data factory to a model class @@ -174,6 +184,18 @@ def wrapper(model_cls: N) -> N: return wrapper + def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]): + if model_cls in self._dummy_encoder_factories_by_model_type: + dummy_factory = self._dummy_encoder_factories_by_model_type[ + model_cls] + else: + logger.warning( + "No dummy encoder data factory registered to %s. " + "Using the dummy data factory for the model instead.", + model_cls) + dummy_factory = self._get_dummy_data_factory(model_cls) + return dummy_factory + def dummy_data_for_profiling( self, model_config: "ModelConfig", @@ -198,26 +220,16 @@ def dummy_data_for_profiling( model_cls, _ = get_model_architecture(model_config) if is_encoder_data: - if model_cls in self._dummy_encoder_factories_by_model_type: - dummy_factory = self._dummy_encoder_factories_by_model_type[ - model_cls] - else: - logger.warning( - "No dummy encoder data factory registered to %s. " - "Using the dummy data factory for the model instead.", - model_cls) - dummy_factory = self._dummy_factories_by_model_type \ - .get(model_cls, self._default_dummy_data_factory) + dummy_factory = self._get_dummy_encoder_data_factory(model_cls) else: - dummy_factory = self._dummy_factories_by_model_type \ - .get(model_cls, self._default_dummy_data_factory) + dummy_factory = self._get_dummy_data_factory(model_cls) mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + dummy_factory, overrides=model_config.mm_processor_kwargs) - seq_data, mm_data = dummy_factory( - InputContext(model_config), - seq_len, - _MultiModalCounts(mm_counts), - ) + seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len, + _MultiModalCounts(mm_counts), + **mm_processor_kwargs) # Having more tokens is over-conservative but otherwise fine num_tokens = seq_data.prompt_token_ids @@ -269,6 +281,10 @@ def wrapper(model_cls: N) -> N: return wrapper + def _get_model_input_processor(self, model_cls: Type[nn.Module]): + return self._input_processors_by_model_type \ + .get(model_cls, self._default_input_processor) + def process_input(self, model_config: "ModelConfig", inputs: LLMInputs) -> LLMInputs: """ @@ -283,15 +299,17 @@ def process_input(self, model_config: "ModelConfig", from vllm.model_executor.model_loader import get_model_architecture model_cls, _ = get_model_architecture(model_config) + processor = self._get_model_input_processor(model_cls) - processor = self._input_processors_by_model_type \ - .get(model_cls, self._default_input_processor) + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + processor, overrides=model_config.mm_processor_kwargs) - return processor(InputContext(model_config), inputs) + return processor(InputContext(model_config), inputs, + **mm_processor_kwargs) def create_input_processor(self, model_config: "ModelConfig"): """ - Create an input processor (see :meth:`process_input`) for a + Create an input processor (see :meth:`_process_input`) for a specific model. """ return functools.partial(self.process_input, model_config) diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/bgmv_expand.py index 619408b9315c..6a32387a6f36 100644 --- a/vllm/lora/ops/bgmv_expand.py +++ b/vllm/lora/ops/bgmv_expand.py @@ -100,7 +100,7 @@ def _bgmv_expand( corresponding to each batch, An index of -1 means no lora should be applied. batches (int): batch size - add_inputs (bool, optional): Defaults to False. adds the final lora + add_inputs (bool, optional): Defaults to False, adds the final lora results to the output. """ assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/bgmv_expand_slice.py index c16db233891a..73628fd20d32 100644 --- a/vllm/lora/ops/bgmv_expand_slice.py +++ b/vllm/lora/ops/bgmv_expand_slice.py @@ -104,7 +104,7 @@ def _bgmv_expand_slice( lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch, An index of -1 means no lora should be applied. - slice_offst (int): output_tensor's offst + slice_offset (int): output_tensor's offset slice_size (int): current output_tensor's size batches (int): batch size add_inputs (bool, optional): Defaults to False. diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index c71332d8bdfb..adb3ab5b46b8 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -106,6 +106,7 @@ def _sgmv_expand( lora_indices_tensor: torch.Tensor, batches: int, max_seq_length: int, + token_nums: int, add_inputs: bool = False, ) -> None: """ @@ -115,17 +116,19 @@ def _sgmv_expand( output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index - into sequence. E.g.,if the sequence length is [4, 6], it is + into sequence. E.g., if the sequence length is [4, 6], it is [0, 4, 10]. - seq_len_tensor (torch.Tensor): (batch_size,). record the sequence - length of the sequences in the batch + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch. An index of -1 means no lora should be applied. batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences - in the batch - add_inputs (bool, optional): Defaults to False. adds the final lora + max_seq_length (int): The max sequence lengths of the sequences in the + batch. + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + add_inputs (bool, optional): Defaults to False, adds the final lora results to the output. """ @@ -134,6 +137,7 @@ def _sgmv_expand( torch.float16, torch.bfloat16, ] + assert inputs.size(0) == token_nums assert inputs.size(1) == lora_b_weights.size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index b4ae9a2acbb5..efa234520ab8 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -112,6 +112,7 @@ def _sgmv_expand_slice( lora_indices_tensor: torch.Tensor, batches: int, max_seq_length: int, + token_nums: int, slice_offset: int, slice_size: int, add_inputs: bool = False, @@ -124,20 +125,22 @@ def _sgmv_expand_slice( output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index - into sequence. E.g.,if the sequence length is [4, 6], it is + into sequence. E.g., if the sequence length is [4, 6], it is [0, 4, 10]. - seq_len_tensor (torch.Tensor): (batch_size,). record the sequence - length of the sequences in the batch + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch. An index of -1 means no lora should be applied. batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences + max_seq_length (int): The max sequence lengths of the sequences in the batch - slice_offst (int): output_tensor's offst + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + slice_offset (int): output_tensor's offset slice_size (int): current output_tensor's size - add_inputs (bool, optional): Defaults to False. adds the final lora - results to the output.. + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. """ assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] @@ -145,6 +148,7 @@ def _sgmv_expand_slice( torch.float16, torch.bfloat16, ] + assert inputs.size(0) == token_nums assert inputs.size(1) == lora_b_weights.size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index c0791c260e91..c003f3dc0ce9 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -110,6 +110,7 @@ def _sgmv_shrink( lora_indices_tensor: torch.Tensor, batches: int, max_seq_length: int, + token_nums: int, scaling: float, ) -> None: """ @@ -120,17 +121,19 @@ def _sgmv_shrink( output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index - into sequence. E.g.,if the sequence length is [4, 6], it is + into sequence. E.g., if the sequence length is [4, 6], it is [0, 4]. - seq_len_tensor (torch.Tensor): (batch_size,). record the sequence - length of the sequences in the batch + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index corresponding to each batch. An index of -1 means no lora should be applied. batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences - in the batch - scaling (float): Scaling factor. + max_seq_length (int): The max sequence lengths of the sequences in the + batch. + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + scaling (float): Scaling factor. """ assert inputs.dtype == lora_a_weights.dtype assert inputs.dtype in [torch.float16, torch.bfloat16] @@ -138,6 +141,7 @@ def _sgmv_shrink( torch.float16, torch.bfloat16, ] + assert inputs.size(0) == token_nums assert inputs.size(1) == lora_a_weights.size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 6d5c83429996..5033ce412692 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -27,7 +27,7 @@ def compute_meta( token_lora_tensor: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: """ Get the information required for the sgmv kernel. With the features: 1. If consecutive requests in the batch use the same LoRA, this function @@ -43,7 +43,7 @@ def compute_meta( b_seq_start_tensor = torch.zeros_like(seq_length_tensor) b_seq_start_tensor[1:].copy_(cum_result[:-1]) max_length = seq_length_tensor.max().item() - + token_nums = seq_length_tensor.sum().item() batch_size = lora_indices_tensor.size(0) no_lora = False # -1 means no lora should be applied. Use `no_lora` to determine whether @@ -52,7 +52,7 @@ def compute_meta( if batch_size == 1 and lora_indices_tensor == -1: no_lora = True return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, no_lora) + batch_size, max_length, token_nums, no_lora) # TODO see if this can be vectorized @@ -178,7 +178,7 @@ def convert_mapping( class PunicaWrapper: """ PunicaWrapper is designed to manage and provide metadata for the punica - kernel. The main function is to maintain the state information for + kernel. The main function is to maintain the state information for Multi-LoRA, and to provide the interface for the punica kernel. """ @@ -216,6 +216,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, dtype=torch.long, device=device) self.max_length: int = 0 + self.token_nums: int = 0 self.batch_size: int = -1 self.is_prefill = False self.no_lora = False @@ -276,13 +277,13 @@ def _update_base_metadata( long_lora_offsets_tensor) else: self._long_lora_indices.zero_() - self.indices_len[:] = indices_len def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, - batch_size, max_length, no_lora) = compute_meta(token_lora_tensor) + batch_size, max_length, token_nums, + no_lora) = compute_meta(token_lora_tensor) self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( b_seq_start_tensor) @@ -291,25 +292,28 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: lora_indices_tensor) self.batch_size = batch_size self.max_length = max_length + self.token_nums = token_nums self.no_lora = no_lora @property def prefill_metadata( - self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]: + self + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: """ This property provides a convenient way to access the necessary metadata for prefill-related kernel computations. - 1. seq_start_locs: Tensor of sequence start positions - 2. seq_lengths: Tensor of sequence lengths + 1. seq_start_locs: Tensor of sequence start positions. + 2. seq_lengths: Tensor of sequence lengths. 3. lora_indices_per_batch: Tensor of lora indices, and an index of -1 means no lora should be applied. - 4. batch_size: batch size after clustering identical lora indices - 5. max_length: The maximum sequence length in the batch + 4. batch_size: Batch size after clustering identical lora indices. + 5. max_length: The maximum sequence length in the batch. + 6. token_nums: The token numbers in the batch. """ return (self._seq_start_locs[:self.batch_size], self._seq_lengths[:self.batch_size], self._lora_indices_per_batch[:self.batch_size], - self.batch_size, self.max_length) + self.batch_size, self.max_length, self.token_nums) @property def token_lora_indices(self) -> torch.Tensor: @@ -324,7 +328,7 @@ def token_lora_indices(self) -> torch.Tensor: def sampler_indices(self) -> torch.Tensor: """ This property is used to access the lora indices specifically for - LogitsProcessorWithLoRA + LogitsProcessorWithLoRA. """ sampler_indices_len = self.indices_len[1] return self._sampler_indices[:sampler_indices_len] @@ -332,7 +336,7 @@ def sampler_indices(self) -> torch.Tensor: @property def sampler_indices_padded(self) -> torch.Tensor: """ - This property provides access to padded sampler indices + This property provides access to padded sampler indices. """ indices_padded_len = self.indices_len[2] return self._sampler_indices_padded[:indices_padded_len] @@ -341,7 +345,7 @@ def sampler_indices_padded(self) -> torch.Tensor: def embeddings_indices(self) -> torch.Tensor: """ This property provides access to the indices used for lora embeddings, - specifically for VocabParallelEmbeddingWithLoRA + specifically for VocabParallelEmbeddingWithLoRA. """ embeddings_indices_len = self.indices_len[3] return self._embeddings_indices[:, :embeddings_indices_len] @@ -350,7 +354,7 @@ def embeddings_indices(self) -> torch.Tensor: def long_lora_indices(self) -> torch.Tensor: """ This property provides access to the indices used for long context - lora, specifically for LinearScalingRotaryEmbeddingWithLora + lora, specifically for LinearScalingRotaryEmbeddingWithLora. """ long_lora_len = self.indices_len[4] return self._long_lora_indices[:long_lora_len] @@ -524,7 +528,7 @@ def add_lora(self, scale (float): Scaling factor. y_offset (Optional[int], optional): Offset to apply to the starting column of y. - y_slice_size (Optional[int], optional): Size of the y column slice.. + y_slice_size (Optional[int], optional): Size of the y column slice. buffer (Optional[torch.Tensor], optional): Defaults to None. """ y_org = y diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index eed01953fb4a..fe33b7341fd3 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -7,10 +7,11 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) + verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) @@ -231,7 +232,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qweight", marlin_qweight) + replace_parameter(layer, "qweight", marlin_qweight) # Permute scales from AWQ format to marlin format. marlin_scales = marlin_permute_scales( @@ -239,7 +240,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, group_size=self.quant_config.group_size) - replace_tensor(layer, "scales", marlin_scales) + replace_parameter(layer, "scales", marlin_scales) # Permute zero-points from AWQ format to marlin format. marlin_zp = awq_to_marlin_zero_points( @@ -247,7 +248,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=layer.num_groups, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits) - replace_tensor(layer, "qzeros", marlin_zp) + replace_parameter(layer, "qzeros", marlin_zp) # Not-used layer.g_idx = marlin_make_empty_g_idx(device) diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 66bc5395dbd7..38495d5a5a86 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -121,12 +121,12 @@ class BitsAndBytesLinearMethod(LinearMethodBase): def __init__(self, quant_config: BitsAndBytesConfig): try: import bitsandbytes - if bitsandbytes.__version__ < "0.42.0": + if bitsandbytes.__version__ < "0.44.0": raise ImportError("bitsandbytes version is wrong. Please " - "install bitsandbytes>=0.42.0.") + "install bitsandbytes>=0.44.0.") except ImportError as err: - raise ImportError("Please install bitsandbytes>=0.42.0 via " - "`pip install bitsandbytes>=0.42.0` to use " + raise ImportError("Please install bitsandbytes>=0.44.0 via " + "`pip install bitsandbytes>=0.44.0` to use " "bitsandbytes quantizer.") from err self.quant_config = quant_config diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 3cade3d3fbcd..cb65557be8f9 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -1,17 +1,16 @@ -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Set import torch -from vllm import _custom_ops as ops +from vllm.logger import init_logger from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( ActivationOrdering) +from vllm.model_executor.layers.quantization.kernels import ( + MPLinearLayerConfig, choose_mp_linear_kernel) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, marlin_repeat_scales_on_all_ranks, - marlin_sort_g_idx, replace_tensor, verify_marlin_supported, - verify_marlin_supports_shape) + marlin_repeat_scales_on_all_ranks) from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -19,6 +18,8 @@ RowvLLMParameter) from vllm.scalar_type import scalar_types +logger = init_logger(__name__) + __all__ = ["CompressedTensorsWNA16"] WNA16_SUPPORTED_TYPES_MAP = { 4: scalar_types.uint4b8, @@ -28,6 +29,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): + _kernel_backends_being_used: Set[str] = set() def __init__(self, strategy: str, @@ -52,35 +54,43 @@ def __init__(self, self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits] - # Verify supported on platform. - verify_marlin_supported(quant_type=self.quant_type, - group_size=self.group_size) - @classmethod def get_min_capability(cls) -> int: # ampere and up return 80 - def create_weights(self, layer: torch.nn.Module, input_size: int, - output_partition_sizes: List[int], + def create_weights(self, layer: torch.nn.Module, output_size: int, + input_size: int, output_partition_sizes: List[int], input_size_per_partition: int, params_dtype: torch.dtype, weight_loader: Callable, **kwargs): output_size_per_partition = sum(output_partition_sizes) + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_type, + act_type=params_dtype, + group_size=self.group_size, + zero_points=False, + has_g_idx=self.has_g_idx + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsWNA16", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + # If group_size is -1, we are in channelwise case. group_size = self.group_size if self.group_size != -1 else input_size row_parallel = (input_size != input_size_per_partition) partition_scales = not marlin_repeat_scales_on_all_ranks( self.has_g_idx, self.group_size, row_parallel) - verify_marlin_supports_shape( - output_size_per_partition=output_size_per_partition, - input_size_per_partition=input_size_per_partition, - input_size=input_size, - group_size=group_size) - scales_and_zp_size = input_size // group_size if partition_scales: @@ -137,69 +147,17 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, weight_loader=weight_loader) layer.register_parameter("weight_g_idx", weight_g_idx) - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.input_size = input_size - layer.group_size = group_size + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name=None, + w_gidx_param_name="weight_g_idx") # Checkpoints are serialized in compressed-tensors format, which is - # different from marlin format. Handle repacking here. + # different from the format the kernel may want. Handle repacking here. def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - device = layer.weight_packed.device - - # Allocate marlin workspace. - layer.workspace = marlin_make_workspace( - layer.output_size_per_partition, device) - - # Handle sorting for activation reordering if needed. - if self.has_g_idx: - g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx) - layer.g_idx_sort_indices = g_idx_sort_indices - replace_tensor(layer, "weight_g_idx", g_idx) - else: - layer.weight_g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - # No zero-point - layer.weight_zp = marlin_make_empty_g_idx(device) - # Update for kernel - layer.weight_packed = torch.nn.Parameter( - layer.weight_packed.t().contiguous(), requires_grad=False) - layer.weight_scale = torch.nn.Parameter( - layer.weight_scale.squeeze().t().contiguous(), requires_grad=False) - - # Repack weights from compressed-tensors format to marlin format. - marlin_qweight = ops.gptq_marlin_repack( - layer.weight_packed, - perm=layer.g_idx_sort_indices, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - num_bits=self.quant_type.size_bits) - replace_tensor(layer, "weight_packed", marlin_qweight) - - # Permute scales from compressed-tensors format to marlin format. - # scale is required on all partitions if activation reordering - marlin_scales = marlin_permute_scales( - layer.weight_scale, - size_k=(layer.input_size - if self.has_g_idx else layer.input_size_per_partition), - size_n=layer.output_size_per_partition, - group_size=layer.group_size) - replace_tensor(layer, "weight_scale", marlin_scales) + self.kernel.process_weights_after_loading(layer) def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: - - return apply_gptq_marlin_linear( - input=x, - weight=layer.weight_packed, - weight_scale=layer.weight_scale, - weight_zp=layer.weight_zp, - g_idx=layer.weight_g_idx, - g_idx_sort_indices=layer.g_idx_sort_indices, - workspace=layer.workspace, - wtype=self.quant_type, - output_size_per_partition=layer.output_size_per_partition, - input_size_per_partition=layer.input_size_per_partition, - is_k_full=True, - bias=bias) + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 5a1b2d701ab0..3d3ce711e58b 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,7 +1,6 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Set, Union import torch -from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -11,12 +10,12 @@ set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.quantization.kernels import ( + MPLinearLayerConfig, choose_mp_linear_kernel) +from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, - marlin_permute_scales, marlin_repeat_scales_on_all_ranks, - marlin_sort_g_idx, replace_tensor, verify_marlin_supported, - verify_marlin_supports_shape) + check_marlin_supported, marlin_moe_permute_scales, + marlin_repeat_scales_on_all_ranks, verify_marlin_supported) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -159,6 +158,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase): quant_config: The GPTQ Marlin quantization config. """ + _kernel_backends_being_used: Set[str] = set() + def __init__(self, quant_config: GPTQMarlinConfig) -> None: self.quant_config = quant_config @@ -176,25 +177,34 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: - - del output_size output_size_per_partition = sum(output_partition_sizes) is_row_parallel = input_size != input_size_per_partition weight_loader = extra_weight_attrs.get("weight_loader") + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_config.quant_type, + act_type=params_dtype, + group_size=self.quant_config.group_size, + zero_points=False, + has_g_idx=self.quant_config.desc_act + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for GPTQMarlinLinearMethod", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + # Normalize group_size if self.quant_config.group_size != -1: group_size = self.quant_config.group_size else: group_size = input_size - verify_marlin_supports_shape( - output_size_per_partition=output_size_per_partition, - input_size_per_partition=input_size_per_partition, - input_size=input_size, - group_size=group_size, - ) - # Determine sharding if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, self.quant_config.group_size, @@ -275,57 +285,15 @@ def create_weights( layer.register_parameter("g_idx", g_idx) layer.register_parameter("scales", scales) layer.register_parameter("qzeros", qzeros) - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.input_size = input_size - layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act, - is_row_parallel) - - # Checkpoints are serialized in AutoGPTQ format, which is different from the - # marlin format. This function is called after the weights are loaded. - # Here, we handle the repacking, including the activation reordering case. - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - device = layer.qweight.device - # required by torch.compile - layer.qweight = Parameter(layer.qweight.data, requires_grad=False) - layer.scales = Parameter(layer.scales.data, requires_grad=False) + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="qweight", + w_s_param_name="scales", + w_zp_param_name="qzeros", + w_gidx_param_name="g_idx") - # Allocate marlin workspace - layer.workspace = marlin_make_workspace( - layer.output_size_per_partition, device) - - # Handle sorting for activation reordering if needed. - if self.quant_config.desc_act: - g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx) - layer.g_idx_sort_indices = g_idx_sort_indices - replace_tensor(layer, "g_idx", g_idx) - else: - layer.g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - # No zero-point - layer.zp = marlin_make_empty_g_idx(device) - - # Repack weights from autogptq format to marlin format. - marlin_qweight = ops.gptq_marlin_repack( - layer.qweight, - perm=layer.g_idx_sort_indices, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits, - ) - replace_tensor(layer, "qweight", marlin_qweight) - - # Permute scales from autogptq format to marlin format. - marlin_scales = marlin_permute_scales( - layer.scales, - size_k=(layer.input_size if self.quant_config.desc_act else - layer.input_size_per_partition), - size_n=layer.output_size_per_partition, - group_size=self.quant_config.group_size, - ) - replace_tensor(layer, "scales", marlin_scales) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) def apply( self, @@ -333,20 +301,7 @@ def apply( x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return apply_gptq_marlin_linear( - input=x, - weight=layer.qweight, - weight_scale=layer.scales, - weight_zp=layer.zp, - g_idx=layer.g_idx, - g_idx_sort_indices=layer.g_idx_sort_indices, - workspace=layer.workspace, - wtype=self.quant_config.quant_type, - output_size_per_partition=layer.output_size_per_partition, - input_size_per_partition=layer.input_size_per_partition, - is_k_full=layer.is_k_full, - bias=bias, - ) + return self.kernel.apply_weights(layer, x, bias) class GPTQMarlinMoEMethod(FusedMoEMethodBase): @@ -506,12 +461,12 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w13_g_idx_sort_indices[e]] w2_sorted_g_idx[e] = layer.w2_g_idx[e][ w2_g_idx_sort_indices[e]] - replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx) - replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx) - replace_tensor(layer, "w13_g_idx_sort_indices", - w13_g_idx_sort_indices) - replace_tensor(layer, "w2_g_idx_sort_indices", - w2_g_idx_sort_indices) + replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx) + replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx) + replace_parameter(layer, "w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", + w2_g_idx_sort_indices) else: # Reset g_idx related tensors num_experts = layer.w13_g_idx.shape[0] @@ -544,7 +499,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_qweight.shape[2], self.quant_config.quant_type.size_bits, ) - replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + replace_parameter(layer, "w13_qweight", marlin_w13_qweight) marlin_w2_qweight = ops.gptq_marlin_moe_repack( layer.w2_qweight, layer.w2_g_idx_sort_indices, @@ -552,7 +507,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w2_qweight.shape[2], self.quant_config.quant_type.size_bits, ) - replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + replace_parameter(layer, "w2_qweight", marlin_w2_qweight) # Repack scales marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, @@ -560,14 +515,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_n=layer.w13_scales.shape[2], group_size=self.quant_config.group_size, ) - replace_tensor(layer, "w13_scales", marlin_w13_scales) + replace_parameter(layer, "w13_scales", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, size_n=layer.w2_scales.shape[2], group_size=self.quant_config.group_size, ) - replace_tensor(layer, "w2_scales", marlin_w2_scales) + replace_parameter(layer, "w2_scales", marlin_w2_scales) def apply( self, diff --git a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py new file mode 100644 index 000000000000..fe50c4930d04 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py @@ -0,0 +1,83 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Optional, Tuple + +import torch + +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.scalar_type import ScalarType + + +@dataclass +class MPLinearLayerConfig: + full_weight_shape: Tuple[int, int] # [in, out] + partition_weight_shape: Tuple[int, int] + weight_type: ScalarType + act_type: torch.dtype + group_size: int + zero_points: bool + has_g_idx: bool + + +class MPLinearKernel(ABC): + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + raise NotImplementedError + + def __init__(self, + c: MPLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + w_zp_param_name: Optional[str] = None, + w_gidx_param_name: Optional[str] = None) -> None: + assert self.can_implement(c) + self.config = c + self.w_q_name = w_q_param_name + self.w_s_name = w_s_param_name + self.w_zp_name = w_zp_param_name + self.w_gidx_name = w_gidx_param_name + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + def _transform_param(self, layer: torch.nn.Module, name: Optional[str], + fn: Callable) -> None: + if name is not None and getattr(layer, name, None) is not None: + + old_param = getattr(layer, name) + new_param = fn(old_param) + # replace the parameter with torch.nn.Parameter for TorchDynamo + # compatibility + replace_parameter( + layer, name, + torch.nn.Parameter(new_param.data, requires_grad=False)) + + def _get_weight_params( + self, layer: torch.nn.Module + ) -> Tuple[torch.Tensor, # w_q + torch.Tensor, # w_s + Optional[torch.Tensor], # w_zp, + Optional[torch.Tensor] # w_gidx + ]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.w_zp_name or "", None), + getattr(layer, self.w_gidx_name or "", None), + ) diff --git a/vllm/model_executor/layers/quantization/kernels/__init__.py b/vllm/model_executor/layers/quantization/kernels/__init__.py new file mode 100644 index 000000000000..47591c2aa644 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/__init__.py @@ -0,0 +1,72 @@ +import os +from typing import List, Optional, Type + +from vllm.model_executor.layers.quantization.kernels.machete import ( + MacheteLinearKernel) +from vllm.model_executor.layers.quantization.kernels.marlin import ( + MarlinLinearKernel) +from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( + MPLinearKernel, MPLinearLayerConfig) +from vllm.platforms import current_platform + +# in priority/performance order (when available) +_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ + MacheteLinearKernel, + MarlinLinearKernel, +] + + +def choose_mp_linear_kernel( + config: MPLinearLayerConfig, + compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: + """ + Choose an MPLinearKernel that can implement the given config for the given + compute capability. Attempts to choose the best kernel in terms of + performance. + + Args: + config (MPLinearLayerConfig): Description of the linear layer to be + implemented. + compute_capability (Optional[int], optional): The compute capability of + the target device, if None uses `current_platform` to get the compute + capability. Defaults to None. + + Raises: + ValueError: If no kernel can implement the given config. + + Returns: + Type[MPLinearKernel]: Chosen kernel. + """ + if compute_capability is None: + if current_platform is None: + raise ValueError("Cannot determine compute capability") + _cc = current_platform.get_device_capability() + compute_capability = _cc[0] * 10 + _cc[1] + + failure_reasons = [] + for kernel in _POSSIBLE_KERNELS: + if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ + .split(","): + failure_reasons.append( + f' {kernel.__name__} disabled by environment variable') + continue + + if kernel.get_min_capability() > compute_capability: + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel.get_min_capability()}, current compute capability " + f"is {compute_capability}") + continue + + can_implement, failure_reason = kernel.can_implement(config) + if can_implement: + return kernel + else: + failure_reasons.append( + f' {kernel.__name__} cannot implement due to: {failure_reason}' + ) + + raise ValueError( + "Failed to find a kernel that can implement the "\ + "WNA16 linear layer. Reasons: \n" + + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/kernels/machete.py b/vllm/model_executor/layers/quantization/kernels/machete.py new file mode 100644 index 000000000000..fa39cb511528 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/machete.py @@ -0,0 +1,118 @@ +from functools import partial +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.machete_utils import ( + MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, + query_machete_supported_quant_types) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_weights_into_int32, unpack_weights_into_int32) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class MacheteLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.has_g_idx and\ + c.partition_weight_shape[0] != c.full_weight_shape[0]: + return False, "Act reordering currently not supported by Machete, "\ + "when the input features are partitioned across "\ + "devices" + + if c.zero_points: + return False, "Zero points currently not supported by "\ + " Compressed Tensors + Machete. (Kernel supports it"\ + " but CompressedTensorsWNA16 does not so support has"\ + " not been added to MacheteWNA16Kernel yet" + + if c.weight_type not in query_machete_supported_quant_types( + c.zero_points): + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Machete, supported types are: "\ + f"{query_machete_supported_quant_types(c.zero_points)}" + + if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Machete, supported group sizes are: "\ + f"{MACHETE_SUPPORTED_GROUP_SIZES}" + + return check_machete_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1]) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + c = self.config + + if c.has_g_idx: + assert self.w_gidx_name is not None + perm = torch.argsort(getattr(layer, self.w_gidx_name))\ + .to(torch.int) + + self.act_perm = lambda x: x[:, perm] + # use `ops.permute_cols` if possible + if c.act_type in [torch.float16, torch.bfloat16] \ + and c.partition_weight_shape[0] % 8 == 0: + self.act_perm = partial(ops.permute_cols, perm=perm) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + if c.has_g_idx: + x_unpacked = unpack_weights_into_int32(x.data, + c.weight_type, + packed_dim=0) + x_perm = x_unpacked[perm, :] + x.data = pack_weights_into_int32(x_perm, + c.weight_type, + packed_dim=0) + x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), + self.config.weight_type) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous() + return x + + # Repack weights and scales for Machete + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + if c.has_g_idx: + x_2d = self.act_perm(x_2d) + + output = ops.machete_gemm(a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_zeros=None, + b_scales=w_s, + b_group_size=c.group_size) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/marlin.py b/vllm/model_executor/layers/quantization/kernels/marlin.py new file mode 100644 index 000000000000..5b4bba76ee0c --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/marlin.py @@ -0,0 +1,132 @@ +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, + check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, + marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, + query_marlin_supported_quant_types) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class MarlinLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.zero_points: + return False, "Zero points currently not supported by "\ + " MarlinLinearKernel. Will be added when AWQMarlin "\ + "is migrated over to using MPLinearKernel backend" + + quant_types = query_marlin_supported_quant_types(c.zero_points) + if c.weight_type not in quant_types: + return False, f"Quant type ({c.weight_type}) not supported by"\ + f" Marlin, supported types are: {quant_types}" + + if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Marlin, supported group sizes are: "\ + f"{MARLIN_SUPPORTED_GROUP_SIZES}" + + return check_marlin_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1], + c.full_weight_shape[1], + c.group_size) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, self.w_q_name).device + c = self.config + + row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0]) + self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) + + # Allocate marlin workspace. + self.workspace = marlin_make_workspace(c.partition_weight_shape[1], + device) + + # Default names since marlin requires empty parameters for these, + # TODO: remove this requirement from marlin (allow optional tensors) + if self.w_gidx_name is None: + self.w_gidx_name = "g_idx" + if self.w_zp_name is None: + self.w_zp_name = "w_zp" + + if c.has_g_idx: + g_idx, g_idx_sort_indices = marlin_sort_g_idx( + getattr(layer, self.w_gidx_name)) + self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + if c.zero_points: + pass + # TODO (lucas): add the following when AWQMarlin is migrated over to + # using MPLinearKernel backend + # self._transform_param(layer, self.w_zp_name, lambda x: \ + # marlin_zero_points( + # x, + # size_k=c.partition_weight_shape[0], + # size_n=c.partition_weight_shape[1], + # num_bits=c.weight_type.size_bits)) + else: + setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = ops.gptq_marlin_repack(x.data.contiguous(), + perm=layer.g_idx_sort_indices, + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = marlin_permute_scales(x.data.contiguous(), + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + group_size=c.group_size) + return x + + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) + + # `process_weights_after_loading` will ensure w_zp and w_gidx are not + # None for marlin + return apply_gptq_marlin_linear( + input=x, + weight=w_q, + weight_scale=w_s, + weight_zp=w_zp, # type: ignore + g_idx=w_gidx, # type: ignore + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=self.workspace, + wtype=c.weight_type, + input_size_per_partition=c.partition_weight_shape[0], + output_size_per_partition=c.partition_weight_shape[1], + is_k_full=self.is_k_full, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/utils/__init__.py b/vllm/model_executor/layers/quantization/utils/__init__.py index e69de29bb2d1..e60f0c79ac1f 100644 --- a/vllm/model_executor/layers/quantization/utils/__init__.py +++ b/vllm/model_executor/layers/quantization/utils/__init__.py @@ -0,0 +1,3 @@ +from .layer_utils import replace_parameter, update_tensor_inplace + +__all__ = ['update_tensor_inplace', 'replace_parameter'] diff --git a/vllm/model_executor/layers/quantization/utils/layer_utils.py b/vllm/model_executor/layers/quantization/utils/layer_utils.py new file mode 100644 index 000000000000..edce6d19b6c4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/layer_utils.py @@ -0,0 +1,37 @@ +from typing import Union + +import torch + + +def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor): + assert dst.dtype == src.dtype, "Tensors must have the same dtype" + + # update tensor shape and stride + dst.as_strided_(src.shape, src.stride()) + + # If not the same underlying storage move tensor data + if dst.data_ptr() != src.data_ptr(): + dst.copy_(src) + del src + + +# Newly generated tensors need to replace existing tensors that are +# already registered as parameters by vLLM (and won't be freed) +def replace_parameter(mod: torch.nn.Module, name: str, + new: Union[torch.Tensor, torch.nn.Parameter]) -> None: + + old = getattr(mod, name) + if type(old) is type(new) and old.dtype == new.dtype and \ + old.untyped_storage().nbytes() == new.untyped_storage().nbytes(): + # If we can just update in-place to avoid re-registering + # can be faster if the underlying storage is the same + update_tensor_inplace(old, new) + else: + # Fallback re-register parameter, convert to Parameter if necessary + # this not only ensures we don't register a tensor as a parameter, but + # also ensures that all parameter subclasses get re-registered as + # parameters for `torch.compile` compatibility + if not isinstance(new, torch.nn.Parameter): + new = torch.nn.Parameter(new, requires_grad=False) + mod.register_parameter(name, + torch.nn.Parameter(new, requires_grad=False)) diff --git a/vllm/model_executor/layers/quantization/utils/machete_utils.py b/vllm/model_executor/layers/quantization/utils/machete_utils.py new file mode 100644 index 000000000000..18e1332050cd --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/machete_utils.py @@ -0,0 +1,30 @@ +from typing import List, Optional, Tuple + +import torch + +from vllm.scalar_type import ScalarType, scalar_types + +MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128] +MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128] + + +def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]: + if zero_points: + return [scalar_types.uint4, scalar_types.uint8] + else: + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]: + return [torch.float16, torch.bfloat16] + + +def check_machete_supports_shape(in_features: int, out_featrues: int) \ + -> Tuple[bool, Optional[str]]: + if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: + return False, "Input features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}" + if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: + return False, "Output features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}" + return True, None diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index fea94cf7322a..53762965732c 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -120,6 +120,19 @@ def verify_marlin_supports_shape(output_size_per_partition: int, "with --quantization gptq.") +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) + except ValueError as e: + return False, e.__str__() + return True, None + + def marlin_make_workspace(output_size_per_partition: int, device: torch.device) -> torch.Tensor: max_workspace_size = (output_size_per_partition // @@ -148,6 +161,11 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: requires_grad=False) +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + def marlin_sort_g_idx( g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) @@ -240,17 +258,6 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, return marlin_zp -# Newly generated tensors need to replace existing tensors that are -# already registered as parameters by vLLM (and won't be freed) -def replace_tensor(layer: torch.nn.Module, name: str, - new_t: torch.Tensor) -> None: - # It is important to use resize_() here since it ensures - # the same buffer is reused - getattr(layer, name).resize_(new_t.shape) - getattr(layer, name).copy_(new_t) - del new_t - - def apply_gptq_marlin_linear( input: torch.Tensor, weight: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index bdfda31de852..833d00073564 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -20,6 +20,49 @@ } +def pack_weights_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_weights_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: # prefix: model.layers.0.self_attn.q_proj # proj_name: q_proj diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 2ca86a4653cf..583bb02dcb5b 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -15,7 +15,8 @@ SamplingTensors, SequenceGroupToSample) from vllm.sampling_params import SamplingType -from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SampleLogprobs, SequenceOutput) from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics @@ -759,10 +760,10 @@ def _sample_with_torch( # Create output tensor for sampled token ids. if include_gpu_probs_tensor: - sampled_token_ids_tensor = torch.empty(logprobs.shape[0], - 1, - dtype=torch.long, - device=logprobs.device) + sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1), + VLLM_INVALID_TOKEN_ID, + dtype=torch.long, + device=logprobs.device) else: sampled_token_ids_tensor = None diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index 8c03e4692775..584cf971d9c0 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -80,7 +80,7 @@ def forward( target_probs = target_with_bonus_probs[:, :-1] accepted = self._evaluate_accepted_tokens(target_probs, draft_token_ids) - recovered_token_ids = self._replacement_token_ids(target_probs) + recovered_token_ids = self._get_recovered_token_ids(target_probs) output_token_ids = self._create_output(accepted, recovered_token_ids, draft_token_ids, bonus_token_ids) @@ -148,16 +148,10 @@ def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): accepted_mask = candidates_prob > threshold return accepted_mask - def _replacement_token_ids(self, target_probs): + def _get_recovered_token_ids(self, target_probs): """ - Generate one replacement token ID for each sequence based on target - probabilities. The replacement token is used as the fallback option - if typical acceptance sampling does not accept any draft tokens for - that particular sequence. - - This method computes the token IDs to be replaced by selecting the - token with the highest probability for each sequence in the first - position. The rest of the output is filled with -1. + The recovered token ids will fill the first unmatched token + by the target token. Parameters ---------- @@ -168,13 +162,9 @@ def _replacement_token_ids(self, target_probs): Returns ------- torch.Tensor - A tensor of shape (batch_size, k) with the replacement - token IDs. Only the first column is set, and the rest of the - columns are filled with -1. + A tensor of shape (batch_size, k) with the recovered token + ids which are selected from target probs. """ - max_indices = torch.argmax(target_probs[:, 0, :], dim=1) - output = -torch.ones((target_probs.shape[0], target_probs.shape[1]), - dtype=self.token_id_dtype, - device=target_probs.device) - output[:, 0] = max_indices - return output + max_indices = torch.argmax(target_probs, dim=-1) + + return max_indices diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index f0d2a9e7f06b..c21b10d661ec 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1,6 +1,7 @@ # ruff: noqa: SIM117 import collections import copy +import dataclasses import fnmatch import glob import json @@ -8,7 +9,8 @@ import os from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional, Tuple, Type +from typing import (Any, Dict, Generator, Iterable, List, Optional, Tuple, + Type, cast) import gguf import huggingface_hub @@ -207,6 +209,22 @@ def load_model(self, *, model_config: ModelConfig, class DefaultModelLoader(BaseModelLoader): """Model loader that can load different file types from disk.""" + @dataclasses.dataclass + class Source: + """A source for weights.""" + + model_or_path: str + """The model ID or path.""" + + revision: Optional[str] + """The optional model revision.""" + + prefix: str = "" + """A prefix to prepend to all weights.""" + + fall_back_to_pt: bool = True + """Whether .pt weights can be used.""" + def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: @@ -313,17 +331,16 @@ def _prepare_weights(self, model_name_or_path: str, return hf_folder, hf_weights_files, use_safetensors def _get_weights_iterator( - self, model_name_or_path: str, revision: Optional[str], - fall_back_to_pt: bool + self, source: "Source" ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( - model_name_or_path, revision, fall_back_to_pt) + source.model_or_path, source.revision, source.fall_back_to_pt) if self.load_config.load_format == LoadFormat.NPCACHE: # Currently np_cache only support *.bin checkpoints assert use_safetensors is False weights_iterator = np_cache_weights_iterator( - model_name_or_path, self.load_config.download_dir, hf_folder, + source.model_or_path, self.load_config.download_dir, hf_folder, hf_weights_files) elif use_safetensors: weights_iterator = safetensors_weights_iterator(hf_weights_files) @@ -341,7 +358,29 @@ def _xla_weights_iterator(iterator: Generator): xm.mark_step() weights_iterator = _xla_weights_iterator(weights_iterator) - return weights_iterator + + # Apply the prefix. + return ((source.prefix + name, tensor) + for (name, tensor) in weights_iterator) + + def _get_all_weights( + self, + model_config: ModelConfig, + model: nn.Module, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + + primary_weights = DefaultModelLoader.Source( + model_config.model, + model_config.revision, + prefix="", + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", + True)) + yield from self._get_weights_iterator(primary_weights) + + secondary_weights = cast(Iterable[DefaultModelLoader.Source], + getattr(model, "secondary_weights", ())) + for source in secondary_weights: + yield from self._get_weights_iterator(source) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, @@ -360,13 +399,8 @@ def load_model(self, *, model_config: ModelConfig, model = _initialize_model(model_config, self.load_config, lora_config, cache_config, scheduler_config) - model.load_weights( - self._get_weights_iterator(model_config.model, - model_config.revision, - fall_back_to_pt=getattr( - model, - "fall_back_to_pt_during_load", - True)), ) + + model.load_weights(self._get_all_weights(model_config, model)) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) @@ -817,12 +851,12 @@ def _get_quantized_weights_iterator( # only load the bitsandbytes module when needed try: import bitsandbytes - if bitsandbytes.__version__ < "0.42.0": + if bitsandbytes.__version__ < "0.44.0": raise ImportError("bitsandbytes version is wrong. Please " - "install bitsandbytes>=0.42.0.") + "install bitsandbytes>=0.44.0.") except ImportError as err: - raise ImportError("Please install bitsandbytes>=0.42.0 via " - "`pip install bitsandbytes>=0.42.0` to use " + raise ImportError("Please install bitsandbytes>=0.44.0 via " + "`pip install bitsandbytes>=0.44.0` to use " "bitsandbytes quantizer.") from err hf_weights_files, use_safetensors = self._prepare_weights( diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 4cf3b0b93dcf..d50f4fb9e6ed 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -229,7 +229,7 @@ def __init__(self, self.multimodal_config = multimodal_config self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size + self.vocab_size = config.text_config.vocab_size self.image_token_id = _IMAGE_TOKEN_ID self.image_feature_size = config.patch_size**2 * config.num_channels diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 68b6d0cf808e..8130eb54753e 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -152,7 +152,8 @@ def __init__(self, self.unpadded_vocab_size = config.text_config.vocab_size logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) + config.text_config.vocab_size, + logit_scale) self.sampler = Sampler() def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index f8fc1cd8ef1f..ced846cbe335 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -213,10 +213,10 @@ def __init__(self, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.vocab_size = config.vocab_size + self.vocab_size = config.text_config.vocab_size - self.embed_tokens = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) + self.embed_tokens = VocabParallelEmbedding( + config.text_config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ PersimmonDecoderLayer(config, cache_config=cache_config, @@ -257,14 +257,14 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.vocab_size = config.vocab_size + self.vocab_size = config.text_config.vocab_size self.model = PersimmonModel(config, cache_config=cache_config, quant_config=quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, + self.lm_head = ParallelLMHead(config.text_config.vocab_size, config.hidden_size, bias=False) - self.logits_processor = LogitsProcessor(config.vocab_size) + self.logits_processor = LogitsProcessor(config.text_config.vocab_size) self.sampler = Sampler() def forward( diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 6f17f571ccae..245381518a7f 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -307,7 +307,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90 -def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16): +def _calc_hd_transform_size(*, width: int, height: int, hd_num: int): transposed = False if width < height: width, height = height, width @@ -337,8 +337,10 @@ def get_phi3v_image_feature_size( *, input_height: int, input_width: int, + num_crops: int, ) -> int: - num_crops = hf_config.get("num_crops", 16) + if num_crops is None: + num_crops = hf_config.get("num_crops", 16) new_width, new_height = _calc_hd_transform_size(width=input_width, height=input_height, hd_num=num_crops) @@ -347,20 +349,26 @@ def get_phi3v_image_feature_size( + (new_height // 336 + 1) * 12 -def get_max_phi3v_image_tokens(ctx: InputContext): +def get_max_phi3v_image_tokens(ctx: InputContext, + *, + num_crops: Optional[int] = None): return get_phi3v_image_feature_size( ctx.get_hf_image_processor_config(), input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, + num_crops=num_crops, ) -def dummy_data_for_phi3v(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): +def dummy_data_for_phi3v(ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + *, + num_crops: Optional[int] = None): num_images = mm_counts["image"] - image_feature_size = get_max_phi3v_image_tokens(ctx) + image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops) seq_data = dummy_seq_data_for_clip( CLIP_VIT_LARGE_PATCH14_336_CONFIG, @@ -398,7 +406,10 @@ def _get_image_placeholder_token_ids(model_config: ModelConfig, return image_placeholder_token_ids -def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_phi3v(ctx: InputContext, + llm_inputs: LLMInputs, + *, + num_crops: Optional[int] = None): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs @@ -412,7 +423,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): image_feature_size = [ get_phi3v_image_feature_size(hf_config, input_width=w, - input_height=h) + input_height=h, + num_crops=num_crops) ] image_data = [image_data] elif is_list_of(image_data, Image.Image): @@ -422,7 +434,8 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): image_feature_size.append( get_phi3v_image_feature_size(hf_config, input_width=w, - input_height=h)) + input_height=h, + num_crops=num_crops)) elif isinstance(image_data, torch.Tensor): num_images, image_feature_size, hidden_size = image_data.shape elif is_list_of(image_data, torch.Tensor): diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index a64e08c422bc..5e6737ad7fa4 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA -from .utils import is_pp_missing_parameter, make_layers +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers class Qwen2MLP(nn.Module): @@ -235,11 +235,16 @@ def __init__( self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - ) + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Qwen2DecoderLayer(config=config, @@ -248,7 +253,10 @@ def __init__( prefix=f"{prefix}.layers", ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 1011c9256793..9f72210c60bf 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -45,7 +45,7 @@ from vllm.attention.selector import (_Backend, backend_name_to_enum, get_global_forced_attn_backend) from vllm.config import CacheConfig, MultiModalConfig -from vllm.distributed import parallel_state +from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import utils as dist_utils from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger @@ -68,6 +68,9 @@ from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.processor import get_processor +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory) + logger = init_logger(__name__) # === Vision Inputs === # @@ -856,15 +859,21 @@ def __init__(self, self.model = Qwen2Model(config, cache_config, quant_config) - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def _validate_and_reshape_mm_tensor(self, mm_input: Union[torch.Tensor, @@ -979,7 +988,8 @@ def forward( image_input = self._parse_and_validate_image_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs) - if image_input is None and video_input is None: + if (image_input is None + and video_input is None) or not get_pp_group().is_first_rank: inputs_embeds = None else: if getattr(self.config, "rope_scaling", {}).get("type", @@ -1015,6 +1025,7 @@ def forward( positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) return hidden_states @@ -1055,6 +1066,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -1081,6 +1094,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] except KeyError: print(params_dict.keys()) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 32a0e895005c..71808eb4c271 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -25,6 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.utils import (flatten_bn, @@ -334,14 +335,23 @@ def __init__(self, self.multi_modal_config = multimodal_config assert self.multi_modal_config + self.secondary_weights = [] + self.audio_tower = ModifiedWhisperEncoder(config.audio_config) if config.audio_model_id is not None: - self.audio_tower = ModifiedWhisperEncoder.from_pretrained( - config.audio_model_id) - else: - self.audio_tower = ModifiedWhisperEncoder(config.audio_config) + self.secondary_weights.append( + DefaultModelLoader.Source( + model_or_path=config.audio_model_id, + revision=None, + prefix="audio_tower.", + )) self.multi_modal_projector = UltravoxProjector(config) self.language_model = init_vllm_registered_model( config.text_config, cache_config, quant_config) + if config.text_model_id is not None: + self.secondary_weights.append( + DefaultModelLoader.Source(model_or_path=config.text_model_id, + revision=None, + prefix="language_model.")) def _audio_features_to_embeddings( self, input_features: torch.Tensor) -> torch.Tensor: @@ -466,6 +476,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # prepare weight iterators for components weights_group = group_weights_with_prefix(weights) + # load audio tower weights + audio_tower_weights = weights_group["audio_tower"] + audio_tower_params_dict = dict( + self.audio_tower.named_parameters( + prefix=self.audio_tower.base_model_prefix)) + for name, loaded_weight in audio_tower_weights: + if name in audio_tower_params_dict: + param = audio_tower_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + # load projector weights projector_weights = weights_group["multi_modal_projector"] projector_params_dict = dict( diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 9ffb339ffeab..7a6d7c90f34d 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -328,6 +328,64 @@ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): marlin_tile_size=self.marlin_tile_size) +def permute_param_layout_(param: BasevLLMParameter, input_dim: int, + output_dim: int, **kwargs) -> BasevLLMParameter: + """ + Permute a parameter's layout to the specified input and output dimensions, + useful for forcing the parameter into a known layout, for example, if I need + a packed (quantized) weight matrix to be in the layout + {input_dim = 0, output_dim = 1, packed_dim = 0} + then I can call: + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + to ensure x is in the correct layout (permuting it to the correct layout if + required, asserting if it cannot get it to the correct layout) + """ + + curr_input_dim = getattr(param, "input_dim", None) + curr_output_dim = getattr(param, "output_dim", None) + + if curr_input_dim is None or curr_output_dim is None: + assert param.data.dim() == 2,\ + "permute_param_layout_ only supports 2D parameters when either "\ + "input_dim or output_dim is not set" + + # if one of the dimensions is not set, set it to the opposite of the other + # we can only do this since we asserted the parameter is 2D above + if curr_input_dim is None: + assert curr_output_dim is not None,\ + "either input or output dim must be set" + curr_input_dim = (curr_output_dim + 1) % 2 + if curr_output_dim is None: + assert curr_input_dim is not None,\ + "either input or output dim must be set" + curr_output_dim = (curr_input_dim + 1) % 2 + + # create permutation from the current layout to the layout with + # self.input_dim at input_dim and self.output_dim at output_dim preserving + # other dimensions + perm = [ + i for i in range(param.data.dim()) + if i not in [curr_input_dim, curr_output_dim] + ] + perm.insert(input_dim, curr_input_dim) + perm.insert(output_dim, curr_output_dim) + + if "packed_dim" in kwargs: + assert hasattr(param, "packed_dim") and\ + param.packed_dim == perm[kwargs["packed_dim"]],\ + "permute_param_layout_ currently doesn't support repacking" + + param.data = param.data.permute(*perm) + if hasattr(param, "_input_dim"): + param._input_dim = input_dim + if hasattr(param, "_output_dim"): + param._output_dim = output_dim + if "packed_dim" in kwargs and hasattr(param, "_packed_dim"): + param._packed_dim = kwargs["packed_dim"] + + return param + + def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index e8589525a558..8bcb38ef241e 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -14,7 +14,8 @@ from vllm.config import ModelConfig from vllm.inputs import InputContext from vllm.logger import init_logger -from vllm.utils import JSONTree, is_list_of, json_map_leaves +from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of, + json_map_leaves) logger = init_logger(__name__) @@ -262,11 +263,20 @@ def map_input(self, model_config: ModelConfig, model_cls, _ = get_model_architecture(model_config) mapper = self._input_mappers.get(model_cls) + # Only get processor kwargs at mapping time if we are not using the + # input mapper; no overrides are used on the default here because they + # should be passed to the huggingface resource at initialization time. + if mapper is not None and mapper != self._default_input_mapper: + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + mapper, overrides=model_config.mm_processor_kwargs) + else: + mm_processor_kwargs = {} + if mapper is None: raise KeyError(f"No input mapper in {self} is registered for " f"model class {model_cls.__name__}.") - return mapper(InputContext(model_config), data) + return mapper(InputContext(model_config), data, **mm_processor_kwargs) @abstractmethod def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: @@ -339,7 +349,10 @@ def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: f"for model class {model_cls.__name__} in {self}.") if callable(max_mm_tokens): - max_mm_tokens = max_mm_tokens(InputContext(model_config)) + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + max_mm_tokens, overrides=model_config.mm_processor_kwargs) + max_mm_tokens = max_mm_tokens(InputContext(model_config), + **mm_processor_kwargs) self._validate_max_multimodal_tokens(max_mm_tokens) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 9969336c61d0..d3a230e40477 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -7,7 +7,7 @@ from vllm.config import ModelConfig from vllm.inputs.registry import InputContext from vllm.logger import init_logger -from vllm.transformers_utils.image_processor import get_image_processor +from vllm.transformers_utils.processor import get_image_processor from vllm.utils import is_list_of from .base import MultiModalData, MultiModalInputs, MultiModalPlugin @@ -24,9 +24,14 @@ def get_data_key(self) -> str: return "image" def _get_hf_image_processor(self, model_config: ModelConfig): + mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None + else model_config.mm_processor_kwargs) + # We don't explicitly check kwarg overrides to the HF class + # since the automodel just takes kwargs, so we can't inspect it return cached_get_image_processor( model_config.model, - trust_remote_code=model_config.trust_remote_code) + trust_remote_code=model_config.trust_remote_code, + **mm_processor_kwargs) def _default_input_mapper( self, @@ -42,6 +47,7 @@ def _default_input_mapper( # PIL image if isinstance(data, Image.Image) or is_list_of(data, Image.Image): image_processor = self._get_hf_image_processor(model_config) + if image_processor is None: raise RuntimeError("No HuggingFace processor is available " "to process the image object") diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 745fc715caf4..3940e1671b57 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -138,6 +138,15 @@ def create_input_mapper(self, model_config: ModelConfig): """ Create an input mapper (see :meth:`map_input`) for a specific model. """ + # NOTE - we currently make the assumption that if a model has multiple + # supported modalities, they take the same kwargs. For the default, + # this could be an issue in the future if it falls back to two HF + # resources and we can't inspect the signature easily since it's + # getting initialized through the autoclass. + # + # If this is a problem in the future, we should revisit it, but since + # it potentially introduces a lot of complexity for a currently + # uncommon case, we do not for simplicity of both use & implementation return functools.partial(self.map_input, model_config) def register_max_multimodal_tokens( diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py index 4401d1315792..39e75dbaf687 100644 --- a/vllm/multimodal/video.py +++ b/vllm/multimodal/video.py @@ -6,7 +6,7 @@ from vllm.config import ModelConfig from vllm.inputs.registry import InputContext from vllm.logger import init_logger -from vllm.transformers_utils.image_processor import get_video_processor +from vllm.transformers_utils.processor import get_video_processor from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import is_list_of @@ -37,9 +37,14 @@ def get_data_key(self) -> str: return "video" def _get_hf_video_processor(self, model_config: ModelConfig): + mm_processor_kwargs = ({} if model_config.mm_processor_kwargs is None + else model_config.mm_processor_kwargs) + # We don't explicitly check kwarg overrides to the HF class + # since the automodel just takes kwargs, so we can't inspect it return cached_get_video_processor( model_config.model, - trust_remote_code=model_config.trust_remote_code) + trust_remote_code=model_config.trust_remote_code, + **mm_processor_kwargs) def _default_input_mapper( self, diff --git a/vllm/outputs.py b/vllm/outputs.py index 85ea9196b25d..44cde6b561d8 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -114,17 +114,28 @@ def __init__( self.encoder_prompt_token_ids = encoder_prompt_token_ids @classmethod - def from_seq_group(cls, - seq_group: SequenceGroup) -> Optional["RequestOutput"]: + def from_seq_group(cls, seq_group: SequenceGroup, + use_cache: bool) -> Optional["RequestOutput"]: sampling_params = seq_group.sampling_params if sampling_params is None: raise ValueError( "Sampling parameters are missing for a CompletionRequest.") + finished = seq_group.is_finished() if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( not finished): return None + # Init cache (if needed) + if use_cache and seq_group.cached_request_output is None: + seq_group.cached_request_output = RequestOutput( # type: ignore + request_id="", + prompt=None, + prompt_token_ids=[], + prompt_logprobs=None, + outputs=[], + finished=False) + seqs = seq_group.get_seqs() if len(seqs) == 1: top_n_seqs = seqs @@ -149,29 +160,66 @@ def from_seq_group(cls, outputs = [] include_prompt = True - for seq in top_n_seqs: + for i, seq in enumerate(top_n_seqs): output_text = seq.get_output_text_to_return( text_buffer_length, delta) + output_token_ids = seq.get_output_token_ids_to_return(delta) + num_output_tokens = 1 if isinstance(output_token_ids, + int) else len(output_token_ids) + output_logprobs = seq.output_logprobs if include_logprobs else None if delta: # Slice logprobs delta if applicable if output_logprobs: - output_logprobs = output_logprobs[-len(output_token_ids):] + output_logprobs = output_logprobs[-num_output_tokens:] # Don't include prompt if this is after the first output # containing decode token ids - if include_prompt and seq.get_output_len() > len( - output_token_ids): + if include_prompt and seq.get_output_len() > num_output_tokens: include_prompt = False - outputs.append( - CompletionOutput( - seqs.index(seq), output_text, output_token_ids, + if use_cache: + # Get cached output object + cached_outputs = seq_group.cached_request_output.outputs # type: ignore + if i >= len(cached_outputs): + cached_outputs.append( + CompletionOutput(index=i, + text="", + token_ids=[], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + stop_reason=None)) + output = cached_outputs[i] + + # Init cached output object + assert output.index == i + output.text = output_text + + if isinstance(output_token_ids, int): + output.token_ids.clear() + output.token_ids.append(output_token_ids) + else: + output.token_ids = output_token_ids + + output.cumulative_logprob = seq.get_cumulative_logprob() \ + if include_logprobs else None + output.logprobs = output_logprobs + output.finish_reason = SequenceStatus.get_finished_reason( + seq.status) + output.stop_reason = seq.stop_reason + + else: + output = CompletionOutput( + seqs.index(seq), output_text, [output_token_ids] + if isinstance(output_token_ids, int) else output_token_ids, seq.get_cumulative_logprob() if include_logprobs else None, output_logprobs, SequenceStatus.get_finished_reason(seq.status), - seq.stop_reason)) + seq.stop_reason) + + outputs.append(output) # Every sequence in the sequence group should have the same prompt. if include_prompt: @@ -188,16 +236,20 @@ def from_seq_group(cls, prompt_logprobs = None finished_time = time.time() if finished else None seq_group.set_finished_time(finished_time) - return cls(seq_group.request_id, - prompt, - prompt_token_ids, - prompt_logprobs, - outputs, - finished, - seq_group.metrics, - lora_request=seq_group.lora_request, - encoder_prompt=encoder_prompt, - encoder_prompt_token_ids=encoder_prompt_token_ids) + + init_args = (seq_group.request_id, prompt, prompt_token_ids, + prompt_logprobs, outputs, finished, seq_group.metrics, + seq_group.lora_request, encoder_prompt, + encoder_prompt_token_ids) + + if use_cache: + request_output = seq_group.cached_request_output + request_output.__init__(*init_args) # type: ignore + + else: + request_output = cls(*init_args) + + return request_output def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " @@ -261,10 +313,10 @@ def __repr__(self): class RequestOutputFactory: @staticmethod - def create(seq_group): + def create(seq_group: SequenceGroup, use_cache: bool = False): # Determine the type based on a condition, for example: if hasattr(seq_group, 'embeddings') and seq_group.embeddings is not None: return EmbeddingRequestOutput.from_seq_group(seq_group) else: - return RequestOutput.from_seq_group(seq_group) + return RequestOutput.from_seq_group(seq_group, use_cache) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 86e80ae5e224..f9ba4b4777e4 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -8,6 +8,7 @@ import torch from typing_extensions import Annotated +import vllm.envs as envs from vllm.logger import init_logger logger = init_logger(__name__) @@ -260,6 +261,10 @@ def __post_init__(self) -> None: self._verify_args() if self.use_beam_search: + if not envs.VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: + raise ValueError( + "Using beam search as a sampling parameter is deprecated, and will be removed in the future release. Please use the `vllm.LLM.use_beam_search` method for dedicated beam search instead, or set the environment variable `VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1` to suppress this error. For more details, see https://github.com/vllm-project/vllm/issues/8306 ." # noqa + ) self._verify_beam_search() else: self._verify_non_beam_search() diff --git a/vllm/sequence.py b/vllm/sequence.py index 8d486395fe5d..49a198df045b 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -26,6 +26,8 @@ VLLM_TOKEN_ID_ARRAY_TYPE = "l" +VLLM_INVALID_TOKEN_ID = -1 + # We use dataclass for now because it is used for # openai server output, and msgspec is not serializable. @@ -436,7 +438,7 @@ def __init__( self.stop_reason: Union[int, str, None] = None # These are used to keep track of delta outputs - self._last_token_ids_offset: int = 0 + self._last_output_token_ids_offset: int = 0 self._last_output_text_offset: int = 0 # Used for incremental detokenization @@ -507,18 +509,26 @@ def get_output_text_to_return(self, buffer_length: int, return self.output_text[last_offset:length] return "" - def get_output_token_ids_to_return(self, - delta: bool) -> GenericSequence[int]: + def get_output_token_ids_to_return( + self, delta: bool) -> Union[GenericSequence[int], int]: """If delta is True, only new tokens since the last call to this method are returned""" if not delta: return self.get_output_token_ids() - length = self.get_output_len() - last_offset = self._last_token_ids_offset - if last_offset < length: - self._last_token_ids_offset = length - return self.data._output_token_ids[last_offset:] - return () + + output_len = self.get_output_len() + + # Get the number of new tokens + num_new_tokens = output_len - self._last_output_token_ids_offset + self._last_output_token_ids_offset = output_len + + # Return new tokens + if num_new_tokens == 1: + # Optimization for single decode token case + # (which is what we have most of the time) + return self.data._cached_all_token_ids[-1] + + return self.data._cached_all_token_ids[-num_new_tokens:] def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size @@ -644,6 +654,7 @@ class SequenceGroup: unless you are working with an encoder/decoder model. trace_headers: OpenTelemetry trace headers. prompt_adapter_request: Prompt Adapter request. + priority: User-defined priority of the request. """ def __init__( @@ -658,9 +669,11 @@ def __init__( encoder_seq: Optional[Sequence] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, ) -> None: self.request_id = request_id self.seqs = seqs + self.arrival_time = arrival_time self.is_single_seq = len(seqs) == 1 self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -678,6 +691,9 @@ def __init__( self.prompt_adapter_request = prompt_adapter_request self.encoder_seq = encoder_seq self.trace_headers = trace_headers + self.priority = priority + + self.cached_request_output = None @property def prompt(self) -> Optional[str]: diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index b2204e8b27af..9eb8bbfc5407 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -6,9 +6,9 @@ from vllm import SamplingParams from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest, - SequenceData, SequenceGroupMetadata, - get_all_seq_ids) +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE, + ExecuteModelRequest, SequenceData, + SequenceGroupMetadata, get_all_seq_ids) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len @@ -69,10 +69,10 @@ def score_proposals( proposal_lens_list = proposals.proposal_lens.tolist() proposal_token_ids_list = proposals.proposal_token_ids.tolist() - # Filter the list to ignore -1 proposals. + # Filter the list to ignore invalid proposals. proposal_token_ids_list_without_skips = [ proposals for proposals in proposal_token_ids_list - if -1 not in proposals + if VLLM_INVALID_TOKEN_ID not in proposals ] (spec_indices, non_spec_indices, target_seq_group_metadata_list, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 9e645a49f699..dbf880a8f475 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -13,9 +13,10 @@ SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) from vllm.model_executor.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) -from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, ExecuteModelRequest, HiddenStates, SequenceGroupMetadata, - get_all_seq_ids, get_all_seq_ids_and_request_ids) + get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, @@ -28,7 +29,8 @@ from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker from vllm.spec_decode.target_model_runner import TargetModelRunner -from vllm.spec_decode.util import (Timer, create_sequence_group_output, +from vllm.spec_decode.util import (Timer, create_logprobs_output, + create_sequence_group_output, get_all_num_logprobs, get_sampled_token_logprobs, nvtx_range, split_batch_by_proposal_len) @@ -436,8 +438,8 @@ def _serialize_sampler_output_no_logprobs( self, execute_model_req: ExecuteModelRequest, sampler_output: SamplerOutput) -> SamplerOutput: """ - Creates and returns a `SamplerOutput` with only the sampled token IDs - being serialized to CPU & populated in `CompletionSequenceGroupOutput`. + Creates and returns a `SamplerOutput` with only the token IDs being + serialized to CPU and populated in `CompletionSequenceGroupOutput`. All other parameters in `CompletionSequenceGroupOutput` related to log probabilities are skipped. @@ -449,14 +451,46 @@ def _serialize_sampler_output_no_logprobs( Returns: SamplerOutput: A new `SamplerOutput` instance containing a list of - `CompletionSequenceGroupOutput` objects with only sampled token - IDs populated. + `CompletionSequenceGroupOutput` objects with only token IDs + populated. """ - seq_ids = get_all_seq_ids(execute_model_req.seq_group_metadata_list) - sampled_token_ids_list = sampler_output.sampled_token_ids.tolist() + seq_output_prompt_logprobs = [ + seq.is_prompt and seq.sampling_params.prompt_logprobs is not None + and seq.sampling_params.prompt_logprobs > 0 + for seq in execute_model_req.seq_group_metadata_list + ] + # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID + sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where( + # subtracting is faster than testing for equality + sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \ + if any(seq_output_prompt_logprobs) else \ + sampler_output.sampled_token_ids).tolist() + + seq_data_entries = ( + (seq_id, seq_data) for sg in \ + execute_model_req.seq_group_metadata_list \ + for seq_id, seq_data in sg.seq_data.items() + ) completion_seq_group_output_list: List[ CompletionSequenceGroupOutput] = [] - for index, seq_id in enumerate(seq_ids): + for index, ((seq_id, seq_data), needs_prompt_logprobs) in \ + enumerate(zip(seq_data_entries, seq_output_prompt_logprobs)): + if needs_prompt_logprobs: + prompt_token_ids = seq_data.get_prompt_token_ids() + prompt_logprobs = [ + create_logprobs_output( + token_id=p_token_id, + token_id_logprob_rank=-1, + token_id_logprob=0.0, + topk_token_ids=[], + topk_logprobs=[], + ) + # no prompt logprobs for the first token + for p_token_id in prompt_token_ids[1:] + ] + else: + prompt_logprobs = None + completion_seq_group_output_list.append( create_sequence_group_output( token_id=sampled_token_ids_list[index][0], @@ -465,7 +499,7 @@ def _serialize_sampler_output_no_logprobs( seq_id=seq_id, topk_token_ids=[], topk_logprobs=[], - )) + prompt_logprobs=prompt_logprobs)) return SamplerOutput(outputs=completion_seq_group_output_list) @nvtx_range("spec_decode_worker._run_no_spec") @@ -485,6 +519,12 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, # Store hidden states from target model execution. hidden_states = sampler_output.hidden_states if hidden_states is not None: + # remove hidden_states for prompt tokens + if any(seq.is_prompt + for seq in execute_model_req.seq_group_metadata_list): + hidden_states = hidden_states[ + torch.where(sampler_output.sampled_token_ids - + VLLM_INVALID_TOKEN_ID)[0]] if self.previous_hidden_states is None: self.previous_hidden_states = HiddenStates( hidden_states, execute_model_req.seq_group_metadata_list) diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index 54e718bc4901..193ef870dfce 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -6,7 +6,8 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SequenceGroupMetadata, SequenceOutput) + PromptLogprobs, SequenceGroupMetadata, + SequenceOutput) SeqId = int @@ -49,21 +50,19 @@ def get_sampled_token_logprobs( return sampled_token_ids_ranks, selected_logprobs -def create_sequence_group_output( +def create_logprobs_output( token_id: int, token_id_logprob_rank: int, token_id_logprob: float, - seq_id: SeqId, topk_token_ids: List[Optional[int]], topk_logprobs: List[Optional[float]], -) -> CompletionSequenceGroupOutput: - """Create a SequenceGroupOutput given the sampling results. +) -> Dict[int, Logprob]: + """Create a Logprob Dict for a token given the sampling results. Args: token_id (int): The sampled token for the sequence. token_id_logprob_rank (int): The logprob rank of the sampled token. token_id_logprob (float): The logprob value of the sampled token. - seq_id (int): The sequence id. topk_token_ids (List[Optional[int]]): The list of top-k token ids. topk_logprobs (List[Optional[float]]): The list of top-k logprobs. """ @@ -85,14 +84,44 @@ def create_sequence_group_output( if topk_token_id is not None }) + return logprobs + + +def create_sequence_group_output( + token_id: int, + token_id_logprob_rank: int, + token_id_logprob: float, + seq_id: SeqId, + topk_token_ids: List[Optional[int]], + topk_logprobs: List[Optional[float]], + prompt_logprobs: Optional[PromptLogprobs] = None, +) -> CompletionSequenceGroupOutput: + """Create a SequenceGroupOutput given the sampling results. + + Args: + token_id (int): The sampled token for the sequence. + token_id_logprob_rank (int): The logprob rank of the sampled token. + token_id_logprob (float): The logprob value of the sampled token. + seq_id (int): The sequence id. + topk_token_ids (List[Optional[int]]): The list of top-k token ids. + topk_logprobs (List[Optional[float]]): The list of top-k logprobs. + """ + + logprobs = create_logprobs_output( + token_id, + token_id_logprob_rank, + token_id_logprob, + topk_token_ids, + topk_logprobs, + ) + return CompletionSequenceGroupOutput( samples=[ SequenceOutput(parent_seq_id=seq_id, output_token=token_id, logprobs=logprobs) ], - # TODO add prompt logprobs support. - prompt_logprobs=None, + prompt_logprobs=prompt_logprobs, ) diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index d27d7ba9e67b..2b418f3603a0 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -1,13 +1,11 @@ from typing import Dict, List, Optional, Tuple -from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams, + Sequence, SequenceGroup) from .tokenizer import AnyTokenizer from .tokenizer_group import BaseTokenizerGroup -# Used eg. for marking rejected tokens in spec decoding. -INVALID_TOKEN_ID = -1 - class Detokenizer: """Provides methods to decode the output of a model into text.""" @@ -61,7 +59,7 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, continue for token_id, sample_logprob in prompt_logprobs_for_token.items(): if (sample_logprob.decoded_token is None - and token_id != INVALID_TOKEN_ID): + and token_id != VLLM_INVALID_TOKEN_ID): prompt_token_ids_with_token = ( prompt_token_ids[:token_position] + [token_id]) (new_tokens, new_text, new_prefix_offset, @@ -143,7 +141,7 @@ def decode_sequence_inplace(self, seq: Sequence, continue if (sample_logprob.decoded_token is None - and token_id != INVALID_TOKEN_ID): + and token_id != VLLM_INVALID_TOKEN_ID): all_input_ids_with_logprob = previous_tokens + [token_id] (_, new_text, _, _) = detokenize_incrementally( tokenizer=tokenizer, @@ -282,14 +280,14 @@ def detokenize_incrementally( assert prev_tokens is not None # If the new token id is out of bounds, return an empty string. - if new_token_id >= len(tokenizer): - new_tokens = [""] - else: + if 0 <= new_token_id < len(tokenizer): # Put new_token_id in a list so skip_special_tokens is respected new_tokens = tokenizer.convert_ids_to_tokens( [new_token_id], skip_special_tokens=skip_special_tokens) if isinstance(new_tokens, str): new_tokens = [new_tokens] + else: + new_tokens = [""] output_tokens = prev_tokens + new_tokens # If this is the first iteration, return all tokens. diff --git a/vllm/transformers_utils/image_processor.py b/vllm/transformers_utils/image_processor.py deleted file mode 100644 index 4cffac3724ba..000000000000 --- a/vllm/transformers_utils/image_processor.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import cast - - -def get_video_processor( - processor_name: str, - trust_remote_code: bool = False, -): - """ - Gets a processor for the given model name via HuggingFace. - """ - from transformers import AutoProcessor - - try: - processor = AutoProcessor.from_pretrained(processor_name) - video_processor = processor.video_processor - - except ValueError as e: - if not trust_remote_code: - err_msg = ( - "Failed to load the processor. If the processor is " - "a custom processor not yet available in the HuggingFace " - "transformers library, consider setting " - "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e - return video_processor - - -def get_image_processor( - processor_name: str, - *args, - trust_remote_code: bool = False, - **kwargs, -): - """Gets an image processor for the given model name via HuggingFace.""" - # don't put this import at the top level - # it will call torch.cuda.device_count() - from transformers import AutoImageProcessor - from transformers.image_processing_utils import BaseImageProcessor - - try: - processor = AutoImageProcessor.from_pretrained( - processor_name, - *args, - trust_remote_code=trust_remote_code, - **kwargs) - except ValueError as e: - # If the error pertains to the processor class not existing or not - # currently being imported, suggest using the --trust-remote-code flag. - # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors - if not trust_remote_code: - err_msg = ( - "Failed to load the image processor. If the image processor is " - "a custom processor not yet available in the HuggingFace " - "transformers library, consider setting " - "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e - - return cast(BaseImageProcessor, processor) diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 2001746c5f7f..98663f7f0bd0 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -1,13 +1,13 @@ -from typing import cast +from typing import Any, cast def get_processor( processor_name: str, - *args, + *args: Any, trust_remote_code: bool = False, - **kwargs, + **kwargs: Any, ): - """Gets a processor for the given model name via HuggingFace.""" + """Load a processor for the given model name via HuggingFace.""" # don't put this import at the top level # it will call torch.cuda.device_count() from transformers import AutoProcessor @@ -35,3 +35,60 @@ def get_processor( raise e return cast(ProcessorMixin, processor) + + +def get_image_processor( + processor_name: str, + *args: Any, + trust_remote_code: bool = False, + **kwargs: Any, +): + """Load an image processor for the given model name via HuggingFace.""" + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoImageProcessor + from transformers.image_processing_utils import BaseImageProcessor + + try: + processor = AutoImageProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs) + except ValueError as e: + # If the error pertains to the processor class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors + if not trust_remote_code: + err_msg = ( + "Failed to load the image processor. If the image processor is " + "a custom processor not yet available in the HuggingFace " + "transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + return cast(BaseImageProcessor, processor) + + +def get_video_processor( + processor_name: str, + *args: Any, + trust_remote_code: bool = False, + **kwargs: Any, +): + """Load a video processor for the given model name via HuggingFace.""" + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers.image_processing_utils import BaseImageProcessor + + processor = get_processor( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + return cast(BaseImageProcessor, processor.video_processor) diff --git a/vllm/utils.py b/vllm/utils.py index b1513b91a06c..b73e3b9bbf68 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -4,6 +4,8 @@ import datetime import enum import gc +import inspect +import ipaddress import os import random import socket @@ -532,6 +534,14 @@ def get_ip() -> str: return "0.0.0.0" +def is_valid_ipv6_address(address: str) -> bool: + try: + ipaddress.IPv6Address(address) + return True + except ValueError: + return False + + def get_distributed_init_method(ip: str, port: int) -> str: # Brackets are not permitted in ipv4 addresses, # see https://github.com/python/cpython/issues/103848 @@ -1237,6 +1247,53 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, return await task(*args, **kwargs) +def get_allowed_kwarg_only_overrides( + callable: Callable[..., object], + overrides: Optional[Dict[str, Any]], +) -> Dict[str, Any]: + """ + Given a callable which has one or more keyword only params and a dict + mapping param names to values, drop values that can be not be kwarg + expanded to overwrite one or more keyword-only args. This is used in a + few places to handle custom processor overrides for multimodal models, + e.g., for profiling when processor options provided by the user + may affect the number of mm tokens per instance. + + Args: + callable: Callable which takes 0 or more keyword only arguments. + overrides: Potential overrides to be used when invoking the callable. + + Returns: + Dictionary containing the kwargs to be leveraged which may be used + to overwrite one or more keyword only arguments when invoking the + callable. + """ + if not overrides: + return {} + + allowed_override_names = [ + name for name, param in inspect.signature(callable).parameters.items() + if param.kind == inspect.Parameter.KEYWORD_ONLY + ] + + # Drop any mm_processor_kwargs provided by the user that are + # not kwarg names accepted by the provided input processor. + filtered_overrides = { + kwarg_name: val + for kwarg_name, val in overrides.items() + if kwarg_name in allowed_override_names + } + + # If anything is dropped, log a warning + dropped_keys = overrides.keys() - filtered_overrides.keys() + if dropped_keys: + logger.warning( + "The following intended overrides are not keyword-only args " + "and and will be dropped: %s", dropped_keys) + + return filtered_overrides + + # Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. # In particular, the FakeScalarType is not supported for earlier versions of # PyTorch which breaks dynamo for any ops registered using ScalarType. diff --git a/vllm/version.py b/vllm/version.py index 0ddc7fb99ad4..66e189dcedf7 100644 --- a/vllm/version.py +++ b/vllm/version.py @@ -1,13 +1,11 @@ -import warnings - try: - import vllm.commit_id - - __commit__ = vllm.commit_id.__commit__ + from ._version import __version__, __version_tuple__ except Exception as e: + import warnings + warnings.warn(f"Failed to read commit hash:\n{e}", RuntimeWarning, stacklevel=2) - __commit__ = "COMMIT_HASH_PLACEHOLDER" -__version__ = "0.6.1.post2" + __version__ = "dev" + __version_tuple__ = (0, 0, __version__) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 7b2caf497358..d7d7d65659b7 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,3 +1,5 @@ +import dataclasses +import weakref from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union @@ -17,7 +19,7 @@ from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad from vllm.worker.model_runner_base import ( - ModelRunnerBase, ModelRunnerInputBase, + ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, @@ -32,16 +34,17 @@ @dataclass(frozen=True) -class CPUModelInput(ModelRunnerInputBase): +class ModelInputForCPU(ModelRunnerInputBase): """ - Used by the CPUModelRunner. + Base class contains metadata needed for the base model forward pass on CPU """ input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None attn_metadata: Optional["AttentionMetadata"] = None - sampling_metadata: Optional["SamplingMetadata"] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None virtual_engine: Optional[int] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -51,88 +54,96 @@ def as_broadcastable_tensor_dict( "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) - _add_sampling_metadata_broadcastable_dict(tensor_dict, - self.sampling_metadata) + return tensor_dict @classmethod def from_broadcasted_tensor_dict( - cls: Type["CPUModelInput"], - tensor_dict: Dict[str, Any], - attn_backend: Optional["AttentionBackend"] = None - ) -> "CPUModelInput": - tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + cls: Type["ModelInputForCPU"], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None + ) -> "ModelInputForCPU": if attn_backend is not None: tensor_dict = _init_attn_metadata_from_tensor_dict( attn_backend, tensor_dict) return cls(**tensor_dict) -class CPUModelRunner(ModelRunnerBase[CPUModelInput]): +@dataclass(frozen=True) +class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None - def __init__( - self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - kv_cache_dtype: Optional[str] = "auto", - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - is_driver_worker: bool = False, - *args, - **kwargs, - ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - # Currently, CPU worker doesn't support chunked prefill. - assert self.scheduler_config.chunked_prefill_enabled is False - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.prompt_adapter_config = prompt_adapter_config - self.load_config = load_config - self.is_driver_worker = is_driver_worker + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict - self.device = self.device_config.device + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForCPUWithSamplingMetadata": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) - self.kv_cache_dtype = kv_cache_dtype - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), - self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), - self.model_config.get_sliding_window(), - self.model_config.dtype, - self.kv_cache_dtype, - self.block_size, - ) - # Multi-modal data support - self.mm_registry = MULTIMODAL_REGISTRY - self.multi_modal_input_mapper = self.mm_registry \ - .create_input_mapper(self.model_config) - self.mm_registry.init_mm_limits_per_prompt(self.model_config) +class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): - # Lazy initialization. - self.model: nn.Module # Set after init_Model + def __init__(self, + runner: "CPUModelRunner", + finished_requests_ids: Optional[List[str]] = None) -> None: + super().__init__() + self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] + self.runner = runner + self.model_input_cls = self.runner._model_input_cls + self.attn_backend = self.runner.attn_backend + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.device = self.runner.device + self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper - if self.model_config.is_encoder_decoder_model: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU']) + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): + self.seq_group_metadata_list.append(seq_group_metadata) - def load_model(self) -> None: - self.model = get_model(model_config=self.model_config, - load_config=self.load_config, - device_config=self.device_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config) + def build(self) -> ModelInputForCPU: + multi_modal_kwargs = None + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = self.seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_kwargs) = self._prepare_prompt( + self.seq_group_metadata_list) + else: + (input_tokens, input_positions, + attn_metadata) = self._prepare_decode( + self.seq_group_metadata_list) + seq_lens = [] + + return self.model_input_cls( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + multi_modal_kwargs=multi_modal_kwargs, + # query_lens is not needed if chunked prefill is not + # supported. Since CPU worker doesn't support chunked prefill + # just use seq_lens instead. + seq_lens=seq_lens, + query_lens=seq_lens, + ) def _prepare_prompt( self, @@ -165,8 +176,7 @@ def _prepare_prompt( # is always the first token in the sequence. input_positions.extend(list(range(computed_len, seq_len))) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: + if (mm_data := seq_group_metadata.multi_modal_data): mm_kwargs = self.multi_modal_input_mapper(mm_data) multi_modal_inputs_list.append(mm_kwargs) @@ -302,56 +312,130 @@ def _prepare_decode( attn_metadata, ) + +class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]): + _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = ( + ModelInputForCPUWithSamplingMetadata) + _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + is_driver_worker: bool = False, + *args, + **kwargs, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + # Currently, CPU worker doesn't support chunked prefill. + assert self.scheduler_config.chunked_prefill_enabled is False + self.device_config = device_config + self.cache_config = cache_config + self.lora_config = lora_config + self.prompt_adapter_config = prompt_adapter_config + self.load_config = load_config + self.is_driver_worker = is_driver_worker + + self.device = self.device_config.device + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) + + # Multi-modal data support + self.mm_registry = MULTIMODAL_REGISTRY + self.multi_modal_input_mapper = self.mm_registry \ + .create_input_mapper(self.model_config) + self.mm_registry.init_mm_limits_per_prompt(self.model_config) + + # Lazy initialization. + self.model: nn.Module # Set after init_Model + + if self.model_config.is_encoder_decoder_model: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU']) + + def load_model(self) -> None: + self.model = get_model(model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) + def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any], - ) -> CPUModelInput: - return CPUModelInput.from_broadcasted_tensor_dict( + ) -> ModelInputForCPU: + return ModelInputForCPU.from_broadcasted_tensor_dict( tensor_dict, attn_backend=self.attn_backend, ) + def _prepare_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForCPUWithSamplingMetadata: + """Helper method to prepare the model input based on a given sequence + group. Prepares metadata needed for the base model forward pass but not + metadata for possible additional steps, e.g., sampling. + + """ + builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) + for seq_group_metadata in seq_group_metadata_list: + builder.add_seq_group(seq_group_metadata) + + return builder.build() # type: ignore + def prepare_model_input( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None - ) -> CPUModelInput: - multi_modal_kwargs = None - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs - ) = self._prepare_prompt(seq_group_metadata_list) - else: - (input_tokens, input_positions, - attn_metadata) = self._prepare_decode(seq_group_metadata_list) - seq_lens = [] - sampling_metadata = SamplingMetadata.prepare( - seq_group_metadata_list, - seq_lens, - # query_lens is not needed if chunked prefill is not - # supported. Since CPU worker doesn't support chunked prefill - # just use seq_lens instead. - seq_lens, - self.device, - pin_memory=False, - generators=self.get_generators(finished_requests_ids)) - return CPUModelInput( - input_tokens=input_tokens, - input_positions=input_positions, - attn_metadata=attn_metadata, - sampling_metadata=sampling_metadata, - multi_modal_kwargs=multi_modal_kwargs, - ) + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForCPUWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + # Sampling metadata is only required for the final pp group + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + pin_memory=False, + generators=generators) + + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + virtual_engine=virtual_engine) @torch.no_grad() def execute_model( self, - model_input: CPUModelInput, + model_input: ModelInputForCPUWithSamplingMetadata, kv_caches: List[torch.Tensor], intermediate_tensors: Optional[IntermediateTensors] = None, num_steps: int = 1, @@ -372,6 +456,8 @@ def execute_model( model_input.attn_metadata, **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), + "intermediate_tensors": + intermediate_tensors, } hidden_states = model_executable(**execute_model_kwargs) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 975b88c0e79a..86883cf15244 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -137,7 +137,15 @@ def _wrapper(*args, **kwargs): for t in kv_caches if is_tensor(t)] - pickle.dump(dumped_inputs, filep) + try: + pickle.dump(dumped_inputs, filep) + except Exception as pickle_err: + logger.warning( + "Failed to pickle inputs of failed execution: %s", + str(pickle_err)) + raise type(err)(f"Error in model execution: " + f"{str(err)}") from err + logger.info( "Completed writing input of failed execution to %s.", filename) From d7750d3348186d45ef761772a9b2cedab5b672cc Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 24 Sep 2024 23:33:33 -0700 Subject: [PATCH 64/75] update doc and hf model id --- docs/source/models/supported_models.rst | 5 +++++ examples/offline_inference_vision_language.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index d86d0860f7f2..bf690726a637 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -254,6 +254,11 @@ Multimodal Language Models - Image\ :sup:`+` - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - + * - :code:`MllamaForConditionalGeneration` + - Llama 3.2 + - Image + - :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc. + - * - :code:`PaliGemmaForConditionalGeneration` - PaliGemma - Image\ :sup:`E` diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index d4cb5e285a87..40a0fede7463 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -246,7 +246,7 @@ def run_qwen2_vl(question, modality): def run_mllama(question, modality): assert modality == "image" - model_name = "nltpt/Llama-3.2-11B-Vision-Instruct" + model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" llm = LLM( model=model_name, From 1ebd6dc82d28a0876f61fce8bb1ba236b0a88dc3 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 24 Sep 2024 23:34:46 -0700 Subject: [PATCH 65/75] update hf model id again --- tests/models/encoder_decoder/vision_language/test_mllama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index 44c657bc317d..531d347638a3 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -26,7 +26,7 @@ ] models = [ - "nltpt/Llama-3.2-11B-Vision-Instruct", + "meta-llama/Llama-3.2-11B-Vision-Instruct", ] # TODO: Update model path to huggingface model From c8577358c6de271bceffff854880d2c595ce902d Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 25 Sep 2024 11:07:26 -0700 Subject: [PATCH 66/75] fix format problem --- vllm/model_executor/models/qwen2_vl.py | 3 - vllm/worker/cpu_model_runner.py | 218 +++++++++++++++++-------- 2 files changed, 152 insertions(+), 69 deletions(-) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 2b65fff27b70..889ebc6c2e1f 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -69,9 +69,6 @@ from vllm.transformers_utils.processor import get_processor from vllm.utils import is_cpu -from .utils import (PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory) - from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index f822733d16c3..cebb0f36a2b2 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -103,69 +103,81 @@ def from_broadcasted_tensor_dict( class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): - def __init__( - self, - model_config: ModelConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - cache_config: CacheConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - kv_cache_dtype: Optional[str] = "auto", - prompt_adapter_config: Optional[PromptAdapterConfig] = None, - is_driver_worker: bool = False, - *args, - **kwargs, - ): - self.model_config = model_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - # Currently, CPU worker doesn't support chunked prefill. - assert self.scheduler_config.chunked_prefill_enabled is False - self.device_config = device_config - self.cache_config = cache_config - self.lora_config = lora_config - self.prompt_adapter_config = prompt_adapter_config - self.load_config = load_config - self.is_driver_worker = is_driver_worker - - self.device = self.device_config.device - - self.kv_cache_dtype = kv_cache_dtype - self.sliding_window = model_config.get_sliding_window() - self.block_size = cache_config.block_size - self.attn_backend = get_attn_backend( - self.model_config.get_num_attention_heads(self.parallel_config), - self.model_config.get_head_size(), - self.model_config.get_num_kv_heads(self.parallel_config), - self.model_config.get_sliding_window(), - self.model_config.dtype, - self.kv_cache_dtype, - self.block_size, + def __init__(self, + runner: "CPUModelRunner", + finished_requests_ids: Optional[List[str]] = None) -> None: + super().__init__() + self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] + self.runner = runner + self.model_input_cls = self.runner._model_input_cls + self.attn_backend = self.runner.attn_backend + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.device = self.runner.device + self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): + self.seq_group_metadata_list.append(seq_group_metadata) + + def build(self) -> ModelInputForCPU: + multi_modal_kwargs = None + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = self.seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_kwargs) = self._prepare_prompt( + self.seq_group_metadata_list) + else: + (input_tokens, input_positions, + attn_metadata) = self._prepare_decode( + self.seq_group_metadata_list) + seq_lens = [] + + return self.model_input_cls( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + multi_modal_kwargs=multi_modal_kwargs, + # query_lens is not needed if chunked prefill is not + # supported. Since CPU worker doesn't support chunked prefill + # just use seq_lens instead. + seq_lens=seq_lens, + query_lens=seq_lens, ) - # Multi-modal data support - self.mm_registry = MULTIMODAL_REGISTRY - self.multi_modal_input_mapper = self.mm_registry \ - .create_input_mapper(self.model_config) - self.mm_registry.init_mm_limits_per_prompt(self.model_config) - - # Lazy initialization. - self.model: nn.Module # Set after init_Model - - if self.model_config.is_encoder_decoder_model: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU']) - - def load_model(self) -> None: - self.model = get_model(model_config=self.model_config, - load_config=self.load_config, - device_config=self.device_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config, - cache_config=self.cache_config) + def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data, + computed_len: int): + mm_kwargs = self.multi_modal_input_mapper(mm_data) + + # special processing for mrope position deltas. + mrope_positions = None + if self.runner.model_is_mrope: + image_grid_thw = mm_kwargs.get("image_grid_thw", None) + video_grid_thw = mm_kwargs.get("video_grid_thw", None) + assert image_grid_thw is not None or video_grid_thw is not None, ( + "mrope embedding type requires multi-modal input mapper " + "returns 'image_grid_thw' or 'video_grid_thw'.") + + hf_config = self.runner.model_config.hf_config + token_ids = seq_data.get_token_ids() + + mrope_positions, mrope_position_delta = \ + MRotaryEmbedding.get_input_positions( + token_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + vision_start_token_id=hf_config.vision_start_token_id, + vision_end_token_id=hf_config.vision_end_token_id, + spatial_merge_size=hf_config.vision_config. + spatial_merge_size, + context_len=computed_len, + ) + seq_data.mrope_position_delta = mrope_position_delta + return mm_kwargs, mrope_positions def _prepare_prompt( self, @@ -210,11 +222,6 @@ def _prepare_prompt( else: input_positions.extend(list(range(computed_len, seq_len))) - mm_data = seq_group_metadata.multi_modal_data - if mm_data: - mm_kwargs = self.multi_modal_input_mapper(mm_data) - multi_modal_inputs_list.append(mm_kwargs) - # Compute the slot mapping. block_table = seq_group_metadata.block_tables[seq_id] # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, @@ -370,6 +377,85 @@ def _prepare_decode( attn_metadata, ) + +class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]): + _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = ( + ModelInputForCPUWithSamplingMetadata) + _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + is_driver_worker: bool = False, + *args, + **kwargs, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + # Currently, CPU worker doesn't support chunked prefill. + assert self.scheduler_config.chunked_prefill_enabled is False + self.device_config = device_config + self.cache_config = cache_config + self.lora_config = lora_config + self.prompt_adapter_config = prompt_adapter_config + self.load_config = load_config + self.is_driver_worker = is_driver_worker + + self.device = self.device_config.device + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.attn_backend = get_attn_backend( + self.model_config.get_num_attention_heads(self.parallel_config), + self.model_config.get_head_size(), + self.model_config.get_num_kv_heads(self.parallel_config), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + ) + + # Multi-modal data support + self.mm_registry = MULTIMODAL_REGISTRY + self.multi_modal_input_mapper = self.mm_registry \ + .create_input_mapper(self.model_config) + self.mm_registry.init_mm_limits_per_prompt(self.model_config) + + # Lazy initialization. + self.model: nn.Module # Set after init_Model + + if self.model_config.is_encoder_decoder_model: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU']) + + @property + def model_is_mrope(self) -> bool: + """Detect if the model has "mrope" rope_scaling type. + mrope requires keep "rope_deltas" between prompt and decoding phases.""" + rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) + if rope_scaling is None: + return False + return rope_scaling.get("type", None) == "mrope" + + def load_model(self) -> None: + self.model = get_model(model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) + def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any], From e4bf8038d8e047e1f4c83b3f8fceed874faee721 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 25 Sep 2024 11:34:54 -0700 Subject: [PATCH 67/75] Apply suggestions from code review Co-authored-by: Simon Mo --- tests/models/encoder_decoder/vision_language/test_mllama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index 531d347638a3..cda0926d0baf 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -28,7 +28,6 @@ models = [ "meta-llama/Llama-3.2-11B-Vision-Instruct", ] -# TODO: Update model path to huggingface model def vllm_to_hf_output(vllm_output: Tuple[List[int], str, From 4d7fe0abce9692e27255030098c10fd9e17a2540 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 25 Sep 2024 11:35:23 -0700 Subject: [PATCH 68/75] Update vllm/worker/enc_dec_model_runner.py Co-authored-by: Simon Mo --- vllm/worker/enc_dec_model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 14be7d1bd2b3..bd716ac3e7ec 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -294,7 +294,7 @@ def profile_run(self) -> None: max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( self.model_config) if max_mm_tokens > 0: - logger.warning("profile run for multi-modal models") + logger.info("Starting profile run for multi-modal models.") batch_size = 0 for group_id in range(max_num_seqs): From 4cdc6b54ea1ee5c8bf8c29cf2fd0d24a97d94c87 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 25 Sep 2024 11:35:47 -0700 Subject: [PATCH 69/75] Update vllm/worker/worker.py Co-authored-by: Simon Mo --- vllm/worker/worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 6ad14e921c3e..acd9a3f5a3e2 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -222,8 +222,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # of the model. self.model_runner.profile_run() - # # Calculate the number of blocks that can be allocated with the - # # profiled peak memory. + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. torch.cuda.synchronize() free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() From a6ad79fe4e1273f1d0b789a198c8c9e718c1d1df Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 25 Sep 2024 11:35:59 -0700 Subject: [PATCH 70/75] Update vllm/worker/worker.py Co-authored-by: Simon Mo --- vllm/worker/worker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index acd9a3f5a3e2..3851843afc96 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -226,7 +226,6 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # profiled peak memory. torch.cuda.synchronize() free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() - # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. peak_memory = self.init_gpu_memory - free_gpu_memory From 8364093e137fb6af019b9e3b3aae8051f619768b Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 25 Sep 2024 11:37:56 -0700 Subject: [PATCH 71/75] upgrade huggingface --- requirements-common.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-common.txt b/requirements-common.txt index c113ff363042..2fc89c026901 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -4,7 +4,7 @@ numpy < 2.0.0 requests tqdm py-cpuinfo -transformers >= 4.43.2 # Required for Chameleon and Llama 3.1 hotfox. +transformers >= 4.45.0 # Required for Llama 3.2. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. fastapi < 0.113.0; python_version < '3.9' From a12c8d332e05555a2f40d3abac358d9e756904b8 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 25 Sep 2024 11:41:46 -0700 Subject: [PATCH 72/75] Update vllm/transformers_utils/configs/__init__.py Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> --- vllm/transformers_utils/configs/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index a67944ce0b93..d5b13adb58a0 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -35,5 +35,4 @@ # Granite can be removed from here once we have upgraded to # transformers 4.45+ "GraniteConfig", - "" ] From 4065047e72ed62f580af0556523ebf19a0f9fc97 Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 25 Sep 2024 12:03:42 -0700 Subject: [PATCH 73/75] update code based on code review --- vllm/entrypoints/chat_utils.py | 48 +++++--------- vllm/model_executor/models/mllama.py | 95 ++++++++++++++-------------- 2 files changed, 64 insertions(+), 79 deletions(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 291239f43358..4a575ae8f853 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -375,34 +375,7 @@ def _parse_chat_message_content_parts( mm_tracker._model_config.hf_config.model_type in \ MODEL_KEEP_MULTI_MODAL_CONTENT - if keep_multimodal_content: - is_image = False - for part in parts: - part_type = part["type"] - if part_type == "text": - text = _TextParser(part)["text"] - texts.append(text) - elif part_type == "image_url": - image_url = _ImageParser(part)["image_url"] - - if image_url.get("detail", "auto") != "auto": - logger.warning( - "'image_url.detail' is currently not supported and " - "will be ignored.") - - mm_parser.parse_image(image_url["url"]) - is_image = True - else: - raise NotImplementedError(f"Unknown part type: {part_type}") - - text_prompt = "\n".join(texts) - role_content = [{'type': 'text', 'text': text_prompt}] - - if is_image: - role_content = [{'type': 'image'}] + role_content - return [ConversationMessage(role=role, - content=role_content)] # type: ignore - + has_image = False for part in parts: part_type = part["type"] if part_type == "text": @@ -417,6 +390,7 @@ def _parse_chat_message_content_parts( "will be ignored.") mm_parser.parse_image(image_url["url"]) + has_image = True elif part_type == "audio_url": audio_url = _AudioParser(part)["audio_url"] @@ -428,12 +402,20 @@ def _parse_chat_message_content_parts( raise NotImplementedError(f"Unknown part type: {part_type}") text_prompt = "\n".join(texts) - mm_placeholder_counts = mm_parser.mm_placeholder_counts() - if mm_placeholder_counts: - text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts, - text_prompt) + if keep_multimodal_content: + text_prompt = "\n".join(texts) + role_content = [{'type': 'text', 'text': text_prompt}] - return [ConversationMessage(role=role, content=text_prompt)] + if has_image: + role_content = [{'type': 'image'}] + role_content + return [ConversationMessage(role=role, + content=role_content)] # type: ignore + else: + mm_placeholder_counts = mm_parser.mm_placeholder_counts() + if mm_placeholder_counts: + text_prompt = _get_full_multimodal_text_prompt( + mm_placeholder_counts, text_prompt) + return [ConversationMessage(role=role, content=text_prompt)] # No need to validate using Pydantic again diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 6dac117cdc6c..aa868a3b8da2 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -54,7 +54,8 @@ from .llama import LlamaDecoderLayer, LlamaMLP logger = init_logger(__name__) -LLAMA_IMAGE_TOKEN_ID = 128256 +MLLAMA_IMAGE_TOKEN_ID = 128256 +MLLAMA_IMAGE_TOKEN = "<|image|>" class MllamaImagePixelInputs(TypedDict): @@ -72,7 +73,7 @@ class MllamaImagePixelInputs(TypedDict): def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): - # move prompt to encoder_prompt + # move encoder_prompt to prompt if llm_inputs.get("prompt") is None: llm_inputs["prompt"] = llm_inputs["encoder_prompt"] llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"] @@ -113,8 +114,8 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): "chunk size should be multiple of 14" token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 num_tokens = num_tiles * token_per_chunk - llm_inputs["encoder_prompt"] = "<|image|>" * num_tokens - llm_inputs["encoder_prompt_token_ids"] = [LLAMA_IMAGE_TOKEN_ID + llm_inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens + llm_inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID ] * num_tokens return llm_inputs @@ -131,7 +132,7 @@ def dummy_decoder_seq_data(seq_len: int, num_images: int): assert seq_len >= num_images, \ "seq_len should be greater than or equal to num_images" token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [LLAMA_IMAGE_TOKEN_ID]) * num_images + [MLLAMA_IMAGE_TOKEN_ID]) * num_images token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - num_images) return SequenceData(token_ids) @@ -139,7 +140,7 @@ def dummy_decoder_seq_data(seq_len: int, num_images: int): def dummy_encoder_seq_data(ctx: InputContext, num_images: int): num_tokens = get_max_mllama_image_tokens(ctx) * num_images token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, - [LLAMA_IMAGE_TOKEN_ID]) * num_tokens + [MLLAMA_IMAGE_TOKEN_ID]) * num_tokens return SequenceData(token_ids) @@ -298,6 +299,7 @@ def forward(self, hidden_state: torch.Tensor, return hidden_state +# TODO: support other attention backends for attention in vision model class MllamaVisionSdpaAttention(nn.Module): def __init__(self, config: config_mllama.MllamaVisionConfig): @@ -354,7 +356,9 @@ def forward( class MllamaVisionEncoderLayer(nn.Module): - def __init__(self, config, is_gated: bool = False): + def __init__(self, + config: config_mllama.MllamaVisionConfig, + is_gated: bool = False): super().__init__() self.hidden_size = config.hidden_size @@ -399,13 +403,6 @@ def forward( class MllamaVisionEncoder(nn.Module): - """ - Transformer encoder consisting of `config.num_hidden_layers` self attention - layers. Each layer is a [`MllamaEncoderLayer`]. - - Args: - config: MllamaConfig - """ def __init__(self, config: config_mllama.MllamaVisionConfig, @@ -768,9 +765,6 @@ def forward( class MllamaTextModel(nn.Module): config_class = config_mllama.MllamaTextConfig base_model_prefix = "model" - _no_split_modules = [ - "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer" - ] def __init__(self, config: config_mllama.MllamaTextConfig, cache_config: Optional[CacheConfig], @@ -907,7 +901,7 @@ def __init__(self, config.pad_token_id if config.pad_token_id is not None else -1 self.image_size = config.vision_config.image_size - self.vision_model = MllamaVisionModel(config.vision_config, ) + self.vision_model = MllamaVisionModel(config.vision_config) self.language_model = MllamaForCausalLM( config.text_config, cache_config=cache_config, @@ -1024,6 +1018,38 @@ def _parse_and_validate_image_input(self, **kwargs: object): raise AssertionError("This line should be unreachable.") + def flat_encoder_result(self, cross_attention_states: torch.Tensor, + attn_metadata: AttentionMetadata): + + cross_attention_states_flat = torch.zeros( + sum(attn_metadata.encoder_seq_lens), + cross_attention_states.shape[-1], + device=cross_attention_states.device, + dtype=cross_attention_states.dtype) + start_pos = 0 + for seq_len, vision_token_in_batch in zip( + attn_metadata.encoder_seq_lens, cross_attention_states): + end_pos = start_pos + seq_len + cross_attention_states_flat[ + start_pos:end_pos] = vision_token_in_batch[:seq_len] + start_pos = end_pos + cross_attention_states = cross_attention_states_flat + + full_text_row_masked_out_mask = torch.ones( + (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool) + start_pos = 0 + for seq_len, encoder_seq_len in zip( + attn_metadata.seq_lens_tensor.cpu(), + attn_metadata.encoder_seq_lens): + if encoder_seq_len == 0: + full_text_row_masked_out_mask[start_pos:start_pos + + seq_len] = False + start_pos += seq_len + full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( + cross_attention_states.device) + + return cross_attention_states, full_text_row_masked_out_mask + def forward( self, input_ids: torch.Tensor, @@ -1044,7 +1070,7 @@ def forward( cross_attention_states = None skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0 else: - # llama's reference implementation runs the vision model on CPU + # NOTE: llama's reference implementation runs vision model on CPU pixel_values = image_inputs['data'] aspect_ratio_ids = image_inputs['aspect_ratio_ids'] aspect_ratio_mask = image_inputs['aspect_ratio_mask'] @@ -1058,34 +1084,11 @@ def forward( cross_attention_states = cross_attention_states.view( bsz, -1, image_token_dim) - cross_attention_states_flat = torch.zeros( - sum(attn_metadata.encoder_seq_lens), - image_token_dim, - device=cross_attention_states.device, - dtype=cross_attention_states.dtype) - start_pos = 0 - for seq_len, vision_token_in_batch in zip( - attn_metadata.encoder_seq_lens, cross_attention_states): - end_pos = start_pos + seq_len - cross_attention_states_flat[ - start_pos:end_pos] = vision_token_in_batch[:seq_len] - start_pos = end_pos - cross_attention_states = cross_attention_states_flat - cross_attention_mask = None # TODO - - full_text_row_masked_out_mask = torch.ones( - (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool) - start_pos = 0 - for seq_len, encoder_seq_len in zip( - attn_metadata.seq_lens_tensor.cpu(), - attn_metadata.encoder_seq_lens): - if encoder_seq_len == 0: - full_text_row_masked_out_mask[start_pos:start_pos + - seq_len] = False - start_pos += seq_len - full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( - cross_attention_states.device) + cross_attention_states, full_text_row_masked_out_mask = \ + self.flat_encoder_result(cross_attention_states, attn_metadata) skip_cross_attention = False + # TODO: support multi-image by this mask + cross_attention_mask = None outputs = self.language_model( input_ids=input_ids, From 293f07f2d4b7ee54b0472e1ccb29c18115e9b349 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 25 Sep 2024 13:11:26 -0700 Subject: [PATCH 74/75] add note --- examples/offline_inference_vision_language.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 40a0fede7463..a9db8f2b576d 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -248,6 +248,12 @@ def run_mllama(question, modality): model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" + # Note: The default setting of max_num_seqs (256) and + # max_model_len (131072) for this model may cause OOM. + # You may lower either to run this example on lower-end GPUs. + + # The configuration below has been confirmed to launch on a + # single H100 GPU. llm = LLM( model=model_name, max_num_seqs=16, From 3db294bc39768649f152e3a264985633d465b882 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 25 Sep 2024 13:13:52 -0700 Subject: [PATCH 75/75] format --- examples/offline_inference_vision_language.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index a9db8f2b576d..6d34621a8a9b 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -252,7 +252,7 @@ def run_mllama(question, modality): # max_model_len (131072) for this model may cause OOM. # You may lower either to run this example on lower-end GPUs. - # The configuration below has been confirmed to launch on a + # The configuration below has been confirmed to launch on a # single H100 GPU. llm = LLM( model=model_name,