3838IMAGENET_MEAN = (0.485 , 0.456 , 0.406 )
3939IMAGENET_STD = (0.229 , 0.224 , 0.225 )
4040
41- MAX_IMAGE_FEATURE_SIZE_WIDTH = 3000
42- MAX_IMAGE_FEATURE_SIZE_HEIGHT = 500
43-
4441
4542class InternVLImagePixelInputs (TypedDict ):
4643 type : Literal ["pixel_values" ]
@@ -84,11 +81,9 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
8481 return best_ratio
8582
8683
87- def calculate_num_blocks (orig_width : int ,
88- orig_height : int ,
89- min_num = 1 ,
90- max_num = 6 ,
91- image_size = 448 ):
84+ def calculate_num_blocks (orig_width : int , orig_height : int , min_num : int ,
85+ max_num : int ,
86+ image_size : int ) -> Tuple [int , int , int ]:
9287 aspect_ratio = orig_width / orig_height
9388
9489 # calculate the existing image aspect ratio
@@ -110,11 +105,9 @@ def calculate_num_blocks(orig_width: int,
110105
111106
112107# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
113- def dynamic_preprocess (image ,
114- min_num = 1 ,
115- max_num = 6 ,
116- image_size = 448 ,
117- use_thumbnail = False ):
108+ def dynamic_preprocess (image : Image .Image , min_num : int , max_num : int ,
109+ image_size : int ,
110+ use_thumbnail : int ) -> List [Image .Image ]:
118111 orig_width , orig_height = image .size
119112
120113 blocks , target_width , target_height = calculate_num_blocks (
@@ -138,12 +131,14 @@ def dynamic_preprocess(image,
138131
139132
140133# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
141- def image_to_pixel_values (image : Image .Image , input_size = 448 , max_num = 6 ):
134+ def image_to_pixel_values (image : Image .Image , input_size : int , min_num : int ,
135+ max_num : int , use_thumbnail : bool ) -> torch .Tensor :
142136 transform = build_transform (input_size = input_size )
143137 images = dynamic_preprocess (image ,
138+ min_num = min_num ,
139+ max_num = max_num ,
144140 image_size = input_size ,
145- use_thumbnail = True ,
146- max_num = max_num )
141+ use_thumbnail = use_thumbnail )
147142 pixel_values = [transform (image ) for image in images ]
148143 pixel_values = torch .stack (pixel_values )
149144 return pixel_values
@@ -159,12 +154,18 @@ def get_internvl_num_patches(image_size: int, patch_size: int,
159154def get_max_internvl_image_tokens (ctx : InputContext ):
160155 hf_config = ctx .get_hf_config (PretrainedConfig )
161156 vision_config = hf_config .vision_config
157+
158+ use_thumbnail = hf_config .use_thumbnail
159+ max_dynamic_patch = hf_config .max_dynamic_patch
160+ if use_thumbnail :
161+ max_dynamic_patch += 1
162+ downsample_ratio = hf_config .downsample_ratio
163+
162164 image_size = vision_config .image_size
163165 patch_size = vision_config .patch_size
164- downsample_ratio = hf_config .downsample_ratio
165166 num_patches = get_internvl_num_patches (image_size , patch_size ,
166167 downsample_ratio )
167- return num_patches * 7
168+ return num_patches * max_dynamic_patch
168169
169170
170171def input_processor_for_internvl (ctx : InputContext , llm_inputs : LLMInputs ):
@@ -176,30 +177,35 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
176177 hf_config = ctx .get_hf_config (PretrainedConfig )
177178 vision_config = hf_config .vision_config
178179
180+ image_size = vision_config .image_size
181+ patch_size = vision_config .patch_size
182+ downsample_ratio = hf_config .downsample_ratio
183+ num_patches = get_internvl_num_patches (image_size , patch_size ,
184+ downsample_ratio )
185+
179186 image_data = multi_modal_data ["image" ]
180187 if isinstance (image_data , Image .Image ):
181188 width , height = image_data .size
182- num_blocks , _ , _ = calculate_num_blocks (width , height )
189+ min_num = hf_config .min_dynamic_patch
190+ max_num = hf_config .max_dynamic_patch
191+ num_blocks , _ , _ = calculate_num_blocks (width , height , min_num ,
192+ max_num , image_size )
193+ # add thumbnail image if num_blocks > 1
194+ if hf_config .use_thumbnail and num_blocks > 1 :
195+ num_blocks += 1
183196 elif isinstance (image_data , torch .Tensor ):
184197 raise NotImplementedError ("Embeddings input is not supported yet" )
185198 else :
186199 raise TypeError (f"Invalid image type: { type (image_data )} " )
187200
188- image_size = vision_config .image_size
189- patch_size = vision_config .patch_size
190- downsample_ratio = hf_config .downsample_ratio
191- num_patches = get_internvl_num_patches (image_size , patch_size ,
192- downsample_ratio )
193-
194201 tokenizer = cached_get_tokenizer (model_config .tokenizer ,
195202 trust_remote_code = True )
196203
197204 prompt = llm_inputs .get ("prompt" )
198205 prompt_token_ids = llm_inputs ["prompt_token_ids" ]
199206 if prompt is None :
200207 prompt = tokenizer .decode (prompt_token_ids )
201- image_prompt = IMG_START + IMG_CONTEXT * (num_blocks +
202- 1 ) * num_patches + IMG_END
208+ image_prompt = IMG_START + IMG_CONTEXT * num_blocks * num_patches + IMG_END
203209 new_prompt = prompt .replace ('<image>' , image_prompt , 1 )
204210 new_prompt_token_ids = tokenizer .encode (new_prompt )
205211
@@ -209,8 +215,19 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
209215
210216
211217def input_mapper_for_internvl (ctx : InputContext , data : object ):
218+ hf_config = ctx .get_hf_config (PretrainedConfig )
219+
220+ use_thumbnail = hf_config .use_thumbnail
221+ min_num = hf_config .min_dynamic_patch
222+ max_num = hf_config .max_dynamic_patch
223+ image_size = hf_config .vision_config .image_size
224+
212225 if isinstance (data , Image .Image ):
213- data = image_to_pixel_values (data )
226+ data = image_to_pixel_values (data ,
227+ image_size ,
228+ min_num ,
229+ max_num ,
230+ use_thumbnail = use_thumbnail )
214231 model_config = ctx .model_config
215232 tokenizer = cached_get_tokenizer (model_config .tokenizer ,
216233 trust_remote_code = True )
@@ -240,10 +257,17 @@ def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
240257 add_special_tokens = False )[0 ],
241258 image_feature_size_override = image_feature_size ,
242259 )
260+
261+ image_size = vision_config .image_size
262+ min_num = hf_config .min_dynamic_patch
263+ max_num = hf_config .max_dynamic_patch
264+ max_image_width = max_num * image_size
265+ max_image_height = min_num * image_size
266+
243267 mm_data = dummy_image_for_clip (
244268 vision_config ,
245- image_width_override = MAX_IMAGE_FEATURE_SIZE_WIDTH ,
246- image_height_override = MAX_IMAGE_FEATURE_SIZE_HEIGHT ,
269+ image_width_override = max_image_width ,
270+ image_height_override = max_image_height ,
247271 )
248272
249273 return seq_data , mm_data
0 commit comments