Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
####################################################################################################

import argparse
import json
import os
import re
import zipfile

import torch

from transformers import MegatronBertConfig


####################################################################################################

Expand All @@ -48,13 +49,62 @@ def recursive_print(name, val, spaces=0):
print(msg, ":", val)


def fix_query_key_value_ordering(param, checkpoint_version, num_splits, num_heads, hidden_size):
# Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :]
# for compatibility with later versions of NVIDIA Megatron-LM.
# The inverse operation is performed inside Megatron-LM to read checkpoints:
# https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209
# If param is the weight tensor of the self-attention block, the returned tensor
# will have to be transposed one more time to be read by HuggingFace BERT.
input_shape = param.size()
if checkpoint_version == 1.0:
# version 1.0 stores [num_heads * hidden_size * num_splits, :]
saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:]
param = param.view(*saved_shape)
param = param.transpose(0, 2)
param = param.transpose(1, 2).contiguous()
elif checkpoint_version >= 2.0:
# other versions store [num_heads * num_splits * hidden_size, :]
saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:]
param = param.view(*saved_shape)
param = param.transpose(0, 1).contiguous()
param = param.view(*input_shape)
return param


####################################################################################################


def convert_megatron_checkpoint(args, input_state_dict):
def convert_megatron_checkpoint(args, input_state_dict, config):
# The converted output model.
output_state_dict = {}

# old versions did not store training args
ds_args = input_state_dict.get("args", None)
if ds_args is not None:
# do not make the user write a config file when the exact dimensions/sizes are already in the checkpoint
# from pprint import pprint
# pprint(vars(ds_args))

config.tokenizer_type = ds_args.tokenizer_type
config.vocab_size = ds_args.padded_vocab_size
config.max_position_embeddings = ds_args.max_position_embeddings
config.hidden_size = ds_args.hidden_size
config.num_hidden_layers = ds_args.num_layers
config.num_attention_heads = ds_args.num_attention_heads
config.intermediate_size = ds_args.get("ffn_hidden_size", 4 * ds_args.hidden_size)
# pprint(config)

# The number of heads.
heads = config.num_attention_heads
# The hidden_size per head.
hidden_size_per_head = config.hidden_size // heads
# Megatron-LM checkpoint version
if "checkpoint_version" in input_state_dict.keys():
checkpoint_version = input_state_dict["checkpoint_version"]
else:
checkpoint_version = 0.0

# The model.
model = input_state_dict["model"]
# The language model.
Expand All @@ -64,13 +114,14 @@ def convert_megatron_checkpoint(args, input_state_dict):

# The word embeddings.
word_embeddings = embeddings["word_embeddings"]["weight"]
# Truncate the embedding table to vocab_size rows.
word_embeddings = word_embeddings[: config.vocab_size, :]
# Store the word embeddings.
output_state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings

# The position embeddings.
pos_embeddings = embeddings["position_embeddings"]["weight"]
# Trained for 512 x 1024.
assert pos_embeddings.size(0) == 512 and pos_embeddings.size(1) == 1024
assert pos_embeddings.size(0) == config.max_position_embeddings and pos_embeddings.size(1) == config.hidden_size
# Store the position embeddings.
output_state_dict["bert.embeddings.position_embeddings.weight"] = pos_embeddings

Expand All @@ -80,7 +131,7 @@ def convert_megatron_checkpoint(args, input_state_dict):
output_state_dict["bert.embeddings.token_type_embeddings.weight"] = tokentype_embeddings

# The transformer.
transformer = lm["transformer"]
transformer = lm["transformer"] if "transformer" in lm.keys() else lm["encoder"]

# The regex to extract layer names.
layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
Expand Down Expand Up @@ -126,8 +177,9 @@ def convert_megatron_checkpoint(args, input_state_dict):
# Make sure the QKV pointer is nil.
assert attention_qkv_weight is None, ""

out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)
# Store the tensor as we need the bias as well to interleave QKV and biases.
attention_qkv_weight = val
attention_qkv_weight = out_val

