From f31d0f288e6d38760a2f35544390b0348e302d4a Mon Sep 17 00:00:00 2001 From: Kevin-Yang Date: Sun, 27 Oct 2024 00:45:37 +0900 Subject: [PATCH 01/15] classification compatible with debugginglogs. Signed-off-by: Kevin-Yang --- vllm/attention/backends/flash_attn.py | 3 + vllm/model_executor/models/qwen2_cls.py | 158 ++++++++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + vllm/worker/embedding_model_runner.py | 1 + 4 files changed, 163 insertions(+) create mode 100644 vllm/model_executor/models/qwen2_cls.py diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index ffa05e80623a..4cb29e3462d8 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -425,6 +425,9 @@ def build(self, seq_lens: List[int], query_lens: List[int], block_tables = self._get_graph_runner_block_tables( num_seqs, self.block_tables) else: + print(f"block tables: {self.block_tables}") + if self.block_tables[0] is None: + self.block_tables = [list() for _ in range(num_seqs)] block_tables = make_tensor_with_pad( self.block_tables, pad=0, diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py new file mode 100644 index 000000000000..cf4cb8467d5c --- /dev/null +++ b/vllm/model_executor/models/qwen2_cls.py @@ -0,0 +1,158 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +"""Inference-only Qwen2-RM model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import Qwen2Config + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.qwen2 import Qwen2Model +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + +from .utils import is_pp_missing_parameter + + +class ReLU(nn.Module): + + def __init__(self): + super().__init__() + self.activation = nn.ReLU() + + def forward(self, input): + input, _ = input + return self.activation(input) + + +class Qwen2ForSequenceClassification(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + # TODO (@robertgshaw2): see if this can be moved out + if (cache_config.sliding_window is not None + and hasattr(config, "max_window_layers")): + raise ValueError("Sliding window for some but all layers is not " + "supported. This model uses sliding window " + "but `max_window_layers` = %s is less than " + "`num_hidden_layers` = %s. Please open an issue " + "to discuss this feature." % ( + config.max_window_layers, + config.num_hidden_layers, + )) + + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen2Model(config, cache_config, quant_config) + + self.score = ColumnParallelLinear(config.hidden_size, + config.num_labels, + quant_config=quant_config) + self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + hidden_states = hidden_states[0] + logits, _ = self.score(hidden_states) + return logits + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + # Skip loading lm_head for embedding model + if name == "lm_head.weight": + continue + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 717615988a90..6c070604a191 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -96,6 +96,7 @@ "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"), "MistralModel": ("llama", "LlamaEmbeddingModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), + "Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # [Multimodal] "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index a7f5b2d4fdd1..a8d078daaf0d 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -173,6 +173,7 @@ def prepare_model_input( finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForGPUWithPoolingMetadata: assert seq_group_metadata_list is not None + print(f"seq_group_metadata_list: {seq_group_metadata_list}") model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) # Prepare PoolingMetadata. From 1bcdf25608deccb1a0d46517b9e84e288ddf1c50 Mon Sep 17 00:00:00 2001 From: Kevin-Yang Date: Sun, 27 Oct 2024 00:45:37 +0900 Subject: [PATCH 02/15] fixed prefill error Signed-off-by: Kevin-Yang --- vllm/attention/backends/flash_attn.py | 10 ++++++++-- vllm/model_executor/layers/pooler.py | 5 ++++- vllm/model_executor/models/qwen2_cls.py | 10 +++++++--- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 4cb29e3462d8..91e70ada2e59 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -349,6 +349,11 @@ def _add_seq_group( else: block_table = block_tables[seq_id][ -curr_sliding_window_block:] + + print(f"prefix cache hit: {prefix_cache_hit}") + print(f"chunked prefill enabled: {chunked_prefill_enabled}") + print(f"prompt: {is_prompt}") + print(f"block table: {block_table}") self.block_tables.append(block_table) # Compute slot mapping. @@ -400,6 +405,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], for inter_data in self.input_builder.inter_data_list ]) for inter_data in self.input_builder.inter_data_list: + print(f"inter_data: {inter_data}") self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled, prefix_cache_hit) @@ -426,8 +432,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_seqs, self.block_tables) else: print(f"block tables: {self.block_tables}") - if self.block_tables[0] is None: - self.block_tables = [list() for _ in range(num_seqs)] + # if self.block_tables[0] is None: + # self.block_tables = [list() for _ in range(num_seqs)] block_tables = make_tensor_with_pad( self.block_tables, pad=0, diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 3455a4ccf282..fcb18767aaf8 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -28,11 +28,12 @@ class Pooler(nn.Module): normalize: Whether to normalize the pooled data. """ - def __init__(self, pooling_type: PoolingType, normalize: bool): + def __init__(self, pooling_type: PoolingType, normalize: bool, softmax: bool = False): super().__init__() self.pooling_type = pooling_type self.normalize = normalize + self.softmax = softmax def forward( self, @@ -63,6 +64,8 @@ def forward( if self.normalize: pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) + if self.softmax: + pooled_data = nn.functional.softmax(pooled_data, dim=-1) pooled_outputs = [ EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index cf4cb8467d5c..5286b70b1cb3 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -85,12 +85,13 @@ def __init__( self.lora_config = lora_config self.quant_config = quant_config + print(f"config: {config}\ncache_config: {cache_config}\nquant_config: {quant_config}") self.model = Qwen2Model(config, cache_config, quant_config) - self.score = ColumnParallelLinear(config.hidden_size, + self.score = RowParallelLinear(config.hidden_size, config.num_labels, quant_config=quant_config) - self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False, softmax=True) def forward( self, @@ -100,10 +101,11 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: + print(f"{input_ids}\n{positions}\n{kv_caches}\n{attn_metadata}\n{intermediate_tensors}") hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) - hidden_states = hidden_states[0] logits, _ = self.score(hidden_states) + print(logits) return logits def pooler( @@ -135,6 +137,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: + print(f"bias is ignored: {name}") continue if is_pp_missing_parameter(name, self): continue @@ -145,6 +148,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: + print(f"bias is ignored: {name}") continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) From faff871d3074f3a4602aed7ac213ff0cf082023d Mon Sep 17 00:00:00 2001 From: Kevin-Yang Date: Sun, 27 Oct 2024 00:45:37 +0900 Subject: [PATCH 03/15] add test code Signed-off-by: Kevin-Yang --- tests/conftest.py | 32 ++++++++++- .../decoder_only/language/test_cls_models.py | 54 +++++++++++++++++++ 2 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 tests/models/decoder_only/language/test_cls_models.py diff --git a/tests/conftest.py b/tests/conftest.py index 6adff5e2328c..b18468c9a2dc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ import torch.nn.functional as F from huggingface_hub import snapshot_download from PIL import Image -from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding, +from transformers import (AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, BatchEncoding, BatchFeature) from transformers.models.auto.auto_factory import _BaseAutoModelClass @@ -277,6 +277,16 @@ def __init__( ).to(dtype=torch_dtype)) else: model_kwargs = model_kwargs if model_kwargs is not None else {} + config = AutoConfig.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + arch = config.architectures + if len(arch) > 0: + cls_type = arch[0].split("For")[-1] + auto_cls = eval(f"AutoModelFor{cls_type}") + self.model = self.wrap_device( auto_cls.from_pretrained( model_name, @@ -343,6 +353,18 @@ def get_inputs( return all_inputs + def classify(self, prompts: List[str]) -> List[str]: + # output is final logits + all_inputs = self.get_inputs(prompts) + outputs = [] + print(f"model: {self.model}") + for inputs in all_inputs: + output = self.model(**self.wrap_device(inputs)) + logits = output.logits.softmax(dim=-1)[0].tolist() + outputs.append(logits) + + return outputs + def generate( self, prompts: List[str], @@ -687,6 +709,14 @@ def get_inputs( inputs[i]["multi_modal_data"] = {"audio": audio} return inputs + + def classify(self, prompts: List[str]) -> List[str]: + req_outputs = self.model.encode(prompts) + outputs = [] + for req_output in req_outputs: + embedding = req_output.outputs.embedding + outputs.append(embedding) + return outputs def generate( self, diff --git a/tests/models/decoder_only/language/test_cls_models.py b/tests/models/decoder_only/language/test_cls_models.py new file mode 100644 index 000000000000..f2dc3c3c0d79 --- /dev/null +++ b/tests/models/decoder_only/language/test_cls_models.py @@ -0,0 +1,54 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling. + +This test only tests small models. Big models such as 7B should be tested from +test_big_models.py because it could use a larger instance to run tests. + +Run `pytest tests/models/test_models.py`. +""" +import pytest +import torch + +from ...utils import check_logprobs_close, check_outputs_equal + +CLASSIFICATION_MODELS = [ + "jason9693/Qwen2.5-1.5B-apeach" +] + + +@pytest.mark.parametrize("model", CLASSIFICATION_MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +def test_classification_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.classify(example_prompts) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.classify(example_prompts) + + print(hf_outputs, vllm_outputs) + + # check logits difference + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output) + vllm_output = torch.tensor(vllm_output) + + assert torch.allclose(hf_output, vllm_output, 1e-3) + + +@pytest.mark.parametrize("model", CLASSIFICATION_MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_classification_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, dtype=dtype) as vllm_model: + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) From 568c2f9ff9351c8cc8fa247909626696f2a69197 Mon Sep 17 00:00:00 2001 From: Kevin-Yang Date: Sun, 27 Oct 2024 00:45:37 +0900 Subject: [PATCH 04/15] remove unnecessary print and codes Signed-off-by: Kevin-Yang --- tests/models/decoder_only/language/test_cls_models.py | 4 +--- vllm/model_executor/models/qwen2_cls.py | 5 ----- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/tests/models/decoder_only/language/test_cls_models.py b/tests/models/decoder_only/language/test_cls_models.py index f2dc3c3c0d79..352f8a4e7418 100644 --- a/tests/models/decoder_only/language/test_cls_models.py +++ b/tests/models/decoder_only/language/test_cls_models.py @@ -8,15 +8,13 @@ import pytest import torch -from ...utils import check_logprobs_close, check_outputs_equal - CLASSIFICATION_MODELS = [ "jason9693/Qwen2.5-1.5B-apeach" ] @pytest.mark.parametrize("model", CLASSIFICATION_MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("dtype", ["float"]) def test_classification_models( hf_runner, vllm_runner, diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index 5286b70b1cb3..84b5e72e6220 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -85,7 +85,6 @@ def __init__( self.lora_config = lora_config self.quant_config = quant_config - print(f"config: {config}\ncache_config: {cache_config}\nquant_config: {quant_config}") self.model = Qwen2Model(config, cache_config, quant_config) self.score = RowParallelLinear(config.hidden_size, @@ -101,11 +100,9 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: - print(f"{input_ids}\n{positions}\n{kv_caches}\n{attn_metadata}\n{intermediate_tensors}") hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) logits, _ = self.score(hidden_states) - print(logits) return logits def pooler( @@ -137,7 +134,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: - print(f"bias is ignored: {name}") continue if is_pp_missing_parameter(name, self): continue @@ -148,7 +144,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: - print(f"bias is ignored: {name}") continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) From 20faf4678931ce0c60f9b9aea952518cc4e09897 Mon Sep 17 00:00:00 2001 From: kevin-us Date: Sun, 27 Oct 2024 00:45:38 +0900 Subject: [PATCH 05/15] remove unnecessary print, modifiied pooling logic. Signed-off-by: Kevin-Yang --- tests/conftest.py | 14 +++++++------- vllm/model_executor/layers/pooler.py | 5 +---- vllm/model_executor/models/qwen2_cls.py | 12 ++++++------ 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index b18468c9a2dc..0f5a6e29bad8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,8 +14,9 @@ import torch.nn.functional as F from huggingface_hub import snapshot_download from PIL import Image -from transformers import (AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, AutoConfig, BatchEncoding, - BatchFeature) +from transformers import (AutoModelForCausalLM, + AutoModelForSequenceClassification, AutoTokenizer, + AutoConfig, BatchEncoding, BatchFeature) from transformers.models.auto.auto_factory import _BaseAutoModelClass from tests.models.utils import (TokensTextLogprobs, @@ -282,11 +283,11 @@ def __init__( torch_dtype=torch_dtype, trust_remote_code=True, ) - arch = config.architectures + arch = config.architectures if len(arch) > 0: cls_type = arch[0].split("For")[-1] auto_cls = eval(f"AutoModelFor{cls_type}") - + self.model = self.wrap_device( auto_cls.from_pretrained( model_name, @@ -357,13 +358,12 @@ def classify(self, prompts: List[str]) -> List[str]: # output is final logits all_inputs = self.get_inputs(prompts) outputs = [] - print(f"model: {self.model}") for inputs in all_inputs: output = self.model(**self.wrap_device(inputs)) logits = output.logits.softmax(dim=-1)[0].tolist() outputs.append(logits) - return outputs + return outputs def generate( self, @@ -709,7 +709,7 @@ def get_inputs( inputs[i]["multi_modal_data"] = {"audio": audio} return inputs - + def classify(self, prompts: List[str]) -> List[str]: req_outputs = self.model.encode(prompts) outputs = [] diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index fcb18767aaf8..3455a4ccf282 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -28,12 +28,11 @@ class Pooler(nn.Module): normalize: Whether to normalize the pooled data. """ - def __init__(self, pooling_type: PoolingType, normalize: bool, softmax: bool = False): + def __init__(self, pooling_type: PoolingType, normalize: bool): super().__init__() self.pooling_type = pooling_type self.normalize = normalize - self.softmax = softmax def forward( self, @@ -64,8 +63,6 @@ def forward( if self.normalize: pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) - if self.softmax: - pooled_data = nn.functional.softmax(pooled_data, dim=-1) pooled_outputs = [ EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index 84b5e72e6220..f0a50a9543af 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -12,8 +12,7 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) @@ -88,9 +87,9 @@ def __init__( self.model = Qwen2Model(config, cache_config, quant_config) self.score = RowParallelLinear(config.hidden_size, - config.num_labels, - quant_config=quant_config) - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False, softmax=True) + config.num_labels, + quant_config=quant_config) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False) def forward( self, @@ -110,7 +109,8 @@ def pooler( hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Optional[PoolerOutput]: - return self._pooler(hidden_states, pooling_metadata) + pooled = self._pooler(hidden_states, pooling_metadata) + return nn.functional.softmax(pooled, dim=-1) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ From 4afa7e1f9250b0d1f92dfed7ae76d22897e2a0b2 Mon Sep 17 00:00:00 2001 From: Kevin-Yang Date: Sun, 27 Oct 2024 00:45:38 +0900 Subject: [PATCH 06/15] modified auto_cls logic, and lint check Signed-off-by: Kevin-Yang --- tests/conftest.py | 15 ++------------- .../decoder_only/language/test_cls_models.py | 9 +++++---- vllm/attention/backends/flash_attn.py | 2 +- vllm/model_executor/models/registry.py | 3 ++- 4 files changed, 10 insertions(+), 19 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0f5a6e29bad8..2fce2d772c6e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,9 +14,8 @@ import torch.nn.functional as F from huggingface_hub import snapshot_download from PIL import Image -from transformers import (AutoModelForCausalLM, - AutoModelForSequenceClassification, AutoTokenizer, - AutoConfig, BatchEncoding, BatchFeature) +from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding, + BatchFeature) from transformers.models.auto.auto_factory import _BaseAutoModelClass from tests.models.utils import (TokensTextLogprobs, @@ -278,16 +277,6 @@ def __init__( ).to(dtype=torch_dtype)) else: model_kwargs = model_kwargs if model_kwargs is not None else {} - config = AutoConfig.from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - ) - arch = config.architectures - if len(arch) > 0: - cls_type = arch[0].split("For")[-1] - auto_cls = eval(f"AutoModelFor{cls_type}") - self.model = self.wrap_device( auto_cls.from_pretrained( model_name, diff --git a/tests/models/decoder_only/language/test_cls_models.py b/tests/models/decoder_only/language/test_cls_models.py index 352f8a4e7418..daf6461b38dc 100644 --- a/tests/models/decoder_only/language/test_cls_models.py +++ b/tests/models/decoder_only/language/test_cls_models.py @@ -7,10 +7,9 @@ """ import pytest import torch +from transformers import AutoModelForSequenceClassification -CLASSIFICATION_MODELS = [ - "jason9693/Qwen2.5-1.5B-apeach" -] +CLASSIFICATION_MODELS = ["jason9693/Qwen2.5-1.5B-apeach"] @pytest.mark.parametrize("model", CLASSIFICATION_MODELS) @@ -22,7 +21,9 @@ def test_classification_models( model: str, dtype: str, ) -> None: - with hf_runner(model, dtype=dtype) as hf_model: + with hf_runner(model, + dtype=dtype, + auto_cls=AutoModelForSequenceClassification) as hf_model: hf_outputs = hf_model.classify(example_prompts) with vllm_runner(model, dtype=dtype) as vllm_model: diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 91e70ada2e59..2a8eaa8314b0 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -349,7 +349,7 @@ def _add_seq_group( else: block_table = block_tables[seq_id][ -curr_sliding_window_block:] - + print(f"prefix cache hit: {prefix_cache_hit}") print(f"chunked prefill enabled: {chunked_prefill_enabled}") print(f"prompt: {is_prompt}") diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 6c070604a191..f6713ab0898f 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -96,7 +96,8 @@ "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"), "MistralModel": ("llama", "LlamaEmbeddingModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), - "Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), + "Qwen2ForSequenceClassification": ( + "qwen2_cls", "Qwen2ForSequenceClassification"), # [Multimodal] "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), From 8546773efe7b1485903266164c0d2a19c00eeebd Mon Sep 17 00:00:00 2001 From: Kevin-Yang Date: Sun, 27 Oct 2024 00:45:38 +0900 Subject: [PATCH 07/15] remve unnecessary print Signed-off-by: Kevin-Yang --- vllm/attention/backends/flash_attn.py | 9 --------- vllm/worker/embedding_model_runner.py | 1 - 2 files changed, 10 deletions(-) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 2a8eaa8314b0..ffa05e80623a 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -349,11 +349,6 @@ def _add_seq_group( else: block_table = block_tables[seq_id][ -curr_sliding_window_block:] - - print(f"prefix cache hit: {prefix_cache_hit}") - print(f"chunked prefill enabled: {chunked_prefill_enabled}") - print(f"prompt: {is_prompt}") - print(f"block table: {block_table}") self.block_tables.append(block_table) # Compute slot mapping. @@ -405,7 +400,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], for inter_data in self.input_builder.inter_data_list ]) for inter_data in self.input_builder.inter_data_list: - print(f"inter_data: {inter_data}") self._add_seq_group(inter_data, self.input_builder.chunked_prefill_enabled, prefix_cache_hit) @@ -431,9 +425,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], block_tables = self._get_graph_runner_block_tables( num_seqs, self.block_tables) else: - print(f"block tables: {self.block_tables}") - # if self.block_tables[0] is None: - # self.block_tables = [list() for _ in range(num_seqs)] block_tables = make_tensor_with_pad( self.block_tables, pad=0, diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index a8d078daaf0d..a7f5b2d4fdd1 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -173,7 +173,6 @@ def prepare_model_input( finished_requests_ids: Optional[List[str]] = None ) -> ModelInputForGPUWithPoolingMetadata: assert seq_group_metadata_list is not None - print(f"seq_group_metadata_list: {seq_group_metadata_list}") model_input = self._prepare_model_input_tensors( seq_group_metadata_list, finished_requests_ids) # Prepare PoolingMetadata. From bde37c2b24d9868d31f388a31de2201c2bab0ad5 Mon Sep 17 00:00:00 2001 From: Kevin-Yang Date: Sun, 27 Oct 2024 00:45:38 +0900 Subject: [PATCH 08/15] make docstring accurate Signed-off-by: Kevin-Yang --- tests/models/decoder_only/language/test_cls_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/decoder_only/language/test_cls_models.py b/tests/models/decoder_only/language/test_cls_models.py index daf6461b38dc..d8ca6d361f0e 100644 --- a/tests/models/decoder_only/language/test_cls_models.py +++ b/tests/models/decoder_only/language/test_cls_models.py @@ -3,7 +3,7 @@ This test only tests small models. Big models such as 7B should be tested from test_big_models.py because it could use a larger instance to run tests. -Run `pytest tests/models/test_models.py`. +Run `pytest tests/models/test_cls_models.py`. """ import pytest import torch From 658176f4fa610df4dd81c8d79de624fde4152564 Mon Sep 17 00:00:00 2001 From: Kevin-Yang Date: Sun, 27 Oct 2024 00:45:38 +0900 Subject: [PATCH 09/15] add supported models Signed-off-by: Kevin-Yang --- docs/source/models/supported_models.rst | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 98d804052b57..ab43c12629b4 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -361,6 +361,28 @@ Reward Modeling .. note:: As an interim measure, these models are supported via Embeddings API. See `this RFC `_ for upcoming changes. +Classification +--------------- + +.. list-table:: + :widths: 25 25 50 5 5 + :header-rows: 1 + + * - Architecture + - Models + - Example HF Models + - :ref:`LoRA ` + - :ref:`PP ` + * - :code:`Qwen2ForSequenceClassification` + - Qwen2-based + - :code:`jason9693/Qwen2.5-1.5B-apeach`, etc. + - + - ✅︎ + +.. note:: + As an interim measure, these models are supported via Embeddings API. See `this RFC `_ for upcoming changes. + + Multimodal Language Models ^^^^^^^^^^^^^^^^^^^^^^^^^^ From f2ee1e2488a2dd691903398a5862c020e57a3fdd Mon Sep 17 00:00:00 2001 From: Kevin-Yang Date: Sun, 27 Oct 2024 00:45:38 +0900 Subject: [PATCH 10/15] modified docs Signed-off-by: Kevin-Yang --- docs/source/models/supported_models.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index ab43c12629b4..ff893b613f15 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -380,7 +380,7 @@ Classification - ✅︎ .. note:: - As an interim measure, these models are supported via Embeddings API. See `this RFC `_ for upcoming changes. + As an interim measure, these models are supported via Embeddings API. It will be supported via Classification API in the future (no reference APIs exist now). Multimodal Language Models From 18cf2690d378b4a9f89b6a9162f58e04e9d735ba Mon Sep 17 00:00:00 2001 From: Kevin-Yang Date: Sun, 27 Oct 2024 00:45:38 +0900 Subject: [PATCH 11/15] add AutoWeightsLoader loading Signed-off-by: Kevin-Yang --- vllm/model_executor/models/qwen2_cls.py | 49 ++----------------------- 1 file changed, 4 insertions(+), 45 deletions(-) diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index f0a50a9543af..f6aec7409b2c 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -16,13 +16,11 @@ from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput -from .utils import is_pp_missing_parameter +from .utils import AutoWeightsLoader class ReLU(nn.Module): @@ -113,45 +111,6 @@ def pooler( return nn.functional.softmax(pooled, dim=-1) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in weights: - # Skip loading lm_head for embedding model - if name == "lm_head.weight": - continue - if "rotary_emb.inv_freq" in name: - continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - if is_pp_missing_parameter(name, self): - continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + loader = AutoWeightsLoader(self, + ignore_unexpected_prefixes=["lm_head."]) + loader.load_weights(weights) From 65d3f50c572d6692936aa2b5f2238db6f5f01d9d Mon Sep 17 00:00:00 2001 From: Kevin-Yang Date: Sun, 27 Oct 2024 00:45:38 +0900 Subject: [PATCH 12/15] move test code under embedding Signed-off-by: Kevin-Yang --- .../{decoder_only => embedding}/language/test_cls_models.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/models/{decoder_only => embedding}/language/test_cls_models.py (100%) diff --git a/tests/models/decoder_only/language/test_cls_models.py b/tests/models/embedding/language/test_cls_models.py similarity index 100% rename from tests/models/decoder_only/language/test_cls_models.py rename to tests/models/embedding/language/test_cls_models.py From 3374de669845c545a11b9a57a8c6e7e6817609a6 Mon Sep 17 00:00:00 2001 From: Kevin-Yang Date: Sun, 27 Oct 2024 00:45:38 +0900 Subject: [PATCH 13/15] remove unnecessary code and update info Signed-off-by: Kevin-Yang --- vllm/model_executor/models/qwen2_cls.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index f6aec7409b2c..cb12cce3d478 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -1,9 +1,10 @@ # coding=utf-8 # Adapted from # https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py +# Copyright 2024 Kakao Corp(Kanana-X Team). # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. -"""Inference-only Qwen2-RM model compatible with HuggingFace weights.""" +"""Inference-only Qwen2-Classification model compatible with HuggingFace weights.""" from typing import Iterable, List, Optional, Tuple import torch @@ -23,17 +24,6 @@ from .utils import AutoWeightsLoader -class ReLU(nn.Module): - - def __init__(self): - super().__init__() - self.activation = nn.ReLU() - - def forward(self, input): - input, _ = input - return self.activation(input) - - class Qwen2ForSequenceClassification(nn.Module): packed_modules_mapping = { "qkv_proj": [ From cc2a9ad4f1358cab838f5567c285979b9f4480cd Mon Sep 17 00:00:00 2001 From: Kevin-Yang Date: Sun, 27 Oct 2024 00:45:38 +0900 Subject: [PATCH 14/15] modified for linting Signed-off-by: Kevin-Yang --- vllm/model_executor/models/qwen2_cls.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index cb12cce3d478..67ad4337477f 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -1,10 +1,10 @@ # coding=utf-8 # Adapted from # https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py -# Copyright 2024 Kakao Corp(Kanana-X Team). +# Copyright 2024 Kakao Corp. (Kanana-X Team) # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. -"""Inference-only Qwen2-Classification model compatible with HuggingFace weights.""" +"""Inference-only Qwen2-Classification model compatible with HF weights.""" from typing import Iterable, List, Optional, Tuple import torch From 81bad154b1ec9b92d16ba5f10eb7bdf2a19fa520 Mon Sep 17 00:00:00 2001 From: Kevin-Yang Date: Sun, 27 Oct 2024 00:45:39 +0900 Subject: [PATCH 15/15] revert softmax inside the pooledr Signed-off-by: Kevin-Yang --- vllm/model_executor/layers/pooler.py | 9 ++++++++- vllm/model_executor/models/qwen2_cls.py | 7 ++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 3455a4ccf282..0a1df9cb699a 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -28,11 +28,15 @@ class Pooler(nn.Module): normalize: Whether to normalize the pooled data. """ - def __init__(self, pooling_type: PoolingType, normalize: bool): + def __init__(self, + pooling_type: PoolingType, + normalize: bool, + softmax: bool = False): super().__init__() self.pooling_type = pooling_type self.normalize = normalize + self.softmax = softmax def forward( self, @@ -64,6 +68,9 @@ def forward( if self.normalize: pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) + if self.softmax: + pooled_data = nn.functional.softmax(pooled_data, dim=-1) + pooled_outputs = [ EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data ] diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py index 67ad4337477f..e10c6dbbb647 100644 --- a/vllm/model_executor/models/qwen2_cls.py +++ b/vllm/model_executor/models/qwen2_cls.py @@ -77,7 +77,9 @@ def __init__( self.score = RowParallelLinear(config.hidden_size, config.num_labels, quant_config=quant_config) - self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False) + self._pooler = Pooler(pooling_type=PoolingType.LAST, + normalize=False, + softmax=True) def forward( self, @@ -97,8 +99,7 @@ def pooler( hidden_states: torch.Tensor, pooling_metadata: PoolingMetadata, ) -> Optional[PoolerOutput]: - pooled = self._pooler(hidden_states, pooling_metadata) - return nn.functional.softmax(pooled, dim=-1) + return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self,