Skip to content

Commit 71d1761

Browse files
DarkLight1337Isotr0py
authored andcommitted
[VLM] Enable overriding whether post layernorm is used in vision encoder + fix quant args (vllm-project#9217)
Co-authored-by: Isotr0py <[email protected]> Signed-off-by: Erkin Sagiroglu <[email protected]>
1 parent e067bba commit 71d1761

File tree

18 files changed

+551
-253
lines changed

18 files changed

+551
-253
lines changed

vllm/model_executor/layers/quantization/awq.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import torch
44

55
from vllm import _custom_ops as ops
6-
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
6+
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
7+
UnquantizedLinearMethod)
78
from vllm.model_executor.layers.quantization.base_config import (
89
QuantizationConfig)
910
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
@@ -21,10 +22,12 @@ def __init__(
2122
weight_bits: int,
2223
group_size: int,
2324
zero_point: bool,
25+
modules_to_not_convert: Optional[List[str]] = None,
2426
) -> None:
2527
self.weight_bits = weight_bits
2628
self.group_size = group_size
2729
self.zero_point = zero_point
30+
self.modules_to_not_convert = modules_to_not_convert or []
2831

2932
if self.weight_bits != 4:
3033
raise ValueError(
@@ -35,7 +38,8 @@ def __init__(
3538
def __repr__(self) -> str:
3639
return (f"AWQConfig(weight_bits={self.weight_bits}, "
3740
f"group_size={self.group_size}, "
38-
f"zero_point={self.zero_point})")
41+
f"zero_point={self.zero_point}, "
42+
f"modules_to_not_convert={self.modules_to_not_convert})")
3943

4044
def get_name(self) -> str:
4145
return "awq"
@@ -61,18 +65,26 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
6165
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
6266
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
6367
zero_point = cls.get_from_keys(config, ["zero_point"])
64-
return cls(weight_bits, group_size, zero_point)
68+
modules_to_not_convert = cls.get_from_keys_or(
69+
config, ["modules_to_not_convert"], None)
70+
return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
6571

6672
def get_quant_method(self, layer: torch.nn.Module,
67-
prefix: str) -> Optional["AWQLinearMethod"]:
73+
prefix: str) -> Optional["LinearMethodBase"]:
6874
if isinstance(layer, LinearBase):
75+
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
76+
return UnquantizedLinearMethod()
6977
return AWQLinearMethod(self)
7078
return None
7179

7280
def get_scaled_act_names(self) -> List[str]:
7381
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
7482

7583

84+
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
85+
return any(module_name in prefix for module_name in modules_to_not_convert)
86+
87+
7688
class AWQLinearMethod(LinearMethodBase):
7789
"""Linear method for AWQ.
7890

vllm/model_executor/models/blip.py

Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def input_processor_for_blip(
122122
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
123123
class BlipVisionEmbeddings(nn.Module):
124124

125-
def __init__(self, config: BlipVisionConfig):
125+
def __init__(self, config: Union[BlipVisionConfig, Blip2VisionConfig]):
126126
super().__init__()
127127

128128
self.config = config
@@ -167,9 +167,10 @@ class BlipParallelAttention(nn.Module):
167167

168168
def __init__(
169169
self,
170-
config: BlipVisionConfig,
170+
config: Union[BlipVisionConfig, Blip2VisionConfig],
171171
quant_config: Optional[QuantizationConfig] = None,
172-
):
172+
prefix: str = "",
173+
) -> None:
173174
super().__init__()
174175
self.config = config
175176
self.embed_dim = config.hidden_size
@@ -189,11 +190,13 @@ def __init__(
189190
self.num_heads,
190191
bias=config.qkv_bias,
191192
quant_config=quant_config,
193+
prefix=f"{prefix}.qkv",
192194
)
193195
self.projection = RowParallelLinear(
194196
self.embed_dim,
195197
self.embed_dim,
196198
quant_config=quant_config,
199+
prefix=f"{prefix}.projection",
197200
)
198201

199202
self.tp_size = get_tensor_model_parallel_world_size()
@@ -235,9 +238,12 @@ def forward(
235238

236239
class BlipMLP(nn.Module):
237240

238-
def __init__(self,
239-
config: BlipVisionConfig,
240-
quant_config: Optional[QuantizationConfig] = None):
241+
def __init__(
242+
self,
243+
config: BlipVisionConfig,
244+
quant_config: Optional[QuantizationConfig] = None,
245+
prefix: str = "",
246+
) -> None:
241247
super().__init__()
242248

243249
self.config = config
@@ -246,11 +252,13 @@ def __init__(self,
246252
self.fc1 = ColumnParallelLinear(config.hidden_size,
247253
config.intermediate_size,
248254
bias=True,
249-
quant_config=quant_config)
255+
quant_config=quant_config,
256+
prefix=f"{prefix}.fc1")
250257
self.fc2 = RowParallelLinear(config.intermediate_size,
251258
config.hidden_size,
252259
bias=True,
253-
quant_config=quant_config)
260+
quant_config=quant_config,
261+
prefix=f"{prefix}.fc2")
254262

255263
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
256264
hidden_states, _ = self.fc1(hidden_states)
@@ -262,24 +270,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
262270

263271
class BlipEncoderLayer(nn.Module):
264272

265-
def __init__(self,
266-
config: BlipVisionConfig,
267-
quant_config: Optional[QuantizationConfig] = None):
273+
def __init__(
274+
self,
275+
config: BlipVisionConfig,
276+
quant_config: Optional[QuantizationConfig] = None,
277+
prefix: str = "",
278+
) -> None:
268279
super().__init__()
269280

270281
# fallback to sdpa attention if tp unavailable
271282
num_heads = config.num_attention_heads
272283
tp_size = get_tensor_model_parallel_world_size()
273284
if USE_XFORMERS_OPS and num_heads % tp_size == 0:
274-
self.self_attn = BlipParallelAttention(config,
275-
quant_config=quant_config)
285+
self.self_attn = BlipParallelAttention(
286+
config,
287+
quant_config=quant_config,
288+
prefix=f"{prefix}.self_attn",
289+
)
276290
else:
277291
# Blip doesn't have SDPA attention implemented in transformers
278292
# use eager attention instead for cpu backend
279293
self.self_attn = BlipAttention(config)
280294
self.layer_norm1 = nn.LayerNorm(config.hidden_size,
281295
eps=config.layer_norm_eps)
282-
self.mlp = BlipMLP(config, quant_config=quant_config)
296+
self.mlp = BlipMLP(config,
297+
quant_config=quant_config,
298+
prefix=f"{prefix}.mlp")
283299
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
284300
eps=config.layer_norm_eps)
285301

@@ -307,10 +323,13 @@ class BlipEncoder(nn.Module):
307323
config: BlipConfig
308324
"""
309325