# Transpose the bias.
elif op_name == "attention.query_key_value" and weight_or_bias == "bias":
Expand All @@ -136,14 +188,15 @@ def convert_megatron_checkpoint(args, input_state_dict):
assert attention_qkv_weight is not None, ""

# Split the QKV matrix into Q, K and V. Megatron stores Q,K,V interleaved.
q = attention_qkv_weight[0 * 1024 : 1 * 1024, :]
k = attention_qkv_weight[1 * 1024 : 2 * 1024, :]
v = attention_qkv_weight[2 * 1024 : 3 * 1024, :]
q = attention_qkv_weight[0 * config.hidden_size : 1 * config.hidden_size, :]
k = attention_qkv_weight[1 * config.hidden_size : 2 * config.hidden_size, :]
v = attention_qkv_weight[2 * config.hidden_size : 3 * config.hidden_size, :]

out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)
# Split the bias.
q_bias = val[0 * 1024 : 1 * 1024]
k_bias = val[1 * 1024 : 2 * 1024]
v_bias = val[2 * 1024 : 3 * 1024]
q_bias = out_val[0 * config.hidden_size : 1 * config.hidden_size]
k_bias = out_val[1 * config.hidden_size : 2 * config.hidden_size]
v_bias = out_val[2 * config.hidden_size : 3 * config.hidden_size]

# Store.
output_state_dict[f"{layer_name}.attention.self.query.weight"] = q
Expand All @@ -166,24 +219,6 @@ def convert_megatron_checkpoint(args, input_state_dict):
output_state_dict["bert.encoder.ln.weight"] = transformer["final_layernorm.weight"]
output_state_dict["bert.encoder.ln.bias"] = transformer["final_layernorm.bias"]

# The config.
output_config = {
"vocab_size": word_embeddings.size(0),
"hidden_size": 1024,
"num_hidden_layers": 24,
"num_attention_heads": 16,
"hidden_act": "gelu_new",
"intermediate_size": 4096,
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.2,
"layer_norm_eps": 1e-12,
"position_embedding_type": "absolute",
"use_cache": False,
}

# The pooler.
pooler = lm["pooler"]

Expand Down Expand Up @@ -214,7 +249,7 @@ def convert_megatron_checkpoint(args, input_state_dict):
output_state_dict["cls.seq_relationship.bias"] = binary_head["bias"]

# It should be done!
return output_state_dict, output_config
return output_state_dict


####################################################################################################
Expand All @@ -225,30 +260,44 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument("--print-checkpoint-structure", action="store_true")
parser.add_argument("path_to_checkpoint", type=str, help="Path to the ZIP file containing the checkpoint")
parser.add_argument(
"--config_file",
default="",
type=str,
help="An optional config json file describing the pre-trained model.",
)
args = parser.parse_args()

# Extract the basename.
basename = os.path.dirname(args.path_to_checkpoint)

# Load the model.
# the .zip is very optional, let's keep it for backward compatibility
print(f'Extracting PyTorch state dictionary from "{args.path_to_checkpoint}"')
with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint:
with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict:
input_state_dict = torch.load(pytorch_dict, map_location="cpu")
if args.path_to_checkpoint.endswith(".zip"):
with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint:
with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict:
input_state_dict = torch.load(pytorch_dict, map_location="cpu")
else:
input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu")

if args.config_file == "":
# Default config of megatron-bert 345m
config = MegatronBertConfig()
else:
config = MegatronBertConfig.from_json_file(args.config_file)

# Convert.
print("Converting")
output_state_dict, output_config = convert_megatron_checkpoint(args, input_state_dict)
output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config)

# Print the structure of converted state dict.
if args.print_checkpoint_structure:
recursive_print(None, output_state_dict)

# Store the config to file.
output_config_file = os.path.join(basename, "config.json")
print(f'Saving config to "{output_config_file}"')
with open(output_config_file, "w") as f:
json.dump(output_config, f)
print("Saving config")
config.save_pretrained(basename)

# Store the state_dict to file.
output_checkpoint_file = os.path.join(basename, "pytorch_model.bin")
Expand Down