@@ -122,7 +122,7 @@ def input_processor_for_blip(
122122# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa
123123class BlipVisionEmbeddings (nn .Module ):
124124
125- def __init__ (self , config : BlipVisionConfig ):
125+ def __init__ (self , config : Union [ BlipVisionConfig , Blip2VisionConfig ] ):
126126 super ().__init__ ()
127127
128128 self .config = config
@@ -167,9 +167,10 @@ class BlipParallelAttention(nn.Module):
167167
168168 def __init__ (
169169 self ,
170- config : BlipVisionConfig ,
170+ config : Union [ BlipVisionConfig , Blip2VisionConfig ] ,
171171 quant_config : Optional [QuantizationConfig ] = None ,
172- ):
172+ prefix : str = "" ,
173+ ) -> None :
173174 super ().__init__ ()
174175 self .config = config
175176 self .embed_dim = config .hidden_size
@@ -189,11 +190,13 @@ def __init__(
189190 self .num_heads ,
190191 bias = config .qkv_bias ,
191192 quant_config = quant_config ,
193+ prefix = f"{ prefix } .qkv" ,
192194 )
193195 self .projection = RowParallelLinear (
194196 self .embed_dim ,
195197 self .embed_dim ,
196198 quant_config = quant_config ,
199+ prefix = f"{ prefix } .projection" ,
197200 )
198201
199202 self .tp_size = get_tensor_model_parallel_world_size ()
@@ -235,9 +238,12 @@ def forward(
235238
236239class BlipMLP (nn .Module ):
237240
238- def __init__ (self ,
239- config : BlipVisionConfig ,
240- quant_config : Optional [QuantizationConfig ] = None ):
241+ def __init__ (
242+ self ,
243+ config : BlipVisionConfig ,
244+ quant_config : Optional [QuantizationConfig ] = None ,
245+ prefix : str = "" ,
246+ ) -> None :
241247 super ().__init__ ()
242248
243249 self .config = config
@@ -246,11 +252,13 @@ def __init__(self,
246252 self .fc1 = ColumnParallelLinear (config .hidden_size ,
247253 config .intermediate_size ,
248254 bias = True ,
249- quant_config = quant_config )
255+ quant_config = quant_config ,
256+ prefix = f"{ prefix } .fc1" )
250257 self .fc2 = RowParallelLinear (config .intermediate_size ,
251258 config .hidden_size ,
252259 bias = True ,
253- quant_config = quant_config )
260+ quant_config = quant_config ,
261+ prefix = f"{ prefix } .fc2" )
254262
255263 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
256264 hidden_states , _ = self .fc1 (hidden_states )
@@ -262,24 +270,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
262270
263271class BlipEncoderLayer (nn .Module ):
264272
265- def __init__ (self ,
266- config : BlipVisionConfig ,
267- quant_config : Optional [QuantizationConfig ] = None ):
273+ def __init__ (
274+ self ,
275+ config : BlipVisionConfig ,
276+ quant_config : Optional [QuantizationConfig ] = None ,
277+ prefix : str = "" ,
278+ ) -> None :
268279 super ().__init__ ()
269280
270281 # fallback to sdpa attention if tp unavailable
271282 num_heads = config .num_attention_heads
272283 tp_size = get_tensor_model_parallel_world_size ()
273284 if USE_XFORMERS_OPS and num_heads % tp_size == 0 :
274- self .self_attn = BlipParallelAttention (config ,
275- quant_config = quant_config )
285+ self .self_attn = BlipParallelAttention (
286+ config ,
287+ quant_config = quant_config ,
288+ prefix = f"{ prefix } .self_attn" ,
289+ )
276290 else :
277291 # Blip doesn't have SDPA attention implemented in transformers
278292 # use eager attention instead for cpu backend
279293 self .self_attn = BlipAttention (config )
280294 self .layer_norm1 = nn .LayerNorm (config .hidden_size ,
281295 eps = config .layer_norm_eps )
282- self .mlp = BlipMLP (config , quant_config = quant_config )
296+ self .mlp = BlipMLP (config ,
297+ quant_config = quant_config ,
298+ prefix = f"{ prefix } .mlp" )
283299 self .layer_norm2 = nn .LayerNorm (config .hidden_size ,
284300 eps = config .layer_norm_eps )
285301
@@ -307,10 +323,13 @@ class BlipEncoder(nn.Module):
307323 config: BlipConfig
308324 """
309325
310- def __init__ (self ,
311- config : BlipVisionConfig ,
312- quant_config : Optional [QuantizationConfig ] = None ,
313- num_hidden_layers_override : Optional [int ] = None ):
326+ def __init__ (
327+ self ,
328+ config : BlipVisionConfig ,
329+ quant_config : Optional [QuantizationConfig ] = None ,
330+ num_hidden_layers_override : Optional [int ] = None ,
331+ prefix : str = "" ,
332+ ) -> None :
314333 super ().__init__ ()
315334
316335 self .config = config
@@ -321,8 +340,10 @@ def __init__(self,
321340 num_hidden_layers = num_hidden_layers_override
322341
323342 self .layers = nn .ModuleList ([
324- BlipEncoderLayer (config = config , quant_config = quant_config )
325- for _ in range (num_hidden_layers )
343+ BlipEncoderLayer (config = config ,
344+ quant_config = quant_config ,
345+ prefix = f"{ prefix } .layers.{ layer_idx } " )
346+ for layer_idx in range (num_hidden_layers )
326347 ])
327348
328349 def forward (self , inputs_embeds : torch .Tensor ):
@@ -337,10 +358,15 @@ class BlipVisionModel(nn.Module):
337358 config_class = BlipVisionConfig
338359 main_input_name = "pixel_values"
339360
340- def __init__ (self ,
341- config : BlipVisionConfig ,
342- quant_config : Optional [QuantizationConfig ] = None ,
343- num_hidden_layers_override : Optional [int ] = None ):
361+ def __init__ (
362+ self ,
363+ config : BlipVisionConfig ,
364+ quant_config : Optional [QuantizationConfig ] = None ,
365+ * ,
366+ num_hidden_layers_override : Optional [int ] = None ,
367+ require_post_norm : Optional [bool ] = None ,
368+ prefix : str = "" ,
369+ ) -> None :
344370 super ().__init__ ()
345371
346372 tp_size = get_tensor_model_parallel_world_size ()
@@ -354,19 +380,24 @@ def __init__(self,
354380 config = config ,
355381 quant_config = quant_config ,
356382 num_hidden_layers_override = num_hidden_layers_override ,
383+ prefix = f"{ prefix } .encoder" ,
357384 )
358385
386+ num_hidden_layers = config .num_hidden_layers
359387 if len (self .encoder .layers ) > config .num_hidden_layers :
360388 raise ValueError (
361- f"The original encoder only has { config . num_hidden_layers } "
389+ f"The original encoder only has { num_hidden_layers } "
362390 f"layers, but you requested { len (self .encoder .layers )} layers."
363391 )
364- elif len (self .encoder .layers ) == config .num_hidden_layers :
392+
393+ # If possible, skip post_layernorm to conserve memory
394+ if require_post_norm is None :
395+ require_post_norm = len (self .encoder .layers ) == num_hidden_layers
396+
397+ if require_post_norm :
365398 self .post_layernorm = nn .LayerNorm (config .hidden_size ,
366399 eps = config .layer_norm_eps )
367400 else :
368- # post_layernorm is unused when we extract intermediate features
369- # In this case, we can skip it to conserve memory
370401 self .post_layernorm = None
371402
372403 def forward (self , pixel_values : torch .Tensor ) -> torch .Tensor :
0 commit comments