Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
"num_reserved_decode_tokens",
"weight_loader_disable_mmap",
"enable_triton_kernel_moe",
"enable_multimodal",
]

# Put some global args for easy access
Expand Down
14 changes: 11 additions & 3 deletions python/sglang/srt/models/mllama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Modality,
MultimodalDataItem,
MultimodalInputs,
global_server_args_dict,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
Expand Down Expand Up @@ -55,13 +56,18 @@ def __init__(
self.quant_config = quant_config

# Check if this is a text-only model (modelopt fp8 llama4 has no vision components)
self.has_vision = self._has_vision_weights(config)
if not self.has_vision:
self.has_vision_weights = self._has_vision_weights(config)
if not self.has_vision_weights:
logger.warning(
"No vision weights found in checkpoint. Model will run in text-only mode. "
"Multimodal capabilities (image processing) will be unavailable."
)

self.has_vision = (
self.has_vision_weights
and global_server_args_dict["enable_multimodal"] == True
)

if self.has_vision:
self.vision_model = Llama4VisionModel(config.vision_config)
self.multi_modal_projector = Llama4MultiModalProjector(config)
Expand Down Expand Up @@ -269,7 +275,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:

def _should_skip_weight(self, name: str) -> bool:
"""Check if we should skip loading this weight."""
return "vision" in name and not self.has_vision
return not self.has_vision and (
"vision" in name or "multi_modal_projector" in name
)

def _transform_weight_name(self, name: str) -> str:
"""Transform weight name by adding language_model prefix if needed."""
Expand Down
Loading