Skip to content

Commit 4a84a37

Browse files
[https://nvbugs/5429689][fix] Fix mllama model structure update with transformers issue (NVIDIA#6699)
Signed-off-by: Wangshanshan <[email protected]>
1 parent 00a95a2 commit 4a84a37

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

tensorrt_llm/tools/multimodal_builder.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,8 +1190,18 @@ def forward(self, pixel_values, aspect_ratio_ids, aspect_ratio_mask):
11901190
model = MllamaForConditionalGeneration.from_pretrained(args.model_path,
11911191
torch_dtype='auto',
11921192
device_map='auto')
1193-
wrapper = MLLaMAVisionWrapper(model.vision_model,
1194-
model.multi_modal_projector)
1193+
1194+
# Check if the model structure is updated to transformers >= 4.52.0
1195+
if hasattr(model, 'model') and hasattr(model.model, 'vision_model'):
1196+
vision_model = model.model.vision_model
1197+
multi_modal_projector = model.model.multi_modal_projector
1198+
else:
1199+
# transformers < 4.52.0
1200+
vision_model = model.vision_model
1201+
multi_modal_projector = model.multi_modal_projector
1202+
1203+
wrapper = MLLaMAVisionWrapper(vision_model, multi_modal_projector)
1204+
11951205
model_dtype = model.dtype
11961206
image = Image.new('RGB', [2048, 2688]) # dummy image
11971207
inputs = processor(images=image,

0 commit comments

Comments
 (0)