Skip to content

Commit 098b8a0

Browse files
nussejzzprincepride
andcommitted
fix detail & format
Co-authored-by: Wang Zhipeng <[email protected]> Signed-off-by: Ding Zuhao <[email protected]>
1 parent 5159878 commit 098b8a0

2 files changed

Lines changed: 130 additions & 96 deletions

File tree

vllm_omni/diffusion/models/bagel/bagel_transformer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import torch
1616
from torch import nn
1717
from torch.nn.attention.flex_attention import flex_attention
18-
from tqdm import tqdm
1918
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
2019
from transformers.models.qwen2.modeling_qwen2 import (
2120
Qwen2PreTrainedModel,
@@ -1301,7 +1300,7 @@ def generate_image(
13011300
timestep_shift: float = 1.0,
13021301
cfg_renorm_min: float = 0.0,
13031302
cfg_renorm_type: str = "global",
1304-
cfg_interval: tuple[float, float] | None = [0, 1],
1303+
cfg_interval: tuple[float, float] = [0, 1],
13051304
# cfg_text
13061305
cfg_text_scale: float = 1.0,
13071306
cfg_text_packed_query_indexes: torch.LongTensor | None = None,
@@ -1325,7 +1324,7 @@ def generate_image(
13251324
dts = timesteps[:-1] - timesteps[1:]
13261325
timesteps = timesteps[:-1]
13271326

1328-
for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
1327+
for i, t in enumerate(timesteps):
13291328
timestep = torch.tensor([t] * x_t.shape[0], device=x_t.device)
13301329
if t > cfg_interval[0] and t <= cfg_interval[1]:
13311330
cfg_text_scale_ = cfg_text_scale

vllm_omni/diffusion/models/bagel/pipeline_bagel.py

Lines changed: 128 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from torch import nn
2121
from transformers import AutoTokenizer, SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel
2222
from vllm.logger import init_logger
23-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23+
from vllm.model_executor.models.utils import AutoWeightsLoader
2424
from vllm.transformers_utils.configs.bagel import BagelConfig
2525

2626
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
@@ -256,97 +256,6 @@ def __init__(self, *, od_config: OmniDiffusionConfig, prefix: str = ""):
256256

257257
self.to(self.device)
258258

259-
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
260-
stacked_params_mapping = [
261-
(".qkv_proj_moe_gen", ".q_proj_moe_gen", "q"),
262-
(".qkv_proj_moe_gen", ".k_proj_moe_gen", "k"),
263-
(".qkv_proj_moe_gen", ".v_proj_moe_gen", "v"),
264-
(".qkv_proj", ".q_proj", "q"),
265-
(".qkv_proj", ".k_proj", "k"),
266-
(".qkv_proj", ".v_proj", "v"),
267-
]
268-
# Common prefixes that need to be mapped to `bagel.` namespace
269-
bagel_prefixes = (
270-
"language_model.",
271-
"time_embedder.",
272-
"latent_pos_embed.",
273-
"vae2llm.",
274-
"llm2vae.",
275-
"vit_model.",
276-
"vision_model.",
277-
"connector.",
278-
"vit_pos_embed.",
279-
)
280-
281-
params_dict = dict(self.named_parameters())
282-
loaded_params: set[str] = set()
283-
284-
for name, loaded_weight in weights:
285-
# Generate Candidate Names
286-
candidates = []
287-
288-
# Direct match
289-
candidates.append(name)
290-
291-
# Bagel Prefix match
292-
if name.startswith(bagel_prefixes):
293-
candidates.append("bagel." + name)
294-
295-
# VAE match (from ae.safetensors or unet checkpoints)
296-
if name.startswith(("encoder.", "decoder.")):
297-
candidates.append("vae." + name)
298-
299-
# Try loading candidates
300-
loaded = False
301-
for cand in candidates:
302-
# 1. Try QKV Mapping first (most specific)
303-
for param_name, weight_name, shard_id in stacked_params_mapping:
304-
if weight_name in cand:
305-
mapped_cand = cand.replace(weight_name, param_name)
306-
param = params_dict.get(mapped_cand)
307-
if param is not None:
308-
getattr(param, "weight_loader", default_weight_loader)(param, loaded_weight, shard_id)
309-
loaded = True
310-
break
311-
if loaded:
312-
break
313-
314-
# 2. Try direct parameter match
315-
param = params_dict.get(cand)
316-
if param is not None:
317-
# Special handling for resize/reshape
318-
319-
# Case A: Latent Pos Embed Resize
320-
if cand.endswith("bagel.latent_pos_embed.pos_embed") and loaded_weight.ndim == 2:
321-
npos, hdim = loaded_weight.shape
322-
if param.shape != loaded_weight.shape:
323-
param.data = param.data.new_empty((npos, hdim))
324-
# Update config
325-
side = isqrt(npos)
326-
self.bagel.max_latent_size = side
327-
if hasattr(self.bagel, "config"):
328-
setattr(self.bagel.config, "max_latent_size", side)
329-
if hasattr(self.bagel.latent_pos_embed, "max_num_patch_per_side"):
330-
self.bagel.latent_pos_embed.max_num_patch_per_side = side
331-
332-
# Case B: SigLIP Patch Embedding Reshape
333-
if cand.endswith("embeddings.patch_embedding.weight") and loaded_weight.ndim == 2:
334-
# Checkpoint has (Hidden, C*P*P), model expects (Hidden, C, P, P)
335-
if param.ndim == 4 and loaded_weight.numel() == param.numel():
336-
loaded_weight = loaded_weight.view(param.shape)
337-
338-
if param.shape != loaded_weight.shape:
339-
pass
340-
341-
getattr(param, "weight_loader", default_weight_loader)(param, loaded_weight)
342-
loaded = True
343-
break
344-
345-
if loaded:
346-
loaded_params.add(name)
347-
348-
return loaded_params
349-
350259
@staticmethod
351260
def _decode_image_from_latent(
352261
bagel: Bagel, vae: AutoEncoder, latent: torch.Tensor, image_shape: tuple[int, int]
@@ -545,7 +454,6 @@ def vae_transforms(img):
545454
for k, v in generation_input.items():
546455
if torch.is_tensor(v):
547456
generation_input[k] = v.to(self.device)
548-
549457
with torch.autocast(
550458
device_type=self.device.type,
551459
enabled=self.device.type != "cpu",
@@ -687,3 +595,130 @@ def vae_transforms(img):
687595

688596
img = self._decode_image_from_latent(self.bagel, self.vae, latents[0], image_shape)
689597
return DiffusionOutput(output=img)
598+
599+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
600+
state = self.state_dict()
601+
allowed = set(state.keys())
602+
shapes = {k: tuple(v.shape) for k, v in state.items()}
603+
604+
tp_aware_params = {name for name, p in self.named_parameters() if hasattr(p, "weight_loader")}
605+
606+
# Expand allowed/tp_aware_params with stacked param source names.
607+
# QKVParallelLinear merges q_proj+k_proj+v_proj into qkv_proj; the
608+
# checkpoint stores the original separate names. We must recognise
609+
# those names so _filtered_weights does not drop them.
610+
_stacked_expansions = [
611+
(".qkv_proj", ".q_proj"),
612+
(".qkv_proj", ".k_proj"),
613+
(".qkv_proj", ".v_proj"),
614+
(".qkv_proj_moe_gen", ".q_proj_moe_gen"),
615+
(".qkv_proj_moe_gen", ".k_proj_moe_gen"),
616+
(".qkv_proj_moe_gen", ".v_proj_moe_gen"),
617+
]
618+
stacked_source_names: set[str] = set()
619+
for name in list(allowed):
620+
for target_suffix, source_suffix in _stacked_expansions:
621+
if target_suffix in name:
622+
stacked_source_names.add(name.replace(target_suffix, source_suffix))
623+
allowed.update(stacked_source_names)
624+
tp_aware_params.update(stacked_source_names)
625+
626+
def _normalize_name(name: str) -> str:
627+
# Common wrappers/prefixes in checkpoints.
628+
for pfx in ("module.", "model."):
629+
if name.startswith(pfx):
630+
name = name[len(pfx) :]
631+
# Common component renames across repos.
632+
if name.startswith("vae_model."):
633+
name = "vae." + name[len("vae_model.") :]
634+
# Bagel `ae.safetensors` commonly stores AE weights without a top-level prefix.
635+
# Map them into this pipeline's `vae.*` namespace.
636+
if name.startswith("encoder.") or name.startswith("decoder."):
637+
name = "vae." + name
638+
return name
639+
640+
def _iter_candidate_names(name: str) -> Iterable[str]:
641+
"""Yield candidate parameter names in this pipeline for a checkpoint key.
642+
643+
The upstream Bagel repo typically stores Bagel-core layers (time_embedder,
644+
latent_pos_embed, vae2llm, llm2vae, etc.) at the top-level of the model,
645+
while this vllm-omni integration nests them under `self.bagel`.
646+
"""
647+
n = _normalize_name(name)
648+
yield n
649+
650+
# Map Bagel core layers from top-level -> `bagel.*` namespace.
651+
for pfx in ("time_embedder.", "latent_pos_embed.", "vae2llm.", "llm2vae."):
652+
if n.startswith(pfx):
653+
yield "bagel." + n
654+
break
655+
656+
# Map connector and vit_pos_embed to `bagel.*`
657+
for pfx in ("connector.", "vit_pos_embed."):
658+
if n.startswith(pfx):
659+
yield "bagel." + n
660+
break
661+
662+
if n.startswith("vit_model."):
663+
yield "bagel." + n # matches self.bagel.vit_model
664+
elif n.startswith("vision_model."):
665+
yield "bagel.vit_model." + n
666+
elif n.startswith("model.vision_model."):
667+
yield "bagel.vit_model." + n[len("model.") :]
668+
669+
def _filtered_weights():
670+
total = 0
671+
kept = 0
672+
shape_mismatch = 0
673+
for name, tensor in weights:
674+
total += 1
675+
picked = None
676+
for cand in _iter_candidate_names(name):
677+
if cand in allowed:
678+
# Only accept if tensor shape matches target param/buffer shape.
679+
if tuple(tensor.shape) == shapes.get(cand) or cand in tp_aware_params:
680+
picked = cand
681+
break
682+
else:
683+
if cand.endswith("bagel.latent_pos_embed.pos_embed") and tensor.ndim == 2:
684+
npos, hdim = tensor.shape
685+
side = isqrt(int(npos))
686+
if side * side == int(npos) and hdim == int(self.bagel.hidden_size):
687+
param = self.bagel.latent_pos_embed.pos_embed
688+
# Resize in-place to keep the same Parameter object.
689+
param.data = param.data.new_empty((npos, hdim))
690+
# Update model bookkeeping so position-id generation matches.
691+
self.bagel.max_latent_size = int(side)
692+
if hasattr(self.bagel, "config"):
693+
setattr(self.bagel.config, "max_latent_size", int(side))
694+
if hasattr(self.bagel.latent_pos_embed, "max_num_patch_per_side"):
695+
self.bagel.latent_pos_embed.max_num_patch_per_side = int(side)
696+
shapes[cand] = (npos, hdim)
697+
picked = cand
698+
break
699+
# Handle flattened patch embedding for SigLIP
700+
if cand.endswith("embeddings.patch_embedding.weight") and tensor.ndim == 2:
701+
# Checkpoint has (Hidden, C*P*P), model expects (Hidden, C, P, P)
702+
if shapes.get(cand) is not None:
703+
target_shape = shapes[cand]
704+
if tensor.numel() == torch.prod(torch.tensor(target_shape)):
705+
# Reshape tensor to match target
706+
tensor = tensor.view(target_shape)
707+
picked = cand
708+
break
709+
710+
shape_mismatch += 1
711+
# Keep this quiet; shape mismatches are expected for ignored modules.
712+
if picked is not None:
713+
kept += 1
714+
yield picked, tensor
715+
# else: ignore extra weights (e.g. connector/vision/und)
716+
logger.info_once(
717+
"BagelPipeline weight filter kept %d/%d tensors (shape mismatches seen: %d)",
718+
kept,
719+
total,
720+
shape_mismatch,
721+
)
722+
723+
loader = AutoWeightsLoader(self)
724+
return loader.load_weights(_filtered_weights())

0 commit comments

Comments
 (0)