|
19 | 19 | # See the License for the specific language governing permissions and |
20 | 20 | # limitations under the License. |
21 | 21 | """Inference-only Qwen2-Audio model compatible with HuggingFace weights.""" |
22 | | -from functools import lru_cache |
| 22 | +from functools import cached_property, lru_cache |
23 | 23 | from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict, |
24 | 24 | Union) |
25 | 25 |
|
|
34 | 34 | from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, |
35 | 35 | InputContext, token_inputs) |
36 | 36 | from vllm.logger import init_logger |
37 | | -from vllm.model_executor.layers.logits_processor import LogitsProcessor |
38 | 37 | from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler |
39 | | -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead |
40 | | -from vllm.model_executor.model_loader.weight_utils import ( |
41 | | - default_weight_loader, maybe_remap_kv_scale_name) |
42 | | -from vllm.model_executor.models.qwen2 import Qwen2Model |
43 | 38 | from vllm.model_executor.sampling_metadata import SamplingMetadata |
44 | 39 | from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs |
45 | 40 | from vllm.multimodal.inputs import NestedTensors |
46 | 41 | from vllm.multimodal.utils import consecutive_placeholder_ranges |
47 | 42 | from vllm.sequence import IntermediateTensors, SequenceData |
48 | 43 |
|
49 | 44 | from .interfaces import SupportsMultiModal, SupportsPP |
50 | | -from .utils import merge_multimodal_embeddings |
| 45 | +from .utils import (AutoWeightsLoader, init_vllm_registered_model, |
| 46 | + maybe_prefix, merge_multimodal_embeddings) |
51 | 47 |
|
52 | 48 | logger = init_logger(__name__) |
53 | 49 |
|
54 | | -_KEYS_TO_MODIFY_MAPPING = { |
55 | | - "language_model.lm_head": "lm_head", |
56 | | - "language_model.model": "language_model", |
57 | | -} |
58 | | - |
59 | 50 |
|
60 | 51 | # # === Audio Inputs === # |
61 | 52 | class Qwen2AudioInputs(TypedDict): |
@@ -281,25 +272,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
281 | 272 |
|
282 | 273 | self.quant_config = quant_config |
283 | 274 |
|
284 | | - self.language_model = Qwen2Model( |
285 | | - vllm_config=vllm_config.with_hf_config(config.text_config), |
286 | | - prefix=prefix) |
287 | | - self.unpadded_vocab_size = config.text_config.vocab_size |
288 | | - if config.text_config.tie_word_embeddings: |
289 | | - self.lm_head = self.language_model.embed_tokens |
290 | | - else: |
291 | | - self.lm_head = ParallelLMHead(config.text_config.vocab_size, |
292 | | - config.text_config.hidden_size, |
293 | | - quant_config=quant_config) |
294 | | - logit_scale = getattr(config, "logit_scale", 1.0) |
295 | | - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, |
296 | | - config.text_config.vocab_size, |
297 | | - logit_scale) |
298 | | - self.sampler = get_sampler() |
| 275 | + self.language_model = init_vllm_registered_model( |
| 276 | + vllm_config=vllm_config, |
| 277 | + hf_config=config.text_config, |
| 278 | + prefix=maybe_prefix(prefix, "language_model"), |
| 279 | + architectures=["Qwen2ForCausalLM"], |
| 280 | + ) |
299 | 281 |
|
300 | 282 | self.make_empty_intermediate_tensors = ( |
301 | 283 | self.language_model.make_empty_intermediate_tensors) |
302 | 284 |
|
| 285 | + @cached_property |
| 286 | + def sampler(self): |
| 287 | + if hasattr(self.language_model, "sampler"): |
| 288 | + return self.language_model.sampler |
| 289 | + |
| 290 | + return get_sampler() |
| 291 | + |
303 | 292 | def _validate_and_reshape_mm_tensor(self, |
304 | 293 | mm_input: Union[torch.Tensor, |
305 | 294 | List[torch.Tensor]], |
@@ -414,72 +403,30 @@ def forward( |
414 | 403 | multimodal_embeddings) |
415 | 404 | input_ids = None |
416 | 405 |
|
417 | | - hidden_states = self.language_model(input_ids, |
418 | | - positions, |
419 | | - kv_caches, |
420 | | - attn_metadata, |
421 | | - intermediate_tensors, |
422 | | - inputs_embeds=inputs_embeds) |
| 406 | + hidden_states = self.language_model.model(input_ids, |
| 407 | + positions, |
| 408 | + kv_caches, |
| 409 | + attn_metadata, |
| 410 | + intermediate_tensors, |
| 411 | + inputs_embeds=inputs_embeds) |
423 | 412 | return hidden_states |
424 | 413 |
|
425 | | - def compute_logits(self, hidden_states: torch.Tensor, |
426 | | - sampling_metadata: SamplingMetadata) -> torch.Tensor: |
427 | | - logits = self.logits_processor(self.lm_head, hidden_states, |
428 | | - sampling_metadata) |
429 | | - return logits |
| 414 | + def compute_logits( |
| 415 | + self, |
| 416 | + hidden_states: torch.Tensor, |
| 417 | + sampling_metadata: SamplingMetadata, |
| 418 | + ) -> Optional[torch.Tensor]: |
| 419 | + return self.language_model.compute_logits(hidden_states, |
| 420 | + sampling_metadata) |
430 | 421 |
|
431 | 422 | def sample( |
432 | 423 | self, |
433 | 424 | logits: torch.Tensor, |
434 | 425 | sampling_metadata: SamplingMetadata, |
435 | 426 | ) -> Optional[SamplerOutput]: |
436 | | - next_tokens = self.sampler(logits, sampling_metadata) |
437 | | - return next_tokens |
| 427 | + return self.language_model.sample(logits, sampling_metadata) |
438 | 428 |
|
439 | 429 | def load_weights(self, weights: Iterable[Tuple[str, |
440 | 430 | torch.Tensor]]) -> Set[str]: |
441 | | - stacked_params_mapping = [ |
442 | | - # (param_name, shard_name, shard_id) |
443 | | - ("qkv_proj", "q_proj", "q"), |
444 | | - ("qkv_proj", "k_proj", "k"), |
445 | | - ("qkv_proj", "v_proj", "v"), |
446 | | - ("gate_up_proj", "gate_proj", 0), |
447 | | - ("gate_up_proj", "up_proj", 1), |
448 | | - ] |
449 | | - params_dict = dict(self.named_parameters(remove_duplicate=False)) |
450 | | - loaded_params: Set[str] = set() |
451 | | - for name, loaded_weight in weights: |
452 | | - if "rotary_emb.inv_freq" in name: |
453 | | - continue |
454 | | - if (self.config.text_config.tie_word_embeddings |
455 | | - and "lm_head.weight" in name): |
456 | | - continue |
457 | | - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): |
458 | | - if key_to_modify in name: |
459 | | - name = name.replace(key_to_modify, new_key) |
460 | | - for (param_name, weight_name, shard_id) in stacked_params_mapping: |
461 | | - if weight_name not in name or 'audio' in name: |
462 | | - continue |
463 | | - name = name.replace(weight_name, param_name) |
464 | | - # Skip loading extra bias for GPTQ models. |
465 | | - if name.endswith(".bias") and name not in params_dict: |
466 | | - continue |
467 | | - param = params_dict[name] |
468 | | - weight_loader = param.weight_loader |
469 | | - weight_loader(param, loaded_weight, shard_id) |
470 | | - break |
471 | | - else: |
472 | | - # Skip loading extra bias for GPTQ models. |
473 | | - if name.endswith(".bias") and name not in params_dict: |
474 | | - continue |
475 | | - # Remapping the name of FP8 kv-scale. |
476 | | - name = maybe_remap_kv_scale_name(name, params_dict) |
477 | | - if name is None: |
478 | | - continue |
479 | | - |
480 | | - param = params_dict[name] |
481 | | - weight_loader = getattr(param, "weight_loader", |
482 | | - default_weight_loader) |
483 | | - weight_loader(param, loaded_weight) |
484 | | - loaded_params.add(name) |
485 | | - return loaded_params |
| 431 | + loader = AutoWeightsLoader(self) |
| 432 | + return loader.load_weights(weights) |
0 commit comments