Skip to content

Commit 271d4b0

Browse files
2ez4bzNVShreyas
authored andcommitted
[fix] Fix Mistral3VLM weight-loading & enable in pre-merge (NVIDIA#6105)
Signed-off-by: William Zhang <[email protected]> Signed-off-by: Shreyas Misra <[email protected]>
1 parent 35ee1a5 commit 271d4b0

File tree

4 files changed

+17
-7
lines changed

4 files changed

+17
-7
lines changed

tensorrt_llm/_torch/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .modeling_hyperclovax import HCXVisionForCausalLM
1111
from .modeling_llama import LlamaForCausalLM
1212
from .modeling_llava_next import LlavaNextModel
13-
from .modeling_mistral import MistralForCausalLM
13+
from .modeling_mistral import Mistral3VLM, MistralForCausalLM
1414
from .modeling_mixtral import MixtralForCausalLM
1515
from .modeling_nemotron import NemotronForCausalLM
1616
from .modeling_nemotron_h import NemotronHForCausalLM
@@ -39,6 +39,7 @@
3939
"HCXVisionForCausalLM",
4040
"LlamaForCausalLM",
4141
"LlavaNextModel",
42+
"Mistral3VLM",
4243
"MistralForCausalLM",
4344
"MixtralForCausalLM",
4445
"NemotronForCausalLM",

tensorrt_llm/_torch/models/modeling_mistral.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ def __init__(
296296

297297
llm_model_config = self._get_sub_model_config(model_config,
298298
"text_config")
299+
# This is necessary for the auto weight mapper to figure out what it needs.
300+
llm_model_config.pretrained_config.architectures = config.architectures
299301
self.llm = MistralForCausalLM(llm_model_config)
300302

301303
self._device = "cuda"

tests/integration/defs/local_venv.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55
import copy
66
import os
7+
import shlex
78
import subprocess
89
import tempfile
910
import textwrap as tw
@@ -116,12 +117,17 @@ def run_cmd(self,
116117
new_env = os.environ
117118

118119
if caller.__name__ == 'check_output':
119-
result = subprocess.run(call_args,
120-
env=new_env,
121-
check=True,
122-
capture_output=True,
123-
**kwargs)
124-
return result.stdout.decode('utf-8')
120+
try:
121+
result = subprocess.run(call_args,
122+
env=new_env,
123+
check=True,
124+
capture_output=True,
125+
**kwargs)
126+
return result.stdout.decode('utf-8')
127+
except subprocess.CalledProcessError as e:
128+
raise RuntimeError(f"Failed to run `{shlex.join(e.cmd)}`:\n"
129+
f"Stdout: {e.stdout.decode()}\n"
130+
f"Stderr: {e.stderr.decode()}\n")
125131
else:
126132
print(f"Start subprocess with {caller}({call_args}, env={new_env})")
127133
return caller(call_args, env=new_env, **kwargs)

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ l0_h100:
193193
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency]
194194
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding[llguidance]
195195
- test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
196+
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
196197
- condition:
197198
ranges:
198199
system_gpu_count:

0 commit comments

Comments
 (0)