Skip to content

Commit 34c1e9c

Browse files
authored
[None][feat] Skip prefetching consolidated safetensors when appropriate (#7225)
* Why? Some models (e.g. anything produced by Mistral) can have both sharded safetensors and a consolidated safetensor in the same checkpoint directory. In such cases, prefetching both to memory is a waste of time, and memory. * What? This commit skips over consolidated safetensors when they are not the only safetensor file present in the checkpoint directory. Signed-off-by: William Zhang <[email protected]>
1 parent 85b4ae2 commit 34c1e9c

File tree

4 files changed

+95
-0
lines changed

4 files changed

+95
-0
lines changed

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ extend_skip_glob = [
3333
"tensorrt_llm/top_model_mixin.py",
3434
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
3535
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
36+
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
3637
]
3738

3839
[tool.yapf]
@@ -63,6 +64,7 @@ ignore_patterns = [
6364
"tensorrt_llm/top_model_mixin.py",
6465
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
6566
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
67+
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
6668
]
6769

6870
[tool.codespell]
@@ -97,6 +99,7 @@ exclude = [
9799
"tensorrt_llm/top_model_mixin.py",
98100
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
99101
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
102+
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
100103
]
101104

102105

@@ -140,6 +143,7 @@ include = [
140143
"tensorrt_llm/top_model_mixin.py",
141144
"tests/unittest/_torch/modeling/test_modeling_mistral.py",
142145
"tests/unittest/_torch/modeling/test_modeling_pixtral.py",
146+
"tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py",
143147
]
144148
exclude = [
145149
"**3rdparty/**",

tensorrt_llm/_torch/models/checkpoints/hf/weight_loader.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@ class HfWeightLoader(BaseWeightLoader):
2626

2727
def load_weights(self, checkpoint_dir: str) -> dict[str, Any]:
2828
weight_files = glob.glob(f"{checkpoint_dir}/*.safetensors")
29+
# Some model checkpoint directories contain not only the sharded safetensors, but one
30+
# consolidated tensor. In the presence of both, we favor the former, as there really is no need
31+
# to prefetch the (usually) ridiculously large consolidated tensor into memory in such a case.
32+
filtered_weight_files = [
33+
x for x in weight_files if "consolidated" not in os.path.split(x)[1]
34+
]
35+
if len(filtered_weight_files) > 0:
36+
weight_files = filtered_weight_files
2937
if weight_files:
3038
# Prefetch the weight files to CPU memory if the size is less than 90% of the available memory.
3139
# This is a heuristic to avoid prefetching files that are too large and causing file cache thrashing.

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ l0_a10:
1616
# ------------- PyTorch tests ---------------
1717
- unittest/_torch/modeling/test_modeling_mistral.py
1818
- unittest/_torch/modeling/test_modeling_pixtral.py
19+
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no
20+
# test list either).
21+
- unittest/_torch/models/checkpoints/hf/test_weight_loader.py
1922
- disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0]
2023
- disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun_trt_backend[TinyLlama-1.1B-Chat-v1.0]
2124
- disaggregated/test_disaggregated.py::test_disaggregated_cuda_graph[TinyLlama-1.1B-Chat-v1.0]
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from unittest import mock
2+
3+
import pytest
4+
5+
from tensorrt_llm._torch.models.checkpoints import HfWeightLoader
6+
7+
8+
class MyError(Exception):
9+
pass
10+
11+
12+
@pytest.mark.parametrize(
13+
"dir_name, safetensor_filenames, expected_safetensor_filenames",
14+
[
15+
(
16+
"foo",
17+
[
18+
"model-00001-of-00002.safetensors",
19+
"model-000002-of-00002.safetensors",
20+
"consolidated.safetensors",
21+
],
22+
["model-00001-of-00002.safetensors", "model-000002-of-00002.safetensors"],
23+
),
24+
(
25+
"foo",
26+
[
27+
*(f"model-0000{i}-of-00010.safetensors" for i in range(1, 11)),
28+
"foo-consolidated.safetensors",
29+
],
30+
[f"model-0000{i}-of-00010.safetensors" for i in range(1, 11)],
31+
),
32+
# If there is only a consolidated safetensor, that one should still be used.
33+
(
34+
"foo",
35+
["consolidated.safetensors"],
36+
["consolidated.safetensors"],
37+
),
38+
# If the directory contains "consolidated" in its name, but its contents are sharded tensors.
39+
(
40+
"consolidated-model",
41+
[
42+
"model-00001-of-00002.safetensors",
43+
"model-000002-of-00002.safetensors",
44+
"consolidated.safetensors",
45+
],
46+
["model-00001-of-00002.safetensors", "model-000002-of-00002.safetensors"],
47+
),
48+
],
49+
)
50+
def test_load_weights_ignores_consolidated_ckpt_when_sharded_ckpt_exists(
51+
tmp_path,
52+
dir_name: str,
53+
safetensor_filenames: list[str],
54+
expected_safetensor_filenames: list[str],
55+
):
56+
checkpoint_dir = tmp_path / dir_name
57+
checkpoint_dir.mkdir()
58+
for filename in safetensor_filenames:
59+
(checkpoint_dir / filename).touch()
60+
expected_safetensor_filenames = set(
61+
str(checkpoint_dir / filename) for filename in expected_safetensor_filenames
62+
)
63+
64+
loader = HfWeightLoader()
65+
with (
66+
mock.patch.object(
67+
loader, "_load_weights_in_parallel", side_effect=MyError
68+
) as load_weights_in_parallel,
69+
mock.patch.object(loader, "prefetch_files") as prefetch_files,
70+
pytest.raises(MyError),
71+
):
72+
loader.load_weights(checkpoint_dir=str(checkpoint_dir))
73+
74+
prefetch_files.assert_called_once()
75+
prefetched_files = prefetch_files.call_args[0][0]
76+
assert set(prefetched_files) == expected_safetensor_filenames
77+
78+
load_weights_in_parallel.assert_called_once()
79+
loaded_weight_files = load_weights_in_parallel.call_args[0][0]
80+
assert set(loaded_weight_files) == expected_safetensor_filenames

0 commit comments

Comments
 (0)