Skip to content

Commit 2551817

Browse files
Fix GGUF pan-and-scan attention and CUDA graph mask preservation
Fixes four critical issues in GGUF multimodal inference: 1. Attention scaling parameter bug (gemma3.py): - Fix F.scaled_dot_product_attention to use named parameters - Changed positional args to attn_mask=attn_mask, scale=self.scaling - Prevents incorrect dropout application (was 6.25% instead of 0%) 2. Custom attention mask persistence (gpu_model_runner.py): - Store custom_model_kwargs after mask generation - Merge custom_model_kwargs in _dummy_run - Prevents loss of attention masks during CUDA graph re-initialization 3. Pan-and-scan attention pattern (gemma3_mm.py): - Detect pan-and-scan mode via multimodal_config.do_pan_and_scan - Prevents crop isolation artifacts in sequential processing 4. GGUF unquantized weight loading (weight_utils.py): - Add proper dtype conversion for BF16/F16/F32 stored as uint8 - Handle byte-to-dtype conversion (BF16: 2 bytes, F16: 2 bytes, F32: 4 bytes) - Add fallback handling for unexpected dtype/type combinations - Fixes weight loading for unquantized GGUF multimodal projector weights Signed-off-by: Luciano Martins <[email protected]>
1 parent bb47210 commit 2551817

File tree

4 files changed

+96
-7
lines changed

4 files changed

+96
-7
lines changed

