From 58eaf78d3a6756733bbb43fefb603a2f1de46a47 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 11 Sep 2024 18:10:10 +0200 Subject: [PATCH 01/15] add --- docs/source/models/supported_models.rst | 5 + requirements-common.txt | 2 +- tests/models/test_pixtral.py | 46 +++ vllm/entrypoints/chat_utils.py | 2 +- vllm/model_executor/models/__init__.py | 2 + vllm/model_executor/models/pixtral.py | 517 ++++++++++++++++++++++++ vllm/transformers_utils/config.py | 25 +- 7 files changed, 590 insertions(+), 9 deletions(-) create mode 100644 tests/models/test_pixtral.py create mode 100644 vllm/model_executor/models/pixtral.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 29fa5d812deb..e8f2c0875a10 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -257,6 +257,11 @@ Multimodal Language Models - Audio\ :sup:`E+` - :code:`fixie-ai/ultravox-v0_3` - + * - :code:`PixtralForConditionalGeneration` + - Pixtral + - Image\ :sup:`E+` + - :code:`mistralai/Pixtral-12B-2409` + - | :sup:`E` Pre-computed embeddings can be inputted for this modality. | :sup:`+` Multiple items can be inputted per text prompt for this modality. diff --git a/requirements-common.txt b/requirements-common.txt index 49a290317f81..808f96281359 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -25,6 +25,6 @@ pyzmq msgspec gguf == 0.9.1 importlib_metadata -mistral_common >= 1.3.4 +mistral_common >= 1.4.0 pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 diff --git a/tests/models/test_pixtral.py b/tests/models/test_pixtral.py new file mode 100644 index 000000000000..ded4883badf2 --- /dev/null +++ b/tests/models/test_pixtral.py @@ -0,0 +1,46 @@ +"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling. + +Run `pytest tests/models/test_mistral.py`. +""" +import pytest +from vllm.sampling_params import SamplingParams + +MODELS = [ + # "mistralai/Pixtral-12B-2409" + "bullerwins/pixtral-12b-240910" +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + image_urls = ["https://picsum.photos/id/237/200/300", "https://picsum.photos/seed/picsum/200/300"] + expected = ["The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", "The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset."] + prompt = "Describe the image in one short sentence." + + sampling_params = SamplingParams(max_tokens=512, temperature=0.0) + + with vllm_runner(model, dtype=dtype, + tokenizer_mode="mistral") as vllm_model: + + tokenizer = vllm_model.model.llm_engine.tokenizer.tokenizer + + for i, image_url in enumerate(image_urls): + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": image_url}}] + }, + ] + + outputs = vllm_model.model.chat(messages, sampling_params=sampling_params) + assert outputs[0].outputs[0].text == expected[i] diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index a42ad81b3eef..e9c70a93eb19 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -148,7 +148,7 @@ def _placeholder_str(self, modality: ModalityStr, return f"<|image_{current_count}|>" if model_type == "minicpmv": return "(./)" - if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"): + if model_type in ("blip-2", "chatglm", "fuyu", "paligemma", "pixtral"): # These models do not use image tokens in the prompt return None if model_type == "qwen": diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index da907e8a7506..353b56e12526 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -90,6 +90,8 @@ "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "UltravoxModel": ("ultravox", "UltravoxModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), + "PixtralForConditionalGeneration": ("pixtral", + "PixtralForConditionalGeneration"), } _CONDITIONAL_GENERATION_MODELS = { "BartModel": ("bart", "BartForConditionalGeneration"), diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py new file mode 100644 index 000000000000..c6c71eaccc69 --- /dev/null +++ b/vllm/model_executor/models/pixtral.py @@ -0,0 +1,517 @@ +import math +from array import array +from dataclasses import dataclass, fields +from itertools import tee +from typing import (Iterable, List, Mapping, Optional, Tuple, Union) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mistral_common.protocol.instruct.messages import ImageChunk +from PIL import Image +from transformers import PretrainedConfig +from xformers.ops.fmha import memory_efficient_attention +from xformers.ops.fmha.attn_bias import BlockDiagonalMask + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SequenceData) + +from .interfaces import SupportsMultiModal +from .utils import (init_vllm_registered_model) + + +def get_max_pixtral_image_tokens(ctx: InputContext): + tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer, tokenizer_mode=ctx.model_config.tokenizer_mode) + mm_encoder = tokenizer.instruct.mm_encoder + + max_image_size = mm_encoder.mm_config.max_image_size + image_patch_size = mm_encoder.mm_config.image_patch_size + + return ((max_image_size // image_patch_size) ** 2) + +def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer, tokenizer_mode=ctx.model_config.tokenizer_mode) + mm_encoder = tokenizer.instruct.mm_encoder + + max_num_images_per_request = ctx.model_config.multimodal_config.limit_per_prompt.get("image", 1) + + # approximate image size + size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size) + + image = Image.new("RGB", (size, size), color=0) + img_chunk = ImageChunk(image=image) + + tokens = mm_encoder(img_chunk).tokens + token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE, tokens) + + seq_data = SequenceData(token_ids) + mm_data = {"image": max_num_images_per_request * [image]} + return seq_data, mm_data + +def input_mapper_for_pixtral(ctx: InputContext, data: object) -> MultiModalInputs: + """Maps the input data to its MultiModalInputs (if any). + + Args: + ctx: Context of the loaded model. + data: data potentially containing image/image embeddings to be mapped + to pixel_values in .forward() for a visual QWenLMHeadModel model. + + Returns: + MultiModalInputs containing the stacked normalized images tensor or + image embeddings. + """ + # Early exit if we have provided an image to a language only Qwen model + model_config = ctx.model_config + tokenizer = cached_get_tokenizer(model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode) + + data_list = data if isinstance(data, list) else [data] + + images = [] + for image_data in data_list: + image = ImageChunk(image=image_data) + encoding = tokenizer.instruct.mm_encoder(image) + image = torch.from_numpy(encoding.image).to(device="cuda", dtype=torch.float16) + images.append(image) + + return MultiModalInputs({"images": images}) + + +def merge_multimodal_embeddings( + input_ids: torch.Tensor, inputs_embeds: torch.Tensor, image_features: Optional[List[torch.Tensor]], image_id: int +) -> torch.Tensor: + text_locations = input_ids != image_id + image_locations = input_ids == image_id + + seq_len = input_ids.shape[0] + + N_txt = text_locations.sum().item() + _, D_txt = inputs_embeds.shape + N_img, D_img = image_features.shape + + assert ( + D_txt == D_img + ), f"Text features dim {D_txt} should be equal to image features dim {D_img}" + assert ( + seq_len == N_txt + N_img + ), f"seq_len {seq_len} should be equal to N_txt + N_img {(N_txt, N_img, image_locations.sum().item())}" + + inputs_embeds[image_locations, :] = image_features + return inputs_embeds + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral) +class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal): + + def __init__(self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + + dataclass_fields = {field.name for field in fields(VisionEncoderArgs)} + vision_args = {key: value for key, value in self.config.vision_config.to_dict().items() if key in dataclass_fields} + + self.vision_args = VisionEncoderArgs(**vision_args) + + # init MistralForCausalLM + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) + + self.vision_encoder = VisionTransformer(self.vision_args) + self.vision_language_adapter = VisionLanguageAdapter(self.vision_args, dim=config.text_config.hidden_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> SamplerOutput: + """Run forward pass for pixtral. + + TODO + + """ + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.vision_args.image_token_id) + + input_ids = None + else: + inputs_embeds = None + + + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + None, + inputs_embeds=inputs_embeds) + + return hidden_states + + + def _parse_and_validate_image_input(self, images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = None) -> Optional[List[torch.Tensor]]: + if images is None: + return None + + if isinstance(images, torch.Tensor): + # always take last images + images = [images[-1][i] for i in range(images.size(1))] + elif isinstance(images, list): + # always take last images + images = [images[-1][i] for i in range(len(images[0]))] + + return images + + + def _process_image_input(self, image_input: List[torch.Tensor]) -> torch.Tensor: + return self.vision_language_adapter(self.vision_encoder(image_input)) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]): + return weight[0].startswith("vision_encoder") + + def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]): + return weight[0].startswith("vision_language_adapter") + + def is_vision_weights(weight: Tuple[str, torch.Tensor]): + return is_vision_encoder_weights(weight) or is_vision_lang_adapter_weights(weight) + + llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee(weights, 3) + + # llm + llm_weights = filter(lambda x: not is_vision_weights(x), llm_weights) + self.language_model.load_weights(llm_weights) + + # vision encoder + vision_encoder_weights = filter(is_vision_encoder_weights, vision_encoder_weights) + vision_encoder_dict = dict(self.vision_encoder.named_parameters()) + for name, loaded_weight in vision_encoder_weights: + # cut 'vision_encoder.' + name = '.'.join(name.split(".")[1:]) + param = vision_encoder_dict[name] + + default_weight_loader(param, loaded_weight) + + # adapter + vision_lang_adapter_weights = filter(is_vision_lang_adapter_weights, vision_lang_adapter_weights) + vision_lang_adpter_dict = dict(self.vision_language_adapter.named_parameters()) + for name, loaded_weight in vision_lang_adapter_weights: + # cut 'vision_language_adapter.' + name = '.'.join(name.split(".")[1:]) + param = vision_lang_adpter_dict[name] + default_weight_loader(param, loaded_weight) + + +# Vision encoder +@dataclass +class VisionEncoderArgs: + hidden_size: int + num_channels: int + image_size: int + patch_size: int + intermediate_size: int + num_hidden_layers: int + num_attention_heads: int + rope_theta: float # for rope-2D + image_token_id: int + + +def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + freqs_cis: complex - (seq_len, head_dim / 2) + x: complex - (bsz, seq_len, head_dim / 2) + """ + ndim = x.ndim + assert ndim > 1 + assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( + freqs_cis.shape, + (x.shape[1], x.shape[-1]), + ) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def precompute_freqs_cis_2d( + dim: int, + height: int, + width: int, + theta: float, +) -> torch.Tensor: + """ + freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by + (height, width) position tuples + """ + # (dim / 2) frequency bases + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + + h = torch.arange(height, device=freqs.device) + w = torch.arange(width, device=freqs.device) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + freqs_2d = torch.cat( + [ + freqs_h[:, None, :].repeat(1, width, 1), + freqs_w[None, :, :].repeat(height, 1, 1), + ], + dim=-1, + ) + return torch.polar(torch.ones_like(freqs_2d), freqs_2d) + + +def apply_rotary_emb_vit( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + assert freqs_cis.dtype == torch.complex64 + freqs_cis = _reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class FeedForward(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + assert args.intermediate_size is not None + self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) + self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False) + self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class Attention(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.args = args + assert not args.hidden_size % args.num_attention_heads + self.n_heads = args.num_attention_heads + self.head_dim = args.hidden_size // args.num_attention_heads + + self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + + def forward( + self, + x: torch.Tensor, + mask: BlockDiagonalMask, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + batch, patches, _ = x.shape + + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = q.reshape(batch, patches, self.n_heads, self.head_dim) + k = k.reshape(batch, patches, self.n_heads, self.head_dim) + v = v.reshape(batch, patches, self.n_heads, self.head_dim) + + q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis) + out = memory_efficient_attention(q, k, v, attn_bias=mask) + out = out.reshape(batch, patches, self.n_heads * self.head_dim) + return self.wo(out) + + +class TransformerBlock(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.attention = Attention(args) + self.feed_forward = FeedForward(args) + self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5) + self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5) + + def forward( + self, + x: torch.Tensor, + mask: BlockDiagonalMask, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + r = self.attention.forward( + self.attention_norm(x), mask=mask, freqs_cis=freqs_cis + ) + h = x + r + r = self.feed_forward.forward(self.ffn_norm(h)) + out = h + r + return out + + +class Transformer(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.layers = torch.nn.ModuleList() + for _ in range(args.num_hidden_layers): + self.layers.append(TransformerBlock(args)) + + def forward( + self, + x: torch.Tensor, + mask: BlockDiagonalMask, + freqs_cis: Optional[torch.Tensor], + ) -> torch.Tensor: + for layer in self.layers: + x = layer(x, mask=mask, freqs_cis=freqs_cis) + return x + + +def position_meshgrid( + patch_embeds_list: list[torch.Tensor], +) -> torch.Tensor: + positions = torch.cat( + [ + torch.stack( + torch.meshgrid( + torch.arange(p.shape[-2]), + torch.arange(p.shape[-1]), + indexing="ij", + ), + dim=-1, + ).reshape(-1, 2) + for p in patch_embeds_list + ] + ) + return positions + + +class VisionTransformer(nn.Module): + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.args = args + self.patch_conv = nn.Conv2d( + in_channels=args.num_channels, + out_channels=args.hidden_size, + kernel_size=args.patch_size, + stride=args.patch_size, + bias=False, + ) + self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5) + self.transformer = Transformer(args) + + head_dim = self.args.hidden_size // self.args.num_attention_heads + assert head_dim % 2 == 0, "ROPE requires even head_dim" + self._freqs_cis: Optional[torch.Tensor] = None + + @property + def max_patches_per_side(self) -> int: + return self.args.image_size // self.args.patch_size + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.device: + return next(self.parameters()).dtype + + + @property + def freqs_cis(self) -> torch.Tensor: + if self._freqs_cis is None: + self._freqs_cis = precompute_freqs_cis_2d( + dim=self.args.hidden_size // self.args.num_attention_heads, + height=self.max_patches_per_side, + width=self.max_patches_per_side, + theta=self.args.rope_theta, + ) + + if self._freqs_cis.device != self.device: + self._freqs_cis = self._freqs_cis.to(device=self.device) + + return self._freqs_cis + + def forward( + self, + images: List[torch.Tensor], + ) -> torch.Tensor: + """ + Args: + images: list of N_img images of variable sizes, each of shape (C, H, W) + Returns: + image_features: tensor of token features for all tokens of all images of + shape (N_toks, D) + """ + # pass images through initial convolution independently + patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images] + + # flatten to a single sequence + patch_embeds = torch.cat( + [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1 + ) + patch_embeds = self.ln_pre(patch_embeds) + + # positional embeddings + positions = position_meshgrid(patch_embeds_list).to(self.device) + freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] + + # pass through Transformer with a block diagonal mask delimiting images + mask = BlockDiagonalMask.from_seqlens( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], + ) + out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) + + # remove batch dimension of the single sequence + return out.squeeze(0) + + +class VisionLanguageAdapter(nn.Module): + def __init__(self, args: VisionEncoderArgs, dim: int): + super().__init__() + assert isinstance(args, VisionEncoderArgs) + self.w_in = nn.Linear( + args.hidden_size, + dim, + bias=True, + ) + self.gelu = nn.GELU() + self.w_out = nn.Linear(dim, dim, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w_out(self.gelu(self.w_in(x))) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 13fcf6b91860..5ad6f6802d04 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -70,7 +70,7 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision, if Path(model).exists(): return (Path(model) / config_name).is_file() - return file_exists(model, HF_CONFIG_NAME, revision=revision, token=token) + return file_exists(model, config_name, revision=revision, token=token) def get_config( @@ -205,14 +205,25 @@ def recurse_elems(elem: Any): config_dict["hidden_act"] = config_dict.get("activation", "silu") config_dict["tie_word_embeddings"] = config_dict.get( "tie_embeddings", False) + config_dict["max_seq_len"] = config_dict.get("max_seq_len", 128_000) - if config_dict["model_type"] == "transformer": - if "moe" in config_dict: - config_dict["architectures"] = ["MixtralForCausalLM"] - else: - config_dict["architectures"] = ["MistralForCausalLM"] + if config_dict.get("moe") is not None: + config_dict["architectures"] = ["MixtralForCausalLM"] + else: + config_dict["architectures"] = ["MistralForCausalLM"] + + if config_dict.get("vision_encoder") is not None: + multimodal_config = config_dict.pop("vision_encoder") - return recurse_elems(config_dict) + config_dict = { + "text_config": config_dict, + "vision_config": multimodal_config + } + config_dict["architectures"] = ["PixtralForConditionalGeneration"] + config_dict["model_type"] = "pixtral" + + config = recurse_elems(config_dict) + return config def get_hf_image_processor_config( From a10f47b45590b4ef0ec7ca484704e1d387a6d89d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 11 Sep 2024 18:17:46 +0200 Subject: [PATCH 02/15] add example --- examples/offline_inference_pixtral.py | 137 ++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 examples/offline_inference_pixtral.py diff --git a/examples/offline_inference_pixtral.py b/examples/offline_inference_pixtral.py new file mode 100644 index 000000000000..feeb99d2a48b --- /dev/null +++ b/examples/offline_inference_pixtral.py @@ -0,0 +1,137 @@ +from vllm import LLM +from vllm.sampling_params import SamplingParams +import argparse + +""" +This script is an offline demo for running Pixtral. + +If you want to run a server/client setup, please follow this code: + +- Server: + +```bash +vllm serve mistralai/Pixtral-12B-2409 --tokenizer_mode mistral --limit_mm_per_prompt 'image=4' --max_num_batched_tokens 16384 +``` + +- Client: + +```bash +curl --location ':8000/v1/chat/completions' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: Bearer token' \ +--data '{ + "model": "mistralai/Pixtral-12B-2409", + "messages": [ + { + "role": "user", + "content": [ + {"type" : "text", "text": "Describe this image in detail please."}, + {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}}, + {"type" : "text", "text": "and this one as well. Answer in French."}, + {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}} + ] + } + ] + }' +``` + +Usage: + python demo.py simple + python demo.py advanced +""" + + +def run_simple_demo(): + model_name = "mistralai/Pixtral-12B-0910" + sampling_params = SamplingParams(max_tokens=8192) + + llm = LLM(model=model_name, tokenizer_mode="mistral") + + prompt = "Describe this image in one sentence." + image_url = "https://picsum.photos/id/237/200/300" + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": image_url}}, + ], + }, + ] + outputs = llm.chat(messages, sampling_params=sampling_params) + + print(outputs[0].outputs[0].text) + + +def run_advanced_demo(): + model_name = "mistralai/Pixtral-12B-0910" + max_img_per_msg = 5 + max_tokens_per_img = 4096 + + sampling_params = SamplingParams(max_tokens=8192, temperature=0.7) + llm = LLM( + model=model_name, + tokenizer_mode="mistral", + limit_mm_per_prompt={"image": max_img_per_msg}, + max_num_batched_tokens=max_img_per_msg * max_tokens_per_img, + ) + + prompt = "Describe the following image." + + url_1 = "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png" + url_2 = "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/slack.png" + url_3 = "https://picsum.photos/200/300" + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": url_1}}, + ], + }, + { + "role": "assistant", + "content": "The image shows nature.", + }, + { + "role": "user", + "content": "In more detail for all the following images.", + }, + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": url_2}}, + {"type": "image_url", "image_url": {"url": url_3}}, + ], + }, + ] + + outputs = llm.chat(messages=messages, sampling_params=sampling_params) + print(outputs[0].outputs[0].text) + + +def main(): + parser = argparse.ArgumentParser( + description="Run a demo in simple or advanced mode." + ) + + parser.add_argument( + "mode", + choices=["simple", "advanced"], + help="Specify the demo mode: 'simple' or 'advanced'", + ) + + args = parser.parse_args() + + if args.mode == "simple": + print("Running simple demo...") + run_simple_demo() + elif args.mode == "advanced": + print("Running advanced demo...") + run_advanced_demo() + + +if __name__ == "__main__": + main() From 7c9681d46dc9285cf4f789735cead3b6faae3d3c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 11 Sep 2024 18:21:20 +0200 Subject: [PATCH 03/15] format --- examples/offline_inference_pixtral.py | 54 +++++++--- tests/models/test_pixtral.py | 28 +++-- vllm/entrypoints/chat_utils.py | 3 +- vllm/model_executor/models/pixtral.py | 150 ++++++++++++++++---------- 4 files changed, 158 insertions(+), 77 deletions(-) diff --git a/examples/offline_inference_pixtral.py b/examples/offline_inference_pixtral.py index feeb99d2a48b..a20294e210bd 100644 --- a/examples/offline_inference_pixtral.py +++ b/examples/offline_inference_pixtral.py @@ -1,6 +1,8 @@ +# ruff: noqa +import argparse + from vllm import LLM from vllm.sampling_params import SamplingParams -import argparse """ This script is an offline demo for running Pixtral. @@ -52,10 +54,19 @@ def run_simple_demo(): messages = [ { - "role": "user", + "role": + "user", "content": [ - {"type": "text", "text": prompt}, - {"type": "image_url", "image_url": {"url": image_url}}, + { + "type": "text", + "text": prompt + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, ], }, ] @@ -85,10 +96,19 @@ def run_advanced_demo(): messages = [ { - "role": "user", + "role": + "user", "content": [ - {"type": "text", "text": prompt}, - {"type": "image_url", "image_url": {"url": url_1}}, + { + "type": "text", + "text": prompt + }, + { + "type": "image_url", + "image_url": { + "url": url_1 + } + }, ], }, { @@ -100,10 +120,21 @@ def run_advanced_demo(): "content": "In more detail for all the following images.", }, { - "role": "user", + "role": + "user", "content": [ - {"type": "image_url", "image_url": {"url": url_2}}, - {"type": "image_url", "image_url": {"url": url_3}}, + { + "type": "image_url", + "image_url": { + "url": url_2 + } + }, + { + "type": "image_url", + "image_url": { + "url": url_3 + } + }, ], }, ] @@ -114,8 +145,7 @@ def run_advanced_demo(): def main(): parser = argparse.ArgumentParser( - description="Run a demo in simple or advanced mode." - ) + description="Run a demo in simple or advanced mode.") parser.add_argument( "mode", diff --git a/tests/models/test_pixtral.py b/tests/models/test_pixtral.py index ded4883badf2..7be98adabecd 100644 --- a/tests/models/test_pixtral.py +++ b/tests/models/test_pixtral.py @@ -3,6 +3,7 @@ Run `pytest tests/models/test_mistral.py`. """ import pytest + from vllm.sampling_params import SamplingParams MODELS = [ @@ -23,8 +24,14 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: - image_urls = ["https://picsum.photos/id/237/200/300", "https://picsum.photos/seed/picsum/200/300"] - expected = ["The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", "The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset."] + image_urls = [ + "https://picsum.photos/id/237/200/300", + "https://picsum.photos/seed/picsum/200/300" + ] + expected = [ + "The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", # noqa + "The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset." # noqa + ] prompt = "Describe the image in one short sentence." sampling_params = SamplingParams(max_tokens=512, temperature=0.0) @@ -32,15 +39,24 @@ def test_models( with vllm_runner(model, dtype=dtype, tokenizer_mode="mistral") as vllm_model: - tokenizer = vllm_model.model.llm_engine.tokenizer.tokenizer for i, image_url in enumerate(image_urls): messages = [ { - "role": "user", - "content": [{"type": "text", "text": prompt}, {"type": "image_url", "image_url": {"url": image_url}}] + "role": + "user", + "content": [{ + "type": "text", + "text": prompt + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }] }, ] - outputs = vllm_model.model.chat(messages, sampling_params=sampling_params) + outputs = vllm_model.model.chat(messages, + sampling_params=sampling_params) assert outputs[0].outputs[0].text == expected[i] diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index e9c70a93eb19..943775451d5d 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -148,7 +148,8 @@ def _placeholder_str(self, modality: ModalityStr, return f"<|image_{current_count}|>" if model_type == "minicpmv": return "(./)" - if model_type in ("blip-2", "chatglm", "fuyu", "paligemma", "pixtral"): + if model_type in ("blip-2", "chatglm", "fuyu", "paligemma", + "pixtral"): # These models do not use image tokens in the prompt return None if model_type == "qwen": diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index c6c71eaccc69..d5d8bba7c19d 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -2,7 +2,7 @@ from array import array from dataclasses import dataclass, fields from itertools import tee -from typing import (Iterable, List, Mapping, Optional, Tuple, Union) +from typing import Iterable, List, Mapping, Optional, Tuple, Union import torch import torch.nn as nn @@ -16,9 +16,9 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext +from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -28,24 +28,30 @@ SequenceData) from .interfaces import SupportsMultiModal -from .utils import (init_vllm_registered_model) +from .utils import init_vllm_registered_model def get_max_pixtral_image_tokens(ctx: InputContext): - tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer, tokenizer_mode=ctx.model_config.tokenizer_mode) + tokenizer = cached_get_tokenizer( + ctx.model_config.tokenizer, + tokenizer_mode=ctx.model_config.tokenizer_mode) mm_encoder = tokenizer.instruct.mm_encoder max_image_size = mm_encoder.mm_config.max_image_size image_patch_size = mm_encoder.mm_config.image_patch_size - return ((max_image_size // image_patch_size) ** 2) + return ((max_image_size // image_patch_size)**2) + def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer, tokenizer_mode=ctx.model_config.tokenizer_mode) + mm_counts: Mapping[str, int]): + tokenizer = cached_get_tokenizer( + ctx.model_config.tokenizer, + tokenizer_mode=ctx.model_config.tokenizer_mode) mm_encoder = tokenizer.instruct.mm_encoder - max_num_images_per_request = ctx.model_config.multimodal_config.limit_per_prompt.get("image", 1) + max_num_images_per_request = ctx.model_config.multimodal_config.limit_per_prompt.get( + "image", 1) # approximate image size size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size) @@ -54,13 +60,16 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, img_chunk = ImageChunk(image=image) tokens = mm_encoder(img_chunk).tokens - token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE, tokens) + token_ids = max_num_images_per_request * array(VLLM_TOKEN_ID_ARRAY_TYPE, + tokens) seq_data = SequenceData(token_ids) mm_data = {"image": max_num_images_per_request * [image]} return seq_data, mm_data -def input_mapper_for_pixtral(ctx: InputContext, data: object) -> MultiModalInputs: + +def input_mapper_for_pixtral(ctx: InputContext, + data: object) -> MultiModalInputs: """Maps the input data to its MultiModalInputs (if any). Args: @@ -74,7 +83,8 @@ def input_mapper_for_pixtral(ctx: InputContext, data: object) -> MultiModalInput """ # Early exit if we have provided an image to a language only Qwen model model_config = ctx.model_config - tokenizer = cached_get_tokenizer(model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode) + tokenizer = cached_get_tokenizer( + model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode) data_list = data if isinstance(data, list) else [data] @@ -82,15 +92,17 @@ def input_mapper_for_pixtral(ctx: InputContext, data: object) -> MultiModalInput for image_data in data_list: image = ImageChunk(image=image_data) encoding = tokenizer.instruct.mm_encoder(image) - image = torch.from_numpy(encoding.image).to(device="cuda", dtype=torch.float16) + image = torch.from_numpy(encoding.image).to(device="cuda", + dtype=torch.float16) images.append(image) return MultiModalInputs({"images": images}) -def merge_multimodal_embeddings( - input_ids: torch.Tensor, inputs_embeds: torch.Tensor, image_features: Optional[List[torch.Tensor]], image_id: int -) -> torch.Tensor: +def merge_multimodal_embeddings(input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + image_features: Optional[List[torch.Tensor]], + image_id: int) -> torch.Tensor: text_locations = input_ids != image_id image_locations = input_ids == image_id @@ -127,7 +139,11 @@ def __init__(self, self.multimodal_config = multimodal_config dataclass_fields = {field.name for field in fields(VisionEncoderArgs)} - vision_args = {key: value for key, value in self.config.vision_config.to_dict().items() if key in dataclass_fields} + vision_args = { + key: value + for key, value in self.config.vision_config.to_dict().items() + if key in dataclass_fields + } self.vision_args = VisionEncoderArgs(**vision_args) @@ -136,7 +152,8 @@ def __init__(self, config.text_config, cache_config, quant_config) self.vision_encoder = VisionTransformer(self.vision_args) - self.vision_language_adapter = VisionLanguageAdapter(self.vision_args, dim=config.text_config.hidden_size) + self.vision_language_adapter = VisionLanguageAdapter( + self.vision_args, dim=config.text_config.hidden_size) def forward( self, @@ -167,7 +184,6 @@ def forward( else: inputs_embeds = None - hidden_states = self.language_model.model(input_ids, positions, kv_caches, @@ -177,8 +193,11 @@ def forward( return hidden_states - - def _parse_and_validate_image_input(self, images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor]] = None) -> Optional[List[torch.Tensor]]: + def _parse_and_validate_image_input( + self, + images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], + torch.Tensor]] = None + ) -> Optional[List[torch.Tensor]]: if images is None: return None @@ -190,9 +209,9 @@ def _parse_and_validate_image_input(self, images: Optional[Union[List[List[torch images = [images[-1][i] for i in range(len(images[0]))] return images - - def _process_image_input(self, image_input: List[torch.Tensor]) -> torch.Tensor: + def _process_image_input(self, + image_input: List[torch.Tensor]) -> torch.Tensor: return self.vision_language_adapter(self.vision_encoder(image_input)) def compute_logits( @@ -211,6 +230,7 @@ def sample( return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]): return weight[0].startswith("vision_encoder") @@ -218,16 +238,19 @@ def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]): return weight[0].startswith("vision_language_adapter") def is_vision_weights(weight: Tuple[str, torch.Tensor]): - return is_vision_encoder_weights(weight) or is_vision_lang_adapter_weights(weight) + return is_vision_encoder_weights( + weight) or is_vision_lang_adapter_weights(weight) - llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee(weights, 3) + llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee( + weights, 3) # llm llm_weights = filter(lambda x: not is_vision_weights(x), llm_weights) self.language_model.load_weights(llm_weights) # vision encoder - vision_encoder_weights = filter(is_vision_encoder_weights, vision_encoder_weights) + vision_encoder_weights = filter(is_vision_encoder_weights, + vision_encoder_weights) vision_encoder_dict = dict(self.vision_encoder.named_parameters()) for name, loaded_weight in vision_encoder_weights: # cut 'vision_encoder.' @@ -237,8 +260,10 @@ def is_vision_weights(weight: Tuple[str, torch.Tensor]): default_weight_loader(param, loaded_weight) # adapter - vision_lang_adapter_weights = filter(is_vision_lang_adapter_weights, vision_lang_adapter_weights) - vision_lang_adpter_dict = dict(self.vision_language_adapter.named_parameters()) + vision_lang_adapter_weights = filter(is_vision_lang_adapter_weights, + vision_lang_adapter_weights) + vision_lang_adpter_dict = dict( + self.vision_language_adapter.named_parameters()) for name, loaded_weight in vision_lang_adapter_weights: # cut 'vision_language_adapter.' name = '.'.join(name.split(".")[1:]) @@ -260,7 +285,8 @@ class VisionEncoderArgs: image_token_id: int -def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: +def _reshape_for_broadcast(freqs_cis: torch.Tensor, + x: torch.Tensor) -> torch.Tensor: """ freqs_cis: complex - (seq_len, head_dim / 2) x: complex - (bsz, seq_len, head_dim / 2) @@ -271,7 +297,9 @@ def _reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Te freqs_cis.shape, (x.shape[1], x.shape[-1]), ) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + shape = [ + d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape) + ] return freqs_cis.view(*shape) @@ -286,7 +314,7 @@ def precompute_freqs_cis_2d( (height, width) position tuples """ # (dim / 2) frequency bases - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim)) h = torch.arange(height, device=freqs.device) w = torch.arange(width, device=freqs.device) @@ -318,18 +346,26 @@ def apply_rotary_emb_vit( class FeedForward(nn.Module): + def __init__(self, args: VisionEncoderArgs): super().__init__() assert args.intermediate_size is not None - self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) - self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False) - self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) + self.w1 = nn.Linear(args.hidden_size, + args.intermediate_size, + bias=False) + self.w2 = nn.Linear(args.intermediate_size, + args.hidden_size, + bias=False) + self.w3 = nn.Linear(args.hidden_size, + args.intermediate_size, + bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) class Attention(nn.Module): + def __init__(self, args: VisionEncoderArgs): super().__init__() self.args = args @@ -362,6 +398,7 @@ def forward( class TransformerBlock(nn.Module): + def __init__(self, args: VisionEncoderArgs): super().__init__() self.attention = Attention(args) @@ -375,9 +412,9 @@ def forward( mask: BlockDiagonalMask, freqs_cis: torch.Tensor, ) -> torch.Tensor: - r = self.attention.forward( - self.attention_norm(x), mask=mask, freqs_cis=freqs_cis - ) + r = self.attention.forward(self.attention_norm(x), + mask=mask, + freqs_cis=freqs_cis) h = x + r r = self.feed_forward.forward(self.ffn_norm(h)) out = h + r @@ -385,6 +422,7 @@ def forward( class Transformer(nn.Module): + def __init__(self, args: VisionEncoderArgs): super().__init__() self.layers = torch.nn.ModuleList() @@ -402,26 +440,22 @@ def forward( return x -def position_meshgrid( - patch_embeds_list: list[torch.Tensor], -) -> torch.Tensor: - positions = torch.cat( - [ - torch.stack( - torch.meshgrid( - torch.arange(p.shape[-2]), - torch.arange(p.shape[-1]), - indexing="ij", - ), - dim=-1, - ).reshape(-1, 2) - for p in patch_embeds_list - ] - ) +def position_meshgrid(patch_embeds_list: list[torch.Tensor], ) -> torch.Tensor: + positions = torch.cat([ + torch.stack( + torch.meshgrid( + torch.arange(p.shape[-2]), + torch.arange(p.shape[-1]), + indexing="ij", + ), + dim=-1, + ).reshape(-1, 2) for p in patch_embeds_list + ]) return positions class VisionTransformer(nn.Module): + def __init__(self, args: VisionEncoderArgs): super().__init__() self.args = args @@ -451,7 +485,6 @@ def device(self) -> torch.device: def dtype(self) -> torch.device: return next(self.parameters()).dtype - @property def freqs_cis(self) -> torch.Tensor: if self._freqs_cis is None: @@ -479,12 +512,13 @@ def forward( shape (N_toks, D) """ # pass images through initial convolution independently - patch_embeds_list = [self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images] + patch_embeds_list = [ + self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images + ] # flatten to a single sequence patch_embeds = torch.cat( - [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1 - ) + [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1) patch_embeds = self.ln_pre(patch_embeds) # positional embeddings @@ -493,8 +527,7 @@ def forward( # pass through Transformer with a block diagonal mask delimiting images mask = BlockDiagonalMask.from_seqlens( - [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], - ) + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) # remove batch dimension of the single sequence @@ -502,6 +535,7 @@ def forward( class VisionLanguageAdapter(nn.Module): + def __init__(self, args: VisionEncoderArgs, dim: int): super().__init__() assert isinstance(args, VisionEncoderArgs) From c4ffaa46caacf38a0f1cdeb92c825c969cfc98fd Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 11 Sep 2024 18:21:34 +0200 Subject: [PATCH 04/15] format --- examples/offline_inference_pixtral.py | 1 - tests/models/test_pixtral.py | 1 - 2 files changed, 2 deletions(-) diff --git a/examples/offline_inference_pixtral.py b/examples/offline_inference_pixtral.py index a20294e210bd..dc0eb5eb1773 100644 --- a/examples/offline_inference_pixtral.py +++ b/examples/offline_inference_pixtral.py @@ -3,7 +3,6 @@ from vllm import LLM from vllm.sampling_params import SamplingParams - """ This script is an offline demo for running Pixtral. diff --git a/tests/models/test_pixtral.py b/tests/models/test_pixtral.py index 7be98adabecd..d85b811a047c 100644 --- a/tests/models/test_pixtral.py +++ b/tests/models/test_pixtral.py @@ -39,7 +39,6 @@ def test_models( with vllm_runner(model, dtype=dtype, tokenizer_mode="mistral") as vllm_model: - for i, image_url in enumerate(image_urls): messages = [ { From 5c2391d5a3513338c2ad4d81cb1042046da693c8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 11 Sep 2024 18:23:39 +0200 Subject: [PATCH 05/15] format --- vllm/model_executor/models/pixtral.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index d5d8bba7c19d..21b3c0aa937e 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -50,7 +50,8 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, tokenizer_mode=ctx.model_config.tokenizer_mode) mm_encoder = tokenizer.instruct.mm_encoder - max_num_images_per_request = ctx.model_config.multimodal_config.limit_per_prompt.get( + mm_config = ctx.model_config.multimodal_config + max_num_images_per_request = mm_config.limit_per_prompt.get( "image", 1) # approximate image size @@ -114,10 +115,16 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor, assert ( D_txt == D_img - ), f"Text features dim {D_txt} should be equal to image features dim {D_img}" + ), ( + f"Text features dim {D_txt} should be equal " + "to image features dim {D_img}" + ) assert ( seq_len == N_txt + N_img - ), f"seq_len {seq_len} should be equal to N_txt + N_img {(N_txt, N_img, image_locations.sum().item())}" + ), ( + f"seq_len {seq_len} should be equal to N_txt + N_img " + f"{(N_txt, N_img, image_locations.sum().item())}" + ) inputs_embeds[image_locations, :] = image_features return inputs_embeds @@ -310,8 +317,8 @@ def precompute_freqs_cis_2d( theta: float, ) -> torch.Tensor: """ - freqs_cis: 2D complex tensor of shape (height, width, dim // 2) to be indexed by - (height, width) position tuples + freqs_cis: 2D complex tensor of shape (height, width, dim // 2) + to be indexed by (height, width) position tuples """ # (dim / 2) frequency bases freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim)) @@ -506,10 +513,11 @@ def forward( ) -> torch.Tensor: """ Args: - images: list of N_img images of variable sizes, each of shape (C, H, W) + images: list of N_img images of variable sizes, + each of shape (C, H, W) Returns: - image_features: tensor of token features for all tokens of all images of - shape (N_toks, D) + image_features: tensor of token features for + all tokens of all images of shape (N_toks, D) """ # pass images through initial convolution independently patch_embeds_list = [ From 9722cd9d7cdb9cec543b4982a8fcb4a93571009b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 11 Sep 2024 16:53:47 +0000 Subject: [PATCH 06/15] Better examples --- examples/offline_inference_pixtral.py | 24 ++++++++++++------------ tests/models/test_pixtral.py | 5 +---- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/examples/offline_inference_pixtral.py b/examples/offline_inference_pixtral.py index dc0eb5eb1773..ff954e3aad31 100644 --- a/examples/offline_inference_pixtral.py +++ b/examples/offline_inference_pixtral.py @@ -43,7 +43,7 @@ def run_simple_demo(): - model_name = "mistralai/Pixtral-12B-0910" + model_name = "mistralai/Pixtral-12B-2409" sampling_params = SamplingParams(max_tokens=8192) llm = LLM(model=model_name, tokenizer_mode="mistral") @@ -75,7 +75,7 @@ def run_simple_demo(): def run_advanced_demo(): - model_name = "mistralai/Pixtral-12B-0910" + model_name = "mistralai/Pixtral-12B-2409" max_img_per_msg = 5 max_tokens_per_img = 4096 @@ -90,8 +90,8 @@ def run_advanced_demo(): prompt = "Describe the following image." url_1 = "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png" - url_2 = "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/slack.png" - url_3 = "https://picsum.photos/200/300" + url_2 = "https://picsum.photos/seed/picsum/200/300" + url_3 = "https://picsum.photos/id/32/512/512" messages = [ { @@ -108,26 +108,26 @@ def run_advanced_demo(): "url": url_1 } }, + { + "type": "image_url", + "image_url": { + "url": url_2 + } + }, ], }, { "role": "assistant", - "content": "The image shows nature.", + "content": "The images show nature.", }, { "role": "user", - "content": "In more detail for all the following images.", + "content": "More details please and answer only in French!.", }, { "role": "user", "content": [ - { - "type": "image_url", - "image_url": { - "url": url_2 - } - }, { "type": "image_url", "image_url": { diff --git a/tests/models/test_pixtral.py b/tests/models/test_pixtral.py index d85b811a047c..a2aa217ee31d 100644 --- a/tests/models/test_pixtral.py +++ b/tests/models/test_pixtral.py @@ -6,10 +6,7 @@ from vllm.sampling_params import SamplingParams -MODELS = [ - # "mistralai/Pixtral-12B-2409" - "bullerwins/pixtral-12b-240910" -] +MODELS = ["mistralai/Pixtral-12B-2409"] @pytest.mark.parametrize("model", MODELS) From b697fa384d3918282657d876f628d0e194cda19a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 11 Sep 2024 16:59:53 +0000 Subject: [PATCH 07/15] WIP --- examples/offline_inference_pixtral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference_pixtral.py b/examples/offline_inference_pixtral.py index ff954e3aad31..d88af0e7da28 100644 --- a/examples/offline_inference_pixtral.py +++ b/examples/offline_inference_pixtral.py @@ -17,7 +17,7 @@ - Client: ```bash -curl --location ':8000/v1/chat/completions' \ +curl --location 'http://:8000/v1/chat/completions' \ --header 'Content-Type: application/json' \ --header 'Authorization: Bearer token' \ --data '{ From dcdcb42ccb8fd196e3d20e7fffa90384f27344e5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 11 Sep 2024 19:00:35 +0200 Subject: [PATCH 08/15] last format --- examples/offline_inference_pixtral.py | 4 ++-- vllm/model_executor/models/pixtral.py | 20 ++++++-------------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/examples/offline_inference_pixtral.py b/examples/offline_inference_pixtral.py index d88af0e7da28..d48b60178ce5 100644 --- a/examples/offline_inference_pixtral.py +++ b/examples/offline_inference_pixtral.py @@ -3,6 +3,7 @@ from vllm import LLM from vllm.sampling_params import SamplingParams + """ This script is an offline demo for running Pixtral. @@ -125,8 +126,7 @@ def run_advanced_demo(): "content": "More details please and answer only in French!.", }, { - "role": - "user", + "role": "user", "content": [ { "type": "image_url", diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 21b3c0aa937e..010cf85f45e0 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -51,8 +51,7 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, mm_encoder = tokenizer.instruct.mm_encoder mm_config = ctx.model_config.multimodal_config - max_num_images_per_request = mm_config.limit_per_prompt.get( - "image", 1) + max_num_images_per_request = mm_config.limit_per_prompt.get("image", 1) # approximate image size size = int(math.sqrt(seq_len) * mm_encoder.mm_config.image_patch_size) @@ -113,18 +112,11 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor, _, D_txt = inputs_embeds.shape N_img, D_img = image_features.shape - assert ( - D_txt == D_img - ), ( - f"Text features dim {D_txt} should be equal " - "to image features dim {D_img}" - ) - assert ( - seq_len == N_txt + N_img - ), ( - f"seq_len {seq_len} should be equal to N_txt + N_img " - f"{(N_txt, N_img, image_locations.sum().item())}" - ) + assert (D_txt == D_img), (f"Text features dim {D_txt} should be equal " + "to image features dim {D_img}") + assert (seq_len == N_txt + + N_img), (f"seq_len {seq_len} should be equal to N_txt + N_img " + f"{(N_txt, N_img, image_locations.sum().item())}") inputs_embeds[image_locations, :] = image_features return inputs_embeds From bc63222aa5c9a09b5072a2f09afcafd012bc81f1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 11 Sep 2024 20:43:23 +0200 Subject: [PATCH 09/15] I <3 yapf --- examples/offline_inference_pixtral.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/offline_inference_pixtral.py b/examples/offline_inference_pixtral.py index d48b60178ce5..64357018b19b 100644 --- a/examples/offline_inference_pixtral.py +++ b/examples/offline_inference_pixtral.py @@ -3,7 +3,6 @@ from vllm import LLM from vllm.sampling_params import SamplingParams - """ This script is an offline demo for running Pixtral. From f0581f490f9107c09c2bde204b2ce14acead6400 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 11 Sep 2024 20:44:28 +0200 Subject: [PATCH 10/15] finish --- examples/offline_inference_pixtral.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/offline_inference_pixtral.py b/examples/offline_inference_pixtral.py index 64357018b19b..d48b60178ce5 100644 --- a/examples/offline_inference_pixtral.py +++ b/examples/offline_inference_pixtral.py @@ -3,6 +3,7 @@ from vllm import LLM from vllm.sampling_params import SamplingParams + """ This script is an offline demo for running Pixtral. From f90cb0fec57bb619ebcfd4170e436216a927e4de Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 11 Sep 2024 20:45:45 +0200 Subject: [PATCH 11/15] again --- examples/offline_inference_pixtral.py | 72 +++++++++++++-------------- 1 file changed, 35 insertions(+), 37 deletions(-) diff --git a/examples/offline_inference_pixtral.py b/examples/offline_inference_pixtral.py index d48b60178ce5..738d890607e3 100644 --- a/examples/offline_inference_pixtral.py +++ b/examples/offline_inference_pixtral.py @@ -4,43 +4,41 @@ from vllm import LLM from vllm.sampling_params import SamplingParams -""" -This script is an offline demo for running Pixtral. - -If you want to run a server/client setup, please follow this code: - -- Server: - -```bash -vllm serve mistralai/Pixtral-12B-2409 --tokenizer_mode mistral --limit_mm_per_prompt 'image=4' --max_num_batched_tokens 16384 -``` - -- Client: - -```bash -curl --location 'http://:8000/v1/chat/completions' \ ---header 'Content-Type: application/json' \ ---header 'Authorization: Bearer token' \ ---data '{ - "model": "mistralai/Pixtral-12B-2409", - "messages": [ - { - "role": "user", - "content": [ - {"type" : "text", "text": "Describe this image in detail please."}, - {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}}, - {"type" : "text", "text": "and this one as well. Answer in French."}, - {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}} - ] - } - ] - }' -``` - -Usage: - python demo.py simple - python demo.py advanced -""" +# This script is an offline demo for running Pixtral. +# +# If you want to run a server/client setup, please follow this code: +# +# - Server: +# +# ```bash +# vllm serve mistralai/Pixtral-12B-2409 --tokenizer_mode mistral --limit_mm_per_prompt 'image=4' --max_num_batched_tokens 16384 +# ``` +# +# - Client: +# +# ```bash +# curl --location 'http://:8000/v1/chat/completions' \ +# --header 'Content-Type: application/json' \ +# --header 'Authorization: Bearer token' \ +# --data '{ +# "model": "mistralai/Pixtral-12B-2409", +# "messages": [ +# { +# "role": "user", +# "content": [ +# {"type" : "text", "text": "Describe this image in detail please."}, +# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}}, +# {"type" : "text", "text": "and this one as well. Answer in French."}, +# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}} +# ] +# } +# ] +# }' +# ``` +# +# Usage: +# python demo.py simple +# python demo.py advanced def run_simple_demo(): From 705006eb692c6c7012669b6bfa885c0253e8f09c Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 11 Sep 2024 12:18:34 -0700 Subject: [PATCH 12/15] move up pixtral --- docs/source/models/supported_models.rst | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 3dea2bbe2624..9404700da44c 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -247,6 +247,11 @@ Multimodal Language Models - Image\ :sup:`E+` - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - + * - :code:`PixtralForConditionalGeneration` + - Pixtral + - Image\ :sup:`E+` + - :code:`mistralai/Pixtral-12B-2409` + - * - :code:`QWenLMHeadModel` - Qwen-VL - Image\ :sup:`E` @@ -262,11 +267,6 @@ Multimodal Language Models - Audio\ :sup:`E+` - :code:`fixie-ai/ultravox-v0_3` - - * - :code:`PixtralForConditionalGeneration` - - Pixtral - - Image\ :sup:`E+` - - :code:`mistralai/Pixtral-12B-2409` - - | :sup:`E` Pre-computed embeddings can be inputted for this modality. | :sup:`+` Multiple items can be inputted per text prompt for this modality. From 2ec31bd6f402d62e64bae309b12a3a6b7d30106e Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 11 Sep 2024 12:19:48 -0700 Subject: [PATCH 13/15] remove embedding support --- docs/source/models/supported_models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 9404700da44c..be81c3883340 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -249,7 +249,7 @@ Multimodal Language Models - * - :code:`PixtralForConditionalGeneration` - Pixtral - - Image\ :sup:`E+` + - Image\ :sup:`+` - :code:`mistralai/Pixtral-12B-2409` - * - :code:`QWenLMHeadModel` From 1f47eac9510881fde38e6bad4f7559154262872a Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 11 Sep 2024 12:44:03 -0700 Subject: [PATCH 14/15] add vlm marker --- tests/models/test_pixtral.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/models/test_pixtral.py b/tests/models/test_pixtral.py index a2aa217ee31d..853201f78f7e 100644 --- a/tests/models/test_pixtral.py +++ b/tests/models/test_pixtral.py @@ -6,6 +6,8 @@ from vllm.sampling_params import SamplingParams +pytestmark = pytest.mark.vlm + MODELS = ["mistralai/Pixtral-12B-2409"] From 05da6f64a5cd8d36e07b825aabec9351f7bbb714 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Wed, 11 Sep 2024 14:38:44 -0700 Subject: [PATCH 15/15] patch --- tests/models/test_pixtral.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/test_pixtral.py b/tests/models/test_pixtral.py index 853201f78f7e..dc60cf7eae8b 100644 --- a/tests/models/test_pixtral.py +++ b/tests/models/test_pixtral.py @@ -11,6 +11,10 @@ MODELS = ["mistralai/Pixtral-12B-2409"] +@pytest.mark.skip( + reason= + "Model is too big, test passed on A100 locally but will OOM on CI machine." +) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64])