|
20 | 20 | from torch import nn |
21 | 21 | from transformers import AutoTokenizer, SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel |
22 | 22 | from vllm.logger import init_logger |
23 | | -from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
| 23 | +from vllm.model_executor.models.utils import AutoWeightsLoader |
24 | 24 | from vllm.transformers_utils.configs.bagel import BagelConfig |
25 | 25 |
|
26 | 26 | from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig |
@@ -256,97 +256,6 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""): |
256 | 256 |
|
257 | 257 | self.to(self.device) |
258 | 258 |
|
259 | | - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: |
260 | | - stacked_params_mapping = [ |
261 | | - (".qkv_proj_moe_gen", ".q_proj_moe_gen", "q"), |
262 | | - (".qkv_proj_moe_gen", ".k_proj_moe_gen", "k"), |
263 | | - (".qkv_proj_moe_gen", ".v_proj_moe_gen", "v"), |
264 | | - (".qkv_proj", ".q_proj", "q"), |
265 | | - (".qkv_proj", ".k_proj", "k"), |
266 | | - (".qkv_proj", ".v_proj", "v"), |
267 | | - ] |
268 | | - # Common prefixes that need to be mapped to `bagel.` namespace |
269 | | - bagel_prefixes = ( |
270 | | - "language_model.", |
271 | | - "time_embedder.", |
272 | | - "latent_pos_embed.", |
273 | | - "vae2llm.", |
274 | | - "llm2vae.", |
275 | | - "vit_model.", |
276 | | - "vision_model.", |
277 | | - "connector.", |
278 | | - "vit_pos_embed.", |
279 | | - ) |
280 | | - |
281 | | - params_dict = dict(self.named_parameters()) |
282 | | - loaded_params: set[str] = set() |
283 | | - |
284 | | - for name, loaded_weight in weights: |
285 | | - # Generate Candidate Names |
286 | | - candidates = [] |
287 | | - |
288 | | - # Direct match |
289 | | - candidates.append(name) |
290 | | - |
291 | | - # Bagel Prefix match |
292 | | - if name.startswith(bagel_prefixes): |
293 | | - candidates.append("bagel." + name) |
294 | | - |
295 | | - # VAE match (from ae.safetensors or unet checkpoints) |
296 | | - if name.startswith(("encoder.", "decoder.")): |
297 | | - candidates.append("vae." + name) |
298 | | - |
299 | | - # Try loading candidates |
300 | | - loaded = False |
301 | | - for cand in candidates: |
302 | | - # 1. Try QKV Mapping first (most specific) |
303 | | - for param_name, weight_name, shard_id in stacked_params_mapping: |
304 | | - if weight_name in cand: |
305 | | - mapped_cand = cand.replace(weight_name, param_name) |
306 | | - param = params_dict.get(mapped_cand) |
307 | | - if param is not None: |
308 | | - getattr(param, "weight_loader", default_weight_loader)(param, loaded_weight, shard_id) |
309 | | - loaded = True |
310 | | - break |
311 | | - if loaded: |
312 | | - break |
313 | | - |
314 | | - # 2. Try direct parameter match |
315 | | - param = params_dict.get(cand) |
316 | | - if param is not None: |
317 | | - # Special handling for resize/reshape |
318 | | - |
319 | | - # Case A: Latent Pos Embed Resize |
320 | | - if cand.endswith("bagel.latent_pos_embed.pos_embed") and loaded_weight.ndim == 2: |
321 | | - npos, hdim = loaded_weight.shape |
322 | | - if param.shape != loaded_weight.shape: |
323 | | - param.data = param.data.new_empty((npos, hdim)) |
324 | | - # Update config |
325 | | - side = isqrt(npos) |
326 | | - self.bagel.max_latent_size = side |
327 | | - if hasattr(self.bagel, "config"): |
328 | | - setattr(self.bagel.config, "max_latent_size", side) |
329 | | - if hasattr(self.bagel.latent_pos_embed, "max_num_patch_per_side"): |
330 | | - self.bagel.latent_pos_embed.max_num_patch_per_side = side |
331 | | - |
332 | | - # Case B: SigLIP Patch Embedding Reshape |
333 | | - if cand.endswith("embeddings.patch_embedding.weight") and loaded_weight.ndim == 2: |
334 | | - # Checkpoint has (Hidden, C*P*P), model expects (Hidden, C, P, P) |
335 | | - if param.ndim == 4 and loaded_weight.numel() == param.numel(): |
336 | | - loaded_weight = loaded_weight.view(param.shape) |
337 | | - |
338 | | - if param.shape != loaded_weight.shape: |
339 | | - pass |
340 | | - |
341 | | - getattr(param, "weight_loader", default_weight_loader)(param, loaded_weight) |
342 | | - loaded = True |
343 | | - break |
344 | | - |
345 | | - if loaded: |
346 | | - loaded_params.add(name) |
347 | | - |
348 | | - return loaded_params |
349 | | - |
350 | 259 | @staticmethod |
351 | 260 | def _decode_image_from_latent( |
352 | 261 | bagel: Bagel, vae: AutoEncoder, latent: torch.Tensor, image_shape: tuple[int, int] |
@@ -545,7 +454,6 @@ def vae_transforms(img): |
545 | 454 | for k, v in generation_input.items(): |
546 | 455 | if torch.is_tensor(v): |
547 | 456 | generation_input[k] = v.to(self.device) |
548 | | - |
549 | 457 | with torch.autocast( |
550 | 458 | device_type=self.device.type, |
551 | 459 | enabled=self.device.type != "cpu", |
@@ -687,3 +595,130 @@ def vae_transforms(img): |
687 | 595 |
|
688 | 596 | img = self._decode_image_from_latent(self.bagel, self.vae, latents[0], image_shape) |
689 | 597 | return DiffusionOutput(output=img) |
| 598 | + |
| 599 | + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: |
| 600 | + state = self.state_dict() |
| 601 | + allowed = set(state.keys()) |
| 602 | + shapes = {k: tuple(v.shape) for k, v in state.items()} |
| 603 | + |
| 604 | + tp_aware_params = {name for name, p in self.named_parameters() if hasattr(p, "weight_loader")} |
| 605 | + |
| 606 | + # Expand allowed/tp_aware_params with stacked param source names. |
| 607 | + # QKVParallelLinear merges q_proj+k_proj+v_proj into qkv_proj; the |
| 608 | + # checkpoint stores the original separate names. We must recognise |
| 609 | + # those names so _filtered_weights does not drop them. |
| 610 | + _stacked_expansions = [ |
| 611 | + (".qkv_proj", ".q_proj"), |
| 612 | + (".qkv_proj", ".k_proj"), |
| 613 | + (".qkv_proj", ".v_proj"), |
| 614 | + (".qkv_proj_moe_gen", ".q_proj_moe_gen"), |
| 615 | + (".qkv_proj_moe_gen", ".k_proj_moe_gen"), |
| 616 | + (".qkv_proj_moe_gen", ".v_proj_moe_gen"), |
| 617 | + ] |
| 618 | + stacked_source_names: set[str] = set() |
| 619 | + for name in list(allowed): |
| 620 | + for target_suffix, source_suffix in _stacked_expansions: |
| 621 | + if target_suffix in name: |
| 622 | + stacked_source_names.add(name.replace(target_suffix, source_suffix)) |
| 623 | + allowed.update(stacked_source_names) |
| 624 | + tp_aware_params.update(stacked_source_names) |
| 625 | + |
| 626 | + def _normalize_name(name: str) -> str: |
| 627 | + # Common wrappers/prefixes in checkpoints. |
| 628 | + for pfx in ("module.", "model."): |
| 629 | + if name.startswith(pfx): |
| 630 | + name = name[len(pfx) :] |
| 631 | + # Common component renames across repos. |
| 632 | + if name.startswith("vae_model."): |
| 633 | + name = "vae." + name[len("vae_model.") :] |
| 634 | + # Bagel `ae.safetensors` commonly stores AE weights without a top-level prefix. |
| 635 | + # Map them into this pipeline's `vae.*` namespace. |
| 636 | + if name.startswith("encoder.") or name.startswith("decoder."): |
| 637 | + name = "vae." + name |
| 638 | + return name |
| 639 | + |
| 640 | + def _iter_candidate_names(name: str) -> Iterable[str]: |
| 641 | + """Yield candidate parameter names in this pipeline for a checkpoint key. |
| 642 | +
|
| 643 | + The upstream Bagel repo typically stores Bagel-core layers (time_embedder, |
| 644 | + latent_pos_embed, vae2llm, llm2vae, etc.) at the top-level of the model, |
| 645 | + while this vllm-omni integration nests them under `self.bagel`. |
| 646 | + """ |
| 647 | + n = _normalize_name(name) |
| 648 | + yield n |
| 649 | + |
| 650 | + # Map Bagel core layers from top-level -> `bagel.*` namespace. |
| 651 | + for pfx in ("time_embedder.", "latent_pos_embed.", "vae2llm.", "llm2vae."): |
| 652 | + if n.startswith(pfx): |
| 653 | + yield "bagel." + n |
| 654 | + break |
| 655 | + |
| 656 | + # Map connector and vit_pos_embed to `bagel.*` |
| 657 | + for pfx in ("connector.", "vit_pos_embed."): |
| 658 | + if n.startswith(pfx): |
| 659 | + yield "bagel." + n |
| 660 | + break |
| 661 | + |
| 662 | + if n.startswith("vit_model."): |
| 663 | + yield "bagel." + n # matches self.bagel.vit_model |
| 664 | + elif n.startswith("vision_model."): |
| 665 | + yield "bagel.vit_model." + n |
| 666 | + elif n.startswith("model.vision_model."): |
| 667 | + yield "bagel.vit_model." + n[len("model.") :] |
| 668 | + |
| 669 | + def _filtered_weights(): |
| 670 | + total = 0 |
| 671 | + kept = 0 |
| 672 | + shape_mismatch = 0 |
| 673 | + for name, tensor in weights: |
| 674 | + total += 1 |
| 675 | + picked = None |
| 676 | + for cand in _iter_candidate_names(name): |
| 677 | + if cand in allowed: |
| 678 | + # Only accept if tensor shape matches target param/buffer shape. |
| 679 | + if tuple(tensor.shape) == shapes.get(cand) or cand in tp_aware_params: |
| 680 | + picked = cand |
| 681 | + break |
| 682 | + else: |
| 683 | + if cand.endswith("bagel.latent_pos_embed.pos_embed") and tensor.ndim == 2: |
| 684 | + npos, hdim = tensor.shape |
| 685 | + side = isqrt(int(npos)) |
| 686 | + if side * side == int(npos) and hdim == int(self.bagel.hidden_size): |
| 687 | + param = self.bagel.latent_pos_embed.pos_embed |
| 688 | + # Resize in-place to keep the same Parameter object. |
| 689 | + param.data = param.data.new_empty((npos, hdim)) |
| 690 | + # Update model bookkeeping so position-id generation matches. |
| 691 | + self.bagel.max_latent_size = int(side) |
| 692 | + if hasattr(self.bagel, "config"): |
| 693 | + setattr(self.bagel.config, "max_latent_size", int(side)) |
| 694 | + if hasattr(self.bagel.latent_pos_embed, "max_num_patch_per_side"): |
| 695 | + self.bagel.latent_pos_embed.max_num_patch_per_side = int(side) |
| 696 | + shapes[cand] = (npos, hdim) |
| 697 | + picked = cand |
| 698 | + break |
| 699 | + # Handle flattened patch embedding for SigLIP |
| 700 | + if cand.endswith("embeddings.patch_embedding.weight") and tensor.ndim == 2: |
| 701 | + # Checkpoint has (Hidden, C*P*P), model expects (Hidden, C, P, P) |
| 702 | + if shapes.get(cand) is not None: |
| 703 | + target_shape = shapes[cand] |
| 704 | + if tensor.numel() == torch.prod(torch.tensor(target_shape)): |
| 705 | + # Reshape tensor to match target |
| 706 | + tensor = tensor.view(target_shape) |
| 707 | + picked = cand |
| 708 | + break |
| 709 | + |
| 710 | + shape_mismatch += 1 |
| 711 | + # Keep this quiet; shape mismatches are expected for ignored modules. |
| 712 | + if picked is not None: |
| 713 | + kept += 1 |
| 714 | + yield picked, tensor |
| 715 | + # else: ignore extra weights (e.g. connector/vision/und) |
| 716 | + logger.info_once( |
| 717 | + "BagelPipeline weight filter kept %d/%d tensors (shape mismatches seen: %d)", |
| 718 | + kept, |
| 719 | + total, |
| 720 | + shape_mismatch, |
| 721 | + ) |
| 722 | + |
| 723 | + loader = AutoWeightsLoader(self) |
| 724 | + return loader.load_weights(_filtered_weights()) |
0 commit comments