@@ -216,6 +216,30 @@ def _parse_and_validate_image_input(
216216
217217 return None
218218
219+ def _select_image_features (self , image_features : torch .Tensor , * ,
220+ strategy : str ) -> torch .Tensor :
221+ # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa
222+ if strategy == "default" :
223+ return image_features [:, 1 :]
224+ elif strategy == "full" :
225+ return image_features
226+
227+ raise ValueError (f"Unexpected select feature strategy: { strategy } " )
228+
229+ def _image_pixels_to_features (self , vision_tower : CLIPVisionModel ,
230+ pixel_values : torch .Tensor ) -> torch .Tensor :
231+ # TODO(xwjiang): Maybe port minimal CLIPVisionModel over.
232+ image_outputs = vision_tower (pixel_values .to (vision_tower .device ),
233+ output_hidden_states = True )
234+
235+ image_features = image_outputs .hidden_states [
236+ self .config .vision_feature_layer ]
237+
238+ return self ._select_image_features (
239+ image_features ,
240+ strategy = self .config .vision_feature_select_strategy ,
241+ )
242+
219243 def _merge_image_patch_embeddings (self , image_size : torch .Tensor ,
220244 patch_embeddings : torch .Tensor , * ,
221245 strategy : str ) -> torch .Tensor :
0 commit comments