-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Model] Support Qwen2 embeddings and use tags to select model tests #10184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
2464b30
Introduce Qwen2 embedding model
DarkLight1337 617ec86
Update docs
DarkLight1337 136786d
format
DarkLight1337 4974a49
Fix default pooling
DarkLight1337 e773573
Fix tests
DarkLight1337 369a66a
format
DarkLight1337 8db9fa8
Fix test
DarkLight1337 0176894
Merge branch 'main' into qwen2-embedding
DarkLight1337 e3e2422
lint
DarkLight1337 f526b55
Merge branch 'main' into qwen2-embedding
DarkLight1337 a7e26d1
Select tests using model flags
DarkLight1337 7e789f4
Update timing
DarkLight1337 1048c04
Combine commands
DarkLight1337 6f09f51
Merge branch 'main' into qwen2-embedding
DarkLight1337 f183929
Update timings
DarkLight1337 08c1bc1
Merge branch 'main' into qwen2-embedding
DarkLight1337 f64f560
Update test distribution
DarkLight1337 50680a9
Merge branch 'main' into qwen2-embedding
DarkLight1337 153f234
Fix coverage test
DarkLight1337 319867b
Remove unused models
DarkLight1337 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,15 +37,17 @@ | |
| QKVParallelLinear, | ||
| RowParallelLinear) | ||
| from vllm.model_executor.layers.logits_processor import LogitsProcessor | ||
| from vllm.model_executor.layers.pooler import Pooler, PoolingType | ||
| from vllm.model_executor.layers.quantization import QuantizationConfig | ||
| from vllm.model_executor.layers.rotary_embedding import get_rope | ||
| from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler | ||
| from vllm.model_executor.layers.vocab_parallel_embedding import ( | ||
| ParallelLMHead, VocabParallelEmbedding) | ||
| from vllm.model_executor.model_loader.weight_utils import ( | ||
| default_weight_loader, maybe_remap_kv_scale_name) | ||
| from vllm.model_executor.pooling_metadata import PoolingMetadata | ||
| from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
| from vllm.sequence import IntermediateTensors | ||
| from vllm.sequence import IntermediateTensors, PoolerOutput | ||
|
|
||
| from .interfaces import SupportsLoRA, SupportsPP | ||
| from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, | ||
|
|
@@ -248,6 +250,19 @@ def __init__( | |
| prefix: str = "", | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| # TODO (@robertgshaw2): see if this can be moved out | ||
| if (cache_config.sliding_window is not None | ||
| and hasattr(config, "max_window_layers")): | ||
| raise ValueError("Sliding window for some but all layers is not " | ||
| "supported. This model uses sliding window " | ||
| "but `max_window_layers` = {} is less than " | ||
| "`num_hidden_layers` = {}. Please open an issue " | ||
| "to discuss this feature.".format( | ||
| config.max_window_layers, | ||
| config.num_hidden_layers, | ||
| )) | ||
|
|
||
| self.config = config | ||
| self.padding_idx = config.pad_token_id | ||
| self.vocab_size = config.vocab_size | ||
|
|
@@ -413,17 +428,7 @@ def __init__( | |
| cache_config = vllm_config.cache_config | ||
| quant_config = vllm_config.quant_config | ||
| lora_config = vllm_config.lora_config | ||
| # TODO (@robertgshaw2): see if this can be moved out | ||
| if (cache_config.sliding_window is not None | ||
| and hasattr(config, "max_window_layers")): | ||
| raise ValueError("Sliding window for some but all layers is not " | ||
| "supported. This model uses sliding window " | ||
| "but `max_window_layers` = {} is less than " | ||
| "`num_hidden_layers` = {}. Please open an issue " | ||
| "to discuss this feature.".format( | ||
| config.max_window_layers, | ||
| config.num_hidden_layers, | ||
| )) | ||
| pooler_config = vllm_config.model_config.pooler_config | ||
|
|
||
| self.config = config | ||
| self.lora_config = lora_config | ||
|
|
@@ -445,6 +450,15 @@ def __init__( | |
|
|
||
| self.logits_processor = LogitsProcessor(config.vocab_size) | ||
| self.sampler = get_sampler() | ||
|
|
||
| # The same model class supports both language generation and embedding | ||
| # because the architecture name is the same | ||
| self._pooler = Pooler.from_config_with_defaults( | ||
| pooler_config, | ||
| pooling_type=PoolingType.LAST, | ||
| normalize=True, | ||
| softmax=False) | ||
|
Comment on lines
+448
to
+451
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will this be able to be controlled by the pooling args we spoke about offline?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes - these are the model's default values which can be overridden. |
||
|
|
||
| self.make_empty_intermediate_tensors = ( | ||
| self.model.make_empty_intermediate_tensors) | ||
|
|
||
|
|
@@ -477,10 +491,88 @@ def sample( | |
| next_tokens = self.sampler(logits, sampling_metadata) | ||
| return next_tokens | ||
|
|
||
| def pooler( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| pooling_metadata: PoolingMetadata, | ||
| ) -> Optional[PoolerOutput]: | ||
| return self._pooler(hidden_states, pooling_metadata) | ||
|
|
||
| def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | ||
| loader = AutoWeightsLoader( | ||
| self, | ||
| skip_prefixes=(["lm_head."] | ||
| if self.config.tie_word_embeddings else None), | ||
| ) | ||
| loader.load_weights(weights) | ||
|
|
||
|
|
||
| class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP): | ||
| packed_modules_mapping = { | ||
| "qkv_proj": [ | ||
| "q_proj", | ||
| "k_proj", | ||
| "v_proj", | ||
| ], | ||
| "gate_up_proj": [ | ||
| "gate_proj", | ||
| "up_proj", | ||
| ], | ||
| } | ||
|
|
||
| # LoRA specific attributes | ||
| supported_lora_modules = [ | ||
| "qkv_proj", | ||
| "o_proj", | ||
| "gate_up_proj", | ||
| "down_proj", | ||
| ] | ||
| embedding_modules = {} | ||
| embedding_padding_modules = [] | ||
|
|
||
| def __init__( | ||
| self, | ||
| vllm_config: VllmConfig, | ||
| prefix: str = "", | ||
| ) -> None: | ||
| super().__init__() | ||
| config = vllm_config.model_config.hf_config | ||
| cache_config = vllm_config.cache_config | ||
| quant_config = vllm_config.quant_config | ||
| lora_config = vllm_config.lora_config | ||
| pooler_config = vllm_config.model_config.pooler_config | ||
|
|
||
| self.config = config | ||
| self.lora_config = lora_config | ||
|
|
||
| self.quant_config = quant_config | ||
| self.model = Qwen2Model(config, cache_config, quant_config) | ||
|
|
||
| self._pooler = Pooler.from_config_with_defaults( | ||
| pooler_config, | ||
| pooling_type=PoolingType.LAST, | ||
| normalize=True, | ||
| softmax=False) | ||
|
|
||
| def forward( | ||
| self, | ||
| input_ids: torch.Tensor, | ||
| positions: torch.Tensor, | ||
| kv_caches: List[torch.Tensor], | ||
| attn_metadata: AttentionMetadata, | ||
| intermediate_tensors: Optional[IntermediateTensors] = None, | ||
| ) -> torch.Tensor: | ||
| return self.model(input_ids, positions, kv_caches, attn_metadata, | ||
| intermediate_tensors) | ||
|
|
||
| def pooler( | ||
| self, | ||
| hidden_states: torch.Tensor, | ||
| pooling_metadata: PoolingMetadata, | ||
| ) -> Optional[PoolerOutput]: | ||
| return self._pooler(hidden_states, pooling_metadata) | ||
|
|
||
| def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | ||
| loader = AutoWeightsLoader(self, | ||
| ignore_unexpected_prefixes=["lm_head."]) | ||
| loader.load_weights(weights) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.