From de65d9424ffae2b8b3eaf60620ec497070226965 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 19 Aug 2024 16:45:58 -0400 Subject: [PATCH 01/35] Enable stub mm image inputs for qwen models Signed-off-by: Alex-Brooks --- vllm/model_executor/models/__init__.py | 2 +- vllm/model_executor/models/qwen.py | 33 ++++++++++++++++++++++---- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index f4c3e43c8f2a..787ee691f0e1 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -51,7 +51,6 @@ "PhiForCausalLM": ("phi", "PhiForCausalLM"), "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), - "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), @@ -87,6 +86,7 @@ "PaliGemmaForConditionalGeneration"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "UltravoxModel": ("ultravox", "UltravoxModel"), + "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), } _CONDITIONAL_GENERATION_MODELS = { "BartModel": ("bart", "BartForConditionalGeneration"), diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 8298e3bac446..e8dabbe85011 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -11,7 +11,7 @@ from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -29,10 +29,13 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once +# Multimodal imports +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.multimodal import MULTIMODAL_REGISTRY from .utils import is_pp_missing_parameter, make_layers - class QWenMLP(nn.Module): def __init__( @@ -236,17 +239,38 @@ def forward( hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states - -class QWenLMHeadModel(nn.Module): +### Stubs that need to be implemented... +def get_max_qwen_image_tokens(ctx: InputContext): + # TODO: calculate this + print("STUB: using hardcoded max image tokens for qwen") + return 2048 + +def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + print("STUB - skipping multimodal image data for qwen") + prompt = llm_inputs.get("prompt") + prompt_token_ids = llm_inputs["prompt_token_ids"] + return LLMInputs(prompt=prompt, + prompt_token_ids=prompt_token_ids, + multi_modal_data=None) + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen_image_tokens) +@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) +class QWenLMHeadModel(nn.Module, SupportsMultiModal): def __init__( self, config: PretrainedConfig, + multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config + self.multimodal_config = multimodal_config self.quant_config = quant_config self.transformer = QWenModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, @@ -264,6 +288,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + pixel_values: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) From debcc8c533c33f4b5b8cc03ba92195549a1bc9af Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 20 Aug 2024 01:52:30 -0400 Subject: [PATCH 02/35] Add calc for max number of qwen image tokens Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index e8dabbe85011..3394cb1147ea 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -36,6 +36,8 @@ from .utils import is_pp_missing_parameter, make_layers + + class QWenMLP(nn.Module): def __init__( @@ -239,11 +241,18 @@ def forward( hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states -### Stubs that need to be implemented... + def get_max_qwen_image_tokens(ctx: InputContext): - # TODO: calculate this - print("STUB: using hardcoded max image tokens for qwen") - return 2048 + """Calculates the max number of image tokens for qwen, i.e., the number of patches.""" + config = ctx.get_hf_config() + vision_config = config.visual + # Images and patches are square and are usually 448/14, respectively + image_size = vision_config["image_size"] + patch_size = vision_config["patch_size"] + # Features will be of size (grid_height, grid_height) + # so usually our max tokens will be 1024 for qwen-vl/chat + grid_height = image_size // patch_size + return grid_height ** 2 def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data = llm_inputs.get("multi_modal_data") From 9803fab12ebe2d4abd883fca586455434f042180 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 20 Aug 2024 02:12:28 -0400 Subject: [PATCH 03/35] Add unmodified visual qwen code, enable visual weight loading Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 433 ++++++++++++++++++++++++++++- 1 file changed, 424 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 3394cb1147ea..394d14192760 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -36,8 +36,429 @@ from .utils import is_pp_missing_parameter, make_layers +### This is a directly copy paste of the qwen visual modeling code for now +from collections import OrderedDict +import math +import requests +from io import BytesIO +from functools import partial +from PIL import Image +from typing import Callable, Optional, Sequence, Tuple, List +import numpy as np + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.init import trunc_normal_ +from torchvision import transforms +from torchvision.transforms import InterpolationMode + + +def get_abs_pos(abs_pos, tgt_size): + # abs_pos: L, C + # tgt_size: M + # return: M, C + src_size = int(math.sqrt(abs_pos.size(0))) + tgt_size = int(math.sqrt(tgt_size)) + dtype = abs_pos.dtype + + if src_size != tgt_size: + return F.interpolate( + abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), + size=(tgt_size, tgt_size), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) + else: + return abs_pos + +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class Resampler(nn.Module): + """ + A 2D perceiver-resampler network with one cross attention layers by + (grid_size**2) learnable queries and 2d sincos pos_emb + Outputs: + A tensor with the shape of (grid_size**2, embed_dim) + """ + def __init__( + self, + grid_size, + embed_dim, + num_heads, + kv_dim=None, + norm_layer=nn.LayerNorm + ): + super().__init__() + self.num_queries = grid_size ** 2 + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.pos_embed = nn.Parameter( + torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float() + ).requires_grad_(False) + + self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) + trunc_normal_(self.query, std=.02) + + if kv_dim is not None and kv_dim != embed_dim: + self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) + else: + self.kv_proj = nn.Identity() + + self.attn = nn.MultiheadAttention(embed_dim, num_heads) + self.ln_q = norm_layer(embed_dim) + self.ln_kv = norm_layer(embed_dim) + + # self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x, attn_mask=None): + + pos_embed = get_abs_pos(self.pos_embed, x.size(1)) + + x = self.kv_proj(x) + x = self.ln_kv(x).permute(1, 0, 2) + + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn( + self._repeat(q, N) + self.pos_embed.unsqueeze(1), + x + pos_embed.unsqueeze(1), + x, + attn_mask=attn_mask)[0] + return out.permute(1, 0, 2) + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class VisualAttention(nn.Module): + """self-attention layer class. + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, embed_dim, num_heads, + bias=True, kdim=None, vdim=None): + super(VisualAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + + # Per attention head and per partition values. + assert embed_dim % num_heads == 0 + self.hidden_size_per_attention_head = embed_dim // num_heads + self.num_attention_heads_per_partition = num_heads + self.hidden_size_per_partition = embed_dim + + # Strided linear layer. + assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently' + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim) + self.out_proj = nn.Linear(embed_dim, embed_dim) + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + + def forward(self, query, key, value, attn_mask = None): + # query/key/value: [sq, b, h] + sq, b, _ = query.size() + + assert torch.allclose(query, key), 'Only Support Self-Attention Currently' + sk = sq + mixed_x_layer = self.in_proj(query) + + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + query_layer, key_layer, value_layer = mixed_x_layer.split( + self.hidden_size_per_attention_head, dim=-1) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(sq, + b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(sk, + b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + + q_scaled = query_layer / self.norm_factor + if attn_mask is not None: + attention_probs = torch.baddbmm(attn_mask, q_scaled, key_layer.transpose(-2, -1)) + else: + attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1)) + attention_probs = attention_probs.softmax(dim=-1) + + value_layer = value_layer.view(sk, + b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = context_layer.view(b, + self.num_attention_heads_per_partition, + sq, self.hidden_size_per_attention_head) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + output = self.out_proj(context_layer) + + return output + + +class VisualAttentionBlock(nn.Module): + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + is_cross_attention: bool = False, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + if is_cross_attention: + self.ln_1_kv = norm_layer(d_model) + + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.attn = VisualAttention(d_model, n_head) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)) + ])) + + def attention( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = k_x if k_x is not None else q_x + v_x = v_x if v_x is not None else q_x + + attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None + return self.attn(q_x, k_x, v_x, attn_mask=attn_mask) + + def forward( + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + + x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x +class TransformerBlock(nn.Module): + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ): + super().__init__() + self.width = width + self.layers = layers + + self.resblocks = nn.ModuleList([ + VisualAttentionBlock( + width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def get_cast_device(self) -> torch.device: + return self.resblocks[0].mlp.c_fc.weight.device + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + n_queries: int = 256, + output_dim: int = 512, + **kwargs + ): + super().__init__() + image_height, image_width = self.image_size = (image_size, image_size) + patch_height, patch_width = self.patch_size = (patch_size, patch_size) + self.grid_size = (image_height // patch_height, image_width // patch_width) + self.output_dim = output_dim + + mean = (0.48145466, 0.4578275, 0.40821073) + std = (0.26862954, 0.26130258, 0.27577711) + self.image_transform = transforms.Compose([ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + # class embeddings and positional embeddings + scale = width ** -0.5 + self.positional_embedding = nn.Parameter(scale * torch.randn(256, width)) + + norm_layer = partial(nn.LayerNorm, eps=1e-6) + act_layer = nn.GELU + + self.ln_pre = norm_layer(width) + self.transformer = TransformerBlock( + width, + layers, + heads, + mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer, + ) + + self.attn_pool = Resampler( + grid_size=int(math.sqrt(n_queries)), + embed_dim=output_dim, + num_heads=output_dim // 128, + kv_dim=width, + norm_layer=norm_layer, + ) + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim)) + + def forward(self, x: torch.Tensor): + x = x.to( + dtype=self.transformer.get_cast_dtype(), + device=self.transformer.get_cast_device(), + ) + # to patches + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + x = x + get_abs_pos(self.positional_embedding, x.size(1)) + + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.attn_pool(x) + x = self.ln_post(x) + x = x @ self.proj + + return x + + def encode(self, image_paths: List[str]): + images = [] + for image_path in image_paths: + if image_path.startswith("http://") or image_path.startswith("https://"): + image = Image.open(requests.get(image_path, stream=True).raw) + else: + image = Image.open(image_path) + image = image.convert("RGB") + images.append(self.image_transform(image)) + images = torch.stack(images, dim=0) + return self(images) +### + class QWenMLP(nn.Module): def __init__( @@ -208,6 +629,7 @@ def __init__( lambda prefix: QWenBlock(config, cache_config, quant_config), prefix=f"{prefix}.h") self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.visual = VisionTransformer(**config.visual) def forward( self, @@ -254,6 +676,7 @@ def get_max_qwen_image_tokens(ctx: InputContext): grid_height = image_size // patch_size return grid_height ** 2 + def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: @@ -265,6 +688,7 @@ def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): prompt_token_ids=prompt_token_ids, multi_modal_data=None) + @MULTIMODAL_REGISTRY.register_image_input_mapper() @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen_image_tokens) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) @@ -362,15 +786,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip loading visual weights to support Qwen-VL models - # in cases with text-only inputs - # TODO: add support for Qwen-VL - if (name not in params_dict - and name.startswith("transformer.visual.")): - print_warning_once( - "Only text inputs are allowed. Images won't be handled " - "until Qwen-VL models are fully supported.") - continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue From 364c110216a8bfb5c49fce780584ee1e8ab03c07 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Mon, 26 Aug 2024 13:19:07 -0400 Subject: [PATCH 04/35] Implement model processor for qwen, fix max tokens Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 68 ++++++++++++++++++++++-------- 1 file changed, 50 insertions(+), 18 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 394d14192760..36931a77e19f 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -33,9 +33,20 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.image import cached_get_tokenizer from .utils import is_pp_missing_parameter, make_layers +IMG_START = "" +IMG_END = "" +IMG_PAD = "" +# Qwen models have a few other special tags, e.g., ref, bbox, quad; +# for the time being, these tags are not considered as special at encoding +# time. This may change as VLLMs multimodal API changes in the future. + +# Qwen images are encoded into a fixed token length of 256, not include IMG_START/IMG_END +MAX_QWEN_IMG_TOKENS = 256 + ### This is a directly copy paste of the qwen visual modeling code for now from collections import OrderedDict import math @@ -663,34 +674,55 @@ def forward( hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states +def get_image_text(image_num: int, padding: bool) -> str: + """Retrieves a placeholder text that when tokenized, will be expanded with image pads. -def get_max_qwen_image_tokens(ctx: InputContext): - """Calculates the max number of image tokens for qwen, i.e., the number of patches.""" - config = ctx.get_hf_config() - vision_config = config.visual - # Images and patches are square and are usually 448/14, respectively - image_size = vision_config["image_size"] - patch_size = vision_config["patch_size"] - # Features will be of size (grid_height, grid_height) - # so usually our max tokens will be 1024 for qwen-vl/chat - grid_height = image_size // patch_size - return grid_height ** 2 - + NOTE: The reason that the reason we don't directly encode the image padding here is that + it will break the re-encoding of the tokens tokenizer, because the contents between the + start / end are treated as bytes containing a URL that then get padded up to the image context + size. + """ + if not padding: + return f"Picture {image_num}: {IMG_START}{IMG_END}\n" + return f"Picture {image_num}: {IMG_START}{MAX_QWEN_IMG_TOKENS * IMG_PAD}{IMG_END}\n" def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: return llm_inputs - print("STUB - skipping multimodal image data for qwen") prompt = llm_inputs.get("prompt") prompt_token_ids = llm_inputs["prompt_token_ids"] - return LLMInputs(prompt=prompt, - prompt_token_ids=prompt_token_ids, - multi_modal_data=None) + model_config = ctx.model_config + tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) + image_data = multi_modal_data["image"] + + if prompt is None: + prompt = tokenizer.decode(prompt_token_ids) + + if not isinstance(image_data, Image.Image): + raise NotImplementedError("Image supported not yet implemented for directly provided image features yet") + + # Replace the image tag with the image prompt with no img pads. We currently do this in + # two steps to sidestep some tokenization substitution stuff with URLs behind handled as bytes + # that do not like existing image pads strings, but it would be nice to find a better way to + # handle it. + image_prompt_without_padding = get_image_text(0, padding=False) + image_prompt_with_padding = get_image_text(0, padding=True) + + new_prompt_no_img_pads = prompt.replace('', image_prompt_without_padding, 1) + new_prompt_with_img_pads = prompt.replace('', image_prompt_with_padding, 1) + new_prompt_token_ids = tokenizer.encode(new_prompt_no_img_pads) + + return LLMInputs(prompt=new_prompt_with_img_pads, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data) + +def input_mapper_for_qwen(ctx: InputContext, data: object): + raise NotImplementedError("Need to implement the input mapper") -@MULTIMODAL_REGISTRY.register_image_input_mapper() -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen_image_tokens) +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) +@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) class QWenLMHeadModel(nn.Module, SupportsMultiModal): From 9b4eb9ac9a219a3cdbe26517bf8c33724daa5a4b Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 27 Aug 2024 15:11:49 -0400 Subject: [PATCH 05/35] Implement qwen input mapper and visual feature forward call Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 101 +++++++++++++++-------------- 1 file changed, 52 insertions(+), 49 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 36931a77e19f..909716a46037 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -13,6 +13,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -26,14 +27,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors -from vllm.utils import print_warning_once -# Multimodal imports -from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.image import cached_get_tokenizer +from vllm.sequence import IntermediateTensors from .utils import is_pp_missing_parameter, make_layers @@ -155,10 +154,11 @@ def __init__( self.num_heads = num_heads self.pos_embed = nn.Parameter( - torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float() + # TODO - fix the hacks for device / dtype here & in the positional embedding retrieval + torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).half().to("cuda"), ).requires_grad_(False) - self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) + self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim).to("cuda")) trunc_normal_(self.query, std=.02) if kv_dim is not None and kv_dim != embed_dim: @@ -169,20 +169,9 @@ def __init__( self.attn = nn.MultiheadAttention(embed_dim, num_heads) self.ln_q = norm_layer(embed_dim) self.ln_kv = norm_layer(embed_dim) - - # self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - def forward(self, x, attn_mask=None): + def forward(self, x, attn_mask=None): pos_embed = get_abs_pos(self.pos_embed, x.size(1)) x = self.kv_proj(x) @@ -392,18 +381,6 @@ def __init__( patch_height, patch_width = self.patch_size = (patch_size, patch_size) self.grid_size = (image_height // patch_height, image_width // patch_width) self.output_dim = output_dim - - mean = (0.48145466, 0.4578275, 0.40821073) - std = (0.26862954, 0.26130258, 0.27577711) - self.image_transform = transforms.Compose([ - transforms.Resize( - (image_size, image_size), - interpolation=InterpolationMode.BICUBIC - ), - transforms.ToTensor(), - transforms.Normalize(mean=mean, std=std), - ]) - self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) # class embeddings and positional embeddings @@ -435,9 +412,10 @@ def __init__( def forward(self, x: torch.Tensor): x = x.to( - dtype=self.transformer.get_cast_dtype(), - device=self.transformer.get_cast_device(), + dtype=self.transformer.get_cast_dtype(), + device=self.transformer.get_cast_device(), ) + # to patches x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] @@ -457,18 +435,6 @@ def forward(self, x: torch.Tensor): return x - def encode(self, image_paths: List[str]): - images = [] - for image_path in image_paths: - if image_path.startswith("http://") or image_path.startswith("https://"): - image = Image.open(requests.get(image_path, stream=True).raw) - else: - image = Image.open(image_path) - image = image.convert("RGB") - images.append(self.image_transform(image)) - images = torch.stack(images, dim=0) - return self(images) -### class QWenMLP(nn.Module): @@ -649,9 +615,15 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + pixel_values: Optional[torch.Tensor]=None, ) -> torch.Tensor: + if pixel_values is not None: + # TODO: get the positions of the image tags from the text + images = self.visual(pixel_values) + if get_pp_group().is_first_rank: hidden_states = self.wte(input_ids) + # TODO - merge image / with wte embeddings residual = None else: assert intermediate_tensors is not None @@ -718,8 +690,39 @@ def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data=multi_modal_data) def input_mapper_for_qwen(ctx: InputContext, data: object): - raise NotImplementedError("Need to implement the input mapper") - + model_config = ctx.model_config + tokenizer = cached_get_tokenizer(model_config.tokenizer, + trust_remote_code=True) + image_pair_tok = tokenizer.encode(IMG_START+IMG_END, add_special_tokens=False, return_tensors="pt").squeeze() + image_start_id = image_pair_tok[0] + image_end_id = image_pair_tok[-1] + assert (image_start_id + 1) == image_end_id + assert len(image_pair_tok) == (MAX_QWEN_IMG_TOKENS + 2) + + # Apply the normalization transform to the PIL Image + hf_config = ctx.get_hf_config() + image_size = hf_config.visual["image_size"] + transform = build_normalization_transform(image_size) + transformed_images = [transform(data)] + + return MultiModalInputs({ + "pixel_values": torch.stack(transformed_images, dim=0) + }) + +def build_normalization_transform(image_size): + # Currently, normalized image tensors are of shape: (batch, 3, image_size, image_size), + # which is usually [1, 3, 448, 448], where batch is single dimension since we don't handle + # multiimage inputs yet. + mean = (0.48145466, 0.4578275, 0.40821073) + std = (0.26862954, 0.26130258, 0.27577711) + return transforms.Compose([ + transforms.Resize( + (image_size, image_size), + interpolation=InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) @MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) @@ -753,10 +756,10 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - pixel_values: torch.Tensor = None, + pixel_values = None, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, pixel_values) return hidden_states def make_empty_intermediate_tensors( From a709dd950b56f5f6330d716fde356ca4cdcd8776 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 27 Aug 2024 17:51:13 -0400 Subject: [PATCH 06/35] Hacky integration of img pos / merging for qwen-vl Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 909716a46037..baf7cb277d26 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -617,13 +617,19 @@ def forward( intermediate_tensors: Optional[IntermediateTensors], pixel_values: Optional[torch.Tensor]=None, ) -> torch.Tensor: + img_pos, images = None, None if pixel_values is not None: - # TODO: get the positions of the image tags from the text images = self.visual(pixel_values) + img_pos = get_image_positions(input_ids) + assert img_pos is not None # TODO - compare with image / batch len if get_pp_group().is_first_rank: hidden_states = self.wte(input_ids) - # TODO - merge image / with wte embeddings + # TODO - make sure batch size etc is properly handled, + # refactor to use common multimodal embedding merging utils + if images is not None and img_pos is not None: + for idx, (img_bos, img_eos) in enumerate(img_pos): + hidden_states[img_bos + 1 : img_eos] = images[idx] residual = None else: assert intermediate_tensors is not None @@ -646,6 +652,20 @@ def forward( hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states +def get_image_positions(input_ids, image_start_id=151857): + # HACK - hardcoded IDs for now to test qwen-vl/chat + image_pad_id = image_start_id + 2 + image_end_id = image_start_id + 1 + if torch.any(input_ids == image_start_id): + bos_pos = torch.where(input_ids == image_start_id) + eos_pos = torch.where(input_ids == image_end_id) + print("BOS: {}".format(bos_pos)) # BOS: (tensor([11], device='cuda:0'),) + print("EOS: {}".format(eos_pos)) # EOS: (tensor([268], device='cuda:0'),) + print("tok stack: {}".format( torch.stack((bos_pos[0], bos_pos[0]), dim=1))) + return torch.stack((bos_pos[0], eos_pos[0]), dim=1) + return None + + def get_image_text(image_num: int, padding: bool) -> str: """Retrieves a placeholder text that when tokenized, will be expanded with image pads. From 59339d25c18a155ef4d8807fa952d4ae587b1efe Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 27 Aug 2024 18:52:12 -0400 Subject: [PATCH 07/35] Add multimodal dummy data for qwen models Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 44 ++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index baf7cb277d26..eaa7ff0df843 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -4,7 +4,7 @@ # Copyright (c) Alibaba Cloud. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE """Inference-only QWen model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Mapping import torch from torch import nn @@ -32,7 +32,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.image import cached_get_tokenizer -from vllm.sequence import IntermediateTensors +from vllm.sequence import IntermediateTensors, SequenceData from .utils import is_pp_missing_parameter, make_layers @@ -46,7 +46,6 @@ # Qwen images are encoded into a fixed token length of 256, not include IMG_START/IMG_END MAX_QWEN_IMG_TOKENS = 256 -### This is a directly copy paste of the qwen visual modeling code for now from collections import OrderedDict import math import requests @@ -659,9 +658,6 @@ def get_image_positions(input_ids, image_start_id=151857): if torch.any(input_ids == image_start_id): bos_pos = torch.where(input_ids == image_start_id) eos_pos = torch.where(input_ids == image_end_id) - print("BOS: {}".format(bos_pos)) # BOS: (tensor([11], device='cuda:0'),) - print("EOS: {}".format(eos_pos)) # EOS: (tensor([268], device='cuda:0'),) - print("tok stack: {}".format( torch.stack((bos_pos[0], bos_pos[0]), dim=1))) return torch.stack((bos_pos[0], eos_pos[0]), dim=1) return None @@ -744,8 +740,44 @@ def build_normalization_transform(image_size): transforms.Normalize(mean=mean, std=std), ]) + +def dummy_data_for_qwen(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + hf_config = ctx.get_hf_config() + + # The presence of a visual config indicates this is a multimodal model. + # If we don't have it, the model is considered an LLM for warmup purposes. + if not hasattr(hf_config, "visual"): + print("Using text data to warmup") + seq_data = SequenceData([0] * seq_len) + mm_data = None + return seq_data, mm_data + + # We have a visual component! Use images to warm up + num_images = mm_counts["image"] + image_feature_size = MAX_QWEN_IMG_TOKENS + model_config = ctx.model_config + + print("Using multimodal data to warmup") + tokenizer = cached_get_tokenizer(model_config.tokenizer, + trust_remote_code=True) + # Encode an image pair for each image. During the encoding, qwen tokenizers will add + # image pads between the start/end. We leave this to the tokenizer, because we need + # to rely on the number of added pads at inference time. + seq_data = SequenceData(tokenizer.encode( + (IMG_START+IMG_END) * num_images, add_special_tokens=False, return_tensors="pt" + )[0].tolist()) + assert seq_data.get_len() == ((2 + MAX_QWEN_IMG_TOKENS) * num_images) + + # Build the input images; width/height doesn't actually matter here since the + # data will get resized, and the number of tokens per image is constant per model. + image = Image.new("RGB", (224, 224), color=0) + mm_data = {"image": image if num_images == 1 else [image] * num_images} + return seq_data, mm_data + @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) @MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) class QWenLMHeadModel(nn.Module, SupportsMultiModal): From d6e3ca462fdc07e4bfb6e1111e95cba9c3bc97d9 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 28 Aug 2024 02:28:45 -0400 Subject: [PATCH 08/35] Conditionally enable visual component to support qwen llm-only models Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index eaa7ff0df843..e2c237902bb4 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -35,7 +35,9 @@ from vllm.sequence import IntermediateTensors, SequenceData from .utils import is_pp_missing_parameter, make_layers +from vllm.logger import init_logger +logger = init_logger(__name__) IMG_START = "" IMG_END = "" IMG_PAD = "" @@ -605,7 +607,7 @@ def __init__( lambda prefix: QWenBlock(config, cache_config, quant_config), prefix=f"{prefix}.h") self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.visual = VisionTransformer(**config.visual) + self.visual = VisionTransformer(**config.visual) if hasattr(config, "visual") else None def forward( self, @@ -617,6 +619,7 @@ def forward( pixel_values: Optional[torch.Tensor]=None, ) -> torch.Tensor: img_pos, images = None, None + # If pixel values are provided, this is a visual model, because filter in the input processor if pixel_values is not None: images = self.visual(pixel_values) img_pos = get_image_positions(input_ids) @@ -676,8 +679,11 @@ def get_image_text(image_num: int, padding: bool) -> str: def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data = llm_inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: + # Only process images if we have multimodal data and a visual config + hf_config = ctx.get_hf_config() + if multi_modal_data is None or "image" not in multi_modal_data or not hasattr(hf_config, "visual"): return llm_inputs + prompt = llm_inputs.get("prompt") prompt_token_ids = llm_inputs["prompt_token_ids"] model_config = ctx.model_config @@ -706,6 +712,13 @@ def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data=multi_modal_data) def input_mapper_for_qwen(ctx: InputContext, data: object): + # Early exit if we have provided an image to a language only Qwen model + hf_config = ctx.get_hf_config() + if not hasattr(hf_config, "visual"): + logger.warning("Images were provided but this model has no visual config; " + "multimodal inputs will not be forwarded to the model.") + return MultiModalInputs() + model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) @@ -748,19 +761,17 @@ def dummy_data_for_qwen(ctx: InputContext, seq_len: int, # The presence of a visual config indicates this is a multimodal model. # If we don't have it, the model is considered an LLM for warmup purposes. if not hasattr(hf_config, "visual"): - print("Using text data to warmup") seq_data = SequenceData([0] * seq_len) mm_data = None return seq_data, mm_data - # We have a visual component! Use images to warm up + # We have a visual component - use images to warm up num_images = mm_counts["image"] image_feature_size = MAX_QWEN_IMG_TOKENS model_config = ctx.model_config - - print("Using multimodal data to warmup") tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) + # Encode an image pair for each image. During the encoding, qwen tokenizers will add # image pads between the start/end. We leave this to the tokenizer, because we need # to rely on the number of added pads at inference time. From 8e27aa2a384570a66784c3d2efd69b6b790b5395 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 28 Aug 2024 02:42:55 -0400 Subject: [PATCH 09/35] Fix hardcoded image start ID in image position extraction Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index e2c237902bb4..3651d2a5752e 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -375,6 +375,7 @@ def __init__( mlp_ratio: float, n_queries: int = 256, output_dim: int = 512, + image_start_id: int = 151857, **kwargs ): super().__init__() @@ -410,6 +411,8 @@ def __init__( ) self.ln_post = norm_layer(output_dim) self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim)) + self.image_start_id = image_start_id + self.image_end_id = image_start_id + 1 def forward(self, x: torch.Tensor): x = x.to( @@ -436,6 +439,13 @@ def forward(self, x: torch.Tensor): return x + def get_image_positions(self, input_ids): + if torch.any(input_ids == self.image_start_id): + bos_pos = torch.where(input_ids == self.image_start_id) + eos_pos = torch.where(input_ids == self.image_end_id) + return torch.stack((bos_pos[0], eos_pos[0]), dim=1) + return None + class QWenMLP(nn.Module): @@ -622,7 +632,7 @@ def forward( # If pixel values are provided, this is a visual model, because filter in the input processor if pixel_values is not None: images = self.visual(pixel_values) - img_pos = get_image_positions(input_ids) + img_pos = self.visual.get_image_positions(input_ids) assert img_pos is not None # TODO - compare with image / batch len if get_pp_group().is_first_rank: @@ -654,16 +664,6 @@ def forward( hidden_states, _ = self.ln_f(hidden_states, residual) return hidden_states -def get_image_positions(input_ids, image_start_id=151857): - # HACK - hardcoded IDs for now to test qwen-vl/chat - image_pad_id = image_start_id + 2 - image_end_id = image_start_id + 1 - if torch.any(input_ids == image_start_id): - bos_pos = torch.where(input_ids == image_start_id) - eos_pos = torch.where(input_ids == image_end_id) - return torch.stack((bos_pos[0], eos_pos[0]), dim=1) - return None - def get_image_text(image_num: int, padding: bool) -> str: """Retrieves a placeholder text that when tokenized, will be expanded with image pads. From f535d618786c98301845b821d0d81279592994a9 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 28 Aug 2024 04:08:08 -0400 Subject: [PATCH 10/35] Enable chat for qwen-vl Signed-off-by: Alex-Brooks --- vllm/entrypoints/chat_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c70c6d9330b1..62131d8ec420 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -133,7 +133,9 @@ def add(self, modality: Literal["image", "audio"], return MultiModalItemTracker._cached_token_str( self._tokenizer, self._model_config.hf_config.image_token_index) - if model_type in ("chameleon", "internvl_chat"): + # NOTE: qwen models do not use normally, but input + # processor will expand it to the expected format + if model_type in ("chameleon", "internvl_chat", "qwen"): return "" raise TypeError(f"Unknown model type: {model_type}") From b10b73cd1bb23b8afdb756b51c5ffd4ac62fffed Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 28 Aug 2024 13:32:48 -0400 Subject: [PATCH 11/35] Improve validation, add qwen single image embed support Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 126 +++++++++++++++++++++++------ 1 file changed, 100 insertions(+), 26 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 3651d2a5752e..847462f87d17 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -4,7 +4,7 @@ # Copyright (c) Alibaba Cloud. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE """Inference-only QWen model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Mapping +from typing import Any, Dict, Iterable, List, Optional, Tuple, Mapping, TypedDict, Literal, Union import torch from torch import nn @@ -65,6 +65,31 @@ from torchvision.transforms import InterpolationMode + +class QwenImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """ + Shape: `(# images, 3, image_size, image_size)` + + Note that image_size is the value in the vision config to which we resize + the image to in the normalization transform. Currently multi-image support + can only be leveraged by passing image embeddings directly. + """ + + +class QwenImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(# images, 256, hidden_size)` + + `hidden_size` must match the hidden size of the language model backbone + and is stored in the visual config of the model if we have one. + """ + + +QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs] + def get_abs_pos(abs_pos, tgt_size): # abs_pos: L, C # tgt_size: M @@ -626,22 +651,29 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], - pixel_values: Optional[torch.Tensor]=None, + pixel_values: Optional[QwenImageInputs], ) -> torch.Tensor: - img_pos, images = None, None - # If pixel values are provided, this is a visual model, because filter in the input processor + img_pos = None + # If pixel / visual embeddings are provided, this is a visual model since we filter inputs if pixel_values is not None: - images = self.visual(pixel_values) + if pixel_values["type"] != "image_embeds": + image_embeds = self.visual(pixel_values["data"]) + else: + image_embeds = pixel_values["data"] + + # features should be of shape (# images, 256, hidden_dim) img_pos = self.visual.get_image_positions(input_ids) - assert img_pos is not None # TODO - compare with image / batch len + if img_pos.shape[0] != image_embeds.shape[0]: + raise ValueError(f"Number of placeholders: {img_pos.shape[0]} " + f"does not match the number of images {image_embeds.shape[0]}.") if get_pp_group().is_first_rank: hidden_states = self.wte(input_ids) - # TODO - make sure batch size etc is properly handled, - # refactor to use common multimodal embedding merging utils - if images is not None and img_pos is not None: + # Merge the image embeddings into the hidden states if actually have + # visual features and the corresponding image tokens + if img_pos is not None: for idx, (img_bos, img_eos) in enumerate(img_pos): - hidden_states[img_bos + 1 : img_eos] = images[idx] + hidden_states[img_bos + 1 : img_eos] = image_embeds[idx] residual = None else: assert intermediate_tensors is not None @@ -679,6 +711,7 @@ def get_image_text(image_num: int, padding: bool) -> str: def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data = llm_inputs.get("multi_modal_data") + # Only process images if we have multimodal data and a visual config hf_config = ctx.get_hf_config() if multi_modal_data is None or "image" not in multi_modal_data or not hasattr(hf_config, "visual"): @@ -689,20 +722,22 @@ def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) image_data = multi_modal_data["image"] + if isinstance(image_data, torch.Tensor): + print("Processing image embed of shape {}".format(image_data.shape)) + else: + print("Processing image embed of type {}".format(type(image_data))) + if prompt is None: prompt = tokenizer.decode(prompt_token_ids) - if not isinstance(image_data, Image.Image): - raise NotImplementedError("Image supported not yet implemented for directly provided image features yet") - # Replace the image tag with the image prompt with no img pads. We currently do this in # two steps to sidestep some tokenization substitution stuff with URLs behind handled as bytes # that do not like existing image pads strings, but it would be nice to find a better way to # handle it. + # TODO - handle multi-image embeddings image_prompt_without_padding = get_image_text(0, padding=False) image_prompt_with_padding = get_image_text(0, padding=True) - new_prompt_no_img_pads = prompt.replace('', image_prompt_without_padding, 1) new_prompt_with_img_pads = prompt.replace('', image_prompt_with_padding, 1) new_prompt_token_ids = tokenizer.encode(new_prompt_no_img_pads) @@ -722,26 +757,50 @@ def input_mapper_for_qwen(ctx: InputContext, data: object): model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) - image_pair_tok = tokenizer.encode(IMG_START+IMG_END, add_special_tokens=False, return_tensors="pt").squeeze() + + image_pair_tok = tokenizer.encode( + IMG_START+IMG_END, + add_special_tokens=False, + return_tensors="pt").squeeze() image_start_id = image_pair_tok[0] image_end_id = image_pair_tok[-1] - assert (image_start_id + 1) == image_end_id - assert len(image_pair_tok) == (MAX_QWEN_IMG_TOKENS + 2) + if (image_start_id + 1) != image_end_id: + raise ValueError(f"Found image end ID {image_end_id}, but expected ID {IMG_START} + 1") + if len(image_pair_tok) != (MAX_QWEN_IMG_TOKENS + 2): + raise ValueError( + f"Expected image context length of {MAX_QWEN_IMG_TOKENS}, but got {image_pair_tok - 2}" + ) - # Apply the normalization transform to the PIL Image hf_config = ctx.get_hf_config() image_size = hf_config.visual["image_size"] - transform = build_normalization_transform(image_size) - transformed_images = [transform(data)] - + img_emb_size = hf_config.visual["output_dim"] + + if isinstance(data, torch.Tensor): + # It's expected that our values have already been processed + # by the visual transformer; shape is expected to be: + # (# images, 256, hidden_size) + if len(data.shape) == 2: + # Assume only one image embed was provided; unsqueeze the extra dim + data = data.unsqueeze(0) + if len(data.shape) != 3 or data.shape[1] != MAX_QWEN_IMG_TOKENS or data.shape[2] != img_emb_size: + raise ValueError("Expected img_embeds to be a tensor of shape" + f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but received " + f"shape [{pixel_values.shape}]") + pixel_values = data + + else: + transform = build_normalization_transform(image_size) + # TODO - handle multiple image inputs once the API is solidified + transformed_images = [transform(data)] + pixel_values = torch.stack(transformed_images, dim=0) return MultiModalInputs({ - "pixel_values": torch.stack(transformed_images, dim=0) + "pixel_values": pixel_values }) def build_normalization_transform(image_size): - # Currently, normalized image tensors are of shape: (batch, 3, image_size, image_size), - # which is usually [1, 3, 448, 448], where batch is single dimension since we don't handle - # multiimage inputs yet. + """Builds a normalization transform which can be applied to one or more input images + from which we want to extract visual features. + """ mean = (0.48145466, 0.4578275, 0.40821073) std = (0.26862954, 0.26130258, 0.27577711) return transforms.Compose([ @@ -812,6 +871,20 @@ def __init__( self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + def _get_image_input_type(self, pixel_values: Optional[torch.Tensor]) -> Optional[QwenImageInputs]: + if pixel_values is not None and self.transformer.visual is not None: + if len(pixel_values.shape) == 3 and pixel_values.shape[1] == MAX_QWEN_IMG_TOKENS and pixel_values.shape[2] == self.config.visual["output_dim"]: + return QwenImageEmbeddingInputs( + type="image_embeds", + data=pixel_values, + ) + else: + # if we don't have the right embedding shape, assume we need to process still + return QwenImagePixelInputs( + type="pixel_values", + data=pixel_values, + ) + def forward( self, input_ids: torch.Tensor, @@ -819,8 +892,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - pixel_values = None, + pixel_values: Optional[torch.Tensor] = None ) -> torch.Tensor: + pixel_values = self._get_image_input_type(pixel_values) hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, pixel_values) return hidden_states From 0611f196d343c9675f954650971009cfc5b9f8ad Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 28 Aug 2024 17:14:29 -0400 Subject: [PATCH 12/35] Tentative support for multi-image embeddings in qwen Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 34 +++++++++++++++++++----------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 847462f87d17..7227c9a037e2 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -723,23 +723,33 @@ def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) image_data = multi_modal_data["image"] if isinstance(image_data, torch.Tensor): - print("Processing image embed of shape {}".format(image_data.shape)) + if len(image_data.shape) < 2 or len(image_data.shape) > 3: + raise ValueError( + f"Expected image embeds to be have 3 dimensions but got {len(image_data.shape)}" + ) + num_images = 1 if len(image_data.shape) == 2 else image_data.shape[0] else: - print("Processing image embed of type {}".format(type(image_data))) + num_images = 1 if prompt is None: prompt = tokenizer.decode(prompt_token_ids) - # Replace the image tag with the image prompt with no img pads. We currently do this in - # two steps to sidestep some tokenization substitution stuff with URLs behind handled as bytes - # that do not like existing image pads strings, but it would be nice to find a better way to - # handle it. - # TODO - handle multi-image embeddings - image_prompt_without_padding = get_image_text(0, padding=False) - image_prompt_with_padding = get_image_text(0, padding=True) - new_prompt_no_img_pads = prompt.replace('', image_prompt_without_padding, 1) - new_prompt_with_img_pads = prompt.replace('', image_prompt_with_padding, 1) + # Iteratively replace image tags for every image that we expect + num_img_tags = prompt.count("") + + if num_img_tags != num_images: + logger.warning("Number of tokens does not match the number of provided images!") + + # Only replace as many image tags as we are going to be able to process correctly + # Sequentially replace image tags; padding shenanigans are mostly to sidestep + # url encoding logic in the tokenizer + new_prompt_no_img_pads = new_prompt_with_img_pads = prompt + for img_num in range(min(num_images, num_img_tags)): + image_prompt_without_padding = get_image_text(img_num, padding=False) + image_prompt_with_padding = get_image_text(img_num, padding=True) + new_prompt_no_img_pads = new_prompt_no_img_pads.replace('', image_prompt_without_padding, 1) + new_prompt_with_img_pads = new_prompt_with_img_pads.replace('', image_prompt_with_padding, 1) new_prompt_token_ids = tokenizer.encode(new_prompt_no_img_pads) return LLMInputs(prompt=new_prompt_with_img_pads, @@ -783,7 +793,7 @@ def input_mapper_for_qwen(ctx: InputContext, data: object): # Assume only one image embed was provided; unsqueeze the extra dim data = data.unsqueeze(0) if len(data.shape) != 3 or data.shape[1] != MAX_QWEN_IMG_TOKENS or data.shape[2] != img_emb_size: - raise ValueError("Expected img_embeds to be a tensor of shape" + raise ValueError("Expected image embeds to be a tensor of shape" f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but received " f"shape [{pixel_values.shape}]") pixel_values = data From 541b7b54d44aa3f251d09ed5541ba496f788ddd4 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 28 Aug 2024 17:32:04 -0400 Subject: [PATCH 13/35] Add example for qwen vl offline inference Signed-off-by: Alex-Brooks --- examples/offline_inference_vision_language.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 9a0e9d4bc536..a05a4a394284 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -159,6 +159,15 @@ def run_blip2(question): return llm, prompt, stop_token_ids +# Qwen +def run_qwen_vl(question): + + llm = LLM(model="Qwen/Qwen-VL", trust_remote_code=True) + prompt = f"{question}" + stop_token_ids = None + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -169,6 +178,7 @@ def run_blip2(question): "minicpmv": run_minicpmv, "blip-2": run_blip2, "internvl_chat": run_internvl, + "qwen_vl": run_qwen_vl, } From 1db0e6df7683e13f2490b255e48201b8cea81616 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 28 Aug 2024 17:52:43 -0400 Subject: [PATCH 14/35] run formatting and linting Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 415 ++++++++++++++++------------- 1 file changed, 224 insertions(+), 191 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 7227c9a037e2..b52df1a1337e 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -4,16 +4,28 @@ # Copyright (c) Alibaba Cloud. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE """Inference-only QWen model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple, Mapping, TypedDict, Literal, Union +import math +from collections import OrderedDict +from functools import partial +from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, + Optional, Tuple, TypedDict, Union) + +import numpy as np import torch +from PIL import Image from torch import nn +from torch.nn import functional as F +from torch.nn.init import trunc_normal_ +from torchvision import transforms +from torchvision.transforms import InterpolationMode from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -35,7 +47,6 @@ from vllm.sequence import IntermediateTensors, SequenceData from .utils import is_pp_missing_parameter, make_layers -from vllm.logger import init_logger logger = init_logger(__name__) IMG_START = "" @@ -45,26 +56,9 @@ # for the time being, these tags are not considered as special at encoding # time. This may change as VLLMs multimodal API changes in the future. -# Qwen images are encoded into a fixed token length of 256, not include IMG_START/IMG_END +# Qwen images are encoded into a fixed context of 256 MAX_QWEN_IMG_TOKENS = 256 -from collections import OrderedDict -import math -import requests -from io import BytesIO -from functools import partial -from PIL import Image -from typing import Callable, Optional, Sequence, Tuple, List -import numpy as np - -import torch -from torch import nn -from torch.nn import functional as F -from torch.nn.init import trunc_normal_ -from torchvision import transforms -from torchvision.transforms import InterpolationMode - - class QwenImagePixelInputs(TypedDict): type: Literal["pixel_values"] @@ -90,6 +84,7 @@ class QwenImageEmbeddingInputs(TypedDict): QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs] + def get_abs_pos(abs_pos, tgt_size): # abs_pos: L, C # tgt_size: M @@ -100,7 +95,8 @@ def get_abs_pos(abs_pos, tgt_size): if src_size != tgt_size: return F.interpolate( - abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), + abs_pos.float().reshape(1, src_size, src_size, + -1).permute(0, 3, 1, 2), size=(tgt_size, tgt_size), mode="bicubic", align_corners=False, @@ -108,12 +104,14 @@ def get_abs_pos(abs_pos, tgt_size): else: return abs_pos + # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ grid_size: int of the grid height and width return: - pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) @@ -123,7 +121,8 @@ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], + axis=0) return pos_embed @@ -131,10 +130,12 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, + grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, + grid[1]) # (H*W, D/2) - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb @@ -152,8 +153,8 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): pos = pos.reshape(-1) # (M,) out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb @@ -166,25 +167,26 @@ class Resampler(nn.Module): Outputs: A tensor with the shape of (grid_size**2, embed_dim) """ - def __init__( - self, - grid_size, - embed_dim, - num_heads, - kv_dim=None, - norm_layer=nn.LayerNorm - ): + + def __init__(self, + grid_size, + embed_dim, + num_heads, + kv_dim=None, + norm_layer=nn.LayerNorm): super().__init__() - self.num_queries = grid_size ** 2 + self.num_queries = grid_size**2 self.embed_dim = embed_dim self.num_heads = num_heads self.pos_embed = nn.Parameter( - # TODO - fix the hacks for device / dtype here & in the positional embedding retrieval - torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).half().to("cuda"), - ).requires_grad_(False) + # TODO - fix the hacks for device / dtype here & in the + # positional embedding retrieval + torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size) + ).half().to("cuda"), ).requires_grad_(False) - self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim).to("cuda")) + self.query = nn.Parameter( + torch.zeros(self.num_queries, embed_dim).to("cuda")) trunc_normal_(self.query, std=.02) if kv_dim is not None and kv_dim != embed_dim: @@ -196,7 +198,6 @@ def __init__( self.ln_q = norm_layer(embed_dim) self.ln_kv = norm_layer(embed_dim) - def forward(self, x, attn_mask=None): pos_embed = get_abs_pos(self.pos_embed, x.size(1)) @@ -205,11 +206,10 @@ def forward(self, x, attn_mask=None): N = x.shape[1] q = self.ln_q(self.query) - out = self.attn( - self._repeat(q, N) + self.pos_embed.unsqueeze(1), - x + pos_embed.unsqueeze(1), - x, - attn_mask=attn_mask)[0] + out = self.attn(self._repeat(q, N) + self.pos_embed.unsqueeze(1), + x + pos_embed.unsqueeze(1), + x, + attn_mask=attn_mask)[0] return out.permute(1, 0, 2) def _repeat(self, query, N: int): @@ -222,13 +222,13 @@ class VisualAttention(nn.Module): and returns output of the same size. """ - def __init__(self, embed_dim, num_heads, - bias=True, kdim=None, vdim=None): + def __init__(self, embed_dim, num_heads, bias=True, kdim=None, vdim=None): super(VisualAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim \ + and self.vdim == embed_dim self.num_heads = num_heads @@ -244,11 +244,12 @@ def __init__(self, embed_dim, num_heads, self.out_proj = nn.Linear(embed_dim, embed_dim) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - def forward(self, query, key, value, attn_mask = None): + def forward(self, query, key, value, attn_mask=None): # query/key/value: [sq, b, h] sq, b, _ = query.size() - assert torch.allclose(query, key), 'Only Support Self-Attention Currently' + assert torch.allclose(query, + key), 'Only Support Self-Attention Currently' sk = sq mixed_x_layer = self.in_proj(query) @@ -263,32 +264,33 @@ def forward(self, query, key, value, attn_mask = None): self.hidden_size_per_attention_head, dim=-1) # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view(sq, - b * self.num_attention_heads_per_partition, + query_layer = query_layer.view( + sq, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head).transpose(0, 1) # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(sk, - b * self.num_attention_heads_per_partition, + key_layer = key_layer.view( + sk, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head).transpose(0, 1) q_scaled = query_layer / self.norm_factor if attn_mask is not None: - attention_probs = torch.baddbmm(attn_mask, q_scaled, key_layer.transpose(-2, -1)) + attention_probs = torch.baddbmm(attn_mask, q_scaled, + key_layer.transpose(-2, -1)) else: attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1)) attention_probs = attention_probs.softmax(dim=-1) - value_layer = value_layer.view(sk, - b * self.num_attention_heads_per_partition, + value_layer = value_layer.view( + sk, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head).transpose(0, 1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer) # change view [b, np, sq, hn] - context_layer = context_layer.view(b, - self.num_attention_heads_per_partition, - sq, self.hidden_size_per_attention_head) + context_layer = context_layer.view( + b, self.num_attention_heads_per_partition, sq, + self.hidden_size_per_attention_head) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() @@ -304,14 +306,15 @@ def forward(self, query, key, value, attn_mask = None): class VisualAttentionBlock(nn.Module): + def __init__( - self, - d_model: int, - n_head: int, - mlp_ratio: float = 4.0, - act_layer: Callable = nn.GELU, - norm_layer: Callable = nn.LayerNorm, - is_cross_attention: bool = False, + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + is_cross_attention: bool = False, ): super().__init__() @@ -322,18 +325,17 @@ def __init__( self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) self.attn = VisualAttention(d_model, n_head) - self.mlp = nn.Sequential(OrderedDict([ - ("c_fc", nn.Linear(d_model, mlp_width)), - ("gelu", act_layer()), - ("c_proj", nn.Linear(mlp_width, d_model)) - ])) + self.mlp = nn.Sequential( + OrderedDict([("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model))])) def attention( - self, - q_x: torch.Tensor, - k_x: Optional[torch.Tensor] = None, - v_x: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None, + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, ): k_x = k_x if k_x is not None else q_x v_x = v_x if v_x is not None else q_x @@ -342,38 +344,44 @@ def attention( return self.attn(q_x, k_x, v_x, attn_mask=attn_mask) def forward( - self, - q_x: torch.Tensor, - k_x: Optional[torch.Tensor] = None, - v_x: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None, + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, ): - k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None - v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None + k_x = self.ln_1_kv(k_x) if hasattr( + self, "ln_1_kv") and k_x is not None else None + v_x = self.ln_1_kv(v_x) if hasattr( + self, "ln_1_kv") and v_x is not None else None - x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask) + x = q_x + self.attention( + q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask) x = x + self.mlp(self.ln_2(x)) return x class TransformerBlock(nn.Module): + def __init__( - self, - width: int, - layers: int, - heads: int, - mlp_ratio: float = 4.0, - act_layer: Callable = nn.GELU, - norm_layer: Callable = nn.LayerNorm, + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, ): super().__init__() self.width = width self.layers = layers self.resblocks = nn.ModuleList([ - VisualAttentionBlock( - width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer) - for _ in range(layers) + VisualAttentionBlock(width, + heads, + mlp_ratio, + act_layer=act_layer, + norm_layer=norm_layer) for _ in range(layers) ]) def get_cast_dtype(self) -> torch.dtype: @@ -382,7 +390,9 @@ def get_cast_dtype(self) -> torch.dtype: def get_cast_device(self) -> torch.device: return self.resblocks[0].mlp.c_fc.weight.device - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + def forward(self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None): for r in self.resblocks: x = r(x, attn_mask=attn_mask) return x @@ -390,29 +400,33 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): class VisionTransformer(nn.Module): - def __init__( - self, - image_size: int, - patch_size: int, - width: int, - layers: int, - heads: int, - mlp_ratio: float, - n_queries: int = 256, - output_dim: int = 512, - image_start_id: int = 151857, - **kwargs - ): + def __init__(self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + n_queries: int = 256, + output_dim: int = 512, + image_start_id: int = 151857, + **kwargs): super().__init__() image_height, image_width = self.image_size = (image_size, image_size) patch_height, patch_width = self.patch_size = (patch_size, patch_size) - self.grid_size = (image_height // patch_height, image_width // patch_width) + self.grid_size = (image_height // patch_height, + image_width // patch_width) self.output_dim = output_dim - self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + self.conv1 = nn.Conv2d(in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) # class embeddings and positional embeddings - scale = width ** -0.5 - self.positional_embedding = nn.Parameter(scale * torch.randn(256, width)) + scale = width**-0.5 + self.positional_embedding = nn.Parameter(scale * + torch.randn(256, width)) norm_layer = partial(nn.LayerNorm, eps=1e-6) act_layer = nn.GELU @@ -435,19 +449,21 @@ def __init__( norm_layer=norm_layer, ) self.ln_post = norm_layer(output_dim) - self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim)) + self.proj = nn.Parameter( + (output_dim**-0.5) * torch.randn(output_dim, output_dim)) self.image_start_id = image_start_id self.image_end_id = image_start_id + 1 def forward(self, x: torch.Tensor): x = x.to( - dtype=self.transformer.get_cast_dtype(), - device=self.transformer.get_cast_device(), + dtype=self.transformer.get_cast_dtype(), + device=self.transformer.get_cast_device(), ) # to patches x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] x = x + get_abs_pos(self.positional_embedding, x.size(1)) @@ -642,7 +658,8 @@ def __init__( lambda prefix: QWenBlock(config, cache_config, quant_config), prefix=f"{prefix}.h") self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.visual = VisionTransformer(**config.visual) if hasattr(config, "visual") else None + self.visual = VisionTransformer( + **config.visual) if hasattr(config, "visual") else None def forward( self, @@ -654,7 +671,7 @@ def forward( pixel_values: Optional[QwenImageInputs], ) -> torch.Tensor: img_pos = None - # If pixel / visual embeddings are provided, this is a visual model since we filter inputs + # If pixel / visual embeddings are provided, this is a visual model if pixel_values is not None: if pixel_values["type"] != "image_embeds": image_embeds = self.visual(pixel_values["data"]) @@ -664,8 +681,10 @@ def forward( # features should be of shape (# images, 256, hidden_dim) img_pos = self.visual.get_image_positions(input_ids) if img_pos.shape[0] != image_embeds.shape[0]: - raise ValueError(f"Number of placeholders: {img_pos.shape[0]} " - f"does not match the number of images {image_embeds.shape[0]}.") + raise ValueError( + f"Number of placeholders: {img_pos.shape[0]} " + f"does not match number of images {image_embeds.shape[0]}." + ) if get_pp_group().is_first_rank: hidden_states = self.wte(input_ids) @@ -673,7 +692,7 @@ def forward( # visual features and the corresponding image tokens if img_pos is not None: for idx, (img_bos, img_eos) in enumerate(img_pos): - hidden_states[img_bos + 1 : img_eos] = image_embeds[idx] + hidden_states[img_bos + 1:img_eos] = image_embeds[idx] residual = None else: assert intermediate_tensors is not None @@ -698,40 +717,45 @@ def forward( def get_image_text(image_num: int, padding: bool) -> str: - """Retrieves a placeholder text that when tokenized, will be expanded with image pads. + """Retrieves a placeholder text that when tokenized, will be expanded with + image pads. - NOTE: The reason that the reason we don't directly encode the image padding here is that - it will break the re-encoding of the tokens tokenizer, because the contents between the - start / end are treated as bytes containing a URL that then get padded up to the image context - size. + NOTE: The reason that the reason we don't directly encode the imagepadding + here is that it will break the re-encoding of the tokens tokenizer, + because the contents between the start / end are treated as bytes + containing a URL that then get padded up to the image context size. """ + image_start = f"Picture {image_num}: {IMG_START}" + image_end = f"{IMG_END}\n" if not padding: - return f"Picture {image_num}: {IMG_START}{IMG_END}\n" - return f"Picture {image_num}: {IMG_START}{MAX_QWEN_IMG_TOKENS * IMG_PAD}{IMG_END}\n" + return f"{image_start}{image_end}" + return f"{image_start}{MAX_QWEN_IMG_TOKENS * IMG_PAD}{image_end}" + def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data = llm_inputs.get("multi_modal_data") # Only process images if we have multimodal data and a visual config hf_config = ctx.get_hf_config() - if multi_modal_data is None or "image" not in multi_modal_data or not hasattr(hf_config, "visual"): + if (multi_modal_data is None or "image" not in multi_modal_data + or not hasattr(hf_config, "visual")): return llm_inputs prompt = llm_inputs.get("prompt") prompt_token_ids = llm_inputs["prompt_token_ids"] model_config = ctx.model_config - tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) + tokenizer = cached_get_tokenizer(model_config.tokenizer, + trust_remote_code=True) image_data = multi_modal_data["image"] if isinstance(image_data, torch.Tensor): - if len(image_data.shape) < 2 or len(image_data.shape) > 3: + num_dims = len(image_data.shape) + if num_dims < 2 or num_dims > 3: raise ValueError( - f"Expected image embeds to be have 3 dimensions but got {len(image_data.shape)}" - ) - num_images = 1 if len(image_data.shape) == 2 else image_data.shape[0] + f"Expected img embeds to be have 3 dimensions, got {num_dims}") + num_images = 1 if num_dims == 2 else image_data.shape[0] else: num_images = 1 - if prompt is None: prompt = tokenizer.decode(prompt_token_ids) @@ -739,47 +763,52 @@ def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): num_img_tags = prompt.count("") if num_img_tags != num_images: - logger.warning("Number of tokens does not match the number of provided images!") + logger.warning( + "Number of tokens does not match the number of images") - # Only replace as many image tags as we are going to be able to process correctly - # Sequentially replace image tags; padding shenanigans are mostly to sidestep - # url encoding logic in the tokenizer + # Only replace as many image tags as we are going to be able to process + # correctly. Sequentially replace image tags; padding shenanigans are + # mostly to sidestep url encoding logic in the tokenizer new_prompt_no_img_pads = new_prompt_with_img_pads = prompt for img_num in range(min(num_images, num_img_tags)): image_prompt_without_padding = get_image_text(img_num, padding=False) image_prompt_with_padding = get_image_text(img_num, padding=True) - new_prompt_no_img_pads = new_prompt_no_img_pads.replace('', image_prompt_without_padding, 1) - new_prompt_with_img_pads = new_prompt_with_img_pads.replace('', image_prompt_with_padding, 1) + new_prompt_no_img_pads = new_prompt_no_img_pads.replace( + '', image_prompt_without_padding, 1) + new_prompt_with_img_pads = new_prompt_with_img_pads.replace( + '', image_prompt_with_padding, 1) new_prompt_token_ids = tokenizer.encode(new_prompt_no_img_pads) return LLMInputs(prompt=new_prompt_with_img_pads, prompt_token_ids=new_prompt_token_ids, multi_modal_data=multi_modal_data) + def input_mapper_for_qwen(ctx: InputContext, data: object): # Early exit if we have provided an image to a language only Qwen model hf_config = ctx.get_hf_config() if not hasattr(hf_config, "visual"): - logger.warning("Images were provided but this model has no visual config; " - "multimodal inputs will not be forwarded to the model.") + logger.warning( + "Images were provided but this model has no visual config; " + "multimodal inputs will not be forwarded to the model.") return MultiModalInputs() model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) - image_pair_tok = tokenizer.encode( - IMG_START+IMG_END, - add_special_tokens=False, - return_tensors="pt").squeeze() + image_pair_tok = tokenizer.encode(IMG_START + IMG_END, + add_special_tokens=False, + return_tensors="pt").squeeze() image_start_id = image_pair_tok[0] image_end_id = image_pair_tok[-1] if (image_start_id + 1) != image_end_id: - raise ValueError(f"Found image end ID {image_end_id}, but expected ID {IMG_START} + 1") + raise ValueError( + f"Found image end ID {image_end_id}, but expected {IMG_START} + 1") if len(image_pair_tok) != (MAX_QWEN_IMG_TOKENS + 2): raise ValueError( - f"Expected image context length of {MAX_QWEN_IMG_TOKENS}, but got {image_pair_tok - 2}" - ) + f"Expected image context length of {MAX_QWEN_IMG_TOKENS}, " + f"but got {image_pair_tok - 2}") hf_config = ctx.get_hf_config() image_size = hf_config.visual["image_size"] @@ -792,39 +821,38 @@ def input_mapper_for_qwen(ctx: InputContext, data: object): if len(data.shape) == 2: # Assume only one image embed was provided; unsqueeze the extra dim data = data.unsqueeze(0) - if len(data.shape) != 3 or data.shape[1] != MAX_QWEN_IMG_TOKENS or data.shape[2] != img_emb_size: - raise ValueError("Expected image embeds to be a tensor of shape" - f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but received " - f"shape [{pixel_values.shape}]") + if len(data.shape) != 3 or data.shape[ + 1] != MAX_QWEN_IMG_TOKENS or data.shape[2] != img_emb_size: + raise ValueError( + "Expected image embeds to be a tensor of shape" + f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but " + f"received shape [{data.shape}]") pixel_values = data - + else: transform = build_normalization_transform(image_size) # TODO - handle multiple image inputs once the API is solidified transformed_images = [transform(data)] pixel_values = torch.stack(transformed_images, dim=0) - return MultiModalInputs({ - "pixel_values": pixel_values - }) + return MultiModalInputs({"pixel_values": pixel_values}) + def build_normalization_transform(image_size): - """Builds a normalization transform which can be applied to one or more input images - from which we want to extract visual features. + """Builds a normalization transform which can be applied to one or + more input images from which we want to extract visual features. """ mean = (0.48145466, 0.4578275, 0.40821073) std = (0.26862954, 0.26130258, 0.27577711) return transforms.Compose([ - transforms.Resize( - (image_size, image_size), - interpolation=InterpolationMode.BICUBIC - ), + transforms.Resize((image_size, image_size), + interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) def dummy_data_for_qwen(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): + mm_counts: Mapping[str, int]): hf_config = ctx.get_hf_config() # The presence of a visual config indicates this is a multimodal model. @@ -836,25 +864,27 @@ def dummy_data_for_qwen(ctx: InputContext, seq_len: int, # We have a visual component - use images to warm up num_images = mm_counts["image"] - image_feature_size = MAX_QWEN_IMG_TOKENS model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) - # Encode an image pair for each image. During the encoding, qwen tokenizers will add - # image pads between the start/end. We leave this to the tokenizer, because we need - # to rely on the number of added pads at inference time. - seq_data = SequenceData(tokenizer.encode( - (IMG_START+IMG_END) * num_images, add_special_tokens=False, return_tensors="pt" - )[0].tolist()) + # Encode an image pair for each image. During the encoding, qwen tokenizers + # will add image pads between the start/end. We leave this to the + # tokenizer, because we need to rely on the number of added pads at + # inference time. + seq_data = SequenceData( + tokenizer.encode((IMG_START + IMG_END) * num_images, + add_special_tokens=False, + return_tensors="pt")[0].tolist()) assert seq_data.get_len() == ((2 + MAX_QWEN_IMG_TOKENS) * num_images) - # Build the input images; width/height doesn't actually matter here since the - # data will get resized, and the number of tokens per image is constant per model. + # Build the input images; width/height doesn't actually matter here since + # the data will get resized and the # of tokens per image is constant image = Image.new("RGB", (224, 224), color=0) - mm_data = {"image": image if num_images == 1 else [image] * num_images} + mm_data = {"image": image if num_images == 1 else [image] * num_images} return seq_data, mm_data + @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) @MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) @@ -881,32 +911,35 @@ def __init__( self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() - def _get_image_input_type(self, pixel_values: Optional[torch.Tensor]) -> Optional[QwenImageInputs]: + def _get_image_input_type( + self, + pixel_values: Optional[torch.Tensor]) -> Optional[QwenImageInputs]: if pixel_values is not None and self.transformer.visual is not None: - if len(pixel_values.shape) == 3 and pixel_values.shape[1] == MAX_QWEN_IMG_TOKENS and pixel_values.shape[2] == self.config.visual["output_dim"]: + if len(pixel_values.shape) == 3 and pixel_values.shape[ + 1] == MAX_QWEN_IMG_TOKENS and pixel_values.shape[ + 2] == self.config.visual["output_dim"]: return QwenImageEmbeddingInputs( type="image_embeds", data=pixel_values, ) else: - # if we don't have the right embedding shape, assume we need to process still + # If we have the wrong shape, assume we still need to process return QwenImagePixelInputs( type="pixel_values", data=pixel_values, ) - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - pixel_values: Optional[torch.Tensor] = None - ) -> torch.Tensor: + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor: pixel_values = self._get_image_input_type(pixel_values) hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors, pixel_values) + attn_metadata, intermediate_tensors, + pixel_values) return hidden_states def make_empty_intermediate_tensors( From ed3e15d0063800ca9a61aebcba2ea4945121a85a Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 29 Aug 2024 17:47:11 -0400 Subject: [PATCH 15/35] Fix bug in image token input processing Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index b52df1a1337e..513f05846553 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -770,7 +770,7 @@ def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): # correctly. Sequentially replace image tags; padding shenanigans are # mostly to sidestep url encoding logic in the tokenizer new_prompt_no_img_pads = new_prompt_with_img_pads = prompt - for img_num in range(min(num_images, num_img_tags)): + for img_num in range(1, min(num_images, num_img_tags) + 1): image_prompt_without_padding = get_image_text(img_num, padding=False) image_prompt_with_padding = get_image_text(img_num, padding=True) new_prompt_no_img_pads = new_prompt_no_img_pads.replace( From cadabb7e01103ca868e77406d145da452b082150 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 29 Aug 2024 18:57:05 -0400 Subject: [PATCH 16/35] Qwen - add comments, error handling in warmup Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 136 +++++++++++++++++++++++------ 1 file changed, 107 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 513f05846553..4ea75c27e693 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -49,15 +49,18 @@ from .utils import is_pp_missing_parameter, make_layers logger = init_logger(__name__) + +# NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad; +# for the time being, these tags are not considered as special at encoding +# time. This may change as VLLMs multimodal API changes in the future. IMG_START = "" IMG_END = "" IMG_PAD = "" -# Qwen models have a few other special tags, e.g., ref, bbox, quad; -# for the time being, these tags are not considered as special at encoding -# time. This may change as VLLMs multimodal API changes in the future. - -# Qwen images are encoded into a fixed context of 256 +# Image context is fixed at 256 for all images MAX_QWEN_IMG_TOKENS = 256 +# Image normalization params +CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) +CLIP_STD = (0.26862954, 0.26130258, 0.27577711) class QwenImagePixelInputs(TypedDict): @@ -85,6 +88,11 @@ class QwenImageEmbeddingInputs(TypedDict): QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs] +### Visual Transformer def / helpers +# These are only used if the model has a visual component in its config. +# The visual components and helpers have all been copied and adapted from +# the implementations in qwen-vl / qwen-vl-chat unless otherwise stated. +# TODO - visual component is not currently using VLLM parallel implementations. def get_abs_pos(abs_pos, tgt_size): # abs_pos: L, C # tgt_size: M @@ -101,10 +109,10 @@ def get_abs_pos(abs_pos, tgt_size): mode="bicubic", align_corners=False, ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) - else: - return abs_pos + return abs_pos +# sin/cos positional embedding helpers are copied from: # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ @@ -239,7 +247,8 @@ def __init__(self, embed_dim, num_heads, bias=True, kdim=None, vdim=None): self.hidden_size_per_partition = embed_dim # Strided linear layer. - assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently' + assert self._qkv_same_embed_dim, \ + 'Visual Attention implementation only supports self-attention' self.in_proj = nn.Linear(embed_dim, 3 * embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) @@ -248,8 +257,8 @@ def forward(self, query, key, value, attn_mask=None): # query/key/value: [sq, b, h] sq, b, _ = query.size() - assert torch.allclose(query, - key), 'Only Support Self-Attention Currently' + assert torch.allclose(query, key), \ + 'Visual Attention implementation only supports self-attention' sk = sq mixed_x_layer = self.in_proj(query) @@ -480,7 +489,15 @@ def forward(self, x: torch.Tensor): return x - def get_image_positions(self, input_ids): + def get_image_positions(self, + input_ids: torch.Tensor) -> Optional[torch.Tensor]: + """Given the input IDs, extracts start/stop points corresponding to + images. + + args: + Returns: + Optional torch tensor corresponding to start/stop pairs of images. + """ if torch.any(input_ids == self.image_start_id): bos_pos = torch.where(input_ids == self.image_start_id) eos_pos = torch.where(input_ids == self.image_end_id) @@ -720,10 +737,13 @@ def get_image_text(image_num: int, padding: bool) -> str: """Retrieves a placeholder text that when tokenized, will be expanded with image pads. - NOTE: The reason that the reason we don't directly encode the imagepadding - here is that it will break the re-encoding of the tokens tokenizer, - because the contents between the start / end are treated as bytes - containing a URL that then get padded up to the image context size. + Args: + image_num: The number of the image that we want a text prompt for. + Images should be indexed starting at 1. + padding: Whether or not padding should be manually added. + + Returns: + Text placeholder prompt for the image being considered. """ image_start = f"Picture {image_num}: {IMG_START}" image_end = f"{IMG_END}\n" @@ -733,6 +753,19 @@ def get_image_text(image_num: int, padding: bool) -> str: def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): + """Processes the inputs, which may or may not be multimodal. + Multimodal inputs will only be processed if the model has a "visual" + component in its model config, otherwise they'll be ignored. + + Args: + ctx: Context of the loaded model. + llm_inputs: LLM inputs which may have a multi_modal_data attribute. + + Returns: + If the model is language only or not multimodal inputs were provided, + returns llm_inputs unmodified. Otherwise, processes the multimodal + images / image embeddings and adds the fixed-length image placeholders. + """ multi_modal_data = llm_inputs.get("multi_modal_data") # Only process images if we have multimodal data and a visual config @@ -754,12 +787,14 @@ def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): f"Expected img embeds to be have 3 dimensions, got {num_dims}") num_images = 1 if num_dims == 2 else image_data.shape[0] else: + # TODO - handle multiple image inputs once the API is solidified num_images = 1 if prompt is None: prompt = tokenizer.decode(prompt_token_ids) # Iteratively replace image tags for every image that we expect + # Currently we only allow multiple images input as embeddings. num_img_tags = prompt.count("") if num_img_tags != num_images: @@ -785,6 +820,17 @@ def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): def input_mapper_for_qwen(ctx: InputContext, data: object): + """Maps the input data to its MultiModalInputs (if any). + + Args: + ctx: Context of the loaded model. + data: data potentially containing image/image embeddings to be mapped + to pixel_values in .forward() for a visual QWenLMHeadModel model. + + Returns: + MultiModalInputs containing the stacked normalized images tensor or + image embeddings. + """ # Early exit if we have provided an image to a language only Qwen model hf_config = ctx.get_hf_config() if not hasattr(hf_config, "visual"): @@ -837,22 +883,39 @@ def input_mapper_for_qwen(ctx: InputContext, data: object): return MultiModalInputs({"pixel_values": pixel_values}) -def build_normalization_transform(image_size): +def build_normalization_transform(image_size: int) -> transforms.Compose: """Builds a normalization transform which can be applied to one or more input images from which we want to extract visual features. + + Args: + image_size: size of the image to be processed for visual embeddings. + + Returns: + Callable transform for normalizing and resizing one RGB image. """ - mean = (0.48145466, 0.4578275, 0.40821073) - std = (0.26862954, 0.26130258, 0.27577711) return transforms.Compose([ transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), - transforms.Normalize(mean=mean, std=std), + transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), ]) def dummy_data_for_qwen(ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]): + """Build dummy data for warming up Qwen models; this will only contain text + matching the defaults for VLLM unless the model has a visual config. + + Args: + ctx: Context of the loaded model. + seq_len: Number of tokens in the text sequence. If this is a visual + model, sequence length will be ignored, and the input sequence + will be determined by the number of images. + mm_counts: multimodal data counts. + + Returns: + Tuple containing sequential and multimodal data. + """ hf_config = ctx.get_hf_config() # The presence of a visual config indicates this is a multimodal model. @@ -868,21 +931,26 @@ def dummy_data_for_qwen(ctx: InputContext, seq_len: int, tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) - # Encode an image pair for each image. During the encoding, qwen tokenizers - # will add image pads between the start/end. We leave this to the - # tokenizer, because we need to rely on the number of added pads at - # inference time. - seq_data = SequenceData( - tokenizer.encode((IMG_START + IMG_END) * num_images, - add_special_tokens=False, - return_tensors="pt")[0].tolist()) - assert seq_data.get_len() == ((2 + MAX_QWEN_IMG_TOKENS) * num_images) + # Build the image prompts with no imgpads; the tokenizer will add img pads + image_prompt = ''.join( + [get_image_text(idx, False) for idx in range(1, num_images + 1)]) + toks = tokenizer.encode(image_prompt, + add_special_tokens=False, + return_tensors="pt")[0].tolist() + + # Make sure we actually get the fixed context size per tok padding + num_pads = toks.count(tokenizer.encode(IMG_PAD)[0]) + if num_pads != (num_images * MAX_QWEN_IMG_TOKENS): + raise ValueError( + f"Tokenized dummy data should encode {MAX_QWEN_IMG_TOKENS} pads" + f" per image, but got {num_pads} pads for {num_images} image(s)" + " in total. Are you using a qwen tokenizer?") # Build the input images; width/height doesn't actually matter here since # the data will get resized and the # of tokens per image is constant image = Image.new("RGB", (224, 224), color=0) mm_data = {"image": image if num_images == 1 else [image] * num_images} - return seq_data, mm_data + return SequenceData(toks), mm_data @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) @@ -914,6 +982,16 @@ def __init__( def _get_image_input_type( self, pixel_values: Optional[torch.Tensor]) -> Optional[QwenImageInputs]: + """Determines if the provided pixel_values are normalized pixel values + or image embeddings. + + Args: + pixel_values: Optional data to processed into visual embeddings. + + Returns: + None of the QwenImageInputs type used to determine whether or not + the visual transformer needs to process the pixel_values. + """ if pixel_values is not None and self.transformer.visual is not None: if len(pixel_values.shape) == 3 and pixel_values.shape[ 1] == MAX_QWEN_IMG_TOKENS and pixel_values.shape[ From 4238c2f1e20b9f8dd3306d0fa9fabb2f6174d525 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Fri, 30 Aug 2024 02:46:40 -0400 Subject: [PATCH 17/35] Fix device and dtype hack in Qwen resampler Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 4ea75c27e693..6bb85b2a15f1 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -180,21 +180,24 @@ def __init__(self, grid_size, embed_dim, num_heads, + device, + dtype, kv_dim=None, norm_layer=nn.LayerNorm): super().__init__() self.num_queries = grid_size**2 self.embed_dim = embed_dim self.num_heads = num_heads - + # NOTE - we need to directly initialize the device / dtype since we + # init this parameter out of a numpy array, so it defaults to CPU + # even though the model itself is initialized in a torch device context + # manager by default. self.pos_embed = nn.Parameter( - # TODO - fix the hacks for device / dtype here & in the - # positional embedding retrieval - torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size) - ).half().to("cuda"), ).requires_grad_(False) + torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).to( + dtype=dtype, device=device).requires_grad_(False)) self.query = nn.Parameter( - torch.zeros(self.num_queries, embed_dim).to("cuda")) + torch.zeros(self.num_queries, embed_dim).to(device)) trunc_normal_(self.query, std=.02) if kv_dim is not None and kv_dim != embed_dim: @@ -454,6 +457,8 @@ def __init__(self, grid_size=int(math.sqrt(n_queries)), embed_dim=output_dim, num_heads=output_dim // 128, + device=self.positional_embedding.device, + dtype=self.positional_embedding.dtype, kv_dim=width, norm_layer=norm_layer, ) From 0258345b2233b2d10f8d4ce88c6480fdfc041b88 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Fri, 30 Aug 2024 04:02:39 -0400 Subject: [PATCH 18/35] Update sequence data initialization Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 6bb85b2a15f1..0b02b37947a8 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -5,9 +5,10 @@ # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE """Inference-only QWen model compatible with HuggingFace weights.""" -import math +from array import array from collections import OrderedDict from functools import partial +import math from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -43,8 +44,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs -from vllm.multimodal.image import cached_get_tokenizer -from vllm.sequence import IntermediateTensors, SequenceData +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.sequence import (IntermediateTensors, + SequenceData, + VLLM_TOKEN_ID_ARRAY_TYPE) from .utils import is_pp_missing_parameter, make_layers @@ -926,7 +929,7 @@ def dummy_data_for_qwen(ctx: InputContext, seq_len: int, # The presence of a visual config indicates this is a multimodal model. # If we don't have it, the model is considered an LLM for warmup purposes. if not hasattr(hf_config, "visual"): - seq_data = SequenceData([0] * seq_len) + seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)) mm_data = None return seq_data, mm_data @@ -955,7 +958,7 @@ def dummy_data_for_qwen(ctx: InputContext, seq_len: int, # the data will get resized and the # of tokens per image is constant image = Image.new("RGB", (224, 224), color=0) mm_data = {"image": image if num_images == 1 else [image] * num_images} - return SequenceData(toks), mm_data + return SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, toks)), mm_data @MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) From d0b89622ea7dd5620fa7890058671f3d7448be7f Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Fri, 30 Aug 2024 04:30:19 -0400 Subject: [PATCH 19/35] Flatten bn dimension for qwen Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 0b02b37947a8..c74b9dd4837f 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -49,7 +49,7 @@ SequenceData, VLLM_TOKEN_ID_ARRAY_TYPE) -from .utils import is_pp_missing_parameter, make_layers +from .utils import flatten_bn, is_pp_missing_parameter, make_layers logger = init_logger(__name__) @@ -70,7 +70,7 @@ class QwenImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor """ - Shape: `(# images, 3, image_size, image_size)` + Shape: `(batch_size * num_images, 3, image_size, image_size)` Note that image_size is the value in the vision config to which we resize the image to in the normalization transform. Currently multi-image support @@ -81,7 +81,7 @@ class QwenImagePixelInputs(TypedDict): class QwenImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: torch.Tensor - """Shape: `(# images, 256, hidden_size)` + """Shape: `(batch_size * num_images, 256, hidden_size)` `hidden_size` must match the hidden size of the language model backbone and is stored in the visual config of the model if we have one. @@ -1001,6 +1001,7 @@ def _get_image_input_type( the visual transformer needs to process the pixel_values. """ if pixel_values is not None and self.transformer.visual is not None: + pixel_values = flatten_bn(pixel_values) if len(pixel_values.shape) == 3 and pixel_values.shape[ 1] == MAX_QWEN_IMG_TOKENS and pixel_values.shape[ 2] == self.config.visual["output_dim"]: From 030b5358ced80ce485456d7dc09e25da82630126 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Fri, 30 Aug 2024 04:41:32 -0400 Subject: [PATCH 20/35] Switch qwen test to text only model Signed-off-by: Alex-Brooks --- tests/models/test_qwen.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py index 0f974fcc1885..fea1531e580e 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/test_qwen.py @@ -5,13 +5,17 @@ from ..conftest import HfRunner, VllmRunner from .utils import check_logprobs_close -models = ["qwen/qwen-vl"] - +text_only_models = [ + "Qwen/Qwen-7B-Chat" # Has no visual component +] +# Text only tests; the primary purpose of this test is to ensure that we can +# load Qwen models, e.g., Qwen/Qwen-7B-Chat, that do not have a visual config, +# without any problems. @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("model", text_only_models) def test_text_only_qwen_model( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], From 27f819ab7c61e1849ef131126249cc8b382137a4 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Fri, 30 Aug 2024 07:23:34 -0400 Subject: [PATCH 21/35] Run code formatting Signed-off-by: Alex-Brooks --- tests/models/test_qwen.py | 3 ++- vllm/model_executor/models/qwen.py | 7 +++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py index fea1531e580e..06e7a7809b76 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/test_qwen.py @@ -6,9 +6,10 @@ from .utils import check_logprobs_close text_only_models = [ - "Qwen/Qwen-7B-Chat" # Has no visual component + "Qwen/Qwen-7B-Chat" # Has no visual component ] + # Text only tests; the primary purpose of this test is to ensure that we can # load Qwen models, e.g., Qwen/Qwen-7B-Chat, that do not have a visual config, # without any problems. diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index c74b9dd4837f..d2f86910caa3 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -5,10 +5,10 @@ # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE """Inference-only QWen model compatible with HuggingFace weights.""" +import math from array import array from collections import OrderedDict from functools import partial -import math from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -45,9 +45,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import (IntermediateTensors, - SequenceData, - VLLM_TOKEN_ID_ARRAY_TYPE) +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SequenceData) from .utils import flatten_bn, is_pp_missing_parameter, make_layers From fcdd6f1455b8dd14aff6a87b0b67500f63cb49a4 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 1 Sep 2024 09:09:41 -0400 Subject: [PATCH 22/35] Add image tag standardization, multimodal qwen tests Signed-off-by: Alex-Brooks --- examples/offline_inference_vision_language.py | 4 + tests/models/test_qwen.py | 177 ++++++++++++++++-- vllm/model_executor/models/qwen.py | 39 +++- 3 files changed, 201 insertions(+), 19 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index a05a4a394284..e7cbf7f0000c 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -163,6 +163,10 @@ def run_blip2(question): def run_qwen_vl(question): llm = LLM(model="Qwen/Qwen-VL", trust_remote_code=True) + # NOTE: In this case, we could pass either '' or + # 'Picture {idx} '; currently tags get + # unified and resolved to the corresponding indices as part + # of the Qwen model input processor. prompt = f"{question}" stop_token_ids = None return llm, prompt, stop_token_ids diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py index 06e7a7809b76..54b508bc7c00 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/test_qwen.py @@ -1,18 +1,170 @@ -from typing import Type +import pathlib +from typing import List, Optional, Type import pytest +from transformers import AutoTokenizer -from ..conftest import HfRunner, VllmRunner +from vllm.model_executor.models.qwen import get_qwen_llm_inputs +from vllm.multimodal.utils import rescale_image_size + +from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from .utils import check_logprobs_close +pytestmark = pytest.mark.vlm + text_only_models = [ "Qwen/Qwen-7B-Chat" # Has no visual component ] +multimodal_models = ["Qwen/Qwen-VL"] + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + "Picture 1: \nWhat's the content of the image?: ", + "cherry_blossom": + "Picture 1: \nWhat is the season?: ", +}) + + +### Tests for multimodal Qwen models +@pytest.mark.parametrize("hf_input_text,vllm_input_text,num_images", [ + ("I have no image tags", "I have no image tags", 0), + ("Picture 1: \n", "Picture 1: \n", 1), + ("Picture 1: \n", "", 1), + ("Picture 1: \n Picture 2: \n", " ", + 2), +]) +def test_qwen_input_processor_tag_unification(hf_input_text, vllm_input_text, + num_images): + tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL", + trust_remote_code=True) + hf_tok_ids = tokenizer.encode(hf_input_text) + vllm_tok_ids = get_qwen_llm_inputs( + vllm_input_text, + tokenizer, + num_images, + multi_modal_data=None, + )["prompt_token_ids"] + assert len(vllm_tok_ids) == len(hf_tok_ids) + assert vllm_tok_ids == hf_tok_ids + + +def run_test( + tmp_path: pathlib.PosixPath, + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: List[float], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding MultiModalConfig as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + images = [asset.pil_image for asset in image_assets] + + # Export the images to a tempdir and substitute it into the hf prompt; + # the contents between / will be ignored by VLLM, but the + # transformers implementation for the visual transformer parses this to + # reload it in the forward call; the contents are treated as a URL or a + # local path. + for idx, asset in enumerate(image_assets): + image_tmp_path = tmp_path / f"{asset.name}.jpg" + asset.pil_image.save(image_tmp_path) + HF_IMAGE_PROMPTS[idx] = HF_IMAGE_PROMPTS[idx].replace( + "", f"{image_tmp_path}") + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). -# Text only tests; the primary purpose of this test is to ensure that we can -# load Qwen models, e.g., Qwen/Qwen-7B-Chat, that do not have a visual config, -# without any problems. + # max_model_len should be greater than image_feature_size + with vllm_runner(model, + max_model_len=2048, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + vllm_outputs_per_image = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image + ] + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs_per_image = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, + vllm_outputs_per_image): + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("model", multimodal_models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [8]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets, + model, size_factors, dtype, max_tokens, + num_logprobs) -> None: + run_test( + tmp_path, + hf_runner, + vllm_runner, + image_assets, + model, + size_factors=size_factors, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + + +### Tests for language only Qwen models @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) @@ -27,19 +179,18 @@ def test_text_only_qwen_model( max_tokens: int, num_logprobs: int, ): - # This test checks language inputs only, since the visual component - # for qwen-vl is still unsupported in VLLM. In the near-future, the - # implementation and this test will be extended to consider - # visual inputs as well. - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy_logprobs_limit( + # the primary purpose of this test is to ensure that we can + # load Qwen models, e.g., Qwen/Qwen-7B-Chat, that do not have a visual + # config, without any problems. + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs=num_logprobs, ) - with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs=num_logprobs, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index d2f86910caa3..8fd0e9af9b9e 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -6,6 +6,7 @@ """Inference-only QWen model compatible with HuggingFace weights.""" import math +import re from array import array from collections import OrderedDict from functools import partial @@ -20,7 +21,7 @@ from torch.nn.init import trunc_normal_ from torchvision import transforms from torchvision.transforms import InterpolationMode -from transformers import PretrainedConfig +from transformers import PretrainedConfig, PreTrainedTokenizer from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig @@ -42,7 +43,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, @@ -800,13 +801,39 @@ def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): if prompt is None: prompt = tokenizer.decode(prompt_token_ids) - # Iteratively replace image tags for every image that we expect - # Currently we only allow multiple images input as embeddings. - num_img_tags = prompt.count("") + return get_qwen_llm_inputs(prompt, tokenizer, num_images, multi_modal_data) + + +def get_qwen_llm_inputs( + prompt: str, tokenizer: PreTrainedTokenizer, num_images: int, + multi_modal_data: Optional[MultiModalDataDict]) -> LLMInputs: + """Standardize the image token format. Qwen generally expects images + to be formatted matching the regex below, but currently, we also let + users pass . This offers a couple benefits. + + 1. Usually the picture numbering is automatically done by the tokenizer + utils when converting from a list format. Expecting users to do it + correctly when they may not have the tokenizer on the client side is + error-prone, e.g., users may accidentally 0-index their images, which + can cause weird results + + 2. Chat can use this to encode images for Qwen without having to consider + image indices at the moment. + Args: + prompt: Prompt whose image tags will be standardized. + tokenizer: Qwen tokenizer for this model. + num_images: Number of images passed in the multimodal data. + multi_modal_data: Multimodal data for this request. + + Returns: + LLM data to be returned by the input processor. + """ + prompt = re.sub(r"Picture :\d* .+?<\/img>", "", prompt) + num_img_tags = prompt.count("") if num_img_tags != num_images: logger.warning( - "Number of tokens does not match the number of images") + "Number of image placeholders does not match the number of images") # Only replace as many image tags as we are going to be able to process # correctly. Sequentially replace image tags; padding shenanigans are From a5c1201338421755b8e75e3a0e412d69bb338f07 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 1 Sep 2024 19:06:20 -0400 Subject: [PATCH 23/35] Remove support for in qwen Signed-off-by: Alex-Brooks --- examples/offline_inference_vision_language.py | 6 +- tests/models/test_qwen.py | 56 +++-------------- vllm/entrypoints/chat_utils.py | 6 +- vllm/model_executor/models/qwen.py | 62 +++++-------------- 4 files changed, 27 insertions(+), 103 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index e7cbf7f0000c..648365630c5f 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -163,11 +163,7 @@ def run_blip2(question): def run_qwen_vl(question): llm = LLM(model="Qwen/Qwen-VL", trust_remote_code=True) - # NOTE: In this case, we could pass either '' or - # 'Picture {idx} '; currently tags get - # unified and resolved to the corresponding indices as part - # of the Qwen model input processor. - prompt = f"{question}" + prompt = f"{question}Picture 1: \n" stop_token_ids = None return llm, prompt, stop_token_ids diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py index 54b508bc7c00..27c1ab36a070 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/test_qwen.py @@ -2,9 +2,7 @@ from typing import List, Optional, Type import pytest -from transformers import AutoTokenizer -from vllm.model_executor.models.qwen import get_qwen_llm_inputs from vllm.multimodal.utils import rescale_image_size from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets @@ -27,28 +25,6 @@ ### Tests for multimodal Qwen models -@pytest.mark.parametrize("hf_input_text,vllm_input_text,num_images", [ - ("I have no image tags", "I have no image tags", 0), - ("Picture 1: \n", "Picture 1: \n", 1), - ("Picture 1: \n", "", 1), - ("Picture 1: \n Picture 2: \n", " ", - 2), -]) -def test_qwen_input_processor_tag_unification(hf_input_text, vllm_input_text, - num_images): - tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL", - trust_remote_code=True) - hf_tok_ids = tokenizer.encode(hf_input_text) - vllm_tok_ids = get_qwen_llm_inputs( - vllm_input_text, - tokenizer, - num_images, - multi_modal_data=None, - )["prompt_token_ids"] - assert len(vllm_tok_ids) == len(hf_tok_ids) - assert vllm_tok_ids == hf_tok_ids - - def run_test( tmp_path: pathlib.PosixPath, hf_runner: Type[HfRunner], @@ -96,8 +72,9 @@ def run_test( # will hurt multiprocessing backend with fork method (the default method). # max_model_len should be greater than image_feature_size + # Qwen encodes images into a fixed content size of 256 with vllm_runner(model, - max_model_len=2048, + max_model_len=300, dtype=dtype, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, @@ -164,13 +141,13 @@ def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets, ) -### Tests for language only Qwen models -@pytest.mark.parametrize("dtype", ["half"]) +# Ensure that a text-only Qwen model can still be loaded and +# used for inference in VLLM without throwing. +@pytest.mark.parametrize("model", text_only_models) +@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("model", text_only_models) -def test_text_only_qwen_model( - hf_runner: Type[HfRunner], +def test_text_only_qwen_model_can_be_loaded_and_run( vllm_runner: Type[VllmRunner], example_prompts, model: str, @@ -179,26 +156,9 @@ def test_text_only_qwen_model( max_tokens: int, num_logprobs: int, ): - # the primary purpose of this test is to ensure that we can - # load Qwen models, e.g., Qwen/Qwen-7B-Chat, that do not have a visual - # config, without any problems. with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( - example_prompts, - max_tokens, - num_logprobs=num_logprobs, - ) - - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy_logprobs_limit( + vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs=num_logprobs, ) - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 62131d8ec420..c6611e58bf70 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -129,13 +129,13 @@ def add(self, modality: Literal["image", "audio"], if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"): # These models do not use image tokens in the prompt return None + if model_type == "qwen": + return f"Picture {current_count}: " if model_type.startswith("llava"): return MultiModalItemTracker._cached_token_str( self._tokenizer, self._model_config.hf_config.image_token_index) - # NOTE: qwen models do not use normally, but input - # processor will expand it to the expected format - if model_type in ("chameleon", "internvl_chat", "qwen"): + if model_type in ("chameleon", "internvl_chat"): return "" raise TypeError(f"Unknown model type: {model_type}") diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 8fd0e9af9b9e..f88f89640c06 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -21,7 +21,7 @@ from torch.nn.init import trunc_normal_ from torchvision import transforms from torchvision.transforms import InterpolationMode -from transformers import PretrainedConfig, PreTrainedTokenizer +from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig @@ -43,7 +43,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, @@ -801,54 +801,22 @@ def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): if prompt is None: prompt = tokenizer.decode(prompt_token_ids) - return get_qwen_llm_inputs(prompt, tokenizer, num_images, multi_modal_data) + # Drops anything between / tags; encoding with the tokenizer + # will automatically add the image pads for the context. + new_prompt, num_matched_images = re.subn( + r"(Picture \d*: ).*?(<\/img>\n)", + r"\1\2", + prompt, + ) + if num_matched_images == num_images: + logger.warning( + "Number of matched image placeholders doesn't match the number " + "of expected images; are your placeholders formatted correctly?") -def get_qwen_llm_inputs( - prompt: str, tokenizer: PreTrainedTokenizer, num_images: int, - multi_modal_data: Optional[MultiModalDataDict]) -> LLMInputs: - """Standardize the image token format. Qwen generally expects images - to be formatted matching the regex below, but currently, we also let - users pass . This offers a couple benefits. - - 1. Usually the picture numbering is automatically done by the tokenizer - utils when converting from a list format. Expecting users to do it - correctly when they may not have the tokenizer on the client side is - error-prone, e.g., users may accidentally 0-index their images, which - can cause weird results - - 2. Chat can use this to encode images for Qwen without having to consider - image indices at the moment. - - Args: - prompt: Prompt whose image tags will be standardized. - tokenizer: Qwen tokenizer for this model. - num_images: Number of images passed in the multimodal data. - multi_modal_data: Multimodal data for this request. + new_prompt_token_ids = tokenizer.encode(new_prompt) - Returns: - LLM data to be returned by the input processor. - """ - prompt = re.sub(r"Picture :\d* .+?<\/img>", "", prompt) - num_img_tags = prompt.count("") - if num_img_tags != num_images: - logger.warning( - "Number of image placeholders does not match the number of images") - - # Only replace as many image tags as we are going to be able to process - # correctly. Sequentially replace image tags; padding shenanigans are - # mostly to sidestep url encoding logic in the tokenizer - new_prompt_no_img_pads = new_prompt_with_img_pads = prompt - for img_num in range(1, min(num_images, num_img_tags) + 1): - image_prompt_without_padding = get_image_text(img_num, padding=False) - image_prompt_with_padding = get_image_text(img_num, padding=True) - new_prompt_no_img_pads = new_prompt_no_img_pads.replace( - '', image_prompt_without_padding, 1) - new_prompt_with_img_pads = new_prompt_with_img_pads.replace( - '', image_prompt_with_padding, 1) - new_prompt_token_ids = tokenizer.encode(new_prompt_no_img_pads) - - return LLMInputs(prompt=new_prompt_with_img_pads, + return LLMInputs(prompt=new_prompt, prompt_token_ids=new_prompt_token_ids, multi_modal_data=multi_modal_data) From 29a3c7f470c124370c526dc4eb2d2e81e38d2f8c Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Sun, 1 Sep 2024 19:19:24 -0400 Subject: [PATCH 24/35] Update docs for multimodal support Signed-off-by: Alex-Brooks --- docs/source/models/supported_models.rst | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 2c20b6e48407..df0b4cabded8 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -155,10 +155,6 @@ Decoder-only Language Models - Persimmon - :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc. - - * - :code:`QWenLMHeadModel` - - Qwen - - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. - - * - :code:`Qwen2ForCausalLM` - Qwen2 - :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc. @@ -242,6 +238,11 @@ Multimodal Language Models - Image - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - + * - :code:`QWenLMHeadModel` + - Qwen + - Image + - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. + - * - :code:`UltravoxModel` - Ultravox - Audio From 7ac8ff9b17c986bf94d2ddd20b781364a97c7997 Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Tue, 3 Sep 2024 11:00:14 -0600 Subject: [PATCH 25/35] Update vllm/model_executor/models/qwen.py Co-authored-by: Cyrus Leung --- vllm/model_executor/models/qwen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index f88f89640c06..18926fd70b4d 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -237,7 +237,7 @@ class VisualAttention(nn.Module): """ def __init__(self, embed_dim, num_heads, bias=True, kdim=None, vdim=None): - super(VisualAttention, self).__init__() + super().__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim From 3fe3a77c211cb9a1c3b0c51a3ab5179bd51c0a02 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 3 Sep 2024 13:01:24 -0400 Subject: [PATCH 26/35] Add qwen back to llm docs Signed-off-by: Alex-Brooks --- docs/source/models/supported_models.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index df0b4cabded8..c3a92634e1b9 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -155,6 +155,10 @@ Decoder-only Language Models - Persimmon - :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc. - + * - :code:`QWenLMHeadModel` + - Qwen + - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. + - * - :code:`Qwen2ForCausalLM` - Qwen2 - :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc. From a4f5400f1b00cdb728d6b362fae2485570190da3 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Tue, 3 Sep 2024 19:59:22 -0400 Subject: [PATCH 27/35] Make qwen/minicpmv embedding utils common Signed-off-by: Alex-Brooks --- vllm/model_executor/models/minicpmv.py | 97 +------------------------- vllm/model_executor/models/qwen.py | 89 ++--------------------- 2 files changed, 7 insertions(+), 179 deletions(-) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index dd10729b9ffb..6aa760f3c11d 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -44,6 +44,8 @@ from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.resampler import (get_abs_pos, + get_2d_sincos_pos_embed) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype @@ -98,101 +100,6 @@ class MiniCPMVImagePixelInputs(TypedDict): DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) -def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor): - # abs_pos: L, C - # tgt_size: (H, W) - # return: M, C - src_size = int(math.sqrt(abs_pos.size(0))) - # tgt_size = int(math.sqrt(tgt_size)) - dtype = abs_pos.dtype - - return (F.interpolate( - abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), - size=(tgt_size[0], tgt_size[1]), - mode="bicubic", - align_corners=False, - ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)) - - -# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 -def get_2d_sincos_pos_embed( - embed_dim: int, - grid_size: Union[int, Tuple[int, int]], - cls_token: bool = False, - version: Tuple[int, int] = (2, 0), -): - """ - grid_size: int of the grid height and width - return: - pos_embed: [grid_size*grid_size, embed_dim] or - [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - if isinstance(grid_size, int): - grid_h_size, grid_w_size = grid_size, grid_size - else: - grid_h_size, grid_w_size = grid_size[0], grid_size[1] - - grid_h = np.arange(grid_h_size, dtype=np.float32) - grid_w = np.arange(grid_w_size, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - if version == (2, 0): - grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) - if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], - axis=0) - else: - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim: int, - grid: np.ndarray, - version: Tuple[int, int] = (2, 0)): - assert embed_dim % 2 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid( - embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid( - embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2) - - if version == (2, 0): - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - else: - emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim: int, - pos: np.ndarray, - version: Tuple[int, int] = (2, 0)): - """ - embed_dim: output dimension for each position - pos: a list of positions to be encoded: size (M,) / (H, W) - out: (M, D) / (H, W, D) - """ - assert embed_dim % 2 == 0 - omega = np.arange(embed_dim // 2, dtype=np.float32) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) - - if version == (2, 0): - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - else: - out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product - emb_sin = np.sin(out) # (H, W, D/2) - emb_cos = np.cos(out) # (H, W, D/2) - emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) - return emb - - class BaseResampler(nn.Module): """ A 2D perceiver-resampler network with one cross attention layers by diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 18926fd70b4d..7f8a45202de2 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -13,11 +13,9 @@ from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) -import numpy as np import torch from PIL import Image from torch import nn -from torch.nn import functional as F from torch.nn.init import trunc_normal_ from torchvision import transforms from torchvision.transforms import InterpolationMode @@ -36,6 +34,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.resampler import (get_abs_pos, + get_2d_sincos_pos_embed) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -91,86 +91,6 @@ class QwenImageEmbeddingInputs(TypedDict): QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs] -### Visual Transformer def / helpers -# These are only used if the model has a visual component in its config. -# The visual components and helpers have all been copied and adapted from -# the implementations in qwen-vl / qwen-vl-chat unless otherwise stated. -# TODO - visual component is not currently using VLLM parallel implementations. -def get_abs_pos(abs_pos, tgt_size): - # abs_pos: L, C - # tgt_size: M - # return: M, C - src_size = int(math.sqrt(abs_pos.size(0))) - tgt_size = int(math.sqrt(tgt_size)) - dtype = abs_pos.dtype - - if src_size != tgt_size: - return F.interpolate( - abs_pos.float().reshape(1, src_size, src_size, - -1).permute(0, 3, 1, 2), - size=(tgt_size, tgt_size), - mode="bicubic", - align_corners=False, - ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) - return abs_pos - - -# sin/cos positional embedding helpers are copied from: -# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 -def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): - """ - grid_size: int of the grid height and width - return: - pos_embed: [grid_size*grid_size, embed_dim] or - [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - grid_h = np.arange(grid_size, dtype=np.float32) - grid_w = np.arange(grid_size, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_size, grid_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], - axis=0) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - assert embed_dim % 2 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, - grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, - grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position - pos: a list of positions to be encoded: size (M,) - out: (M, D) - """ - assert embed_dim % 2 == 0 - omega = np.arange(embed_dim // 2, dtype=np.float32) - omega /= embed_dim / 2. - omega = 1. / 10000**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb - - class Resampler(nn.Module): """ A 2D perceiver-resampler network with one cross attention layers by @@ -213,7 +133,7 @@ def __init__(self, self.ln_kv = norm_layer(embed_dim) def forward(self, x, attn_mask=None): - pos_embed = get_abs_pos(self.pos_embed, x.size(1)) + pos_embed = get_abs_pos(self.pos_embed, int(math.sqrt(x.size(1)))) x = self.kv_proj(x) x = self.ln_kv(x).permute(1, 0, 2) @@ -483,7 +403,8 @@ def forward(self, x: torch.Tensor): -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = x + get_abs_pos(self.positional_embedding, x.size(1)) + x = x + get_abs_pos(self.positional_embedding, int(math.sqrt( + x.size(1)))) x = self.ln_pre(x) From 1989da815bccd78a673410799d752e15e86d13ae Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 4 Sep 2024 00:22:53 -0400 Subject: [PATCH 28/35] Make qwenvl / minicpmv2.0 resampler common Signed-off-by: Alex-Brooks --- vllm/model_executor/layers/resampler.py | 230 ++++++++++++++++++++++++ vllm/model_executor/models/minicpmv.py | 67 +------ vllm/model_executor/models/qwen.py | 71 +------- 3 files changed, 243 insertions(+), 125 deletions(-) create mode 100644 vllm/model_executor/layers/resampler.py diff --git a/vllm/model_executor/layers/resampler.py b/vllm/model_executor/layers/resampler.py new file mode 100644 index 000000000000..818918a62673 --- /dev/null +++ b/vllm/model_executor/layers/resampler.py @@ -0,0 +1,230 @@ +from functools import partial +import math +from typing import Tuple, Union, Optional, Callable + +import numpy as np +import torch +from torch import nn +from torch.nn.init import trunc_normal_ +import torch.nn.functional as F + +from vllm.model_executor.layers.linear import ReplicatedLinear + +DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) + +def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, int]): + # abs_pos: L, C + # tgt_size: (H, W) + # return: M, C + src_size = int(math.sqrt(abs_pos.size(0))) + dtype = abs_pos.dtype + if isinstance(tgt_size, int): + tgt_size = (tgt_size, tgt_size) + if (src_size == tgt_size[0] and src_size == tgt_size[1]): + return abs_pos + return (F.interpolate( + abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), + size=(tgt_size[0], tgt_size[1]), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)) + + +# sin/cos positional embedding helpers are adapted from: +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, + pos: np.ndarray, + version: Tuple[int, int] = (2, 0)): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) / (H, W) + out: (M, D) / (H, W, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + if version == (2, 0): + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + else: + out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product + emb_sin = np.sin(out) # (H, W, D/2) + emb_cos = np.cos(out) # (H, W, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid(embed_dim: int, + grid: np.ndarray, + version: Tuple[int, int] = (2, 0)): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2) + + if version == (2, 0): + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + else: + emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed( + embed_dim: int, + grid_size: Union[int, Tuple[int, int]], + cls_token: bool = False, + version: Tuple[int, int] = (2, 0), +): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_h_size, grid_w_size = grid_size, grid_size + else: + grid_h_size, grid_w_size = grid_size[0], grid_size[1] + + grid_h = np.arange(grid_h_size, dtype=np.float32) + grid_w = np.arange(grid_w_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + if version == (2, 0): + grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], + axis=0) + else: + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + return pos_embed + + +class BaseResampler(nn.Module): + """ + A 2D perceiver-resampler network with one cross attention layers by + (grid_size**2) learnable queries and 2d sincos pos_emb. + Outputs: + A tensor with the shape of (grid_size**2, embed_dim) + """ + + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + do_post_projection: bool = True, + ) -> None: + super().__init__() + + self.num_queries = num_queries + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) + trunc_normal_(self.query, std=0.02) + if kv_dim is not None and kv_dim != embed_dim: + self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False) + else: + # Maintain the same return value with ReplicatedLinear.forward + self.kv_proj = lambda *args, **kwargs: ( + nn.Identity()(*args, **kwargs), + None, + ) + self.attn = nn.MultiheadAttention(embed_dim, num_heads) + self.ln_q = norm_layer(embed_dim) + self.ln_kv = norm_layer(embed_dim) + self.do_post_projection = do_post_projection + self.ln_post = norm_layer(embed_dim) if do_post_projection else None + self.proj = nn.Parameter( + (embed_dim**-0.5) * torch.randn(embed_dim, embed_dim) + ) if do_post_projection else None + + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class Resampler2(BaseResampler): + """Resampler-perceiver network to be used for a variety of model types, + e.g., Qwen and Minicpmv 2.0. + """ + + def __init__( + self, + grid_size: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + adaptive: bool = False, + do_post_projection: bool = True, + ) -> None: + super().__init__(grid_size**2, embed_dim, num_heads, kv_dim, + norm_layer, do_post_projection=do_post_projection) + + self.adaptive = adaptive + pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, + grid_size, + version=(2, 0)) + + self.pos_embed = nn.Parameter( + torch.from_numpy(pos_embed_arr).requires_grad_(False)) + + self.apply(self._init_weights) + + def forward( + self, + x: torch.Tensor, + tgt_sizes: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ): + if tgt_sizes is None: + tgt_sizes = int(math.sqrt(x.size(1))) + if self.adaptive: + pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, + tgt_sizes, + version=(2, 0)) + pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device, + dtype=x.dtype) + else: + pos_embed = get_abs_pos(self.pos_embed, tgt_sizes).to( + device=x.device, + dtype=x.dtype) + + x, _ = self.kv_proj(x) + x = self.ln_kv(x).permute(1, 0, 2) + + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn( + self._repeat(q, N) + self.pos_embed.unsqueeze(1), + x + pos_embed.unsqueeze(1), + x, + attn_mask=attn_mask, + )[0] + x = out.permute(1, 0, 2) + if self.do_post_projection: + x = self.ln_post(x) + x = x @ self.proj + return x diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 6aa760f3c11d..2405417e1b4c 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -26,11 +26,9 @@ from array import array from functools import partial from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple, - TypedDict, Union) + TypedDict) -import numpy as np import torch -import torch.nn.functional as F import torch.types from PIL import Image from torch import nn @@ -44,8 +42,8 @@ from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.resampler import (get_abs_pos, - get_2d_sincos_pos_embed) +from vllm.model_executor.layers.resampler import (get_2d_sincos_pos_embed, + Resampler2) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype @@ -152,62 +150,6 @@ def _repeat(self, query, N: int): return query.unsqueeze(1).repeat(1, N, 1) -class Resampler2(BaseResampler): - - def __init__( - self, - grid_size: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - adaptive: bool = False, - ) -> None: - super().__init__(grid_size**2, embed_dim, num_heads, kv_dim, - norm_layer) - - self.adaptive = adaptive - pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, - grid_size, - version=(2, 0)) - self.pos_embed = nn.Parameter( - torch.from_numpy(pos_embed_arr).float()).requires_grad_(False) - - self.apply(self._init_weights) - - def forward( - self, - x: torch.Tensor, - tgt_sizes: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - ): - if self.adaptive: - pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, - tgt_sizes, - version=(2, 0)) - pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device, - dtype=x.dtype) - else: - pos_embed = get_abs_pos(self.pos_embed, tgt_sizes) - - x, _ = self.kv_proj(x) - x = self.ln_kv(x).permute(1, 0, 2) - - N = x.shape[1] - q = self.ln_q(self.query) - out = self.attn( - self._repeat(q, N) + self.pos_embed.unsqueeze(1), - x + pos_embed.unsqueeze(1), - x, - attn_mask=attn_mask, - )[0] - x = out.permute(1, 0, 2) - - x = self.ln_post(x) - x = x @ self.proj - return x - - class Resampler2_5(BaseResampler): def __init__( @@ -689,7 +631,8 @@ def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: num_heads=embed_dim // 128, grid_size=int(math.sqrt(self.config.query_num)), kv_dim=vision_dim, - adaptive=True, + adaptive=False, + do_post_projection=True, ) return resampler diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 7f8a45202de2..7649f90ee693 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -35,7 +35,7 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.resampler import (get_abs_pos, - get_2d_sincos_pos_embed) + Resampler2) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -91,65 +91,6 @@ class QwenImageEmbeddingInputs(TypedDict): QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs] -class Resampler(nn.Module): - """ - A 2D perceiver-resampler network with one cross attention layers by - (grid_size**2) learnable queries and 2d sincos pos_emb - Outputs: - A tensor with the shape of (grid_size**2, embed_dim) - """ - - def __init__(self, - grid_size, - embed_dim, - num_heads, - device, - dtype, - kv_dim=None, - norm_layer=nn.LayerNorm): - super().__init__() - self.num_queries = grid_size**2 - self.embed_dim = embed_dim - self.num_heads = num_heads - # NOTE - we need to directly initialize the device / dtype since we - # init this parameter out of a numpy array, so it defaults to CPU - # even though the model itself is initialized in a torch device context - # manager by default. - self.pos_embed = nn.Parameter( - torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).to( - dtype=dtype, device=device).requires_grad_(False)) - - self.query = nn.Parameter( - torch.zeros(self.num_queries, embed_dim).to(device)) - trunc_normal_(self.query, std=.02) - - if kv_dim is not None and kv_dim != embed_dim: - self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) - else: - self.kv_proj = nn.Identity() - - self.attn = nn.MultiheadAttention(embed_dim, num_heads) - self.ln_q = norm_layer(embed_dim) - self.ln_kv = norm_layer(embed_dim) - - def forward(self, x, attn_mask=None): - pos_embed = get_abs_pos(self.pos_embed, int(math.sqrt(x.size(1)))) - - x = self.kv_proj(x) - x = self.ln_kv(x).permute(1, 0, 2) - - N = x.shape[1] - q = self.ln_q(self.query) - out = self.attn(self._repeat(q, N) + self.pos_embed.unsqueeze(1), - x + pos_embed.unsqueeze(1), - x, - attn_mask=attn_mask)[0] - return out.permute(1, 0, 2) - - def _repeat(self, query, N: int): - return query.unsqueeze(1).repeat(1, N, 1) - - class VisualAttention(nn.Module): """self-attention layer class. Self-attention layer takes input with size [s, b, h] @@ -376,15 +317,19 @@ def __init__(self, norm_layer=norm_layer, ) - self.attn_pool = Resampler( + self.attn_pool = Resampler2( grid_size=int(math.sqrt(n_queries)), embed_dim=output_dim, num_heads=output_dim // 128, - device=self.positional_embedding.device, - dtype=self.positional_embedding.dtype, kv_dim=width, norm_layer=norm_layer, + adaptive=False, + do_post_projection=False, + ).to( + device=self.positional_embedding.device, + dtype=self.positional_embedding.dtype, ) + self.ln_post = norm_layer(output_dim) self.proj = nn.Parameter( (output_dim**-0.5) * torch.randn(output_dim, output_dim)) From 5e8409b76fdc52f56c4aa0a68388c166c918e659 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 4 Sep 2024 06:07:17 -0400 Subject: [PATCH 29/35] Fix qwen warning for image placeholders Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 50 ++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 7649f90ee693..fc686f497ad5 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -16,7 +16,6 @@ import torch from PIL import Image from torch import nn -from torch.nn.init import trunc_normal_ from torchvision import transforms from torchvision.transforms import InterpolationMode from transformers import PretrainedConfig @@ -34,8 +33,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.resampler import (get_abs_pos, - Resampler2) +from vllm.model_executor.layers.resampler import (get_abs_pos, Resampler2) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -97,7 +95,14 @@ class VisualAttention(nn.Module): and returns output of the same size. """ - def __init__(self, embed_dim, num_heads, bias=True, kdim=None, vdim=None): + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + kdim: Optional[int] = None, + vdim: Optional[int] = None, + ): super().__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim @@ -120,7 +125,13 @@ def __init__(self, embed_dim, num_heads, bias=True, kdim=None, vdim=None): self.out_proj = nn.Linear(embed_dim, embed_dim) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - def forward(self, query, key, value, attn_mask=None): + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: # query/key/value: [sq, b, h] sq, b, _ = query.size() @@ -212,7 +223,7 @@ def attention( k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, - ): + ) -> torch.Tensor: k_x = k_x if k_x is not None else q_x v_x = v_x if v_x is not None else q_x @@ -225,7 +236,7 @@ def forward( k_x: Optional[torch.Tensor] = None, v_x: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, - ): + ) -> torch.Tensor: k_x = self.ln_1_kv(k_x) if hasattr( self, "ln_1_kv") and k_x is not None else None v_x = self.ln_1_kv(v_x) if hasattr( @@ -268,7 +279,7 @@ def get_cast_device(self) -> torch.device: def forward(self, x: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None): + attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: for r in self.resblocks: x = r(x, attn_mask=attn_mask) return x @@ -336,7 +347,7 @@ def __init__(self, self.image_start_id = image_start_id self.image_end_id = image_start_id + 1 - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = x.to( dtype=self.transformer.get_cast_dtype(), device=self.transformer.get_cast_device(), @@ -402,7 +413,7 @@ def __init__( "Only silu is supported for now.") self.act_fn = SiluAndMul() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.c_proj(x) @@ -626,7 +637,8 @@ def get_image_text(image_num: int, padding: bool) -> str: return f"{image_start}{MAX_QWEN_IMG_TOKENS * IMG_PAD}{image_end}" -def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): +def input_processor_for_qwen(ctx: InputContext, + llm_inputs: LLMInputs) -> LLMInputs: """Processes the inputs, which may or may not be multimodal. Multimodal inputs will only be processed if the model has a "visual" component in its model config, otherwise they'll be ignored. @@ -675,10 +687,11 @@ def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): prompt, ) - if num_matched_images == num_images: + if num_matched_images != num_images: logger.warning( - "Number of matched image placeholders doesn't match the number " - "of expected images; are your placeholders formatted correctly?") + f"Number of matched image placeholders {num_matched_images} " + f"doesn't match the number of expected images {num_images}; " + "are your placeholders formatted correctly?") new_prompt_token_ids = tokenizer.encode(new_prompt) @@ -687,7 +700,7 @@ def input_processor_for_qwen(ctx: InputContext, llm_inputs: LLMInputs): multi_modal_data=multi_modal_data) -def input_mapper_for_qwen(ctx: InputContext, data: object): +def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: """Maps the input data to its MultiModalInputs (if any). Args: @@ -769,8 +782,11 @@ def build_normalization_transform(image_size: int) -> transforms.Compose: ]) -def dummy_data_for_qwen(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): +def dummy_data_for_qwen( + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], +) -> Tuple[SequenceData, Optional[Dict]]: """Build dummy data for warming up Qwen models; this will only contain text matching the defaults for VLLM unless the model has a visual config. From c889a696e7ca98b8191b0b584884171669a70522 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 4 Sep 2024 06:45:02 -0400 Subject: [PATCH 30/35] Fix formatting, missing license, typehints Signed-off-by: Alex-Brooks --- vllm/model_executor/layers/resampler.py | 85 +++++++++++++++++++------ vllm/model_executor/models/minicpmv.py | 4 +- vllm/model_executor/models/qwen.py | 16 +++-- 3 files changed, 76 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/layers/resampler.py b/vllm/model_executor/layers/resampler.py index 818918a62673..8cd938fc85fb 100644 --- a/vllm/model_executor/layers/resampler.py +++ b/vllm/model_executor/layers/resampler.py @@ -1,18 +1,52 @@ -from functools import partial +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +# +# Copyright 2023 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Shared resampler perceiver network used in multimodal models and +related helpers for sincos positional embeddings. + +Example models: Qwen (Qwen-VL), Minicpmv2.0 +""" import math -from typing import Tuple, Union, Optional, Callable +from functools import partial +from typing import Callable, Optional, Tuple, Union import numpy as np import torch +import torch.nn.functional as F from torch import nn from torch.nn.init import trunc_normal_ -import torch.nn.functional as F from vllm.model_executor.layers.linear import ReplicatedLinear DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) -def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, int]): + +def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, + int]) -> torch.Tensor: # abs_pos: L, C # tgt_size: (H, W) # return: M, C @@ -32,9 +66,9 @@ def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, int]): # sin/cos positional embedding helpers are adapted from: # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 -def get_1d_sincos_pos_embed_from_grid(embed_dim: int, - pos: np.ndarray, - version: Tuple[int, int] = (2, 0)): +def get_1d_sincos_pos_embed_from_grid( + embed_dim: int, pos: np.ndarray, + version: Tuple[int, int] = (2, 0)) -> torch.Tensor: """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) / (H, W) @@ -59,9 +93,9 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, return emb -def get_2d_sincos_pos_embed_from_grid(embed_dim: int, - grid: np.ndarray, - version: Tuple[int, int] = (2, 0)): +def get_2d_sincos_pos_embed_from_grid( + embed_dim: int, grid: np.ndarray, + version: Tuple[int, int] = (2, 0)) -> torch.Tensor: assert embed_dim % 2 == 0 # use half of dimensions to encode grid_h @@ -82,7 +116,7 @@ def get_2d_sincos_pos_embed( grid_size: Union[int, Tuple[int, int]], cls_token: bool = False, version: Tuple[int, int] = (2, 0), -): +) -> torch.Tensor: """ grid_size: int of the grid height and width return: @@ -98,6 +132,8 @@ def get_2d_sincos_pos_embed( grid_w = np.arange(grid_w_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) + assert isinstance(grid, np.ndarray) and \ + grid.shape == (2, grid_h_size, grid_w_size) if version == (2, 0): grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) @@ -139,7 +175,7 @@ def __init__( self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False) else: # Maintain the same return value with ReplicatedLinear.forward - self.kv_proj = lambda *args, **kwargs: ( + self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa nn.Identity()(*args, **kwargs), None, ) @@ -149,8 +185,8 @@ def __init__( self.do_post_projection = do_post_projection self.ln_post = norm_layer(embed_dim) if do_post_projection else None self.proj = nn.Parameter( - (embed_dim**-0.5) * torch.randn(embed_dim, embed_dim) - ) if do_post_projection else None + (embed_dim**-0.5) * + torch.randn(embed_dim, embed_dim)) if do_post_projection else None def _init_weights(self, m: nn.Module) -> None: if isinstance(m, nn.Linear): @@ -167,7 +203,10 @@ def _repeat(self, query, N: int): class Resampler2(BaseResampler): """Resampler-perceiver network to be used for a variety of model types, - e.g., Qwen and Minicpmv 2.0. + e.g., Qwen-vl / Minicpmv 2.0. The main difference is the addition of the + do_post_projection arg, which indicates whether or not there should be + a post layer normalization and projector after the attention. This is + present in minicpmv2.0, but not qwen-vl. """ def __init__( @@ -180,8 +219,12 @@ def __init__( adaptive: bool = False, do_post_projection: bool = True, ) -> None: - super().__init__(grid_size**2, embed_dim, num_heads, kv_dim, - norm_layer, do_post_projection=do_post_projection) + super().__init__(grid_size**2, + embed_dim, + num_heads, + kv_dim, + norm_layer, + do_post_projection=do_post_projection) self.adaptive = adaptive pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, @@ -198,7 +241,7 @@ def forward( x: torch.Tensor, tgt_sizes: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, - ): + ) -> torch.Tensor: if tgt_sizes is None: tgt_sizes = int(math.sqrt(x.size(1))) if self.adaptive: @@ -208,9 +251,9 @@ def forward( pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device, dtype=x.dtype) else: - pos_embed = get_abs_pos(self.pos_embed, tgt_sizes).to( - device=x.device, - dtype=x.dtype) + pos_embed = get_abs_pos(self.pos_embed, + tgt_sizes).to(device=x.device, + dtype=x.dtype) x, _ = self.kv_proj(x) x = self.ln_kv(x).permute(1, 0, 2) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 2405417e1b4c..f8be9490ee55 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -42,8 +42,8 @@ from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.resampler import (get_2d_sincos_pos_embed, - Resampler2) +from vllm.model_executor.layers.resampler import (Resampler2, + get_2d_sincos_pos_embed) from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index fc686f497ad5..3aa866c99f61 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -13,6 +13,7 @@ from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) +import numpy as np import torch from PIL import Image from torch import nn @@ -33,7 +34,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.resampler import (get_abs_pos, Resampler2) +from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -574,7 +575,7 @@ def forward( ) -> torch.Tensor: img_pos = None # If pixel / visual embeddings are provided, this is a visual model - if pixel_values is not None: + if pixel_values is not None and self.visual is not None: if pixel_values["type"] != "image_embeds": image_embeds = self.visual(pixel_values["data"]) else: @@ -582,7 +583,9 @@ def forward( # features should be of shape (# images, 256, hidden_dim) img_pos = self.visual.get_image_positions(input_ids) - if img_pos.shape[0] != image_embeds.shape[0]: + if isinstance( + img_pos, + np.ndarray) and img_pos.shape[0] != image_embeds.shape[0]: raise ValueError( f"Number of placeholders: {img_pos.shape[0]} " f"does not match number of images {image_embeds.shape[0]}." @@ -689,9 +692,9 @@ def input_processor_for_qwen(ctx: InputContext, if num_matched_images != num_images: logger.warning( - f"Number of matched image placeholders {num_matched_images} " - f"doesn't match the number of expected images {num_images}; " - "are your placeholders formatted correctly?") + "Number of matched image placeholders %s doesn't match the number " + "of expected images %s; check your placeholder formatting.", + num_matched_images, num_images) new_prompt_token_ids = tokenizer.encode(new_prompt) @@ -891,6 +894,7 @@ def _get_image_input_type( type="pixel_values", data=pixel_values, ) + return None def forward(self, input_ids: torch.Tensor, From 2aa954931ead88bf2b0d4868bf615f9477e909f5 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 4 Sep 2024 17:26:31 -0400 Subject: [PATCH 31/35] Remove unreachable optional cross attn in qwenvl Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 45 ++++++++---------------------- 1 file changed, 11 insertions(+), 34 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 3aa866c99f61..31f95db39757 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -96,6 +96,7 @@ class VisualAttention(nn.Module): and returns output of the same size. """ + def __init__( self, embed_dim: int, @@ -128,18 +129,12 @@ def __init__( def forward( self, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, + x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: # query/key/value: [sq, b, h] - sq, b, _ = query.size() - - assert torch.allclose(query, key), \ - 'Visual Attention implementation only supports self-attention' - sk = sq - mixed_x_layer = self.in_proj(query) + sq, b, _ = x.size() + mixed_x_layer = self.in_proj(x) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] new_tensor_shape = mixed_x_layer.size()[:-1] + \ @@ -157,7 +152,7 @@ def forward( self.hidden_size_per_attention_head).transpose(0, 1) # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.view( - sk, b * self.num_attention_heads_per_partition, + sq, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head).transpose(0, 1) q_scaled = query_layer / self.norm_factor @@ -169,7 +164,7 @@ def forward( attention_probs = attention_probs.softmax(dim=-1) value_layer = value_layer.view( - sk, b * self.num_attention_heads_per_partition, + sq, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head).transpose(0, 1) # matmul: [b * np, sq, hn] @@ -192,7 +187,6 @@ def forward( return output - class VisualAttentionBlock(nn.Module): def __init__( @@ -202,14 +196,10 @@ def __init__( mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU, norm_layer: Callable = nn.LayerNorm, - is_cross_attention: bool = False, ): super().__init__() self.ln_1 = norm_layer(d_model) - if is_cross_attention: - self.ln_1_kv = norm_layer(d_model) - self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) self.attn = VisualAttention(d_model, n_head) @@ -220,31 +210,18 @@ def __init__( def attention( self, - q_x: torch.Tensor, - k_x: Optional[torch.Tensor] = None, - v_x: Optional[torch.Tensor] = None, + x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - k_x = k_x if k_x is not None else q_x - v_x = v_x if v_x is not None else q_x - - attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None - return self.attn(q_x, k_x, v_x, attn_mask=attn_mask) + attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None + return self.attn(x, attn_mask=attn_mask) def forward( self, - q_x: torch.Tensor, - k_x: Optional[torch.Tensor] = None, - v_x: Optional[torch.Tensor] = None, + x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: - k_x = self.ln_1_kv(k_x) if hasattr( - self, "ln_1_kv") and k_x is not None else None - v_x = self.ln_1_kv(v_x) if hasattr( - self, "ln_1_kv") and v_x is not None else None - - x = q_x + self.attention( - q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask) + x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) x = x + self.mlp(self.ln_2(x)) return x From 531d6285e976d318b32b44194772733916acfe44 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Wed, 4 Sep 2024 18:49:20 -0400 Subject: [PATCH 32/35] Use parallel linear layers in qwenvl mlp Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 78 +++++++++++++++++++++--------- 1 file changed, 55 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 31f95db39757..07c15c82ca0e 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -8,7 +8,6 @@ import math import re from array import array -from collections import OrderedDict from functools import partial from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -26,9 +25,10 @@ from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger -from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -96,7 +96,6 @@ class VisualAttention(nn.Module): and returns output of the same size. """ - def __init__( self, embed_dim: int, @@ -187,6 +186,36 @@ def forward( return output + +class QwenVMLP(nn.Module): + """MLP for the visual component of the Qwen model.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.c_fc = ColumnParallelLinear(hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config) + self.act_fn = get_act_fn("gelu", quant_config, intermediate_size) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + quant_config=quant_config, + ) + + def forward(self, x): + x, _ = self.c_fc(x) + x = self.act_fn(x) + x, _ = self.c_proj(x) + return x + + class VisualAttentionBlock(nn.Module): def __init__( @@ -194,8 +223,8 @@ def __init__( d_model: int, n_head: int, mlp_ratio: float = 4.0, - act_layer: Callable = nn.GELU, norm_layer: Callable = nn.LayerNorm, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -203,10 +232,11 @@ def __init__( self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) self.attn = VisualAttention(d_model, n_head) - self.mlp = nn.Sequential( - OrderedDict([("c_fc", nn.Linear(d_model, mlp_width)), - ("gelu", act_layer()), - ("c_proj", nn.Linear(mlp_width, d_model))])) + self.mlp = QwenVMLP( + hidden_size=d_model, + intermediate_size=mlp_width, + quant_config=quant_config, + ) def attention( self, @@ -234,8 +264,8 @@ def __init__( layers: int, heads: int, mlp_ratio: float = 4.0, - act_layer: Callable = nn.GELU, norm_layer: Callable = nn.LayerNorm, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.width = width @@ -245,8 +275,9 @@ def __init__( VisualAttentionBlock(width, heads, mlp_ratio, - act_layer=act_layer, - norm_layer=norm_layer) for _ in range(layers) + norm_layer=norm_layer, + quant_config=quant_config) + for _ in range(layers) ]) def get_cast_dtype(self) -> torch.dtype: @@ -275,6 +306,7 @@ def __init__(self, n_queries: int = 256, output_dim: int = 512, image_start_id: int = 151857, + quant_config: Optional[QuantizationConfig] = None, **kwargs): super().__init__() image_height, image_width = self.image_size = (image_size, image_size) @@ -294,17 +326,14 @@ def __init__(self, torch.randn(256, width)) norm_layer = partial(nn.LayerNorm, eps=1e-6) - act_layer = nn.GELU self.ln_pre = norm_layer(width) - self.transformer = TransformerBlock( - width, - layers, - heads, - mlp_ratio, - act_layer=act_layer, - norm_layer=norm_layer, - ) + self.transformer = TransformerBlock(width, + layers, + heads, + mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config) self.attn_pool = Resampler2( grid_size=int(math.sqrt(n_queries)), @@ -369,6 +398,8 @@ def get_image_positions(self, class QWenMLP(nn.Module): + """MLP for the language component of the Qwen model, which contains a + MergedColumnParallelLinear merging 2 outputs via silu activation.""" def __init__( self, @@ -538,8 +569,9 @@ def __init__( lambda prefix: QWenBlock(config, cache_config, quant_config), prefix=f"{prefix}.h") self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.visual = VisionTransformer( - **config.visual) if hasattr(config, "visual") else None + self.visual = VisionTransformer(**config.visual, + quant_config=quant_config) if hasattr( + config, "visual") else None def forward( self, From d19418da24552b4911c3d99590468c85a0a0347f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 5 Sep 2024 06:32:06 +0000 Subject: [PATCH 33/35] Limit `max_num_seqs` --- examples/offline_inference_vision_language.py | 7 ++++++- tests/models/test_qwen.py | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 648365630c5f..aa1580343aee 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -162,7 +162,12 @@ def run_blip2(question): # Qwen def run_qwen_vl(question): - llm = LLM(model="Qwen/Qwen-VL", trust_remote_code=True) + llm = LLM( + model="Qwen/Qwen-VL", + trust_remote_code=True, + max_num_seqs=5, + ) + prompt = f"{question}Picture 1: \n" stop_token_ids = None return llm, prompt, stop_token_ids diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py index 27c1ab36a070..453250dda5d7 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/test_qwen.py @@ -75,6 +75,7 @@ def run_test( # Qwen encodes images into a fixed content size of 256 with vllm_runner(model, max_model_len=300, + max_num_seqs=5, dtype=dtype, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, From 4f25926c5ad103511d8a7ba49339bd5addb134ba Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 5 Sep 2024 08:39:50 +0000 Subject: [PATCH 34/35] Further limit --- tests/models/test_qwen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py index 453250dda5d7..05f5cbf8c343 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/test_qwen.py @@ -75,7 +75,7 @@ def run_test( # Qwen encodes images into a fixed content size of 256 with vllm_runner(model, max_model_len=300, - max_num_seqs=5, + max_num_seqs=1, dtype=dtype, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, From 00c2e09f5419583665ddae0aa4c90732a806c8f5 Mon Sep 17 00:00:00 2001 From: Alex-Brooks Date: Thu, 5 Sep 2024 06:56:14 -0400 Subject: [PATCH 35/35] Fix dummy data seq padding for multimodal qwen Signed-off-by: Alex-Brooks --- vllm/model_executor/models/qwen.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 07c15c82ca0e..a726ec10984c 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -804,9 +804,7 @@ def dummy_data_for_qwen( Args: ctx: Context of the loaded model. - seq_len: Number of tokens in the text sequence. If this is a visual - model, sequence length will be ignored, and the input sequence - will be determined by the number of images. + seq_len: Number of tokens in the text sequence. mm_counts: multimodal data counts. Returns: @@ -830,9 +828,7 @@ def dummy_data_for_qwen( # Build the image prompts with no imgpads; the tokenizer will add img pads image_prompt = ''.join( [get_image_text(idx, False) for idx in range(1, num_images + 1)]) - toks = tokenizer.encode(image_prompt, - add_special_tokens=False, - return_tensors="pt")[0].tolist() + toks = tokenizer.encode(image_prompt, add_special_tokens=False) # Make sure we actually get the fixed context size per tok padding num_pads = toks.count(tokenizer.encode(IMG_PAD)[0]) @@ -842,6 +838,10 @@ def dummy_data_for_qwen( f" per image, but got {num_pads} pads for {num_images} image(s)" " in total. Are you using a qwen tokenizer?") + # Ensure the number of tokens is at minimum the sequence length provided + if len(toks) < seq_len: + toks += [0] * (seq_len - len(toks)) + # Build the input images; width/height doesn't actually matter here since # the data will get resized and the # of tokens per image is constant image = Image.new("RGB", (224, 224), color=0)