Skip to content

Commit a65d78c

Browse files
committed
[https://nvbugs/5441729][test] Fix test_modeling_llama_min_latency.py failures
The test_modeling_llama_min_latency.py::test_llama_allclose_to_hf tests are failing with latest HF transformers due to a bug in their code. A PR has been submitted to fix it in upstream repo: huggingface/transformers#40609 Signed-off-by: Po-Han Huang <[email protected]>
1 parent ba6ab62 commit a65d78c

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,16 +1002,28 @@ def __init__(self, model_config: ModelConfig[Llama4Config], *args,
10021002

10031003
self.dtype = self.pretrained_config.text_config.torch_dtype
10041004

1005-
def load_weights(self):
1005+
def load_weights(self, weights: Dict):
10061006
module_dict = nn.ModuleDict({
10071007
"vision_model":
10081008
Llama4VisionModel(self.pretrained_config.vision_config),
10091009
"multi_modal_projector":
10101010
Llama4MultiModalProjector(self.pretrained_config),
10111011
})
1012-
load_sharded_checkpoint(module_dict,
1013-
self.pretrained_config._name_or_path,
1014-
strict=False)
1012+
1013+
# If the named params are present in the weights, load them directly.
1014+
param_names = [name for name, _ in module_dict.named_parameters()]
1015+
if all(name in weights for name in param_names):
1016+
vision_encoder_weights = {
1017+
name: weights[name]
1018+
for name in param_names
1019+
}
1020+
module_dict.load_state_dict(vision_encoder_weights)
1021+
1022+
# Otherwise, load the weights from the checkpoint.
1023+
else:
1024+
load_sharded_checkpoint(module_dict,
1025+
self.pretrained_config._name_or_path,
1026+
strict=False)
10151027

10161028
self.vision_model = module_dict["vision_model"].to(self.device)
10171029
self.mm_projector = module_dict["multi_modal_projector"].to(self.device)
@@ -1294,7 +1306,7 @@ def infer_max_seq_len(self):
12941306

12951307
def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
12961308
if not DISAGG:
1297-
self.mm_encoder.load_weights()
1309+
self.mm_encoder.load_weights(weights)
12981310

12991311
# Temporarily detach mm_encoder so the TRT-LLM loader doesn't try to load it
13001312
had_mm_encoder = hasattr(self, "mm_encoder")

tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,12 @@ def test_llama_allclose_to_hf(self, scenario: AllCloseScenario) -> None:
266266
attention_backend = "TRTLLM"
267267
metadata_cls = get_attention_backend(attention_backend).Metadata
268268

269-
if transformers.__version__ >= "4.55.0":
269+
if transformers.__version__ >= "4.55.0" \
270+
and transformers.__version__ < "4.56.1":
270271
self.skipTest(
271-
"The transformers 4.55.0 has accuracy issues while 4.33.1 works fine. "
272-
"https://nvbugspro.nvidia.com/bug/5441729")
272+
"The transformers between 4.55.0 and 4.56.1 have accuracy "
273+
"issues for Llama4. See: "
274+
"https://github.com/huggingface/transformers/pull/40609")
273275

274276
torch.random.manual_seed(0)
275277
config_dict = deepcopy(LLAMA_4_MAVERICK_TWO_LAYER_CONFIG)

0 commit comments

Comments
 (0)