|
8 | 8 | import gguf |
9 | 9 | from sentencepiece import SentencePieceProcessor # type: ignore[import] |
10 | 10 |
|
| 11 | +try: |
| 12 | + from safetensors import safe_open |
| 13 | +except ImportError: |
| 14 | + print("Please install `safetensors` python package") |
| 15 | + sys.exit(1) |
| 16 | + |
11 | 17 |
|
12 | 18 | def count_model_parts(dir_model: Path) -> int: |
| 19 | + # get number of model parts |
13 | 20 | num_parts = 0 |
14 | 21 | for filename in os.listdir(dir_model): |
15 | | - if filename.startswith("pytorch_model-"): |
| 22 | + if filename.startswith("model-00"): |
16 | 23 | num_parts += 1 |
17 | 24 |
|
18 | 25 | if num_parts > 0: |
@@ -161,22 +168,22 @@ def parse_args() -> argparse.Namespace: |
161 | 168 | print("gguf: get tensor metadata") |
162 | 169 |
|
163 | 170 | if num_parts == 0: |
164 | | - part_names = iter(("pytorch_model.bin",)) |
| 171 | + part_names = iter(("model.safetensors",)) |
165 | 172 | else: |
166 | 173 | part_names = ( |
167 | | - f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) |
| 174 | + f"model-{n:05}-of-{num_parts:05}.safetensors" for n in range(1, num_parts + 1) |
168 | 175 | ) |
169 | 176 |
|
170 | 177 | for part_name in part_names: |
171 | 178 | if args.vocab_only: |
172 | 179 | break |
173 | 180 | print("gguf: loading model part '" + part_name + "'") |
174 | | - model_part = torch.load(dir_model / part_name, map_location="cpu") |
| 181 | + model_part = safe_open(dir_model / part_name, framework="pt") |
175 | 182 |
|
176 | 183 | for name in model_part.keys(): |
177 | 184 | if "self_attn.rotary_emb.inv_freq" in name: |
178 | 185 | continue |
179 | | - data = model_part[name] |
| 186 | + data = model_part.get_tensor(name) |
180 | 187 |
|
181 | 188 | old_dtype = data.dtype |
182 | 189 |
|
|
0 commit comments