vllm/model_executor/model_loader/weight_utils.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -919,9 +919,79 @@ def gguf_quant_weights_iterator(
919919
weight = tensor.data
920920
weight_type = tensor.tensor_type
921921
name = gguf_to_hf_name_map[tensor.name]
922+
922923
if weight_type.name not in ("F32", "BF16", "F16"):
924+
# Quantized tensors: handled by quantization layers
923925
name = name.replace("weight", "qweight")
924-
param = torch.tensor(weight)
926+
param = torch.tensor(weight)
927+
else:
928+
# Unquantized tensors: may need dtype conversion
929+
# GGUF stores BF16/F16 as uint8 bytes but F32 as float32
930+
931+
# Check if already in target dtype
932+
if weight.dtype == np.float32 and weight_type.name == "F32":
933+
# F32 tensors are stored directly as float32
934+
param = torch.from_numpy(np.array(weight))
935+
936+
elif weight.dtype == np.float16 and weight_type.name == "F16":
937+
# F16 tensors are stored directly as float16
938+
param = torch.from_numpy(np.array(weight))
939+
940+
elif weight.dtype == np.uint8:
941+
# Stored as bytes: convert to target dtype
942+
if weight_type.name == "BF16":
943+
# BF16: 2 bytes per value
944+
# Input: [..., hidden_dim * 2] uint8
945+
# Output: [..., hidden_dim] bfloat16
946+
weight_uint16 = np.frombuffer(
947+
weight.tobytes(), dtype=np.uint16
948+
)
949+
target_shape = weight.shape[:-1] + (weight.shape[-1] // 2,)
950+
weight_uint16 = weight_uint16.reshape(target_shape)
951+
param = torch.from_numpy(weight_uint16).view(torch.bfloat16)
952+
953+
elif weight_type.name == "F16":
954+
# F16 (float16): 2 bytes per value
955+
# Input: [..., hidden_dim * 2] uint8
956+
# Output: [..., hidden_dim] float16
957+
weight_uint16 = np.frombuffer(
958+
weight.tobytes(), dtype=np.uint16
959+
)
960+
target_shape = weight.shape[:-1] + (weight.shape[-1] // 2,)
961+
weight_uint16 = weight_uint16.reshape(target_shape)
962+
param = torch.from_numpy(weight_uint16).view(torch.float16)
963+
964+
elif weight_type.name == "F32":
965+
# F32 (float32): 4 bytes per value
966+
# Input: [..., hidden_dim * 4] uint8
967+
# Output: [..., hidden_dim] float32
968+
weight_float32 = np.frombuffer(
969+
weight.tobytes(), dtype=np.float32
970+
)
971+
target_shape = weight.shape[:-1] + (weight.shape[-1] // 4,)
972+
weight_float32 = weight_float32.reshape(target_shape)
973+
param = torch.from_numpy(weight_float32)
974+
975+
else:
976+
# Unknown format
977+
logger.warning(
978+
"Unknown uint8-stored weight type '%s' for tensor '%s'.",
979+
weight_type.name,
980+
name,
981+
)
982+
param = torch.tensor(weight)
983+
984+
else:
985+
# Unexpected dtype/type combination
986+
logger.warning(
987+
"Unexpected dtype '%s' for weight type '%s' in tensor '%s'. "
988+
"Falling back to torch.tensor().",
989+
weight.dtype,
990+
weight_type.name,
991+
name,
992+
)
993+
param = torch.tensor(weight)
994+
925995
yield name, param
926996

927997

vllm/model_executor/models/gemma3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,8 @@ def naive_attn_with_masks(
286286
query,
287287
key,
288288
value,
289-
attn_mask,
290-
self.scaling,
289+
attn_mask=attn_mask,
290+
scale=self.scaling,
291291
)
292292
output = output.transpose(1, 2).flatten(-2, -1)
293293
out[start_idx:end_idx] = output

vllm/model_executor/models/gemma3_mm.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -708,10 +708,26 @@ def generate_attention_masks(
708708
# Fill the lower triangle with 0 (causal attention)
709709
global_attn_mask = global_attn_mask.triu(diagonal=1)
710710

711-
# Enable bidirectional attention between image tokens
712-
# Use advanced indexing for better performance
713-
img_indices = torch.where(img_pos)[0]
714-
global_attn_mask[:, :, img_indices[:, None], img_indices] = 0
711+
# Conditionally apply bidirectional attention based on pan-and-scan
712+
# mode. Pan-and-scan crops require pure causal attention to build
713+
# sequential context across crops, matching HF transformers behavior.
714+
# Non-pan-and-scan images use bidirectional attention for richer
715+
# cross-token interactions within each image.
716+
is_pan_and_scan = getattr(self.multimodal_config, "do_pan_and_scan", False)
717+
718+
if is_pan_and_scan:
719+
# Pan-and-scan: Keep pure causal attention (mask unchanged).
720+
# Crops are processed sequentially, allowing later crops to
721+
# attend to earlier ones, building coherent context across the
722+
# entire image. This prevents crop isolation artifacts.
723+
pass
724+
else:
725+
# Non-pan-and-scan: Enable bidirectional attention for image
726+
# tokens. This allows all tokens within each image to attend
727+
# to each other, improving representation quality.
728+
img_indices = torch.where(img_pos)[0]
729+
global_attn_mask[:, :, img_indices[:, None], img_indices] = 0
730+
715731
global_attn_masks.append(global_attn_mask)
716732

717733
# GGUF compatibility: config might be Gemma3TextConfig directly

vllm/v1/worker/gpu_model_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2484,6 +2484,8 @@ def _preprocess(
24842484
mask_dtype=self.model.dtype,
24852485
)
24862486
model_kwargs.update(mask_kwargs)
2487+
# Store for _dummy_run to prevent loss during re-initialization.
2488+
self.custom_model_kwargs = mask_kwargs
24872489
elif self.enable_prompt_embeds and is_first_rank:
24882490
# Get the input embeddings for the tokens that are not input embeds,
24892491
# then put them into the appropriate positions.
@@ -3952,6 +3954,7 @@ def _dummy_run(
39523954
model_kwargs = {
39533955
**model_kwargs,
39543956
**self._dummy_mm_kwargs(num_reqs),
3957+
**getattr(self, "custom_model_kwargs", {}),
39553958
}
39563959
elif self.enable_prompt_embeds:
39573960
input_ids = None

0 commit comments

Comments
 (0)