Skip to content
26 changes: 22 additions & 4 deletions vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ def __init__(
config: SiglipVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
num_hidden_layers_override: Optional[int] = None,
need_post_layernorm: bool = True,
):
super().__init__()
self.config = config
Expand All @@ -449,8 +450,11 @@ def __init__(
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
)
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
if need_post_layernorm:
self.post_layernorm = nn.LayerNorm(embed_dim,
eps=config.layer_norm_eps)
else:
self.post_layernorm = nn.Identity()
self.use_head = (True if not hasattr(config, "vision_use_head") else
config.vision_use_head)
if self.use_head:
Expand All @@ -470,7 +474,6 @@ def forward(
encoder_outputs = self.encoder(inputs_embeds=hidden_states)

last_hidden_state = self.post_layernorm(encoder_outputs)

# TODO: add this back when pooled_output is used in inference
# if self.use_head:
# pooled_output = self.head(last_hidden_state)
Expand All @@ -489,6 +492,16 @@ def __init__(
num_hidden_layers_override: Optional[int] = None,
):
super().__init__()
if (num_hidden_layers_override is None
or num_hidden_layers_override == config.num_hidden_layers):
self.need_post_layernorm = True
elif num_hidden_layers_override > config.num_hidden_layers:
raise ValueError(
"num_hidden_layers_override cannot be greater than "
"num_hidden_layers")
else:
self.need_post_layernorm = False

num_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0
Expand All @@ -497,7 +510,7 @@ def __init__(
config,
quant_config,
num_hidden_layers_override=num_hidden_layers_override,
)
need_post_layernorm=self.need_post_layernorm)

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
Expand All @@ -517,6 +530,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
layer_count = len(self.vision_model.encoder.layers)

for name, loaded_weight in weights:
# post_layernorm is optional in SiglipVisionModel
if ("vision_model.post_layernorm" in name
and not self.need_post_layernorm):
continue

# omit layers when num_hidden_layers_override is set
if "vision_model.encoder.layers." in name:
layer_idx = int(name.split(".")[3])
Expand Down