Skip to content

Commit bf0e382

Browse files
[Model] Composite weight loading for multimodal Qwen2 (#10944)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent b26b4cd commit bf0e382

File tree

7 files changed

+147
-205
lines changed

7 files changed

+147
-205
lines changed

vllm/config.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2472,7 +2472,15 @@ def _get_quantization_config(
24722472
return quant_config
24732473
return None
24742474

2475-
def with_hf_config(self, hf_config: PretrainedConfig) -> "VllmConfig":
2475+
def with_hf_config(
2476+
self,
2477+
hf_config: PretrainedConfig,
2478+
architectures: Optional[list[str]] = None,
2479+
) -> "VllmConfig":
2480+
if architectures is not None:
2481+
hf_config = copy.deepcopy(hf_config)
2482+
hf_config.architectures = architectures
2483+
24762484
model_config = copy.deepcopy(self.model_config)
24772485
model_config.hf_config = hf_config
24782486

vllm/model_executor/model_loader/loader.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,10 @@ def _initialize_model(
101101
vllm_config: VllmConfig,
102102
*,
103103
prefix: str = "",
104-
architectures: Optional[list[str]] = None,
105104
) -> nn.Module:
106105
"""Initialize a model with the given configurations."""
107106
model_config = vllm_config.model_config
108-
model_class, _ = get_model_architecture(model_config,
109-
architectures=architectures)
107+
model_class, _ = get_model_architecture(model_config)
110108

111109
signatures = inspect.signature(model_class.__init__)
112110
all_params = [param.name for param in signatures.parameters.values()]

vllm/model_executor/model_loader/utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Utilities for selecting and loading models."""
22
import contextlib
3-
from typing import Optional, Tuple, Type
3+
from typing import Tuple, Type
44

55
import torch
66
from torch import nn
@@ -20,12 +20,8 @@ def set_default_torch_dtype(dtype: torch.dtype):
2020

2121

2222
def get_model_architecture(
23-
model_config: ModelConfig,
24-
*,
25-
architectures: Optional[list[str]] = None,
26-
) -> Tuple[Type[nn.Module], str]:
27-
if architectures is None:
28-
architectures = getattr(model_config.hf_config, "architectures", [])
23+
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
24+
architectures = getattr(model_config.hf_config, "architectures", [])
2925

3026
# Special handling for quantized Mixtral.
3127
# FIXME(woosuk): This is a temporary hack.

vllm/model_executor/models/qwen2.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -444,14 +444,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
444444
self.model = Qwen2Model(vllm_config=vllm_config,
445445
prefix=maybe_prefix(prefix, "model"))
446446

447-
if config.tie_word_embeddings:
448-
self.lm_head = self.model.embed_tokens
447+
if get_pp_group().is_last_rank:
448+
if config.tie_word_embeddings:
449+
self.lm_head = self.model.embed_tokens
450+
else:
451+
self.lm_head = ParallelLMHead(config.vocab_size,
452+
config.hidden_size,
453+
quant_config=quant_config,
454+
prefix=maybe_prefix(
455+
prefix, "lm_head"))
449456
else:
450-
self.lm_head = ParallelLMHead(config.vocab_size,
451-
config.hidden_size,
452-
quant_config=quant_config,
453-
prefix=maybe_prefix(
454-
prefix, "lm_head"))
457+
self.lm_head = PPMissingLayer()
455458

456459
self.logits_processor = LogitsProcessor(config.vocab_size)
457460
self.sampler = get_sampler()

vllm/model_executor/models/qwen2_audio.py

Lines changed: 32 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# See the License for the specific language governing permissions and
2020
# limitations under the License.
2121
"""Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
22-
from functools import lru_cache
22+
from functools import cached_property, lru_cache
2323
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
2424
Union)
2525

@@ -34,28 +34,19 @@
3434
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
3535
InputContext, token_inputs)
3636
from vllm.logger import init_logger
37-
from vllm.model_executor.layers.logits_processor import LogitsProcessor
3837
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
39-
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
40-
from vllm.model_executor.model_loader.weight_utils import (
41-
default_weight_loader, maybe_remap_kv_scale_name)
42-
from vllm.model_executor.models.qwen2 import Qwen2Model
4338
from vllm.model_executor.sampling_metadata import SamplingMetadata
4439
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
4540
from vllm.multimodal.inputs import NestedTensors
4641
from vllm.multimodal.utils import consecutive_placeholder_ranges
4742
from vllm.sequence import IntermediateTensors, SequenceData
4843

4944
from .interfaces import SupportsMultiModal, SupportsPP
50-
from .utils import merge_multimodal_embeddings
45+
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
46+
maybe_prefix, merge_multimodal_embeddings)
5147

5248
logger = init_logger(__name__)
5349

54-
_KEYS_TO_MODIFY_MAPPING = {
55-
"language_model.lm_head": "lm_head",
56-
"language_model.model": "language_model",
57-
}
58-
5950

6051
# # === Audio Inputs === #
6152
class Qwen2AudioInputs(TypedDict):
@@ -281,25 +272,23 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
281272

