diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 743542ec8dfa..2591bbb8ab9d 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -263,6 +263,11 @@ def _call_hf_processor( mm_data, mm_kwargs, ) + if "pixel_values" in processed_outputs: + # Cast pixel values to model dtype already here, + # so we need to transfer less data to the GPU + processed_outputs["pixel_values"] = processed_outputs[ + "pixel_values"].to(self.info.ctx.model_config.dtype) # HF processor pops the `num_crops` kwarg, which is needed by vLLM if (images := mm_data.get("images")) is not None: @@ -549,9 +554,7 @@ def _image_pixels_to_features( vision_tower: SiglipVisionModel, pixel_values: torch.Tensor, ) -> torch.Tensor: - target_dtype = vision_tower.get_input_embeddings().weight.dtype - image_features = vision_tower(pixel_values.to(dtype=target_dtype)) - return image_features + return vision_tower(pixel_values) def _process_image_input( self,