diff --git a/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py b/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py index 1d33ef91e624..8c1123b8c84b 100644 --- a/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py +++ b/src/transformers/models/megatron_bert/convert_megatron_bert_checkpoint.py @@ -17,13 +17,14 @@ #################################################################################################### import argparse -import json import os import re import zipfile import torch +from transformers import MegatronBertConfig + #################################################################################################### @@ -48,13 +49,62 @@ def recursive_print(name, val, spaces=0): print(msg, ":", val) +def fix_query_key_value_ordering(param, checkpoint_version, num_splits, num_heads, hidden_size): + # Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :] + # for compatibility with later versions of NVIDIA Megatron-LM. + # The inverse operation is performed inside Megatron-LM to read checkpoints: + # https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209 + # If param is the weight tensor of the self-attention block, the returned tensor + # will have to be transposed one more time to be read by HuggingFace BERT. + input_shape = param.size() + if checkpoint_version == 1.0: + # version 1.0 stores [num_heads * hidden_size * num_splits, :] + saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 2) + param = param.transpose(1, 2).contiguous() + elif checkpoint_version >= 2.0: + # other versions store [num_heads * num_splits * hidden_size, :] + saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:] + param = param.view(*saved_shape) + param = param.transpose(0, 1).contiguous() + param = param.view(*input_shape) + return param + + #################################################################################################### -def convert_megatron_checkpoint(args, input_state_dict): +def convert_megatron_checkpoint(args, input_state_dict, config): # The converted output model. output_state_dict = {} + # old versions did not store training args + ds_args = input_state_dict.get("args", None) + if ds_args is not None: + # do not make the user write a config file when the exact dimensions/sizes are already in the checkpoint + # from pprint import pprint + # pprint(vars(ds_args)) + + config.tokenizer_type = ds_args.tokenizer_type + config.vocab_size = ds_args.padded_vocab_size + config.max_position_embeddings = ds_args.max_position_embeddings + config.hidden_size = ds_args.hidden_size + config.num_hidden_layers = ds_args.num_layers + config.num_attention_heads = ds_args.num_attention_heads + config.intermediate_size = ds_args.get("ffn_hidden_size", 4 * ds_args.hidden_size) + # pprint(config) + + # The number of heads. + heads = config.num_attention_heads + # The hidden_size per head. + hidden_size_per_head = config.hidden_size // heads + # Megatron-LM checkpoint version + if "checkpoint_version" in input_state_dict.keys(): + checkpoint_version = input_state_dict["checkpoint_version"] + else: + checkpoint_version = 0.0 + # The model. model = input_state_dict["model"] # The language model. @@ -64,13 +114,14 @@ def convert_megatron_checkpoint(args, input_state_dict): # The word embeddings. word_embeddings = embeddings["word_embeddings"]["weight"] + # Truncate the embedding table to vocab_size rows. + word_embeddings = word_embeddings[: config.vocab_size, :] # Store the word embeddings. output_state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings # The position embeddings. pos_embeddings = embeddings["position_embeddings"]["weight"] - # Trained for 512 x 1024. - assert pos_embeddings.size(0) == 512 and pos_embeddings.size(1) == 1024 + assert pos_embeddings.size(0) == config.max_position_embeddings and pos_embeddings.size(1) == config.hidden_size # Store the position embeddings. output_state_dict["bert.embeddings.position_embeddings.weight"] = pos_embeddings @@ -80,7 +131,7 @@ def convert_megatron_checkpoint(args, input_state_dict): output_state_dict["bert.embeddings.token_type_embeddings.weight"] = tokentype_embeddings # The transformer. - transformer = lm["transformer"] + transformer = lm["transformer"] if "transformer" in lm.keys() else lm["encoder"] # The regex to extract layer names. layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") @@ -126,8 +177,9 @@ def convert_megatron_checkpoint(args, input_state_dict): # Make sure the QKV pointer is nil. assert attention_qkv_weight is None, "" + out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head) # Store the tensor as we need the bias as well to interleave QKV and biases. - attention_qkv_weight = val + attention_qkv_weight = out_val # Transpose the bias. elif op_name == "attention.query_key_value" and weight_or_bias == "bias": @@ -136,14 +188,15 @@ def convert_megatron_checkpoint(args, input_state_dict): assert attention_qkv_weight is not None, "" # Split the QKV matrix into Q, K and V. Megatron stores Q,K,V interleaved. - q = attention_qkv_weight[0 * 1024 : 1 * 1024, :] - k = attention_qkv_weight[1 * 1024 : 2 * 1024, :] - v = attention_qkv_weight[2 * 1024 : 3 * 1024, :] + q = attention_qkv_weight[0 * config.hidden_size : 1 * config.hidden_size, :] + k = attention_qkv_weight[1 * config.hidden_size : 2 * config.hidden_size, :] + v = attention_qkv_weight[2 * config.hidden_size : 3 * config.hidden_size, :] + out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head) # Split the bias. - q_bias = val[0 * 1024 : 1 * 1024] - k_bias = val[1 * 1024 : 2 * 1024] - v_bias = val[2 * 1024 : 3 * 1024] + q_bias = out_val[0 * config.hidden_size : 1 * config.hidden_size] + k_bias = out_val[1 * config.hidden_size : 2 * config.hidden_size] + v_bias = out_val[2 * config.hidden_size : 3 * config.hidden_size] # Store. output_state_dict[f"{layer_name}.attention.self.query.weight"] = q @@ -166,24 +219,6 @@ def convert_megatron_checkpoint(args, input_state_dict): output_state_dict["bert.encoder.ln.weight"] = transformer["final_layernorm.weight"] output_state_dict["bert.encoder.ln.bias"] = transformer["final_layernorm.bias"] - # The config. - output_config = { - "vocab_size": word_embeddings.size(0), - "hidden_size": 1024, - "num_hidden_layers": 24, - "num_attention_heads": 16, - "hidden_act": "gelu_new", - "intermediate_size": 4096, - "hidden_dropout_prob": 0.1, - "attention_probs_dropout_prob": 0.1, - "max_position_embeddings": 512, - "type_vocab_size": 2, - "initializer_range": 0.2, - "layer_norm_eps": 1e-12, - "position_embedding_type": "absolute", - "use_cache": False, - } - # The pooler. pooler = lm["pooler"] @@ -214,7 +249,7 @@ def convert_megatron_checkpoint(args, input_state_dict): output_state_dict["cls.seq_relationship.bias"] = binary_head["bias"] # It should be done! - return output_state_dict, output_config + return output_state_dict #################################################################################################### @@ -225,30 +260,44 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("--print-checkpoint-structure", action="store_true") parser.add_argument("path_to_checkpoint", type=str, help="Path to the ZIP file containing the checkpoint") + parser.add_argument( + "--config_file", + default="", + type=str, + help="An optional config json file describing the pre-trained model.", + ) args = parser.parse_args() # Extract the basename. basename = os.path.dirname(args.path_to_checkpoint) # Load the model. + # the .zip is very optional, let's keep it for backward compatibility print(f'Extracting PyTorch state dictionary from "{args.path_to_checkpoint}"') - with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint: - with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict: - input_state_dict = torch.load(pytorch_dict, map_location="cpu") + if args.path_to_checkpoint.endswith(".zip"): + with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint: + with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict: + input_state_dict = torch.load(pytorch_dict, map_location="cpu") + else: + input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu") + + if args.config_file == "": + # Default config of megatron-bert 345m + config = MegatronBertConfig() + else: + config = MegatronBertConfig.from_json_file(args.config_file) # Convert. print("Converting") - output_state_dict, output_config = convert_megatron_checkpoint(args, input_state_dict) + output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config) # Print the structure of converted state dict. if args.print_checkpoint_structure: recursive_print(None, output_state_dict) # Store the config to file. - output_config_file = os.path.join(basename, "config.json") - print(f'Saving config to "{output_config_file}"') - with open(output_config_file, "w") as f: - json.dump(output_config, f) + print("Saving config") + config.save_pretrained(basename) # Store the state_dict to file. output_checkpoint_file = os.path.join(basename, "pytorch_model.bin")