282273
self.quant_config = quant_config
283274

284-
self.language_model = Qwen2Model(
285-
vllm_config=vllm_config.with_hf_config(config.text_config),
286-
prefix=prefix)
287-
self.unpadded_vocab_size = config.text_config.vocab_size
288-
if config.text_config.tie_word_embeddings:
289-
self.lm_head = self.language_model.embed_tokens
290-
else:
291-
self.lm_head = ParallelLMHead(config.text_config.vocab_size,
292-
config.text_config.hidden_size,
293-
quant_config=quant_config)
294-
logit_scale = getattr(config, "logit_scale", 1.0)
295-
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
296-
config.text_config.vocab_size,
297-
logit_scale)
298-
self.sampler = get_sampler()
275+
self.language_model = init_vllm_registered_model(
276+
vllm_config=vllm_config,
277+
hf_config=config.text_config,
278+
prefix=maybe_prefix(prefix, "language_model"),
279+
architectures=["Qwen2ForCausalLM"],
280+
)
299281

300282
self.make_empty_intermediate_tensors = (
301283
self.language_model.make_empty_intermediate_tensors)
302284

285+
@cached_property
286+
def sampler(self):
287+
if hasattr(self.language_model, "sampler"):
288+
return self.language_model.sampler
289+
290+
return get_sampler()
291+
303292
def _validate_and_reshape_mm_tensor(self,
304293
mm_input: Union[torch.Tensor,
305294
List[torch.Tensor]],
@@ -414,72 +403,30 @@ def forward(
414403
multimodal_embeddings)
415404
input_ids = None
416405

417-
hidden_states = self.language_model(input_ids,
418-
positions,
419-
kv_caches,
420-
attn_metadata,
421-
intermediate_tensors,
422-
inputs_embeds=inputs_embeds)
406+
hidden_states = self.language_model.model(input_ids,
407+
positions,
408+
kv_caches,
409+
attn_metadata,
410+
intermediate_tensors,
411+
inputs_embeds=inputs_embeds)
423412
return hidden_states
424413

425-
def compute_logits(self, hidden_states: torch.Tensor,
426-
sampling_metadata: SamplingMetadata) -> torch.Tensor:
427-
logits = self.logits_processor(self.lm_head, hidden_states,
428-
sampling_metadata)
429-
return logits
414+
def compute_logits(
415+
self,
416+
hidden_states: torch.Tensor,
417+
sampling_metadata: SamplingMetadata,
418+
) -> Optional[torch.Tensor]:
419+
return self.language_model.compute_logits(hidden_states,
420+
sampling_metadata)
430421

431422
def sample(
432423
self,
433424
logits: torch.Tensor,
434425
sampling_metadata: SamplingMetadata,
435426
) -> Optional[SamplerOutput]:
436-
next_tokens = self.sampler(logits, sampling_metadata)
437-
return next_tokens
427+
return self.language_model.sample(logits, sampling_metadata)
438428

439429
def load_weights(self, weights: Iterable[Tuple[str,
440430
torch.Tensor]]) -> Set[str]:
441-
stacked_params_mapping = [
442-
# (param_name, shard_name, shard_id)
443-
("qkv_proj", "q_proj", "q"),
444-
("qkv_proj", "k_proj", "k"),
445-
("qkv_proj", "v_proj", "v"),
446-
("gate_up_proj", "gate_proj", 0),
447-
("gate_up_proj", "up_proj", 1),
448-
]
449-
params_dict = dict(self.named_parameters(remove_duplicate=False))
450-
loaded_params: Set[str] = set()
451-
for name, loaded_weight in weights:
452-
if "rotary_emb.inv_freq" in name:
453-
continue
454-
if (self.config.text_config.tie_word_embeddings
455-
and "lm_head.weight" in name):
456-
continue
457-
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
458-
if key_to_modify in name:
459-
name = name.replace(key_to_modify, new_key)
460-
for (param_name, weight_name, shard_id) in stacked_params_mapping:
461-
if weight_name not in name or 'audio' in name:
462-
continue
463-
name = name.replace(weight_name, param_name)
464-
# Skip loading extra bias for GPTQ models.
465-
if name.endswith(".bias") and name not in params_dict:
466-
continue
467-
param = params_dict[name]
468-
weight_loader = param.weight_loader
469-
weight_loader(param, loaded_weight, shard_id)
470-
break
471-
else:
472-
# Skip loading extra bias for GPTQ models.
473-
if name.endswith(".bias") and name not in params_dict:
474-
continue
475-
# Remapping the name of FP8 kv-scale.
476-
name = maybe_remap_kv_scale_name(name, params_dict)
477-
if name is None:
478-
continue
479-
480-
param = params_dict[name]
481-
weight_loader = getattr(param, "weight_loader",
482-
default_weight_loader)
483-
weight_loader(param, loaded_weight)
484-
loaded_params.add(name)
485-
return loaded_params
431+
loader = AutoWeightsLoader(self)
432+
return loader.load_weights(weights)

0 commit comments

Comments
 (0)