From 53e7a52f6751ddffe5d125b9f50c9912f559b9c3 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Tue, 19 Jan 2021 16:02:25 -0500 Subject: [PATCH] Fix Funnel Transformer conversion script --- ...vert_funnel_original_tf_checkpoint_to_pytorch.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py index daff283053e5..dda913c74dbc 100755 --- a/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py @@ -19,18 +19,18 @@ import torch -from transformers import FunnelConfig, FunnelForPreTraining, load_tf_weights_in_funnel +from transformers import FunnelBaseModel, FunnelConfig, FunnelModel, load_tf_weights_in_funnel from transformers.utils import logging logging.set_verbosity_info() -def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path): +def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path, base_model): # Initialise PyTorch model config = FunnelConfig.from_json_file(config_file) print("Building PyTorch model from configuration: {}".format(str(config))) - model = FunnelForPreTraining(config) + model = FunnelBaseModel(config) if base_model else FunnelModel(config) # Load weights from tf checkpoint load_tf_weights_in_funnel(model, config, tf_checkpoint_path) @@ -57,5 +57,10 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du parser.add_argument( "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model." ) + parser.add_argument( + "--base_model", action="store_true", help="Whether you want just the base model (no decoder) or not." + ) args = parser.parse_args() - convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path) + convert_tf_checkpoint_to_pytorch( + args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path, args.base_model + )