310-
def __init__(self,
311-
config: BlipVisionConfig,
312-
quant_config: Optional[QuantizationConfig] = None,
313-
num_hidden_layers_override: Optional[int] = None):
326+
def __init__(
327+
self,
328+
config: BlipVisionConfig,
329+
quant_config: Optional[QuantizationConfig] = None,
330+
num_hidden_layers_override: Optional[int] = None,
331+
prefix: str = "",
332+
) -> None:
314333
super().__init__()
315334

316335
self.config = config
@@ -321,8 +340,10 @@ def __init__(self,
321340
num_hidden_layers = num_hidden_layers_override
322341

323342
self.layers = nn.ModuleList([
324-
BlipEncoderLayer(config=config, quant_config=quant_config)
325-
for _ in range(num_hidden_layers)
343+
BlipEncoderLayer(config=config,
344+
quant_config=quant_config,
345+
prefix=f"{prefix}.layers.{layer_idx}")
346+
for layer_idx in range(num_hidden_layers)
326347
])
327348

328349
def forward(self, inputs_embeds: torch.Tensor):
@@ -337,10 +358,15 @@ class BlipVisionModel(nn.Module):
337358
config_class = BlipVisionConfig
338359
main_input_name = "pixel_values"
339360

340-
def __init__(self,
341-
config: BlipVisionConfig,
342-
quant_config: Optional[QuantizationConfig] = None,
343-
num_hidden_layers_override: Optional[int] = None):
361+
def __init__(
362+
self,
363+
config: BlipVisionConfig,
364+
quant_config: Optional[QuantizationConfig] = None,
365+
*,
366+
num_hidden_layers_override: Optional[int] = None,
367+
require_post_norm: Optional[bool] = None,
368+
prefix: str = "",
369+
) -> None:
344370
super().__init__()
345371

346372
tp_size = get_tensor_model_parallel_world_size()
@@ -354,19 +380,24 @@ def __init__(self,
354380
config=config,
355381
quant_config=quant_config,
356382
num_hidden_layers_override=num_hidden_layers_override,
383+
prefix=f"{prefix}.encoder",
357384
)
358385

386+
num_hidden_layers = config.num_hidden_layers
359387
if len(self.encoder.layers) > config.num_hidden_layers:
360388
raise ValueError(
361-
f"The original encoder only has {config.num_hidden_layers} "
389+
f"The original encoder only has {num_hidden_layers} "
362390
f"layers, but you requested {len(self.encoder.layers)} layers."
363391
)
364-
elif len(self.encoder.layers) == config.num_hidden_layers:
392+
393+
# If possible, skip post_layernorm to conserve memory
394+
if require_post_norm is None:
395+
require_post_norm = len(self.encoder.layers) == num_hidden_layers
396+
397+
if require_post_norm:
365398
self.post_layernorm = nn.LayerNorm(config.hidden_size,
366399
eps=config.layer_norm_eps)
367400
else:
368-
# post_layernorm is unused when we extract intermediate features
369-
# In this case, we can skip it to conserve memory
370401
self.post_layernorm = None
371402

372403
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:

vllm/model_executor/models/blip2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def __init__(self,
490490
self.multimodal_config = multimodal_config
491491

492492
# TODO: Optionally initializes this for supporting embeddings.
493-
self.vision_model = BlipVisionModel(config.vision_config)
493+
self.vision_model = BlipVisionModel(config.vision_config, quant_config)
494494

495495
self.query_tokens = nn.Parameter(
496496
torch.zeros(1, config.num_query_tokens,

0 commit comments

Comments
 (0)