Skip to content

Commit bc4dadb

Browse files
committed
Allow converting Falcon models one part at a time
1 parent 7db8803 commit bc4dadb

1 file changed

Lines changed: 34 additions & 36 deletions

File tree

examples/falcon/convert-hf-to-ggml.py

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# Usage:
44
#
5-
# python3 models/convert-h5-to-ggml.py
5+
# python3 models/convert-h5-to-ggml.py
66
#
77
# This script is similar to "convert-pt-to-ggml.py"
88
#
@@ -40,15 +40,17 @@ def bytes_to_unicode():
4040
cs = [chr(n) for n in cs]
4141
return dict(zip(bs, cs))
4242

43-
if len(sys.argv) < 3:
44-
print("Usage: python convert-hf-to-ggml.py model_name dir-output [use-f32]")
43+
if len(sys.argv) < 4:
44+
print("Usage: python convert-hf-to-ggml.py num_parts model_name dir-output [use-f32]")
45+
print(" num_parts: number of pytorch parts, use 0 if not a multipart model. example: 9")
4546
print(" model_name: name of the model to convert. Example: 'bigscience/bloomz-560m'")
4647
print(" dir-output: directory where the output file will be written")
4748
print(" use-f32: if present, use float32 instead of float16")
4849
sys.exit(1)
4950

50-
model_name = sys.argv[1]
51-
dir_out = sys.argv[2]
51+
num_parts = int(sys.argv[1])
52+
model_name = sys.argv[2]
53+
dir_out = sys.argv[3]
5254

5355
# make sure the output directory exists
5456
os.makedirs(dir_out, exist_ok=True)
@@ -60,16 +62,13 @@ def bytes_to_unicode():
6062
# map from ftype to string
6163
ftype_str = ["f32", "f16"]
6264
ftype = 1
63-
if len(sys.argv) > 3:
65+
if len(sys.argv) > 4:
6466
ftype = 0
6567

6668
tokenizer = AutoTokenizer.from_pretrained(model_name)
6769
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
6870
hparams = config.to_dict()
69-
print("Loading model: ", model_name)
70-
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.float16 if ftype == 1 else torch.float32, low_cpu_mem_usage=True, trust_remote_code=True)
71-
print("Model loaded: ", model_name)
72-
71+
print("* Loading model from: ", model_name)
7372

7473
fname_out = dir_out + f"/ggml-model-{model_name.split('/')[-1]}-{ftype_str[ftype]}.bin"
7574
fout = open(fname_out, "wb")
@@ -90,32 +89,31 @@ def bytes_to_unicode():
9089
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
9190
fout.write(struct.pack("i", len(text)))
9291
fout.write(text)
93-
94-
list_vars = model.state_dict()
95-
for name in list_vars.keys():
96-
src = name
97-
data = list_vars[src].squeeze().numpy()
98-
data = data.astype(np.float32)
99-
100-
n_dims = len(data.shape)
101-
print(name, n_dims, data.shape)
102-
103-
# default type is fp32
104-
ftype_cur = 0
105-
if ftype == 1 and n_dims > 1:
106-
print(" Converting to float16")
107-
data = data.astype(np.float16)
108-
ftype_cur = 1
109-
110-
# header
111-
str = name.encode('utf-8')
112-
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
113-
for i in range(n_dims):
114-
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
115-
fout.write(str)
116-
117-
# data
118-
data.tofile(fout)
92+
if num_parts == 0:
93+
partnames= ('pytorch_model.bin',)
94+
else:
95+
partnames = (f'pytorch_model-{n:05}-of-{num_parts:05}.bin' for n in range(1, num_parts + 1))
96+
for partname in partnames:
97+
filename = f'{model_name}/{partname}'
98+
print(f'\n* Loading part: {partname}')
99+
model = torch.load(filename, map_location = 'cpu')
100+
for name in model.keys():
101+
src = name
102+
data = model[src].squeeze()
103+
n_dims = len(data.shape)
104+
# default type is fp32
105+
ftype_cur = 1 if ftype == 1 and n_dims > 1 else 0
106+
data = data.to(dtype = torch.float16 if ftype_cur == 1 else torch.float32).numpy()
107+
print(f' |', name, data.shape, '->', data.dtype)
108+
# header
109+
str = name.encode('utf-8')
110+
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
111+
for i in range(n_dims):
112+
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
113+
fout.write(str)
114+
115+
# data
116+
data.tofile(fout)
119117

120118
fout.close()
121119

0 commit comments

Comments
 (0)