Skip to content

Commit 864ddb3

Browse files
[https://nvbugs/5429689][fix] Fix mllama model structure update with transformers issue (#6699)
Signed-off-by: Wangshanshan <[email protected]>
1 parent 72eda45 commit 864ddb3

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

tensorrt_llm/tools/multimodal_builder.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,8 +1188,18 @@ def forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):
11881188
model = MllamaForConditionalGeneration.from_pretrained(args.model_path,
11891189
torch_dtype='auto',
11901190
device_map='auto')
1191-
wrapper = MLLaMAVisionWrapper(model.vision_model,
1192-
model.multi_modal_projector)
1191+
1192+
# Check if the model structure is updated to transformers >= 4.52.0
1193+
if hasattr(model, 'model') and hasattr(model.model, 'vision_model'):
1194+
vision_model = model.model.vision_model
1195+
multi_modal_projector = model.model.multi_modal_projector
1196+
else:
1197+
# transformers < 4.52.0
1198+
vision_model = model.vision_model
1199+
multi_modal_projector = model.multi_modal_projector
1200+
1201+
wrapper = MLLaMAVisionWrapper(vision_model, multi_modal_projector)
1202+
11931203
model_dtype = model.dtype
11941204
image = Image.new('RGB', [2048, 2688]) # dummy image
11951205
inputs = processor(images=image,

0 commit comments

Comments
 (0)