@@ -58,12 +58,31 @@ def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
5858 nn .Linear (llm_intermediate_size , llm_hidden_size , bias = False ),
5959 )
6060
61- def _init_vision_model (self , config : PretrainedConfig ,
62- quant_config : Optional [QuantizationConfig ],
63- num_hidden_layers : int ):
64- # We added additional dummy heads to the original num of heads to make
65- # the number of heads divisible by 8.
66- return InternVisionModel (config .vision_config ,
67- quant_config = quant_config ,
68- num_hidden_layers_override = num_hidden_layers ,
69- num_dummy_heads = 7 )
61+ def _init_vision_model (
62+ self ,
63+ config : PretrainedConfig ,
64+ quant_config : Optional [QuantizationConfig ],
65+ * ,
66+ is_mono : bool ,
67+ prefix : str ,
68+ ):
69+ if not is_mono :
70+ vision_feature_layer = config .select_layer
71+ if vision_feature_layer < 0 :
72+ num_hidden_layers = config .vision_config .num_hidden_layers \
73+ + vision_feature_layer + 1
74+ else :
75+ num_hidden_layers = vision_feature_layer + 1
76+
77+ # We added additional dummy heads to the original num of heads to
78+ # make the number of heads divisible by 8.
79+ return InternVisionModel (
80+ config .vision_config ,
81+ quant_config = quant_config ,
82+ num_hidden_layers_override = num_hidden_layers ,
83+ num_dummy_heads = 7 ,
84+ prefix = prefix ,
85+ )
86+ else :
87+ msg = "Monolith mode is not applicable to NVLM_D"
88+ raise NotImplementedError (msg )
0 commit comments