Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 52 additions & 54 deletions examples/falcon/convert-hf-to-ggml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# Usage:
#
# python3 models/convert-h5-to-ggml.py
# python3 models/convert-h5-to-ggml.py
#
# This script is similar to "convert-pt-to-ggml.py"
#
Expand Down Expand Up @@ -40,15 +40,17 @@ def bytes_to_unicode():
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))

if len(sys.argv) < 3:
print("Usage: python convert-hf-to-ggml.py model_name dir-output [use-f32]")
if len(sys.argv) < 4:
print("Usage: python convert-hf-to-ggml.py num_parts model_name dir-output [use-f32]")
print(" num_parts: number of pytorch parts, use 0 if not a multipart model. example: 9")
print(" model_name: name of the model to convert. Example: 'bigscience/bloomz-560m'")
print(" dir-output: directory where the output file will be written")
print(" use-f32: if present, use float32 instead of float16")
sys.exit(1)

model_name = sys.argv[1]
dir_out = sys.argv[2]
num_parts = int(sys.argv[1])
model_name = sys.argv[2]
dir_out = sys.argv[3]

# make sure the output directory exists
os.makedirs(dir_out, exist_ok=True)
Expand All @@ -60,19 +62,17 @@ def bytes_to_unicode():
# map from ftype to string
ftype_str = ["f32", "f16"]
ftype = 1
if len(sys.argv) > 3:
if len(sys.argv) > 4:
ftype = 0

tokenizer = AutoTokenizer.from_pretrained(model_name)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
hparams = config.to_dict()
print("Loading model: ", model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.float16 if ftype == 1 else torch.float32, low_cpu_mem_usage=True, trust_remote_code=True)
print("Model loaded: ", model_name)

n_head = hparams["n_head"]
n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1
head_dim = hparams["hidden_size"] // n_head
print("* Loading model from: ", model_name)

fname_out = dir_out + f"/ggml-model-{model_name.split('/')[-1]}-{ftype_str[ftype]}.bin"
fout = open(fname_out, "wb")
Expand All @@ -93,51 +93,49 @@ def bytes_to_unicode():
text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
fout.write(struct.pack("i", len(text)))
fout.write(text)

list_vars = model.state_dict()
for name in list_vars.keys():
src = name

# The original query_key_value tensor contains n_head_kv "kv groups",
# each consisting of n_head/n_head_kv query weights followed by one key
# and one value weight (shared by all query heads in the kv group).
# This layout makes it a big pain to work with in GGML.
# So we rearrange them here,, so that we have n_head query weights
# followed by n_head_kv key weights followed by n_head_kv value weights,
# in contiguous fashion.

if "query_key_value" in src:
qkv = list_vars[src].view(
n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)

q = qkv[:, :-2 ].reshape(n_head * head_dim, head_dim * n_head)
k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)

list_vars[src] = torch.cat((q,k,v)).reshape_as(list_vars[src])

data = list_vars[src].squeeze().numpy()
data = data.astype(np.float32)

n_dims = len(data.shape)
print(name, n_dims, data.shape)

# default type is fp32
ftype_cur = 0
if ftype == 1 and n_dims > 1:
print(" Converting to float16")
data = data.astype(np.float16)
ftype_cur = 1

# header
str = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
fout.write(str)

# data
data.tofile(fout)

if num_parts == 0:
partnames= ('pytorch_model.bin',)
else:
partnames = (f'pytorch_model-{n:05}-of-{num_parts:05}.bin' for n in range(1, num_parts + 1))
for partname in partnames:
filename = f'{model_name}/{partname}'
print(f'\n* Loading part: {partname}')
model = torch.load(filename, map_location = 'cpu')
for name in model.keys():
src = name
# The original query_key_value tensor contains n_head_kv "kv groups",
# each consisting of n_head/n_head_kv query weights followed by one key
# and one value weight (shared by all query heads in the kv group).
# This layout makes it a big pain to work with in GGML.
# So we rearrange them here,, so that we have n_head query weights
# followed by n_head_kv key weights followed by n_head_kv value weights,
# in contiguous fashion.

if "query_key_value" in src:
qkv = model[src].view(
n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)

q = qkv[:, :-2 ].reshape(n_head * head_dim, head_dim * n_head)
k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)

model[src] = torch.cat((q,k,v)).reshape_as(model[src])
data = model[src].squeeze()
n_dims = len(data.shape)
# default type is fp32
ftype_cur = 1 if ftype == 1 and n_dims > 1 else 0
data = data.to(dtype = torch.float16 if ftype_cur == 1 else torch.float32).numpy()
print(f' |', name, data.shape, '->', data.dtype)
# header
str = name.encode('utf-8')
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
for i in range(n_dims):
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
fout.write(str)

# data
data.tofile(fout)

fout.close()

Expand Down