From d72d4ac4ff7af544e9fa4c9226c104930e84b6ec Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Thu, 24 Mar 2022 01:09:13 -0700 Subject: [PATCH 01/16] temp commit --- examples/finetune/__init__.py | 0 examples/finetune/config.jsonnet | 148 +++++++++++++++++++ examples/finetune/snli_steps.py | 88 +++++++++++ examples/finetune/test.py | 44 ++++++ tango/common/from_params.py | 3 +- tango/integrations/transformers/finetune.py | 155 ++++++++++++++++++++ 6 files changed, 437 insertions(+), 1 deletion(-) create mode 100644 examples/finetune/__init__.py create mode 100644 examples/finetune/config.jsonnet create mode 100644 examples/finetune/snli_steps.py create mode 100644 examples/finetune/test.py create mode 100644 tango/integrations/transformers/finetune.py diff --git a/examples/finetune/__init__.py b/examples/finetune/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/finetune/config.jsonnet b/examples/finetune/config.jsonnet new file mode 100644 index 000000000..65e0058dc --- /dev/null +++ b/examples/finetune/config.jsonnet @@ -0,0 +1,148 @@ +################## +# Model settings # +################## + +//local pretrained_model = "sshleifer/tiny-gpt2"; +local pretrained_model = "patrickvonplaten/t5-tiny-random"; +local model_type = "seq2seq"; //TODO: autodetect. + +# This doesn't seem to work with gpt2, but works fine with gpt-j. +local load_with_low_cpu_mem_usage = false; //std.startsWith(pretrained_model, "EleutherAI/gpt-j"); + +######################## +# Put in correct place # +######################## + + +#################### +# Trainer settings # +#################### + +# Trainer settings, adjust to your use-case. +local training_steps = 20; # total number of optimization steps to train for +local validate_every = 5; # how often to validate and save checkpoints + +local devices = 1; # number of devices to train on (will use GPUs if enough are available, otherwise CPU) +local grad_accum = 1; # number of gradient accumulation steps (changes the effective batch size) +# This is the batch size per GPU, ignoring gradient accumulation: +local batch_size = 2; +# So the effective batch size is `batch_size * grad_accum * devices` + +local activation_checkpointing = false; # use activation/gradient checkpointing (probably need this GPT-J 6B, but not gpt2) +local amp = false; # use PyTorch's native automatic mixed precision +local fsdp = false; # Use FairScale's FullyShardedDataParallel (probably need this GPT-J 6B, but not gpt2) +local cpu_offloading = false; # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow. + +###################### +# Optimizer settings # +###################### + +local warmup_steps = 20; +local learning_rate = 0.00005; # you can probably use a higher LR for a small model like "gpt2" + + +assert fsdp == true || cpu_offloading == false : "cpu_offloading only available with fsdp"; + +# FullyShardedDataParallel config: +local fsdp_config = if fsdp then { + reshard_after_forward: true, + move_params_to_cpu: cpu_offloading, + move_grads_to_cpu: cpu_offloading, + mixed_precision: amp, +} else null; + +local training_engine = { + type: if fsdp then "fairscale" else "torch", + optimizer: { + type: "torch::AdamW", + lr: learning_rate, + betas: [0.9, 0.95], + eps: 1e-6, + }, + lr_scheduler: { + type: "transformers::linear", + num_warmup_steps: warmup_steps, + num_training_steps: training_steps, + }, + amp: amp, + [if fsdp then "fsdp_config" else null]: fsdp_config, +}; + +local collate_fn = { + //type: "transformers::DefaultDataCollator" + type: "transformers::DataCollatorForSeq2Seq", + tokenizer: { pretrained_model_name_or_path: pretrained_model } +}; + +local distributed_dataloader = { + batch_size: batch_size, + collate_fn: collate_fn, + sampler: { + type: "torch::DistributedSampler", + shuffle: true, + drop_last: true, + }, +}; + +local single_device_dataloader = { + shuffle: true, + batch_size: batch_size, + collate_fn: collate_fn, +}; + +local dataloader = if devices > 1 then distributed_dataloader else single_device_dataloader; + +{ + steps: { + raw_data: { + type: "datasets::load", + path: "snli", + }, + "subset_data": { + type: "subset-data", + data: { type: "ref", ref: "raw_data" }, + max_samples: 10, + }, + processed_data: { + type: "snli-text2text", + data: { type: "ref", ref: "subset_data" }, + }, + "tokenized_data": { + type: "tokenize_text2text", + data: { type: "ref", ref: "processed_data" }, + tokenizer: { pretrained_model_name_or_path: pretrained_model } + }, + trained_model: { + type: "torch::train", + model: { + type: "fairscale::with_wrapped_modules", + model: { + //type: "transformers::AutoModelForSeq2SeqLM::from_pretrained", + "type": "transformers::finetune-wrapper", + pretrained_model_name_or_path: pretrained_model, + low_cpu_mem_usage: load_with_low_cpu_mem_usage, + }, + modules_to_wrap: ["model\\.encoder\\.block\\.[0-9]+", "model\\.decoder\\.block\\.[0-9]+"], # tell FairScale to wrap the transformer's blocks individually + fsdp_config: fsdp_config, + activation_checkpointing: activation_checkpointing, + }, + dataset_dict: { type: "ref", ref: "tokenized_data" }, + train_dataloader: dataloader, + validation_split: "validation", + grad_accum: grad_accum, + train_steps: training_steps, + validate_every: validate_every, + checkpoint_every: validate_every, + log_every: 1, + device_count: devices, + training_engine: training_engine, + }, + final_metrics: { + type: "torch::eval", + model: { type: "ref", ref: "trained_model" }, + dataset_dict: { type: "ref", ref: "tokenized_data" }, + dataloader: single_device_dataloader, + test_split: "test", + }, + } +} \ No newline at end of file diff --git a/examples/finetune/snli_steps.py b/examples/finetune/snli_steps.py new file mode 100644 index 000000000..99dacc8de --- /dev/null +++ b/examples/finetune/snli_steps.py @@ -0,0 +1,88 @@ +from typing import Any, Dict, List, Union + +import datasets as ds + +from tango.integrations.datasets import DatasetsFormat +from tango.integrations.transformers import Tokenizer +from tango.step import Step + + +@Step.register("subset-data") +class SubsetData(Step): + DETERMINISTIC = True + CACHEABLE = True + VERSION = "001" + + FORMAT = DatasetsFormat() + + def run( + self, + data: Union[ds.DatasetDict, ds.Dataset], + max_samples: int = 5, + ) -> Union[ds.DatasetDict, ds.Dataset]: + """ + Unlike `ds.Dataset.select`, this works on both `ds.Dataset` and `ds.DatasetDict`. + """ + + def filter_fn(example, indices): + return indices < max_samples + + return data.filter(filter_fn, with_indices=True) + + +@Step.register("snli-text2text") +class SnliText2Text(Step): + DETERMINISTIC = True + CACHEABLE = True + VERSION = "001" + + FORMAT = DatasetsFormat() + + def run( + self, + data: Union[ds.DatasetDict, ds.Dataset], + source_prefix: str = "nli", + premise_prefix: str = "premise", + hypothesis_prefix: str = "hypothesis", + label_prefix: str = "label", + num_workers: int = 1, + seq2seq: bool = True, + ) -> Union[ds.DatasetDict, ds.Dataset]: + def filter_no_gold(example, indices): + if example["label"] == -1: + return False + return True + + data = data.filter(filter_no_gold, with_indices=True) + + label_map = {0: "entails", 1: "neutral", 2: "contradiction"} + + def _seq2seq_mapper(example): + return { + "source": f'{source_prefix} {premise_prefix}: {example["premise"]} {hypothesis_prefix}: {example["hypothesis"]}', + "target": f'{label_prefix}: {label_map[example["label"]]}', + } + + def _causal_mapper(example): + text = ( + f'{source_prefix} {premise_prefix}: {example["premise"]} {hypothesis_prefix}: {example["hypothesis"]} ' + f'{label_prefix}: {label_map[example["label"]]}' + ) + return {"source": text, "target": text} + + if isinstance(data, ds.Dataset): + old_cols = data.column_names + else: + old_cols = list(data.column_names.values())[0] + + _mapper = _seq2seq_mapper if seq2seq else _causal_mapper + + dataset = data.map( + _mapper, + batched=False, + num_proc=num_workers, + remove_columns=old_cols, # remove all old columns + desc="Converting data to seq2seq format", + ) + + return dataset diff --git a/examples/finetune/test.py b/examples/finetune/test.py new file mode 100644 index 000000000..91139eb96 --- /dev/null +++ b/examples/finetune/test.py @@ -0,0 +1,44 @@ +# from .snli_steps import SnliText2Text +import datasets as ds + +from tango.common import Params +from tango.common.testing import TangoTestCase, run_experiment + + +class TestSnliText2Text(TangoTestCase): + def test_config(self): + config = Params.from_file("test_config.jsonnet") + with run_experiment(config, include_package=["snli_steps.py"]) as run_dir: + assert (run_dir / "processed_data").is_dir() + processed = ds.load_from_disk(run_dir / "processed_data" / "data") + assert len(processed["train"][0].keys()) == 2 + assert "source" in processed["train"][0].keys() + assert "target" in processed["train"][0].keys() + assert processed["train"][0]["source"].startswith("nli premise:") + + assert (run_dir / "tokenized_data").is_dir() + tokenized = ds.load_from_disk(run_dir / "tokenized_data" / "data") + assert "input_ids" in tokenized["train"][0] + + assert (run_dir / "trained_model").is_dir() + + # def test_config_with_overrides(self): + # overrides = { + # "steps.processed_data.seq2seq": False, + # } + # config = Params.from_file("test_config.jsonnet", params_overrides=overrides) + # with run_experiment(config, include_package=["snli_steps.py"]) as run_dir: + # assert (run_dir / "processed_data").is_dir() + # processed = ds.load_from_disk(run_dir / "processed_data" / "data") + # assert len(processed["train"][0].keys()) == 1 + # assert "source" in processed["train"][0].keys() + # # assert "target" in processed["train"][0].keys() + # assert processed["train"][0]["source"].startswith("nli premise:") + + +if __name__ == "__main__": + config = Params.from_file("config.jsonnet") + with run_experiment( + config, include_package=["snli_steps.py", "tango.integrations.transformers.finetune"] + ) as run_dir: + assert (run_dir / "processed_data").is_dir() diff --git a/tango/common/from_params.py b/tango/common/from_params.py index 25d44e4e0..1c28e5cbc 100644 --- a/tango/common/from_params.py +++ b/tango/common/from_params.py @@ -447,7 +447,8 @@ def construct_arg( if origin != Step and _params_contain_step(popped_params): result = WithUnresolvedSteps(annotation.from_params, popped_params) else: - result = annotation.from_params(popped_params, **subextras) + # TODO: temporary, until Pete's PR is merged. + result = annotation.from_params(popped_params) # , **subextras) if isinstance(result, Step): expected_return_type = args[0] diff --git a/tango/integrations/transformers/finetune.py b/tango/integrations/transformers/finetune.py new file mode 100644 index 000000000..24b8048f4 --- /dev/null +++ b/tango/integrations/transformers/finetune.py @@ -0,0 +1,155 @@ +import logging +from typing import Any, Dict, Optional + +import datasets as ds +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer + +from tango.integrations.datasets import DatasetsFormat +from tango.integrations.torch import Model +from tango.integrations.torch.util import set_seed_all +from tango.integrations.transformers.tokenizer import Tokenizer +from tango.step import Step + +logger = logging.getLogger(__name__) + + +@Model.register("transformers::finetune-wrapper") +class FinetuneWrapper(Model): + def __init__( + self, pretrained_model_name_or_path: str, tokenizer: Optional[Tokenizer] = None, **kwargs + ): + super().__init__() + try: + self.model = AutoModelForSeq2SeqLM.from_pretrained( + pretrained_model_name_or_path, **kwargs + ) + self.seq2seq = True # Seq2Seq models don't return their own prefix. + except ValueError: + self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path) + self.seq2seq = False + + if tokenizer: + # TODO: is this required? This is the only reason why we have tokenizer here. + self.model.resize_token_embeddings(len(tokenizer)) + + def forward(self, *args, **kwargs): + # TODO: decode and compute other metrics? + return self.model.forward(*args, **kwargs) + + +# def _model_for_finetuning( +# model_name: str, +# tokenizer: Optional[Tokenizer] = None, +# max_source_length: Optional[int] = 1024, +# resize_position_embeddings: Optional[bool] = None, +# seed: int = 42, +# ): +# set_seed_all(seed) +# tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_name) +# +# try: +# model = AutoModelForSeq2SeqLM.from_pretrained(model_name) +# seq2seq_model = True # Seq2Seq models don't return their own prefix. +# except ValueError: +# model = AutoModelForCausalLM.from_pretrained(model_name) +# seq2seq_model = False +# +# # TODO: is this required? This is the only reason why we have tokenizer here. +# model.resize_token_embeddings(len(tokenizer)) +# +# # TODO: MBart specific tokenizer update. +# +# if model.config.decoder_start_token_id is None: +# raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") +# +# if ( +# hasattr(model.config, "max_position_embeddings") +# and model.config.max_position_embeddings < max_source_length +# ): +# if resize_position_embeddings is None: +# logger.warning( +# f"Increasing the model's number of position embedding vectors from {model.config.max_position_embeddings} " +# f"to {max_source_length}." +# ) +# model.resize_position_embeddings(max_source_length) +# elif resize_position_embeddings: +# model.resize_position_embeddings(max_source_length) +# else: +# raise ValueError( +# f"`max_source_length` is set to {max_source_length}, but the model only has {model.config.max_position_embeddings}" +# f" position encodings. Consider either reducing `max_source_length` to {model.config.max_position_embeddings} or to automatically " +# "resize the model's position encodings by setting `resize_position_embeddings`." +# ) +# +# return model + + +# @Step.register("get-model-for-finetuning") +# class GetModelForFinetuning(Step): +# DETERMINISTIC = True +# CACHEABLE = False +# +# def run(self, model): + + +@Step.register("tokenize_text2text") +class TokenizeText2TextData(Step): + DETERMINISTIC = True + CACHEABLE = True + FORMAT = DatasetsFormat() + + def run( # type: ignore[override] + self, + data: ds.DatasetDict, + tokenizer: Tokenizer, + num_workers: int = 1, + source_field: str = "source", + target_field: str = "target", + max_source_length: Optional[int] = 1024, + max_target_length: Optional[int] = 1024, + pad_to_max_length: bool = False, + ignore_pad_token_for_loss: bool = True, + ) -> ds.DatasetDict: + + # Set max_target_length for training. + max_target_length = max_target_length + padding = "max_length" if pad_to_max_length else False + + def preprocess_function(examples): + # remove pairs where at least one record is None + inputs, targets = [], [] + for i in range(len(examples[source_field])): + if examples[source_field][i] is not None and examples[target_field][i] is not None: + inputs.append(examples[source_field][i]) + targets.append(examples[target_field][i]) + + model_inputs = tokenizer( + inputs, max_length=max_source_length, padding=padding, truncation=True + ) + + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + labels = tokenizer( + targets, max_length=max_target_length, padding=padding, truncation=True + ) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore + # padding in the loss. + if padding == "max_length" and ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] + for label in labels["input_ids"] + ] + + model_inputs["labels"] = labels["input_ids"] + return model_inputs + + data = data.map( + preprocess_function, + batched=True, + num_proc=num_workers, + remove_columns=list(data.column_names.values())[0], # remove all old columns + desc="Tokenizing dataset", + ) + + return data From 3f085728b038d4f4991deebfcba1a55057cb15e3 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Thu, 24 Mar 2022 15:39:35 -0700 Subject: [PATCH 02/16] move_to_device should work for UserDict too --- tango/integrations/torch/training_engine.py | 18 +++--------------- tango/integrations/torch/util.py | 3 ++- 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/tango/integrations/torch/training_engine.py b/tango/integrations/torch/training_engine.py index b39ebe124..66b7e8255 100644 --- a/tango/integrations/torch/training_engine.py +++ b/tango/integrations/torch/training_engine.py @@ -13,6 +13,7 @@ from .model import Model from .optim import LRScheduler, Optimizer from .train_config import TrainConfig +from .util import move_to_device class TrainingEngine(Registrable): @@ -60,19 +61,6 @@ def _construct_lr_scheduler(self, lr_scheduler: Lazy[LRScheduler]) -> LRSchedule lr_scheduler: LRScheduler = lr_scheduler.construct(optimizer=self.optimizer) return lr_scheduler - @classmethod - def _move_to_device(cls, o: Any, device: torch.device) -> Any: - if isinstance(o, torch.Tensor): - return o.to(device) - elif isinstance(o, dict): - return {k: cls._move_to_device(v, device) for k, v in o.items()} - elif isinstance(o, list): - return [cls._move_to_device(x, device) for x in o] - elif isinstance(o, tuple): - return tuple((cls._move_to_device(x, device) for x in o)) - else: - return o - @abstractmethod def forward_train( self, micro_batch: Dict[str, Any], micro_batch_idx: int, num_micro_batches: int @@ -207,7 +195,7 @@ def forward_train( self.optimizer.zero_grad(set_to_none=True) # Move tensors to right device. - micro_batch = self._move_to_device(micro_batch, self.device) + micro_batch = move_to_device(micro_batch, self.device) with torch.autocast(self.train_config.device_type, enabled=self.amp, dtype=self.amp_dtype): outputs = self.model(**micro_batch) @@ -217,7 +205,7 @@ def forward_train( def forward_eval(self, batch: Dict[str, Any]) -> Dict[str, Any]: # Move tensors to right device. - batch = self._move_to_device(batch, self.device) + batch = move_to_device(batch, self.device) with torch.autocast(self.train_config.device_type, enabled=self.amp, dtype=self.amp_dtype): with torch.inference_mode(): diff --git a/tango/integrations/torch/util.py b/tango/integrations/torch/util.py index cd0d451a2..e241e5339 100644 --- a/tango/integrations/torch/util.py +++ b/tango/integrations/torch/util.py @@ -1,5 +1,6 @@ import random import warnings +from collections import UserDict from typing import Dict, Optional, TypeVar, Union import numpy as np @@ -15,7 +16,7 @@ def move_to_device(o: T, device: torch.device) -> T: if isinstance(o, torch.Tensor): return o.to(device) # type: ignore[return-value] - elif isinstance(o, dict): + elif isinstance(o, dict) or isinstance(o, UserDict): return {k: move_to_device(v, device) for k, v in o.items()} # type: ignore[return-value] elif isinstance(o, list): return [move_to_device(x, device) for x in o] # type: ignore[return-value] From 9a07a1bffdacd486a0c222b27b538da10c02a8aa Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Thu, 24 Mar 2022 15:48:39 -0700 Subject: [PATCH 03/16] works --- examples/finetune/config.jsonnet | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/finetune/config.jsonnet b/examples/finetune/config.jsonnet index 65e0058dc..ca62f0580 100644 --- a/examples/finetune/config.jsonnet +++ b/examples/finetune/config.jsonnet @@ -3,7 +3,7 @@ ################## //local pretrained_model = "sshleifer/tiny-gpt2"; -local pretrained_model = "patrickvonplaten/t5-tiny-random"; +local pretrained_model = "t5-small"; //"patrickvonplaten/t5-tiny-random"; local model_type = "seq2seq"; //TODO: autodetect. # This doesn't seem to work with gpt2, but works fine with gpt-j. @@ -98,14 +98,14 @@ local dataloader = if devices > 1 then distributed_dataloader else single_device type: "datasets::load", path: "snli", }, - "subset_data": { + /*"subset_data": { type: "subset-data", data: { type: "ref", ref: "raw_data" }, max_samples: 10, - }, + },*/ processed_data: { type: "snli-text2text", - data: { type: "ref", ref: "subset_data" }, + data: { type: "ref", ref: "raw_data" }, }, "tokenized_data": { type: "tokenize_text2text", @@ -117,12 +117,12 @@ local dataloader = if devices > 1 then distributed_dataloader else single_device model: { type: "fairscale::with_wrapped_modules", model: { - //type: "transformers::AutoModelForSeq2SeqLM::from_pretrained", - "type": "transformers::finetune-wrapper", + type: "transformers::AutoModelForSeq2SeqLM::from_pretrained", + //type: "transformers::finetune-wrapper", pretrained_model_name_or_path: pretrained_model, low_cpu_mem_usage: load_with_low_cpu_mem_usage, }, - modules_to_wrap: ["model\\.encoder\\.block\\.[0-9]+", "model\\.decoder\\.block\\.[0-9]+"], # tell FairScale to wrap the transformer's blocks individually + modules_to_wrap: ["encoder\\.block\\.[0-9]+", "decoder\\.block\\.[0-9]+"], # tell FairScale to wrap the transformer's blocks individually fsdp_config: fsdp_config, activation_checkpointing: activation_checkpointing, }, @@ -145,4 +145,4 @@ local dataloader = if devices > 1 then distributed_dataloader else single_device test_split: "test", }, } -} \ No newline at end of file +} From f7fc850e3d536d68c4eb45ca9e5a4a5f305bb668 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Thu, 24 Mar 2022 16:26:12 -0700 Subject: [PATCH 04/16] clean up --- examples/finetune/config.jsonnet | 20 ++---- examples/finetune/snli_steps.py | 17 +++--- examples/finetune/test.py | 8 --- tango/integrations/transformers/finetune.py | 68 ++------------------- 4 files changed, 21 insertions(+), 92 deletions(-) diff --git a/examples/finetune/config.jsonnet b/examples/finetune/config.jsonnet index ca62f0580..851f1442f 100644 --- a/examples/finetune/config.jsonnet +++ b/examples/finetune/config.jsonnet @@ -2,17 +2,8 @@ # Model settings # ################## -//local pretrained_model = "sshleifer/tiny-gpt2"; -local pretrained_model = "t5-small"; //"patrickvonplaten/t5-tiny-random"; -local model_type = "seq2seq"; //TODO: autodetect. - -# This doesn't seem to work with gpt2, but works fine with gpt-j. -local load_with_low_cpu_mem_usage = false; //std.startsWith(pretrained_model, "EleutherAI/gpt-j"); - -######################## -# Put in correct place # -######################## - +local pretrained_model = "patrickvonplaten/t5-tiny-random"; +local load_with_low_cpu_mem_usage = false; #################### # Trainer settings # @@ -69,7 +60,6 @@ local training_engine = { }; local collate_fn = { - //type: "transformers::DefaultDataCollator" type: "transformers::DataCollatorForSeq2Seq", tokenizer: { pretrained_model_name_or_path: pretrained_model } }; @@ -98,14 +88,14 @@ local dataloader = if devices > 1 then distributed_dataloader else single_device type: "datasets::load", path: "snli", }, - /*"subset_data": { + "subset_data": { type: "subset-data", data: { type: "ref", ref: "raw_data" }, max_samples: 10, - },*/ + }, processed_data: { type: "snli-text2text", - data: { type: "ref", ref: "raw_data" }, + data: { type: "ref", ref: "subset_data" }, }, "tokenized_data": { type: "tokenize_text2text", diff --git a/examples/finetune/snli_steps.py b/examples/finetune/snli_steps.py index 99dacc8de..bad1946c9 100644 --- a/examples/finetune/snli_steps.py +++ b/examples/finetune/snli_steps.py @@ -1,9 +1,8 @@ -from typing import Any, Dict, List, Union +from typing import Union import datasets as ds from tango.integrations.datasets import DatasetsFormat -from tango.integrations.transformers import Tokenizer from tango.step import Step @@ -15,7 +14,7 @@ class SubsetData(Step): FORMAT = DatasetsFormat() - def run( + def run( # type: ignore self, data: Union[ds.DatasetDict, ds.Dataset], max_samples: int = 5, @@ -38,7 +37,7 @@ class SnliText2Text(Step): FORMAT = DatasetsFormat() - def run( + def run( # type: ignore self, data: Union[ds.DatasetDict, ds.Dataset], source_prefix: str = "nli", @@ -59,13 +58,17 @@ def filter_no_gold(example, indices): def _seq2seq_mapper(example): return { - "source": f'{source_prefix} {premise_prefix}: {example["premise"]} {hypothesis_prefix}: {example["hypothesis"]}', - "target": f'{label_prefix}: {label_map[example["label"]]}', + "source": ( + f'{source_prefix} {premise_prefix}: {example["premise"]} ' + f'{hypothesis_prefix}: {example["hypothesis"]} {label_prefix}: ' + ), + "target": f'{label_map[example["label"]]}', } def _causal_mapper(example): text = ( - f'{source_prefix} {premise_prefix}: {example["premise"]} {hypothesis_prefix}: {example["hypothesis"]} ' + f'{source_prefix} {premise_prefix}: {example["premise"]} ' + f'{hypothesis_prefix}: {example["hypothesis"]} ' f'{label_prefix}: {label_map[example["label"]]}' ) return {"source": text, "target": text} diff --git a/examples/finetune/test.py b/examples/finetune/test.py index 91139eb96..b93037a5c 100644 --- a/examples/finetune/test.py +++ b/examples/finetune/test.py @@ -34,11 +34,3 @@ def test_config(self): # assert "source" in processed["train"][0].keys() # # assert "target" in processed["train"][0].keys() # assert processed["train"][0]["source"].startswith("nli premise:") - - -if __name__ == "__main__": - config = Params.from_file("config.jsonnet") - with run_experiment( - config, include_package=["snli_steps.py", "tango.integrations.transformers.finetune"] - ) as run_dir: - assert (run_dir / "processed_data").is_dir() diff --git a/tango/integrations/transformers/finetune.py b/tango/integrations/transformers/finetune.py index 24b8048f4..4941a76fa 100644 --- a/tango/integrations/transformers/finetune.py +++ b/tango/integrations/transformers/finetune.py @@ -1,12 +1,11 @@ import logging -from typing import Any, Dict, Optional +from typing import Optional import datasets as ds -from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM from tango.integrations.datasets import DatasetsFormat from tango.integrations.torch import Model -from tango.integrations.torch.util import set_seed_all from tango.integrations.transformers.tokenizer import Tokenizer from tango.step import Step @@ -30,68 +29,13 @@ def __init__( if tokenizer: # TODO: is this required? This is the only reason why we have tokenizer here. - self.model.resize_token_embeddings(len(tokenizer)) + self.model.resize_token_embeddings(len(tokenizer)) # type: ignore def forward(self, *args, **kwargs): # TODO: decode and compute other metrics? return self.model.forward(*args, **kwargs) -# def _model_for_finetuning( -# model_name: str, -# tokenizer: Optional[Tokenizer] = None, -# max_source_length: Optional[int] = 1024, -# resize_position_embeddings: Optional[bool] = None, -# seed: int = 42, -# ): -# set_seed_all(seed) -# tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_name) -# -# try: -# model = AutoModelForSeq2SeqLM.from_pretrained(model_name) -# seq2seq_model = True # Seq2Seq models don't return their own prefix. -# except ValueError: -# model = AutoModelForCausalLM.from_pretrained(model_name) -# seq2seq_model = False -# -# # TODO: is this required? This is the only reason why we have tokenizer here. -# model.resize_token_embeddings(len(tokenizer)) -# -# # TODO: MBart specific tokenizer update. -# -# if model.config.decoder_start_token_id is None: -# raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") -# -# if ( -# hasattr(model.config, "max_position_embeddings") -# and model.config.max_position_embeddings < max_source_length -# ): -# if resize_position_embeddings is None: -# logger.warning( -# f"Increasing the model's number of position embedding vectors from {model.config.max_position_embeddings} " -# f"to {max_source_length}." -# ) -# model.resize_position_embeddings(max_source_length) -# elif resize_position_embeddings: -# model.resize_position_embeddings(max_source_length) -# else: -# raise ValueError( -# f"`max_source_length` is set to {max_source_length}, but the model only has {model.config.max_position_embeddings}" -# f" position encodings. Consider either reducing `max_source_length` to {model.config.max_position_embeddings} or to automatically " -# "resize the model's position encodings by setting `resize_position_embeddings`." -# ) -# -# return model - - -# @Step.register("get-model-for-finetuning") -# class GetModelForFinetuning(Step): -# DETERMINISTIC = True -# CACHEABLE = False -# -# def run(self, model): - - @Step.register("tokenize_text2text") class TokenizeText2TextData(Step): DETERMINISTIC = True @@ -133,11 +77,11 @@ def preprocess_function(examples): targets, max_length=max_target_length, padding=padding, truncation=True ) - # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore - # padding in the loss. + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 + # when we want to ignore padding in the loss. if padding == "max_length" and ignore_pad_token_for_loss: labels["input_ids"] = [ - [(l if l != tokenizer.pad_token_id else -100) for l in label] + [(lb if lb != tokenizer.pad_token_id else -100) for lb in label] for label in labels["input_ids"] ] From 77ade5fadcdb9a9404c4cc4786532892dd9a591a Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 28 Mar 2022 13:30:39 -0700 Subject: [PATCH 05/16] run generation with model --- examples/eval_p3/config.jsonnet | 2 +- examples/finetune/config.jsonnet | 10 ++ .../transformers/run_generation.py | 116 ++++++++++++------ .../transformers/run_generation_test.py | 18 ++- 4 files changed, 108 insertions(+), 38 deletions(-) diff --git a/examples/eval_p3/config.jsonnet b/examples/eval_p3/config.jsonnet index 8c140ce2d..d42c8eda3 100644 --- a/examples/eval_p3/config.jsonnet +++ b/examples/eval_p3/config.jsonnet @@ -30,7 +30,7 @@ local dataset_steps = std.foldl( "max_length": 200, "input": {"ref": "dataset_" + dataset_name}, "batch_size": batch_size, - "model_name": model, + "model": model, "prompt_field": "inputs_pretokenized", "output_field": "generation", "splits": ["validation"] diff --git a/examples/finetune/config.jsonnet b/examples/finetune/config.jsonnet index 851f1442f..cc24d3897 100644 --- a/examples/finetune/config.jsonnet +++ b/examples/finetune/config.jsonnet @@ -134,5 +134,15 @@ local dataloader = if devices > 1 then distributed_dataloader else single_device dataloader: single_device_dataloader, test_split: "test", }, + "generations": { + "type": "transformers::run_generation_dataset", + "max_length": 1, + "input": {"type": "ref", "ref": "processed_data"}, + "batch_size": batch_size, + "model": {"type": "ref", "ref": "trained_model"}, + "prompt_field": "source", + "output_field": "generation", + "splits": ["validation"] + }, } } diff --git a/tango/integrations/transformers/run_generation.py b/tango/integrations/transformers/run_generation.py index ab3556a88..6d7a002cc 100644 --- a/tango/integrations/transformers/run_generation.py +++ b/tango/integrations/transformers/run_generation.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Union +from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Union, cast import more_itertools import torch @@ -15,6 +15,9 @@ GPT2Tokenizer, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer, + PreTrainedModel, + PreTrainedTokenizer, + PreTrainedTokenizerFast, TransfoXLLMHeadModel, TransfoXLTokenizer, XLMTokenizer, @@ -27,6 +30,7 @@ from tango.common import DatasetDict from tango.common.sequences import MappedSequence, SqliteSparseSequence from tango.common.tqdm import Tqdm +from tango.integrations.torch import Model from tango.integrations.torch.util import resolve_device, set_seed_all logger = logging.getLogger(__name__) @@ -61,6 +65,9 @@ the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, with people, even a bishop, begging for his blessing. """ +SEQ2SEQ = AutoModelForSeq2SeqLM._model_mapping.keys() # type: ignore +CAUSAL = AutoModelForCausalLM._model_mapping.keys() # type: ignore + def adjust_length_to_model(length, model): max_sequence_length = ( @@ -78,7 +85,9 @@ def adjust_length_to_model(length, model): def _generate( - model_name: str, + model: Model, + # TODO: Change type to `Tokenizer` once HF includes `convert_tokens_to_ids` in `PretrainedTokenizerBase` class. + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], prompts: Iterable[str], *, batch_size: int = 4, @@ -93,10 +102,15 @@ def _generate( num_return_sequences: int = 1, fp16: bool = False, ) -> Iterable[List[str]]: + + if not isinstance(model.config, tuple(SEQ2SEQ + CAUSAL)): + raise NotImplementedError( + "This function is only defined for huggingface models seq2seq/causal models." + ) + device = resolve_device() set_seed_all(seed) - tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer_kwargs: Dict[str, Any] = {} tokenizer.padding_side = "left" @@ -111,12 +125,8 @@ def _generate( eos_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token) pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) - try: - model = AutoModelForSeq2SeqLM.from_pretrained(model_name) - seq2seq_model = True # Seq2Seq models don't return their own prefix. - except ValueError: - model = AutoModelForCausalLM.from_pretrained(model_name) - seq2seq_model = False + # Seq2Seq models don't return their own prefix. + seq2seq_model = model.config_class in SEQ2SEQ # HF does not do this? WTF? model.eval() @@ -144,27 +154,27 @@ def prepare_batch_with_prefix(prompts: List[str]) -> Dict[str, torch.Tensor]: prepare_batch_fn = prepare_batch_with_prefix num_prefix_tokens: Optional[int] = None - # model-specific exceptions - if model.config_class.model_type == "xlm": - use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb - if hasattr(model.config, "lang2id") and use_lang_emb: - model.config.lang_id = xlm_language - # Original HF code ignores the prefix, but it looks like a bug? - prepare_batch_fn = prepare_batch_without_prefix - num_prefix_tokens = 0 - elif model.config_class.model_type in {"xlnet", "transfo-xl"}: - prefix = prefix if prefix else PREFIX - if model.__class__.__name__ in ["TransfoXLLMHeadModel"]: - # This actually doesn't work in the current version of transformers, which is probably a bug in the - # transformers library. - tokenizer_kwargs = {"add_space_before_punct_symbol": True} + # transformer model-specific exceptions + if isinstance(model, PreTrainedModel) and model.config_class: + if model.config_class.model_type == "xlm": + use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb + if hasattr(model.config, "lang2id") and use_lang_emb: + model.config.lang_id = xlm_language + # Original HF code ignores the prefix, but it looks like a bug? + prepare_batch_fn = prepare_batch_without_prefix + num_prefix_tokens = 0 + elif model.config_class.model_type in {"xlnet", "transfo-xl"}: + prefix = prefix if prefix else PREFIX + if model.__class__.__name__ in ["TransfoXLLMHeadModel"]: + # This actually doesn't work in the current version of transformers, which is probably a bug in the + # transformers library. + tokenizer_kwargs = {"add_space_before_punct_symbol": True} if num_prefix_tokens is None: num_prefix_tokens = len(tokenizer.tokenize(prefix)) batches = more_itertools.chunked(Tqdm.tqdm(prompts, desc="Pre-processing prompts"), batch_size) encoded_batches = map(prepare_batch_fn, batches) - # encoded_batches = threaded_generator(encoded_batches) for encoded_batch in Tqdm.tqdm(encoded_batches, desc="Processing batches"): if seq2seq_model: @@ -172,7 +182,7 @@ def prepare_batch_with_prefix(prompts: List[str]) -> Dict[str, torch.Tensor]: else: length = adjust_length_to_model(max_length + encoded_batch["input_ids"].size(1), model) with torch.inference_mode(): - generated_sequences = model.generate( + generated_sequences: torch.Tensor = model.generate( # type: ignore **encoded_batch, max_length=length, temperature=temperature, @@ -199,26 +209,36 @@ def strip_special_tokens(t: torch.Tensor) -> torch.Tensor: return t[start:end] # strip padding - generated_sequences = [ + generated_sequences_list = [ [strip_special_tokens(sequence) for sequence in per_prompt_sequences] for per_prompt_sequences in generated_sequences ] # strip prefix if not seq2seq_model: - generated_sequences = [ + generated_sequences_list = [ [sequence[num_prefix_tokens:] for sequence in per_prompt_sequences] - for per_prompt_sequences in generated_sequences + for per_prompt_sequences in generated_sequences_list ] texts = [ tokenizer.batch_decode(per_prompt_sequences, clean_up_tokenization_spaces=True) - for per_prompt_sequences in generated_sequences + for per_prompt_sequences in generated_sequences_list ] yield from texts +def _generate_with_model_name(model_name: str, *args, **kwargs) -> Iterable[List[str]]: + try: + model = AutoModelForSeq2SeqLM.from_pretrained(model_name) + except ValueError: + model = AutoModelForCausalLM.from_pretrained(model_name) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + return _generate(model, tokenizer, *args, **kwargs) + + @Step.register("transformers::run_generation") class RunGeneration(Step[Iterable[List[str]]]): """ @@ -236,9 +256,10 @@ class RunGeneration(Step[Iterable[List[str]]]): def run( # type: ignore self, - model_name: str, + model: Union[str, Model], prompts: Iterable[str], *, + tokenizer: Optional[Union[PreTrainedTokenizer, PreTrainedTokenizerFast]] = None, batch_size: int = 4, max_length: int = 20, temperature: float = 1.0, @@ -254,11 +275,14 @@ def run( # type: ignore """ Run a Huggingface seq2seq model in inference mode. - :param model_name: + :param model: The name of the model to run. Any name that works in the transformers library works here. + Or, you can directly provide the model to run. :param prompts: The prompts to run through the model. You can specify prompts directly in the config, but more commonly the prompts are produced by another step that reads a dataset, for example. + :param tokenizer: + The tokenizer to run. :param batch_size: The number of sequences to process at one time. This has no bearing on the output, so you can change this number without invalidating cached results. @@ -292,9 +316,17 @@ def run( # type: ignore :returns: Returns an iterator of lists of string. Each list contains the predictions for one prompt. """ + if isinstance(model, str): + try: + model = cast(Model, AutoModelForSeq2SeqLM.from_pretrained(model)) + except ValueError: + model = cast(Model, AutoModelForCausalLM.from_pretrained(model)) + + tokenizer = tokenizer or AutoTokenizer.from_pretrained(model.name_or_path) return _generate( - model_name, + model, + tokenizer, prompts, batch_size=batch_size, max_length=max_length, @@ -328,10 +360,11 @@ class RunGenerationDataset(Step[DatasetDict]): def run( # type: ignore self, - model_name: str, + model: Union[str, Model], input: Union[DatasetDict, HfDatasetDict], prompt_field: str, *, + tokenizer: Optional[Union[PreTrainedTokenizer, PreTrainedTokenizerFast]] = None, output_field: Optional[str] = None, splits: Optional[Union[str, Set[str]]] = None, batch_size: int = 4, @@ -349,12 +382,15 @@ def run( # type: ignore """ Augment an input dataset with generations from a Huggingface seq2seq model. - :param model_name: + :param model: The name of the model to run. Any name that works in the transformers library works here. + Or, you can directly provide the model to run. :param input: The input dataset. :param prompt_field: The field in the dataset that contains the text of the prompts. + :param tokenizer: + The tokenizer to run. :param output_field: The field in the dataset that we will write the predictions into. In the result, this field will contain ``List[str]``. @@ -393,6 +429,15 @@ def run( # type: ignore :returns: Returns a dataset with an extra field containing the predictions. """ + + if isinstance(model, str): + try: + model = cast(Model, AutoModelForSeq2SeqLM.from_pretrained(model)) + except ValueError: + model = cast(Model, AutoModelForCausalLM.from_pretrained(model)) + + tokenizer = tokenizer or AutoTokenizer.from_pretrained(model.name_or_path) + if isinstance(input, HfDatasetDict): input = DatasetDict(input, {}) if splits is None: @@ -417,7 +462,8 @@ def run( # type: ignore input_split = input_split[len(output_split) :] prompts = MappedSequence(lambda i: i[prompt_field], input_split) generations = _generate( - model_name, + model, + tokenizer, prompts, batch_size=batch_size, max_length=max_length, diff --git a/tests/integrations/transformers/run_generation_test.py b/tests/integrations/transformers/run_generation_test.py index a2eb8fc93..43375bd67 100644 --- a/tests/integrations/transformers/run_generation_test.py +++ b/tests/integrations/transformers/run_generation_test.py @@ -10,7 +10,21 @@ def test_run_generation(self): { "type": "transformers::run_generation", "prompts": ["Tango is the future of", "Everybody should be using Tango to"], - "model_name": "sshleifer/tiny-gpt2", + "model": "sshleifer/tiny-gpt2", + }, + ) + result = list(step.result()) + assert len(result) == 2 + + def test_run_generation_with_model(self): + step = Step.from_params( # type: ignore[assignment] + { + "type": "transformers::run_generation", + "prompts": ["Tango is the future of", "Everybody should be using Tango to"], + "model": { + "type": "transformers::AutoModelForCausalLM::from_pretrained", + "pretrained_model_name_or_path": "sshleifer/tiny-gpt2", + }, }, ) result = list(step.result()) @@ -28,7 +42,7 @@ def test_run_generation_dataset(self): ) step = RunGenerationDataset( - model_name="sshleifer/tiny-gpt2", input=dataset, prompt_field="prompt" + model="sshleifer/tiny-gpt2", input=dataset, prompt_field="prompt" ) result = step.result() From 71c485cea93b9820f08422a78e3734b9070b6bbe Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Tue, 29 Mar 2022 22:58:22 -0700 Subject: [PATCH 06/16] causal lm --- examples/finetune/config.jsonnet | 10 +- examples/finetune/config_gpt2.jsonnet | 158 ++++++++++ examples/finetune/snli_steps.py | 6 +- examples/finetune/test.py | 43 ++- tango/integrations/transformers/finetune.py | 318 ++++++++++++++++++-- 5 files changed, 504 insertions(+), 31 deletions(-) create mode 100644 examples/finetune/config_gpt2.jsonnet diff --git a/examples/finetune/config.jsonnet b/examples/finetune/config.jsonnet index cc24d3897..3d6fa06b0 100644 --- a/examples/finetune/config.jsonnet +++ b/examples/finetune/config.jsonnet @@ -5,6 +5,8 @@ local pretrained_model = "patrickvonplaten/t5-tiny-random"; local load_with_low_cpu_mem_usage = false; +local modules_to_wrap = ["encoder\\.block\\.[0-9]+", "decoder\\.block\\.[0-9]+"]; # tell FairScale to wrap the transformer's blocks individually + #################### # Trainer settings # #################### @@ -107,12 +109,12 @@ local dataloader = if devices > 1 then distributed_dataloader else single_device model: { type: "fairscale::with_wrapped_modules", model: { - type: "transformers::AutoModelForSeq2SeqLM::from_pretrained", - //type: "transformers::finetune-wrapper", + //type: "transformers::AutoModelForSeq2SeqLM::from_pretrained", + type: "transformers::finetune::from_pretrained", pretrained_model_name_or_path: pretrained_model, low_cpu_mem_usage: load_with_low_cpu_mem_usage, }, - modules_to_wrap: ["encoder\\.block\\.[0-9]+", "decoder\\.block\\.[0-9]+"], # tell FairScale to wrap the transformer's blocks individually + modules_to_wrap: modules_to_wrap, # tell FairScale to wrap the transformer's blocks individually fsdp_config: fsdp_config, activation_checkpointing: activation_checkpointing, }, @@ -136,7 +138,7 @@ local dataloader = if devices > 1 then distributed_dataloader else single_device }, "generations": { "type": "transformers::run_generation_dataset", - "max_length": 1, + "max_length": 5, "input": {"type": "ref", "ref": "processed_data"}, "batch_size": batch_size, "model": {"type": "ref", "ref": "trained_model"}, diff --git a/examples/finetune/config_gpt2.jsonnet b/examples/finetune/config_gpt2.jsonnet new file mode 100644 index 000000000..080aae4c1 --- /dev/null +++ b/examples/finetune/config_gpt2.jsonnet @@ -0,0 +1,158 @@ +################## +# Model settings # +################## + +local pretrained_model = "sshleifer/tiny-gpt2"; //""patrickvonplaten/t5-tiny-random"; +local load_with_low_cpu_mem_usage = false; + +//local modules_to_wrap = ["encoder\\.block\\.[0-9]+", "decoder\\.block\\.[0-9]+"]; +local modules_to_wrap = ["transformer\\.h\\.[0-9]+"]; + + +#################### +# Trainer settings # +#################### + +# Trainer settings, adjust to your use-case. +local training_steps = 20; # total number of optimization steps to train for +local validate_every = 5; # how often to validate and save checkpoints + +local devices = 1; # number of devices to train on (will use GPUs if enough are available, otherwise CPU) +local grad_accum = 1; # number of gradient accumulation steps (changes the effective batch size) +# This is the batch size per GPU, ignoring gradient accumulation: +local batch_size = 2; +# So the effective batch size is `batch_size * grad_accum * devices` + +local activation_checkpointing = false; # use activation/gradient checkpointing (probably need this GPT-J 6B, but not gpt2) +local amp = false; # use PyTorch's native automatic mixed precision +local fsdp = false; # Use FairScale's FullyShardedDataParallel (probably need this GPT-J 6B, but not gpt2) +local cpu_offloading = false; # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow. + +###################### +# Optimizer settings # +###################### + +local warmup_steps = 20; +local learning_rate = 0.00005; # you can probably use a higher LR for a small model like "gpt2" + + +assert fsdp == true || cpu_offloading == false : "cpu_offloading only available with fsdp"; + +# FullyShardedDataParallel config: +local fsdp_config = if fsdp then { + reshard_after_forward: true, + move_params_to_cpu: cpu_offloading, + move_grads_to_cpu: cpu_offloading, + mixed_precision: amp, +} else null; + +local training_engine = { + type: if fsdp then "fairscale" else "torch", + optimizer: { + type: "torch::AdamW", + lr: learning_rate, + betas: [0.9, 0.95], + eps: 1e-6, + }, + lr_scheduler: { + type: "transformers::linear", + num_warmup_steps: warmup_steps, + num_training_steps: training_steps, + }, + amp: amp, + [if fsdp then "fsdp_config" else null]: fsdp_config, +}; + +local collate_fn = { + type: "transformers::DefaultDataCollator", + //tokenizer: { pretrained_model_name_or_path: pretrained_model }, + //mlm: false, +}; + +local distributed_dataloader = { + batch_size: batch_size, + collate_fn: collate_fn, + sampler: { + type: "torch::DistributedSampler", + shuffle: true, + drop_last: true, + }, +}; + +local single_device_dataloader = { + shuffle: true, + batch_size: batch_size, + collate_fn: collate_fn, +}; + +local dataloader = if devices > 1 then distributed_dataloader else single_device_dataloader; + +{ + steps: { + raw_data: { + type: "datasets::load", + path: "snli", + }, + "subset_data": { + type: "subset-data", + data: { type: "ref", ref: "raw_data" }, + max_samples: 10, + }, + processed_data: { + type: "snli-text2text", + data: { type: "ref", ref: "subset_data" }, + seq2seq: false, + }, + "tokenized_data": { + type: "tokenize_text2text", + data: { type: "ref", ref: "processed_data" }, + tokenizer: { pretrained_model_name_or_path: pretrained_model }, + max_source_length: 500, + max_target_length: 500, + pad_to_max_length: true, + seq2seq: false, + }, + trained_model: { + type: "torch::train", + model: { + type: "fairscale::with_wrapped_modules", + model: { + //type: "transformers::AutoModelForSeq2SeqLM::from_pretrained", + type: "transformers::finetune::from_pretrained", + pretrained_model_name_or_path: pretrained_model, + low_cpu_mem_usage: load_with_low_cpu_mem_usage, + }, + modules_to_wrap: modules_to_wrap, # tell FairScale to wrap the transformer's blocks individually + fsdp_config: fsdp_config, + activation_checkpointing: activation_checkpointing, + }, + dataset_dict: { type: "ref", ref: "tokenized_data" }, + train_dataloader: dataloader, + validation_split: "validation", + grad_accum: grad_accum, + train_steps: training_steps, + validate_every: validate_every, + checkpoint_every: validate_every, + log_every: 1, + device_count: devices, + training_engine: training_engine, + }, + final_metrics: { + type: "torch::eval", + model: { type: "ref", ref: "trained_model" }, + dataset_dict: { type: "ref", ref: "tokenized_data" }, + dataloader: single_device_dataloader, + test_split: "test", + }, + "generations": { + "type": "transformers::run_generation_dataset", + //"max_length": 2, + "input": {"type": "ref", "ref": "processed_data"}, + "batch_size": batch_size, + "model": {"type": "ref", "ref": "trained_model"}, + "prompt_field": "source", + "output_field": "generation", + "splits": ["validation"] + }, + } +} diff --git a/examples/finetune/snli_steps.py b/examples/finetune/snli_steps.py index bad1946c9..4fbb9a973 100644 --- a/examples/finetune/snli_steps.py +++ b/examples/finetune/snli_steps.py @@ -60,9 +60,9 @@ def _seq2seq_mapper(example): return { "source": ( f'{source_prefix} {premise_prefix}: {example["premise"]} ' - f'{hypothesis_prefix}: {example["hypothesis"]} {label_prefix}: ' + f'{hypothesis_prefix}: {example["hypothesis"]} ' ), - "target": f'{label_map[example["label"]]}', + "target": f'{label_prefix}: {label_map[example["label"]]}', } def _causal_mapper(example): @@ -78,7 +78,7 @@ def _causal_mapper(example): else: old_cols = list(data.column_names.values())[0] - _mapper = _seq2seq_mapper if seq2seq else _causal_mapper + _mapper = _seq2seq_mapper # if seq2seq else _causal_mapper dataset = data.map( _mapper, diff --git a/examples/finetune/test.py b/examples/finetune/test.py index b93037a5c..91398bcb6 100644 --- a/examples/finetune/test.py +++ b/examples/finetune/test.py @@ -6,8 +6,47 @@ class TestSnliText2Text(TangoTestCase): - def test_config(self): - config = Params.from_file("test_config.jsonnet") + def test_config_with_t5(self): + model = "patrickvonplaten/t5-tiny-random" + overrides = { + "steps.trained_model.model.model.pretrained_model_name_or_path": model, + } + config = Params.from_file("config.jsonnet", params_overrides=overrides) + # Make sure we've overrode the model entirely. + flattened = config.as_flat_dict() + for key, value in flattened.items(): + if "model_name" in key or (isinstance(value, str) and "t5" in value): + assert value == model + + with run_experiment(config, include_package=["snli_steps.py"]) as run_dir: + assert (run_dir / "processed_data").is_dir() + processed = ds.load_from_disk(run_dir / "processed_data" / "data") + assert len(processed["train"][0].keys()) == 2 + assert "source" in processed["train"][0].keys() + assert "target" in processed["train"][0].keys() + assert processed["train"][0]["source"].startswith("nli premise:") + + assert (run_dir / "tokenized_data").is_dir() + tokenized = ds.load_from_disk(run_dir / "tokenized_data" / "data") + assert "input_ids" in tokenized["train"][0] + + assert (run_dir / "trained_model").is_dir() + + def test_config_with_gpt2(self): + model = "sshleifer/tiny-gpt2" + overrides = { + "steps.trained_model.model.model.pretrained_model_name_or_path": model, + "steps.tokenized_data.tokenizer.pretrained_model_name_or_path": model, + "steps.trained_model.train_dataloader.collate_fn.tokenizer.pretrained_model_name_or_path": model, + "steps.final_metrics.dataloader.collate_fn.tokenizer.pretrained_model_name_or_path": model, + } + config = Params.from_file("config.jsonnet", params_overrides=overrides) + # Make sure we've overrode the model entirely. + flattened = config.as_flat_dict() + for key, value in flattened.items(): + if "model_name" in key or (isinstance(value, str) and "gpt2" in value): + assert value == model, key + with run_experiment(config, include_package=["snli_steps.py"]) as run_dir: assert (run_dir / "processed_data").is_dir() processed = ds.load_from_disk(run_dir / "processed_data" / "data") diff --git a/tango/integrations/transformers/finetune.py b/tango/integrations/transformers/finetune.py index 4941a76fa..cfe138ff9 100644 --- a/tango/integrations/transformers/finetune.py +++ b/tango/integrations/transformers/finetune.py @@ -2,38 +2,43 @@ from typing import Optional import datasets as ds -from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM +from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, PreTrainedModel from tango.integrations.datasets import DatasetsFormat -from tango.integrations.torch import Model +from tango.integrations.torch import Model, TorchFormat from tango.integrations.transformers.tokenizer import Tokenizer from tango.step import Step logger = logging.getLogger(__name__) -@Model.register("transformers::finetune-wrapper") -class FinetuneWrapper(Model): - def __init__( - self, pretrained_model_name_or_path: str, tokenizer: Optional[Tokenizer] = None, **kwargs - ): - super().__init__() +class FinetuneWrapper(PreTrainedModel): + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: str, tokenizer: Optional[Tokenizer] = None, **kwargs + ) -> PreTrainedModel: try: - self.model = AutoModelForSeq2SeqLM.from_pretrained( - pretrained_model_name_or_path, **kwargs - ) - self.seq2seq = True # Seq2Seq models don't return their own prefix. + model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path, **kwargs) except ValueError: - self.model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path) - self.seq2seq = False + model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs) if tokenizer: + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.pad_token is None: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if tokenizer.sep_token is None: + tokenizer.add_special_tokens({"sep_token": "[SEP]"}) + if tokenizer.eos_token is None: + tokenizer.add_special_tokens({"eos_token": "[EOS]"}) # TODO: is this required? This is the only reason why we have tokenizer here. - self.model.resize_token_embeddings(len(tokenizer)) # type: ignore + model.resize_token_embeddings(len(tokenizer)) # type: ignore + return model + - def forward(self, *args, **kwargs): - # TODO: decode and compute other metrics? - return self.model.forward(*args, **kwargs) +Model.register("transformers::finetune::from_pretrained", constructor="from_pretrained")( + FinetuneWrapper +) @Step.register("tokenize_text2text") @@ -53,19 +58,36 @@ def run( # type: ignore[override] max_target_length: Optional[int] = 1024, pad_to_max_length: bool = False, ignore_pad_token_for_loss: bool = True, + seq2seq: bool = True, ) -> ds.DatasetDict: - # Set max_target_length for training. - max_target_length = max_target_length + if not seq2seq: + pad_to_max_length = True # TODO: address this. padding = "max_length" if pad_to_max_length else False + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.pad_token is None: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if tokenizer.sep_token is None: + tokenizer.add_special_tokens({"sep_token": "[SEP]"}) + if tokenizer.eos_token is None: + tokenizer.add_special_tokens({"eos_token": "[EOS]"}) + def preprocess_function(examples): # remove pairs where at least one record is None inputs, targets = [], [] + input_lengths = [] for i in range(len(examples[source_field])): if examples[source_field][i] is not None and examples[target_field][i] is not None: - inputs.append(examples[source_field][i]) - targets.append(examples[target_field][i]) + if seq2seq: + inputs.append(examples[source_field][i]) + targets.append(examples[target_field][i]) + else: + text = examples[source_field][i] + " " + examples[target_field][i] + inputs.append(text) + targets.append(text) + input_lengths.append(len(examples[source_field][i])) model_inputs = tokenizer( inputs, max_length=max_source_length, padding=padding, truncation=True @@ -73,6 +95,7 @@ def preprocess_function(examples): # Setup the tokenizer for targets with tokenizer.as_target_tokenizer(): + # TODO: do something so the loss isn't counted. labels = tokenizer( targets, max_length=max_target_length, padding=padding, truncation=True ) @@ -97,3 +120,254 @@ def preprocess_function(examples): ) return data + +# from tango.integrations.torch.train import * # TODO: fix +# +# @Step.register("transformers::finetune") +# class FinetuneStep(Step): +# """ +# Mostly similar to :class:`~tango.integrations.torch.train.TorchTrainStep` with additional +# preprocessing for data. +# +# .. tip:: +# +# Registered as a :class:`~tango.step.Step` under the name "transformers::finetune". +# +# .. important:: +# +# The training loop will use GPU(s) automatically when available, as long as at least +# ``device_count`` CUDA devices are available. +# +# Distributed data parallel training is activated when the ``device_count`` is greater than 1. +# +# You can control which CUDA devices to use with the environment variable ``CUDA_VISIBLE_DEVICES``. +# For example, to only use the GPUs with IDs 0 and 1, set ``CUDA_VISIBLE_DEVICES=0,1`` +# (and ``device_count`` to 2). +# +# .. warning:: +# +# During validation, the validation metric (specified by the ``val_metric_name`` parameter) +# is aggregated by simply averaging across validation batches and distributed processes. +# This behavior is usually correct when your validation metric is "loss" or "accuracy", +# for example, but may not be correct for other metrics like "F1". +# +# If this is not correct for your metric you will need to handle the aggregation +# internally in your model or with a :class:`TrainCallback` +# using the :meth:`TrainCallback.post_val_batch()` method. +# Then set the parameter ``auto_aggregate_val_metric`` to ``False``. +# +# Note that correctly aggregating your metric during distributed training will +# involve distributed communication. +# +# """ +# +# DETERMINISTIC = True +# CACHEABLE = True +# FORMAT: Format = TorchFormat() +# SKIP_ID_ARGUMENTS = {"distributed_port", "log_every"} +# +# def run( # type: ignore[override] +# self, +# model: Lazy[Model], +# training_engine: Lazy[TrainingEngine], +# dataset_dict: DatasetDictBase, +# train_dataloader: Lazy[DataLoader], +# *, +# train_split: str = "train", +# validation_split: Optional[str] = None, +# validation_dataloader: Optional[Lazy[DataLoader]] = None, +# seed: int = 42, +# train_steps: Optional[int] = None, +# train_epochs: Optional[int] = None, +# validation_steps: Optional[int] = None, +# grad_accum: int = 1, +# log_every: int = 10, +# checkpoint_every: int = 100, +# validate_every: Optional[int] = None, +# device_count: int = 1, +# distributed_port: int = 54761, +# val_metric_name: str = "loss", +# minimize_val_metric: bool = True, +# auto_aggregate_val_metric: bool = True, +# callbacks: Optional[List[Lazy[TrainCallback]]] = None, +# remove_stale_checkpoints: bool = True, +# ) -> Model: +# """ +# Run a basic training loop to train the ``model``. +# +# :param model: +# The model to train. It should return a ``dict`` that includes the ``loss`` +# during training and the ``val_metric_name`` during validation. +# :param training_engine: +# The :class:`TrainingEngine` to use to train the model. +# :param dataset_dict: +# The train and optional validation data. +# :param train_dataloader: +# The data loader that generates training batches. The batches should be :class:`dict` +# objects that will be used as ``kwargs`` for the model's ``forward()`` method. +# :param train_split: +# The name of the data split used for training in the ``dataset_dict``. +# Default is "train". +# :param validation_split: +# Optional name of the validation split in the ``dataset_dict``. Default is ``None``, +# which means no validation. +# :param validation_dataloader: +# An optional data loader for generating validation batches. The batches should be +# :class:`dict` objects. If not specified, but ``validation_split`` is given, +# the validation ``DataLoader`` will be constructed from the same parameters +# as the train ``DataLoader``. +# :param seed: +# Used to set the RNG states at the beginning of training. +# :param train_steps: +# The number of steps to train for. If not specified training will +# stop after a complete iteration through the ``train_dataloader``. +# :param train_epochs: +# The number of epochs to train for. You cannot specify ``train_steps`` and ``train_epochs`` +# at the same time. +# :param validation_steps: +# The number of steps to validate for. If not specified validation +# will stop after a complete iteration through the ``validation_dataloader``. +# :param grad_accum: +# The number of gradient accumulation steps. Defaults to 1. +# +# .. note:: +# This parameter - in conjuction with the settings of your data loader +# and the number distributed workers - +# determines the *effective batch size* of your training run. +# +# :param log_every: +# Log every this many steps. +# :param checkpoint_every: +# Save a checkpoint every this many steps. +# :param validate_every: +# Run the validation loop every this many steps. +# :param device_count: +# The number of devices to train on, i.e. the number of distributed data parallel workers. +# :param distributed_port: +# The port of the distributed process group. Default = "54761". +# :param val_metric_name: +# The name of the validation metric, i.e. the key of the metric in the dictionary +# returned by the forward pass of the model. Default is "loss". +# :param minimize_val_metric: +# Whether the validation metric is meant to be minimized (such as the loss). +# Default is ``True``. When using a metric such as accuracy, you should set +# this to ``False``. +# :param auto_aggregate_val_metric: +# If ``True`` (the default), the validation metric will be averaged across +# validation batches and distributed processes. This may not be the correct +# behavior for some metrics (such as F1), in which you should set this to +# ``False`` and handle the aggregation internally in your model +# or with a :class:`TrainCallback` (using :meth:`TrainCallback.post_val_batch()`). +# :param callbacks: +# A list of :class:`TrainCallback`. +# :param remove_stale_checkpoints: +# If ``True`` (the default), stale checkpoints will be removed throughout training so that +# only the latest and best checkpoints are kept. +# +# :returns: +# The trained model on CPU with the weights from the best checkpoint loaded. +# +# """ +# # Validate device(s). +# if device_count <= 0: +# raise ConfigurationError("Invalid value for 'device_count'. Must be at least 1.") +# devices: List[int] +# if torch.cuda.is_available() and torch.cuda.device_count() >= device_count: +# devices = list(range(device_count)) +# self.logger.info("Training on %d GPU%s", device_count, "s" if device_count > 1 else "") +# else: +# devices = [-1] * device_count +# self.logger.info( +# "Training on CPU with %d worker%s", device_count, "s" if device_count > 1 else "" +# ) +# +# if validate_every is not None and validation_split is None: +# raise ConfigurationError( +# "You have set a validation interval, but no validation split. " +# "That's probably unintentional." +# ) +# +# is_distributed = False +# num_workers = 1 +# if devices and len(devices) > 1: +# is_distributed = True +# num_workers = len(devices) +# +# if (train_steps is not None) == (train_epochs is not None): +# raise ConfigurationError( +# "One of 'train_steps' or 'train_epochs' needs to be specified, but not both." +# ) +# +# # Tokenize data +# +# # dataset dict +# +# # end tokenization +# +# config = TrainConfig( +# self.unique_id, +# self.work_dir, +# train_split=train_split, +# validation_split=validation_split, +# seed=seed, +# train_steps=train_steps, +# train_epochs=train_epochs, +# grad_accum=grad_accum, +# log_every=log_every, +# checkpoint_every=checkpoint_every, +# validate_every=validate_every, +# validation_steps=validation_steps, +# is_distributed=is_distributed, +# devices=devices, +# distributed_port=distributed_port, +# val_metric_name=val_metric_name, +# minimize_val_metric=minimize_val_metric, +# auto_aggregate_val_metric=auto_aggregate_val_metric, +# remove_stale_checkpoints=remove_stale_checkpoints, +# world_size=num_workers, +# ) +# +# final_model: Model +# if is_distributed: +# import torch.multiprocessing as mp +# +# mp.spawn( +# _train, +# args=( +# config, +# model, +# training_engine, +# dataset_dict, +# train_dataloader, +# validation_dataloader, +# callbacks, +# get_extra_imported_modules(), +# ), +# nprocs=num_workers, +# ) +# self.logger.info("Constructing final model") +# final_model = model.construct() +# else: +# final_model = _train( # type: ignore[assignment] +# 0, +# config, +# model, +# training_engine, +# dataset_dict, +# train_dataloader, +# validation_dataloader=validation_dataloader, +# callbacks=callbacks, +# ) +# assert final_model is not None +# final_model = final_model.cpu() +# +# # Load best checkpoint before returning model. +# if config.final_weights_path.is_file(): +# self.logger.info( +# f"Loading best weights from {str(config.final_weights_path.resolve())}" +# ) +# state = torch.load(config.final_weights_path, map_location="cpu") +# # We use `strict=False` because there might be missing keys due to weight tying. +# final_model.load_state_dict(state, strict=False) +# +# return final_model \ No newline at end of file From ded17f4a8a5347af105df94c0739c3365c6c7416 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Tue, 29 Mar 2022 23:45:56 -0700 Subject: [PATCH 07/16] change label --- examples/finetune/snli_steps.py | 6 +++--- tango/integrations/transformers/finetune.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/finetune/snli_steps.py b/examples/finetune/snli_steps.py index 4fbb9a973..65495c6ee 100644 --- a/examples/finetune/snli_steps.py +++ b/examples/finetune/snli_steps.py @@ -54,15 +54,15 @@ def filter_no_gold(example, indices): data = data.filter(filter_no_gold, with_indices=True) - label_map = {0: "entails", 1: "neutral", 2: "contradiction"} + label_map = {0: "entailment", 1: "neutral", 2: "contradiction"} def _seq2seq_mapper(example): return { "source": ( f'{source_prefix} {premise_prefix}: {example["premise"]} ' - f'{hypothesis_prefix}: {example["hypothesis"]} ' + f'{hypothesis_prefix}: {example["hypothesis"]} {label_prefix}: ' ), - "target": f'{label_prefix}: {label_map[example["label"]]}', + "target": f'{label_map[example["label"]]}', } def _causal_mapper(example): diff --git a/tango/integrations/transformers/finetune.py b/tango/integrations/transformers/finetune.py index cfe138ff9..341c4c22b 100644 --- a/tango/integrations/transformers/finetune.py +++ b/tango/integrations/transformers/finetune.py @@ -121,6 +121,7 @@ def preprocess_function(examples): return data + # from tango.integrations.torch.train import * # TODO: fix # # @Step.register("transformers::finetune") @@ -370,4 +371,4 @@ def preprocess_function(examples): # # We use `strict=False` because there might be missing keys due to weight tying. # final_model.load_state_dict(state, strict=False) # -# return final_model \ No newline at end of file +# return final_model From 87ff48efb5ff318f00f802856b632856692f2837 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Fri, 1 Apr 2022 14:26:33 -0700 Subject: [PATCH 08/16] single step finetune --- examples/finetune/config.jsonnet | 7 - examples/finetune/config_gpt2.jsonnet | 7 - examples/finetune/new_config.jsonnet | 133 ++++ examples/finetune/snli_steps.py | 12 +- tango/integrations/transformers/finetune.py | 746 +++++++++++--------- 5 files changed, 556 insertions(+), 349 deletions(-) create mode 100644 examples/finetune/new_config.jsonnet diff --git a/examples/finetune/config.jsonnet b/examples/finetune/config.jsonnet index 3d6fa06b0..296b1e547 100644 --- a/examples/finetune/config.jsonnet +++ b/examples/finetune/config.jsonnet @@ -129,13 +129,6 @@ local dataloader = if devices > 1 then distributed_dataloader else single_device device_count: devices, training_engine: training_engine, }, - final_metrics: { - type: "torch::eval", - model: { type: "ref", ref: "trained_model" }, - dataset_dict: { type: "ref", ref: "tokenized_data" }, - dataloader: single_device_dataloader, - test_split: "test", - }, "generations": { "type": "transformers::run_generation_dataset", "max_length": 5, diff --git a/examples/finetune/config_gpt2.jsonnet b/examples/finetune/config_gpt2.jsonnet index 080aae4c1..5f245b716 100644 --- a/examples/finetune/config_gpt2.jsonnet +++ b/examples/finetune/config_gpt2.jsonnet @@ -137,13 +137,6 @@ local dataloader = if devices > 1 then distributed_dataloader else single_device device_count: devices, training_engine: training_engine, }, - final_metrics: { - type: "torch::eval", - model: { type: "ref", ref: "trained_model" }, - dataset_dict: { type: "ref", ref: "tokenized_data" }, - dataloader: single_device_dataloader, - test_split: "test", - }, "generations": { "type": "transformers::run_generation_dataset", //"max_length": 2, diff --git a/examples/finetune/new_config.jsonnet b/examples/finetune/new_config.jsonnet new file mode 100644 index 000000000..bea1ceab6 --- /dev/null +++ b/examples/finetune/new_config.jsonnet @@ -0,0 +1,133 @@ +################## +# Model settings # +################## + +local pretrained_model = "patrickvonplaten/t5-tiny-random"; +local load_with_low_cpu_mem_usage = false; + +local modules_to_wrap = ["encoder\\.block\\.[0-9]+", "decoder\\.block\\.[0-9]+"]; # tell FairScale to wrap the transformer's blocks individually + +#################### +# Trainer settings # +#################### + +# Trainer settings, adjust to your use-case. +local training_steps = 20; # total number of optimization steps to train for +local validate_every = 5; # how often to validate and save checkpoints + +local devices = 1; # number of devices to train on (will use GPUs if enough are available, otherwise CPU) +local grad_accum = 1; # number of gradient accumulation steps (changes the effective batch size) +# This is the batch size per GPU, ignoring gradient accumulation: +local batch_size = 2; +# So the effective batch size is `batch_size * grad_accum * devices` + +local activation_checkpointing = false; # use activation/gradient checkpointing (probably need this GPT-J 6B, but not gpt2) +local amp = false; # use PyTorch's native automatic mixed precision +local fsdp = false; # Use FairScale's FullyShardedDataParallel (probably need this GPT-J 6B, but not gpt2) +local cpu_offloading = false; # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow. + +###################### +# Optimizer settings # +###################### + +local warmup_steps = 20; +local learning_rate = 0.00005; # you can probably use a higher LR for a small model like "gpt2" + + +assert fsdp == true || cpu_offloading == false : "cpu_offloading only available with fsdp"; + +# FullyShardedDataParallel config: +local fsdp_config = if fsdp then { + reshard_after_forward: true, + move_params_to_cpu: cpu_offloading, + move_grads_to_cpu: cpu_offloading, + mixed_precision: amp, +} else null; + +local training_engine = { + type: if fsdp then "fairscale" else "torch", + optimizer: { + type: "torch::AdamW", + lr: learning_rate, + betas: [0.9, 0.95], + eps: 1e-6, + }, + lr_scheduler: { + type: "transformers::linear", + num_warmup_steps: warmup_steps, + num_training_steps: training_steps, + }, + amp: amp, + [if fsdp then "fsdp_config" else null]: fsdp_config, +}; + +local distributed_dataloader = { + batch_size: batch_size, + sampler: { + type: "torch::DistributedSampler", + shuffle: true, + drop_last: true, + }, +}; + +local single_device_dataloader = { + shuffle: true, + batch_size: batch_size, +}; + +local dataloader = if devices > 1 then distributed_dataloader else single_device_dataloader; + +{ + steps: { + raw_data: { + type: "datasets::load", + path: "snli", + }, + "subset_data": { + type: "subset-data", + data: { type: "ref", ref: "raw_data" }, + max_samples: 10, + }, + processed_data: { + type: "snli-text2text", + data: { type: "ref", ref: "subset_data" }, + }, + trained_model: { + type: "transformers::finetune", + model: { + type: "fairscale::with_wrapped_modules", + model: { + type: "transformers::finetune::from_pretrained", + pretrained_model_name_or_path: pretrained_model, + low_cpu_mem_usage: load_with_low_cpu_mem_usage, + }, + modules_to_wrap: modules_to_wrap, # tell FairScale to wrap the transformer's blocks individually + fsdp_config: fsdp_config, + activation_checkpointing: activation_checkpointing, + }, + tokenizer: { + pretrained_model_name_or_path: pretrained_model + }, + dataset_dict: { type: "ref", ref: "processed_data" }, + train_dataloader: dataloader, + validation_split: "validation", + grad_accum: grad_accum, + train_steps: training_steps, + validate_every: validate_every, + checkpoint_every: validate_every, + log_every: 1, + device_count: devices, + training_engine: training_engine, + }, + "generations": { + "type": "transformers::run_generation_dataset", + "max_length": 5, + "input": {"type": "ref", "ref": "processed_data"}, + "batch_size": batch_size, + "model": {"type": "ref", "ref": "trained_model"}, + "prompt_field": "source", + "output_field": "generation", + "splits": ["validation"] + }, + } +} diff --git a/examples/finetune/snli_steps.py b/examples/finetune/snli_steps.py index 65495c6ee..5ff74fdae 100644 --- a/examples/finetune/snli_steps.py +++ b/examples/finetune/snli_steps.py @@ -56,7 +56,7 @@ def filter_no_gold(example, indices): label_map = {0: "entailment", 1: "neutral", 2: "contradiction"} - def _seq2seq_mapper(example): + def _mapper(example): return { "source": ( f'{source_prefix} {premise_prefix}: {example["premise"]} ' @@ -65,21 +65,11 @@ def _seq2seq_mapper(example): "target": f'{label_map[example["label"]]}', } - def _causal_mapper(example): - text = ( - f'{source_prefix} {premise_prefix}: {example["premise"]} ' - f'{hypothesis_prefix}: {example["hypothesis"]} ' - f'{label_prefix}: {label_map[example["label"]]}' - ) - return {"source": text, "target": text} - if isinstance(data, ds.Dataset): old_cols = data.column_names else: old_cols = list(data.column_names.values())[0] - _mapper = _seq2seq_mapper # if seq2seq else _causal_mapper - dataset = data.map( _mapper, batched=False, diff --git a/tango/integrations/transformers/finetune.py b/tango/integrations/transformers/finetune.py index 341c4c22b..2d6dff86b 100644 --- a/tango/integrations/transformers/finetune.py +++ b/tango/integrations/transformers/finetune.py @@ -1,38 +1,56 @@ import logging -from typing import Optional +from os import PathLike +from typing import Any, List, Optional, Union, cast import datasets as ds -from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, PreTrainedModel +import torch +from transformers import ( + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, + DataCollatorForSeq2Seq, + DefaultDataCollator, + PreTrainedModel, +) -from tango.integrations.datasets import DatasetsFormat -from tango.integrations.torch import Model, TorchFormat +from tango.common import Lazy +from tango.common.exceptions import ConfigurationError +from tango.common.util import get_extra_imported_modules +from tango.format import Format +from tango.integrations.datasets import DatasetsFormat, convert_to_tango_dataset_dict +from tango.integrations.torch import ( + DataCollator, + DataLoader, + Model, + TorchFormat, + TrainCallback, + TrainConfig, + TrainingEngine, +) +from tango.integrations.torch.train import _train from tango.integrations.transformers.tokenizer import Tokenizer from tango.step import Step logger = logging.getLogger(__name__) +SEQ2SEQ = AutoModelForSeq2SeqLM._model_mapping.keys() # type: ignore +CAUSAL = AutoModelForCausalLM._model_mapping.keys() # type: ignore + class FinetuneWrapper(PreTrainedModel): @classmethod - def from_pretrained( - cls, pretrained_model_name_or_path: str, tokenizer: Optional[Tokenizer] = None, **kwargs + def from_pretrained( # type: ignore + cls, + pretrained_model_name_or_path: Union[str, PathLike], + num_tokens: Optional[int] = None, + **kwargs, ) -> PreTrainedModel: try: model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path, **kwargs) except ValueError: model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs) - if tokenizer: - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - if tokenizer.pad_token is None: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - if tokenizer.sep_token is None: - tokenizer.add_special_tokens({"sep_token": "[SEP]"}) - if tokenizer.eos_token is None: - tokenizer.add_special_tokens({"eos_token": "[EOS]"}) - # TODO: is this required? This is the only reason why we have tokenizer here. - model.resize_token_embeddings(len(tokenizer)) # type: ignore + if num_tokens: + model.resize_token_embeddings(num_tokens) return model @@ -41,6 +59,90 @@ def from_pretrained( ) +def _add_special_tokens(tokenizer: Tokenizer) -> None: + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.pad_token is None: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + if tokenizer.sep_token is None: + tokenizer.add_special_tokens({"sep_token": "[SEP]"}) + if tokenizer.eos_token is None: + tokenizer.add_special_tokens({"eos_token": "[EOS]"}) + + +def tokenize_data( + data: ds.DatasetDict, + tokenizer: Tokenizer, + num_workers: int = 1, + source_field: str = "source", + target_field: str = "target", + max_source_length: Optional[int] = 1024, + max_target_length: Optional[int] = 1024, + pad_to_max_length: bool = False, + ignore_pad_token_for_loss: bool = True, + seq2seq: bool = True, +) -> ds.DatasetDict: + """ + If it's seq2seq, we use `DataCollatorForSeq2Seq`, and take care of padding there. + If it's causal, we use the `DefaultDataCollator`, and take care of padding here. + """ + + if not seq2seq: + pad_to_max_length = True # TODO: address this. + padding = "max_length" if pad_to_max_length else False + + _add_special_tokens(tokenizer) + + def preprocess_function(examples): + # remove pairs where at least one record is None + inputs, targets = [], [] + input_lengths = [] + for i in range(len(examples[source_field])): + if examples[source_field][i] is not None and examples[target_field][i] is not None: + if seq2seq: + inputs.append(examples[source_field][i]) + targets.append(examples[target_field][i]) + else: + text = ( + examples[source_field][i] + tokenizer.sep_token + examples[target_field][i] + ) + inputs.append(text) + targets.append(text) + input_lengths.append(len(examples[source_field][i])) + + model_inputs = tokenizer( + inputs, max_length=max_source_length, padding=padding, truncation=True + ) + + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + # TODO: remove source seq loss. + labels = tokenizer( + targets, max_length=max_target_length, padding=padding, truncation=True + ) + + # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 + # when we want to ignore padding in the loss. + if padding == "max_length" and ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(lb if lb != tokenizer.pad_token_id else -100) for lb in label] + for label in labels["input_ids"] + ] + + model_inputs["labels"] = labels["input_ids"] + return model_inputs + + data = data.map( + preprocess_function, + batched=True, + num_proc=num_workers, + remove_columns=list(data.column_names.values())[0], # remove all old columns + desc="Tokenizing dataset", + ) + + return data + + @Step.register("tokenize_text2text") class TokenizeText2TextData(Step): DETERMINISTIC = True @@ -61,314 +163,310 @@ def run( # type: ignore[override] seq2seq: bool = True, ) -> ds.DatasetDict: - if not seq2seq: - pad_to_max_length = True # TODO: address this. - padding = "max_length" if pad_to_max_length else False - - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - if tokenizer.pad_token is None: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - if tokenizer.sep_token is None: - tokenizer.add_special_tokens({"sep_token": "[SEP]"}) - if tokenizer.eos_token is None: - tokenizer.add_special_tokens({"eos_token": "[EOS]"}) - - def preprocess_function(examples): - # remove pairs where at least one record is None - inputs, targets = [], [] - input_lengths = [] - for i in range(len(examples[source_field])): - if examples[source_field][i] is not None and examples[target_field][i] is not None: - if seq2seq: - inputs.append(examples[source_field][i]) - targets.append(examples[target_field][i]) - else: - text = examples[source_field][i] + " " + examples[target_field][i] - inputs.append(text) - targets.append(text) - input_lengths.append(len(examples[source_field][i])) - - model_inputs = tokenizer( - inputs, max_length=max_source_length, padding=padding, truncation=True + return tokenize_data( + data, + tokenizer=tokenizer, + num_workers=num_workers, + source_field=source_field, + target_field=target_field, + max_source_length=max_source_length, + max_target_length=max_target_length, + pad_to_max_length=pad_to_max_length, + ignore_pad_token_for_loss=ignore_pad_token_for_loss, + seq2seq=seq2seq, + ) + + +@Step.register("transformers::finetune") +class FinetuneStep(Step): + """ + Mostly similar to :class:`~tango.integrations.torch.train.TorchTrainStep` with additional + preprocessing for data. + + .. tip:: + + Registered as a :class:`~tango.step.Step` under the name "transformers::finetune". + + .. important:: + + The training loop will use GPU(s) automatically when available, as long as at least + ``device_count`` CUDA devices are available. + + Distributed data parallel training is activated when the ``device_count`` is greater than 1. + + You can control which CUDA devices to use with the environment variable ``CUDA_VISIBLE_DEVICES``. + For example, to only use the GPUs with IDs 0 and 1, set ``CUDA_VISIBLE_DEVICES=0,1`` + (and ``device_count`` to 2). + + .. warning:: + + During validation, the validation metric (specified by the ``val_metric_name`` parameter) + is aggregated by simply averaging across validation batches and distributed processes. + This behavior is usually correct when your validation metric is "loss" or "accuracy", + for example, but may not be correct for other metrics like "F1". + + If this is not correct for your metric you will need to handle the aggregation + internally in your model or with a :class:`TrainCallback` + using the :meth:`TrainCallback.post_val_batch()` method. + Then set the parameter ``auto_aggregate_val_metric`` to ``False``. + + Note that correctly aggregating your metric during distributed training will + involve distributed communication. + + """ + + DETERMINISTIC = True + CACHEABLE = True + FORMAT: Format = TorchFormat() + SKIP_ID_ARGUMENTS = {"distributed_port", "log_every"} + + def run( # type: ignore[override] + self, + model: Lazy[Model], + tokenizer: Tokenizer, # TODO: restrict the type + training_engine: Lazy[TrainingEngine], + dataset_dict: ds.DatasetDict, + train_dataloader: Lazy[DataLoader], + *, + train_split: str = "train", + validation_split: Optional[str] = None, + validation_dataloader: Optional[Lazy[DataLoader]] = None, + source_field: str = "source", + target_field: str = "target", + max_source_length: Optional[int] = 1024, + max_target_length: Optional[int] = 1024, + seq2seq: bool = True, + seed: int = 42, + train_steps: Optional[int] = None, + train_epochs: Optional[int] = None, + validation_steps: Optional[int] = None, + grad_accum: int = 1, + log_every: int = 10, + checkpoint_every: int = 100, + validate_every: Optional[int] = None, + device_count: int = 1, + distributed_port: int = 54761, + val_metric_name: str = "loss", + minimize_val_metric: bool = True, + auto_aggregate_val_metric: bool = True, + callbacks: Optional[List[Lazy[TrainCallback]]] = None, + remove_stale_checkpoints: bool = True, + ) -> Model: + """ + Run a basic training loop to train the ``model``. + + :param model: + The model to train. It should return a ``dict`` that includes the ``loss`` + during training and the ``val_metric_name`` during validation. + :param training_engine: + The :class:`TrainingEngine` to use to train the model. + :param dataset_dict: + The train and optional validation data. + :param train_dataloader: + The data loader that generates training batches. The batches should be :class:`dict` + objects that will be used as ``kwargs`` for the model's ``forward()`` method. + :param train_split: + The name of the data split used for training in the ``dataset_dict``. + Default is "train". + :param validation_split: + Optional name of the validation split in the ``dataset_dict``. Default is ``None``, + which means no validation. + :param validation_dataloader: + An optional data loader for generating validation batches. The batches should be + :class:`dict` objects. If not specified, but ``validation_split`` is given, + the validation ``DataLoader`` will be constructed from the same parameters + as the train ``DataLoader``. + :param seed: + Used to set the RNG states at the beginning of training. + :param train_steps: + The number of steps to train for. If not specified training will + stop after a complete iteration through the ``train_dataloader``. + :param train_epochs: + The number of epochs to train for. You cannot specify ``train_steps`` and ``train_epochs`` + at the same time. + :param validation_steps: + The number of steps to validate for. If not specified validation + will stop after a complete iteration through the ``validation_dataloader``. + :param grad_accum: + The number of gradient accumulation steps. Defaults to 1. + + .. note:: + This parameter - in conjuction with the settings of your data loader + and the number distributed workers - + determines the *effective batch size* of your training run. + + :param log_every: + Log every this many steps. + :param checkpoint_every: + Save a checkpoint every this many steps. + :param validate_every: + Run the validation loop every this many steps. + :param device_count: + The number of devices to train on, i.e. the number of distributed data parallel workers. + :param distributed_port: + The port of the distributed process group. Default = "54761". + :param val_metric_name: + The name of the validation metric, i.e. the key of the metric in the dictionary + returned by the forward pass of the model. Default is "loss". + :param minimize_val_metric: + Whether the validation metric is meant to be minimized (such as the loss). + Default is ``True``. When using a metric such as accuracy, you should set + this to ``False``. + :param auto_aggregate_val_metric: + If ``True`` (the default), the validation metric will be averaged across + validation batches and distributed processes. This may not be the correct + behavior for some metrics (such as F1), in which you should set this to + ``False`` and handle the aggregation internally in your model + or with a :class:`TrainCallback` (using :meth:`TrainCallback.post_val_batch()`). + :param callbacks: + A list of :class:`TrainCallback`. + :param remove_stale_checkpoints: + If ``True`` (the default), stale checkpoints will be removed throughout training so that + only the latest and best checkpoints are kept. + + :returns: + The trained model on CPU with the weights from the best checkpoint loaded. + + """ + # Validate device(s). + if device_count <= 0: + raise ConfigurationError("Invalid value for 'device_count'. Must be at least 1.") + devices: List[int] + if torch.cuda.is_available() and torch.cuda.device_count() >= device_count: + devices = list(range(device_count)) + self.logger.info("Training on %d GPU%s", device_count, "s" if device_count > 1 else "") + else: + devices = [-1] * device_count + self.logger.info( + "Training on CPU with %d worker%s", device_count, "s" if device_count > 1 else "" + ) + + if validate_every is not None and validation_split is None: + raise ConfigurationError( + "You have set a validation interval, but no validation split. " + "That's probably unintentional." + ) + + is_distributed = False + num_workers = 1 + if devices and len(devices) > 1: + is_distributed = True + num_workers = len(devices) + + if (train_steps is not None) == (train_epochs is not None): + raise ConfigurationError( + "One of 'train_steps' or 'train_epochs' needs to be specified, but not both." ) - # Setup the tokenizer for targets - with tokenizer.as_target_tokenizer(): - # TODO: do something so the loss isn't counted. - labels = tokenizer( - targets, max_length=max_target_length, padding=padding, truncation=True - ) - - # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 - # when we want to ignore padding in the loss. - if padding == "max_length" and ignore_pad_token_for_loss: - labels["input_ids"] = [ - [(lb if lb != tokenizer.pad_token_id else -100) for lb in label] - for label in labels["input_ids"] - ] - - model_inputs["labels"] = labels["input_ids"] - return model_inputs - - data = data.map( - preprocess_function, - batched=True, - num_proc=num_workers, - remove_columns=list(data.column_names.values())[0], # remove all old columns - desc="Tokenizing dataset", + # Setup the tokenizer + _add_special_tokens(tokenizer) + model = Lazy( + model._constructor, + model._params, + constructor_extras=model._constructor_extras, + num_tokens=len(tokenizer), # type: ignore ) - return data - - -# from tango.integrations.torch.train import * # TODO: fix -# -# @Step.register("transformers::finetune") -# class FinetuneStep(Step): -# """ -# Mostly similar to :class:`~tango.integrations.torch.train.TorchTrainStep` with additional -# preprocessing for data. -# -# .. tip:: -# -# Registered as a :class:`~tango.step.Step` under the name "transformers::finetune". -# -# .. important:: -# -# The training loop will use GPU(s) automatically when available, as long as at least -# ``device_count`` CUDA devices are available. -# -# Distributed data parallel training is activated when the ``device_count`` is greater than 1. -# -# You can control which CUDA devices to use with the environment variable ``CUDA_VISIBLE_DEVICES``. -# For example, to only use the GPUs with IDs 0 and 1, set ``CUDA_VISIBLE_DEVICES=0,1`` -# (and ``device_count`` to 2). -# -# .. warning:: -# -# During validation, the validation metric (specified by the ``val_metric_name`` parameter) -# is aggregated by simply averaging across validation batches and distributed processes. -# This behavior is usually correct when your validation metric is "loss" or "accuracy", -# for example, but may not be correct for other metrics like "F1". -# -# If this is not correct for your metric you will need to handle the aggregation -# internally in your model or with a :class:`TrainCallback` -# using the :meth:`TrainCallback.post_val_batch()` method. -# Then set the parameter ``auto_aggregate_val_metric`` to ``False``. -# -# Note that correctly aggregating your metric during distributed training will -# involve distributed communication. -# -# """ -# -# DETERMINISTIC = True -# CACHEABLE = True -# FORMAT: Format = TorchFormat() -# SKIP_ID_ARGUMENTS = {"distributed_port", "log_every"} -# -# def run( # type: ignore[override] -# self, -# model: Lazy[Model], -# training_engine: Lazy[TrainingEngine], -# dataset_dict: DatasetDictBase, -# train_dataloader: Lazy[DataLoader], -# *, -# train_split: str = "train", -# validation_split: Optional[str] = None, -# validation_dataloader: Optional[Lazy[DataLoader]] = None, -# seed: int = 42, -# train_steps: Optional[int] = None, -# train_epochs: Optional[int] = None, -# validation_steps: Optional[int] = None, -# grad_accum: int = 1, -# log_every: int = 10, -# checkpoint_every: int = 100, -# validate_every: Optional[int] = None, -# device_count: int = 1, -# distributed_port: int = 54761, -# val_metric_name: str = "loss", -# minimize_val_metric: bool = True, -# auto_aggregate_val_metric: bool = True, -# callbacks: Optional[List[Lazy[TrainCallback]]] = None, -# remove_stale_checkpoints: bool = True, -# ) -> Model: -# """ -# Run a basic training loop to train the ``model``. -# -# :param model: -# The model to train. It should return a ``dict`` that includes the ``loss`` -# during training and the ``val_metric_name`` during validation. -# :param training_engine: -# The :class:`TrainingEngine` to use to train the model. -# :param dataset_dict: -# The train and optional validation data. -# :param train_dataloader: -# The data loader that generates training batches. The batches should be :class:`dict` -# objects that will be used as ``kwargs`` for the model's ``forward()`` method. -# :param train_split: -# The name of the data split used for training in the ``dataset_dict``. -# Default is "train". -# :param validation_split: -# Optional name of the validation split in the ``dataset_dict``. Default is ``None``, -# which means no validation. -# :param validation_dataloader: -# An optional data loader for generating validation batches. The batches should be -# :class:`dict` objects. If not specified, but ``validation_split`` is given, -# the validation ``DataLoader`` will be constructed from the same parameters -# as the train ``DataLoader``. -# :param seed: -# Used to set the RNG states at the beginning of training. -# :param train_steps: -# The number of steps to train for. If not specified training will -# stop after a complete iteration through the ``train_dataloader``. -# :param train_epochs: -# The number of epochs to train for. You cannot specify ``train_steps`` and ``train_epochs`` -# at the same time. -# :param validation_steps: -# The number of steps to validate for. If not specified validation -# will stop after a complete iteration through the ``validation_dataloader``. -# :param grad_accum: -# The number of gradient accumulation steps. Defaults to 1. -# -# .. note:: -# This parameter - in conjuction with the settings of your data loader -# and the number distributed workers - -# determines the *effective batch size* of your training run. -# -# :param log_every: -# Log every this many steps. -# :param checkpoint_every: -# Save a checkpoint every this many steps. -# :param validate_every: -# Run the validation loop every this many steps. -# :param device_count: -# The number of devices to train on, i.e. the number of distributed data parallel workers. -# :param distributed_port: -# The port of the distributed process group. Default = "54761". -# :param val_metric_name: -# The name of the validation metric, i.e. the key of the metric in the dictionary -# returned by the forward pass of the model. Default is "loss". -# :param minimize_val_metric: -# Whether the validation metric is meant to be minimized (such as the loss). -# Default is ``True``. When using a metric such as accuracy, you should set -# this to ``False``. -# :param auto_aggregate_val_metric: -# If ``True`` (the default), the validation metric will be averaged across -# validation batches and distributed processes. This may not be the correct -# behavior for some metrics (such as F1), in which you should set this to -# ``False`` and handle the aggregation internally in your model -# or with a :class:`TrainCallback` (using :meth:`TrainCallback.post_val_batch()`). -# :param callbacks: -# A list of :class:`TrainCallback`. -# :param remove_stale_checkpoints: -# If ``True`` (the default), stale checkpoints will be removed throughout training so that -# only the latest and best checkpoints are kept. -# -# :returns: -# The trained model on CPU with the weights from the best checkpoint loaded. -# -# """ -# # Validate device(s). -# if device_count <= 0: -# raise ConfigurationError("Invalid value for 'device_count'. Must be at least 1.") -# devices: List[int] -# if torch.cuda.is_available() and torch.cuda.device_count() >= device_count: -# devices = list(range(device_count)) -# self.logger.info("Training on %d GPU%s", device_count, "s" if device_count > 1 else "") -# else: -# devices = [-1] * device_count -# self.logger.info( -# "Training on CPU with %d worker%s", device_count, "s" if device_count > 1 else "" -# ) -# -# if validate_every is not None and validation_split is None: -# raise ConfigurationError( -# "You have set a validation interval, but no validation split. " -# "That's probably unintentional." -# ) -# -# is_distributed = False -# num_workers = 1 -# if devices and len(devices) > 1: -# is_distributed = True -# num_workers = len(devices) -# -# if (train_steps is not None) == (train_epochs is not None): -# raise ConfigurationError( -# "One of 'train_steps' or 'train_epochs' needs to be specified, but not both." -# ) -# -# # Tokenize data -# -# # dataset dict -# -# # end tokenization -# -# config = TrainConfig( -# self.unique_id, -# self.work_dir, -# train_split=train_split, -# validation_split=validation_split, -# seed=seed, -# train_steps=train_steps, -# train_epochs=train_epochs, -# grad_accum=grad_accum, -# log_every=log_every, -# checkpoint_every=checkpoint_every, -# validate_every=validate_every, -# validation_steps=validation_steps, -# is_distributed=is_distributed, -# devices=devices, -# distributed_port=distributed_port, -# val_metric_name=val_metric_name, -# minimize_val_metric=minimize_val_metric, -# auto_aggregate_val_metric=auto_aggregate_val_metric, -# remove_stale_checkpoints=remove_stale_checkpoints, -# world_size=num_workers, -# ) -# -# final_model: Model -# if is_distributed: -# import torch.multiprocessing as mp -# -# mp.spawn( -# _train, -# args=( -# config, -# model, -# training_engine, -# dataset_dict, -# train_dataloader, -# validation_dataloader, -# callbacks, -# get_extra_imported_modules(), -# ), -# nprocs=num_workers, -# ) -# self.logger.info("Constructing final model") -# final_model = model.construct() -# else: -# final_model = _train( # type: ignore[assignment] -# 0, -# config, -# model, -# training_engine, -# dataset_dict, -# train_dataloader, -# validation_dataloader=validation_dataloader, -# callbacks=callbacks, -# ) -# assert final_model is not None -# final_model = final_model.cpu() -# -# # Load best checkpoint before returning model. -# if config.final_weights_path.is_file(): -# self.logger.info( -# f"Loading best weights from {str(config.final_weights_path.resolve())}" -# ) -# state = torch.load(config.final_weights_path, map_location="cpu") -# # We use `strict=False` because there might be missing keys due to weight tying. -# final_model.load_state_dict(state, strict=False) -# -# return final_model + # seq2seq: bool = model.config_class in SEQ2SEQ # TODO: without model construction. + + dataset_dict = tokenize_data( + dataset_dict, + tokenizer=tokenizer, + source_field=source_field, + target_field=target_field, + max_source_length=max_source_length, + max_target_length=max_target_length, + seq2seq=seq2seq, + ) + + if is_distributed: + from torch.utils.data.distributed import DistributedSampler + + sampler = Lazy(DistributedSampler, drop_last=True, shuffle=True) + train_dataloader = Lazy( + train_dataloader._constructor, + train_dataloader._params, + constructor_extras=train_dataloader._constructor_extras, + sampler=sampler, + ) + + collate_fn: DataCollator + if seq2seq: + collate_fn = cast(DataCollator, DataCollatorForSeq2Seq(tokenizer=tokenizer)) + else: + collate_fn = cast(DataCollator, DefaultDataCollator()) + + train_dataloader = Lazy( + train_dataloader._constructor, + train_dataloader._params, + constructor_extras=train_dataloader._constructor_extras, + collate_fn=collate_fn, + ) + + config = TrainConfig( + self.unique_id, + self.work_dir, + train_split=train_split, + validation_split=validation_split, + seed=seed, + train_steps=train_steps, + train_epochs=train_epochs, + grad_accum=grad_accum, + log_every=log_every, + checkpoint_every=checkpoint_every, + validate_every=validate_every, + validation_steps=validation_steps, + is_distributed=is_distributed, + devices=devices, + distributed_port=distributed_port, + val_metric_name=val_metric_name, + minimize_val_metric=minimize_val_metric, + auto_aggregate_val_metric=auto_aggregate_val_metric, + remove_stale_checkpoints=remove_stale_checkpoints, + world_size=num_workers, + ) + + final_model: Model + if is_distributed: + import torch.multiprocessing as mp + + mp.spawn( + _train, + args=( + config, + model, + training_engine, + convert_to_tango_dataset_dict(dataset_dict), + train_dataloader, + validation_dataloader, + callbacks, + get_extra_imported_modules(), + ), + nprocs=num_workers, + ) + self.logger.info("Constructing final model") + final_model = model.construct() + else: + final_model = _train( # type: ignore[assignment] + 0, + config, + model, + training_engine, + convert_to_tango_dataset_dict(dataset_dict), + train_dataloader, + validation_dataloader=validation_dataloader, + callbacks=callbacks, + ) + assert final_model is not None + final_model = final_model.cpu() + + # Load best checkpoint before returning model. + if config.final_weights_path.is_file(): + self.logger.info( + f"Loading best weights from {str(config.final_weights_path.resolve())}" + ) + state = torch.load(config.final_weights_path, map_location="cpu") + # We use `strict=False` because there might be missing keys due to weight tying. + final_model.load_state_dict(state, strict=False) + + return final_model From ff1e6a39d39c20e8e29f3da782c4685f6680eacf Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Sun, 3 Apr 2022 23:13:40 -0700 Subject: [PATCH 09/16] docstrings, tests, cleanup --- examples/finetune/config.jsonnet | 30 ++-- examples/finetune/config_gpt2.jsonnet | 151 ------------------ examples/finetune/new_config.jsonnet | 133 --------------- examples/finetune/snli_steps.py | 50 +++++- examples/finetune/test.py | 66 ++------ tango/integrations/transformers/__init__.py | 11 +- tango/integrations/transformers/finetune.py | 126 ++++++++++++--- .../transformers/finetune_test.py | 46 ++++++ 8 files changed, 235 insertions(+), 378 deletions(-) delete mode 100644 examples/finetune/config_gpt2.jsonnet delete mode 100644 examples/finetune/new_config.jsonnet create mode 100644 tests/integrations/transformers/finetune_test.py diff --git a/examples/finetune/config.jsonnet b/examples/finetune/config.jsonnet index 296b1e547..8677d5ea5 100644 --- a/examples/finetune/config.jsonnet +++ b/examples/finetune/config.jsonnet @@ -2,10 +2,10 @@ # Model settings # ################## -local pretrained_model = "patrickvonplaten/t5-tiny-random"; +local pretrained_model = "t5-base"; local load_with_low_cpu_mem_usage = false; -local modules_to_wrap = ["encoder\\.block\\.[0-9]+", "decoder\\.block\\.[0-9]+"]; # tell FairScale to wrap the transformer's blocks individually +local modules_to_wrap = ["[a-zA-Z_.]+\\.[0-9]+"]; # TODO: works for t5 and gpt2. confirm with other models too. #################### # Trainer settings # @@ -61,14 +61,8 @@ local training_engine = { [if fsdp then "fsdp_config" else null]: fsdp_config, }; -local collate_fn = { - type: "transformers::DataCollatorForSeq2Seq", - tokenizer: { pretrained_model_name_or_path: pretrained_model } -}; - local distributed_dataloader = { batch_size: batch_size, - collate_fn: collate_fn, sampler: { type: "torch::DistributedSampler", shuffle: true, @@ -79,7 +73,6 @@ local distributed_dataloader = { local single_device_dataloader = { shuffle: true, batch_size: batch_size, - collate_fn: collate_fn, }; local dataloader = if devices > 1 then distributed_dataloader else single_device_dataloader; @@ -90,26 +83,20 @@ local dataloader = if devices > 1 then distributed_dataloader else single_device type: "datasets::load", path: "snli", }, - "subset_data": { + /*"subset_data": { type: "subset-data", data: { type: "ref", ref: "raw_data" }, max_samples: 10, - }, + },*/ processed_data: { type: "snli-text2text", - data: { type: "ref", ref: "subset_data" }, - }, - "tokenized_data": { - type: "tokenize_text2text", - data: { type: "ref", ref: "processed_data" }, - tokenizer: { pretrained_model_name_or_path: pretrained_model } + data: { type: "ref", ref: "raw_data" }, }, trained_model: { - type: "torch::train", + type: "transformers::finetune", model: { type: "fairscale::with_wrapped_modules", model: { - //type: "transformers::AutoModelForSeq2SeqLM::from_pretrained", type: "transformers::finetune::from_pretrained", pretrained_model_name_or_path: pretrained_model, low_cpu_mem_usage: load_with_low_cpu_mem_usage, @@ -118,7 +105,10 @@ local dataloader = if devices > 1 then distributed_dataloader else single_device fsdp_config: fsdp_config, activation_checkpointing: activation_checkpointing, }, - dataset_dict: { type: "ref", ref: "tokenized_data" }, + tokenizer: { + pretrained_model_name_or_path: pretrained_model + }, + dataset_dict: { type: "ref", ref: "processed_data" }, train_dataloader: dataloader, validation_split: "validation", grad_accum: grad_accum, diff --git a/examples/finetune/config_gpt2.jsonnet b/examples/finetune/config_gpt2.jsonnet deleted file mode 100644 index 5f245b716..000000000 --- a/examples/finetune/config_gpt2.jsonnet +++ /dev/null @@ -1,151 +0,0 @@ -################## -# Model settings # -################## - -local pretrained_model = "sshleifer/tiny-gpt2"; //""patrickvonplaten/t5-tiny-random"; -local load_with_low_cpu_mem_usage = false; - -//local modules_to_wrap = ["encoder\\.block\\.[0-9]+", "decoder\\.block\\.[0-9]+"]; -local modules_to_wrap = ["transformer\\.h\\.[0-9]+"]; - - -#################### -# Trainer settings # -#################### - -# Trainer settings, adjust to your use-case. -local training_steps = 20; # total number of optimization steps to train for -local validate_every = 5; # how often to validate and save checkpoints - -local devices = 1; # number of devices to train on (will use GPUs if enough are available, otherwise CPU) -local grad_accum = 1; # number of gradient accumulation steps (changes the effective batch size) -# This is the batch size per GPU, ignoring gradient accumulation: -local batch_size = 2; -# So the effective batch size is `batch_size * grad_accum * devices` - -local activation_checkpointing = false; # use activation/gradient checkpointing (probably need this GPT-J 6B, but not gpt2) -local amp = false; # use PyTorch's native automatic mixed precision -local fsdp = false; # Use FairScale's FullyShardedDataParallel (probably need this GPT-J 6B, but not gpt2) -local cpu_offloading = false; # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow. - -###################### -# Optimizer settings # -###################### - -local warmup_steps = 20; -local learning_rate = 0.00005; # you can probably use a higher LR for a small model like "gpt2" - - -assert fsdp == true || cpu_offloading == false : "cpu_offloading only available with fsdp"; - -# FullyShardedDataParallel config: -local fsdp_config = if fsdp then { - reshard_after_forward: true, - move_params_to_cpu: cpu_offloading, - move_grads_to_cpu: cpu_offloading, - mixed_precision: amp, -} else null; - -local training_engine = { - type: if fsdp then "fairscale" else "torch", - optimizer: { - type: "torch::AdamW", - lr: learning_rate, - betas: [0.9, 0.95], - eps: 1e-6, - }, - lr_scheduler: { - type: "transformers::linear", - num_warmup_steps: warmup_steps, - num_training_steps: training_steps, - }, - amp: amp, - [if fsdp then "fsdp_config" else null]: fsdp_config, -}; - -local collate_fn = { - type: "transformers::DefaultDataCollator", - //tokenizer: { pretrained_model_name_or_path: pretrained_model }, - //mlm: false, -}; - -local distributed_dataloader = { - batch_size: batch_size, - collate_fn: collate_fn, - sampler: { - type: "torch::DistributedSampler", - shuffle: true, - drop_last: true, - }, -}; - -local single_device_dataloader = { - shuffle: true, - batch_size: batch_size, - collate_fn: collate_fn, -}; - -local dataloader = if devices > 1 then distributed_dataloader else single_device_dataloader; - -{ - steps: { - raw_data: { - type: "datasets::load", - path: "snli", - }, - "subset_data": { - type: "subset-data", - data: { type: "ref", ref: "raw_data" }, - max_samples: 10, - }, - processed_data: { - type: "snli-text2text", - data: { type: "ref", ref: "subset_data" }, - seq2seq: false, - }, - "tokenized_data": { - type: "tokenize_text2text", - data: { type: "ref", ref: "processed_data" }, - tokenizer: { pretrained_model_name_or_path: pretrained_model }, - max_source_length: 500, - max_target_length: 500, - pad_to_max_length: true, - seq2seq: false, - }, - trained_model: { - type: "torch::train", - model: { - type: "fairscale::with_wrapped_modules", - model: { - //type: "transformers::AutoModelForSeq2SeqLM::from_pretrained", - type: "transformers::finetune::from_pretrained", - pretrained_model_name_or_path: pretrained_model, - low_cpu_mem_usage: load_with_low_cpu_mem_usage, - }, - modules_to_wrap: modules_to_wrap, # tell FairScale to wrap the transformer's blocks individually - fsdp_config: fsdp_config, - activation_checkpointing: activation_checkpointing, - }, - dataset_dict: { type: "ref", ref: "tokenized_data" }, - train_dataloader: dataloader, - validation_split: "validation", - grad_accum: grad_accum, - train_steps: training_steps, - validate_every: validate_every, - checkpoint_every: validate_every, - log_every: 1, - device_count: devices, - training_engine: training_engine, - }, - "generations": { - "type": "transformers::run_generation_dataset", - //"max_length": 2, - "input": {"type": "ref", "ref": "processed_data"}, - "batch_size": batch_size, - "model": {"type": "ref", "ref": "trained_model"}, - "prompt_field": "source", - "output_field": "generation", - "splits": ["validation"] - }, - } -} diff --git a/examples/finetune/new_config.jsonnet b/examples/finetune/new_config.jsonnet deleted file mode 100644 index bea1ceab6..000000000 --- a/examples/finetune/new_config.jsonnet +++ /dev/null @@ -1,133 +0,0 @@ -################## -# Model settings # -################## - -local pretrained_model = "patrickvonplaten/t5-tiny-random"; -local load_with_low_cpu_mem_usage = false; - -local modules_to_wrap = ["encoder\\.block\\.[0-9]+", "decoder\\.block\\.[0-9]+"]; # tell FairScale to wrap the transformer's blocks individually - -#################### -# Trainer settings # -#################### - -# Trainer settings, adjust to your use-case. -local training_steps = 20; # total number of optimization steps to train for -local validate_every = 5; # how often to validate and save checkpoints - -local devices = 1; # number of devices to train on (will use GPUs if enough are available, otherwise CPU) -local grad_accum = 1; # number of gradient accumulation steps (changes the effective batch size) -# This is the batch size per GPU, ignoring gradient accumulation: -local batch_size = 2; -# So the effective batch size is `batch_size * grad_accum * devices` - -local activation_checkpointing = false; # use activation/gradient checkpointing (probably need this GPT-J 6B, but not gpt2) -local amp = false; # use PyTorch's native automatic mixed precision -local fsdp = false; # Use FairScale's FullyShardedDataParallel (probably need this GPT-J 6B, but not gpt2) -local cpu_offloading = false; # Can only be used with 'fsdp' - saves a lot of GPU memory by offloading params+gradients to CPU, but is very slow. - -###################### -# Optimizer settings # -###################### - -local warmup_steps = 20; -local learning_rate = 0.00005; # you can probably use a higher LR for a small model like "gpt2" - - -assert fsdp == true || cpu_offloading == false : "cpu_offloading only available with fsdp"; - -# FullyShardedDataParallel config: -local fsdp_config = if fsdp then { - reshard_after_forward: true, - move_params_to_cpu: cpu_offloading, - move_grads_to_cpu: cpu_offloading, - mixed_precision: amp, -} else null; - -local training_engine = { - type: if fsdp then "fairscale" else "torch", - optimizer: { - type: "torch::AdamW", - lr: learning_rate, - betas: [0.9, 0.95], - eps: 1e-6, - }, - lr_scheduler: { - type: "transformers::linear", - num_warmup_steps: warmup_steps, - num_training_steps: training_steps, - }, - amp: amp, - [if fsdp then "fsdp_config" else null]: fsdp_config, -}; - -local distributed_dataloader = { - batch_size: batch_size, - sampler: { - type: "torch::DistributedSampler", - shuffle: true, - drop_last: true, - }, -}; - -local single_device_dataloader = { - shuffle: true, - batch_size: batch_size, -}; - -local dataloader = if devices > 1 then distributed_dataloader else single_device_dataloader; - -{ - steps: { - raw_data: { - type: "datasets::load", - path: "snli", - }, - "subset_data": { - type: "subset-data", - data: { type: "ref", ref: "raw_data" }, - max_samples: 10, - }, - processed_data: { - type: "snli-text2text", - data: { type: "ref", ref: "subset_data" }, - }, - trained_model: { - type: "transformers::finetune", - model: { - type: "fairscale::with_wrapped_modules", - model: { - type: "transformers::finetune::from_pretrained", - pretrained_model_name_or_path: pretrained_model, - low_cpu_mem_usage: load_with_low_cpu_mem_usage, - }, - modules_to_wrap: modules_to_wrap, # tell FairScale to wrap the transformer's blocks individually - fsdp_config: fsdp_config, - activation_checkpointing: activation_checkpointing, - }, - tokenizer: { - pretrained_model_name_or_path: pretrained_model - }, - dataset_dict: { type: "ref", ref: "processed_data" }, - train_dataloader: dataloader, - validation_split: "validation", - grad_accum: grad_accum, - train_steps: training_steps, - validate_every: validate_every, - checkpoint_every: validate_every, - log_every: 1, - device_count: devices, - training_engine: training_engine, - }, - "generations": { - "type": "transformers::run_generation_dataset", - "max_length": 5, - "input": {"type": "ref", "ref": "processed_data"}, - "batch_size": batch_size, - "model": {"type": "ref", "ref": "trained_model"}, - "prompt_field": "source", - "output_field": "generation", - "splits": ["validation"] - }, - } -} diff --git a/examples/finetune/snli_steps.py b/examples/finetune/snli_steps.py index 5ff74fdae..346195977 100644 --- a/examples/finetune/snli_steps.py +++ b/examples/finetune/snli_steps.py @@ -8,6 +8,10 @@ @Step.register("subset-data") class SubsetData(Step): + """ + Creates a subset of the data; mostly to be used for testing/debugging. + """ + DETERMINISTIC = True CACHEABLE = True VERSION = "001" @@ -20,9 +24,16 @@ def run( # type: ignore max_samples: int = 5, ) -> Union[ds.DatasetDict, ds.Dataset]: """ - Unlike `ds.Dataset.select`, this works on both `ds.Dataset` and `ds.DatasetDict`. + Returns a copy of the `data` with number of samples limited to `max_samples` for + each split. + + :param data: + The dataset or dataset dict object. + :param max_samples: + The maximum number of samples to return per split. """ + # Unlike `ds.Dataset.select`, this works on both `ds.Dataset` and `ds.DatasetDict`. def filter_fn(example, indices): return indices < max_samples @@ -31,6 +42,25 @@ def filter_fn(example, indices): @Step.register("snli-text2text") class SnliText2Text(Step): + """ + Converts the snli dataset to a text-to-text format. + + Examples + -------- + + original_instance = { + "premise": "Two cats are sitting on a wall.", + "hypothesis": "The cats are chasing a mouse.", + "label": 2 # contradiction + } + + returned_instance = { + "source": "nli premise: Two cats are sitting on a wall. hypothesis: The cats are chasing a mouse. label: " + "target": "contradiction" + } + + """ + DETERMINISTIC = True CACHEABLE = True VERSION = "001" @@ -45,8 +75,22 @@ def run( # type: ignore hypothesis_prefix: str = "hypothesis", label_prefix: str = "label", num_workers: int = 1, - seq2seq: bool = True, ) -> Union[ds.DatasetDict, ds.Dataset]: + """ + :param data: + The snli `Dataset` or `DatasetDict` object. + :param source_prefix: + The str to add before the start of the source sequence. + :param premise_prefix: + The str to add before the start of the `premise` in the source sequence. + :param hypothesis_prefix: + The str to add before the start of the `hypothesis` in the source sequence. + :param label_prefix: + The str to add as the prompt for the label. + :param num_workers: + The number of workers to use for processing the data. + """ + def filter_no_gold(example, indices): if example["label"] == -1: return False @@ -75,7 +119,7 @@ def _mapper(example): batched=False, num_proc=num_workers, remove_columns=old_cols, # remove all old columns - desc="Converting data to seq2seq format", + desc="Converting data to text-to-text format", ) return dataset diff --git a/examples/finetune/test.py b/examples/finetune/test.py index 91398bcb6..d9f503c87 100644 --- a/examples/finetune/test.py +++ b/examples/finetune/test.py @@ -1,21 +1,31 @@ -# from .snli_steps import SnliText2Text import datasets as ds +import pytest from tango.common import Params from tango.common.testing import TangoTestCase, run_experiment -class TestSnliText2Text(TangoTestCase): - def test_config_with_t5(self): - model = "patrickvonplaten/t5-tiny-random" +class TestFinetuneSNLI(TangoTestCase): + @pytest.mark.parametrize( + "model, model_type", + [("patrickvonplaten/t5-tiny-random", "t5"), ("sshleifer/tiny-gpt2", "gpt2")], + ) + def test_config(self, model: str, model_type: str): overrides = { "steps.trained_model.model.model.pretrained_model_name_or_path": model, + "steps.trained_model.tokenizer.pretrained_model_name_or_path": model, + "steps.subset_data": { + "type": "subset-data", + "data": {"type": "ref", "ref": "raw_data"}, + "max_samples": 10, + }, + "steps.processed_data.data.ref": "subset_data", } config = Params.from_file("config.jsonnet", params_overrides=overrides) # Make sure we've overrode the model entirely. flattened = config.as_flat_dict() for key, value in flattened.items(): - if "model_name" in key or (isinstance(value, str) and "t5" in value): + if "model_name" in key or (isinstance(value, str) and model_type in value): assert value == model with run_experiment(config, include_package=["snli_steps.py"]) as run_dir: @@ -26,50 +36,4 @@ def test_config_with_t5(self): assert "target" in processed["train"][0].keys() assert processed["train"][0]["source"].startswith("nli premise:") - assert (run_dir / "tokenized_data").is_dir() - tokenized = ds.load_from_disk(run_dir / "tokenized_data" / "data") - assert "input_ids" in tokenized["train"][0] - assert (run_dir / "trained_model").is_dir() - - def test_config_with_gpt2(self): - model = "sshleifer/tiny-gpt2" - overrides = { - "steps.trained_model.model.model.pretrained_model_name_or_path": model, - "steps.tokenized_data.tokenizer.pretrained_model_name_or_path": model, - "steps.trained_model.train_dataloader.collate_fn.tokenizer.pretrained_model_name_or_path": model, - "steps.final_metrics.dataloader.collate_fn.tokenizer.pretrained_model_name_or_path": model, - } - config = Params.from_file("config.jsonnet", params_overrides=overrides) - # Make sure we've overrode the model entirely. - flattened = config.as_flat_dict() - for key, value in flattened.items(): - if "model_name" in key or (isinstance(value, str) and "gpt2" in value): - assert value == model, key - - with run_experiment(config, include_package=["snli_steps.py"]) as run_dir: - assert (run_dir / "processed_data").is_dir() - processed = ds.load_from_disk(run_dir / "processed_data" / "data") - assert len(processed["train"][0].keys()) == 2 - assert "source" in processed["train"][0].keys() - assert "target" in processed["train"][0].keys() - assert processed["train"][0]["source"].startswith("nli premise:") - - assert (run_dir / "tokenized_data").is_dir() - tokenized = ds.load_from_disk(run_dir / "tokenized_data" / "data") - assert "input_ids" in tokenized["train"][0] - - assert (run_dir / "trained_model").is_dir() - - # def test_config_with_overrides(self): - # overrides = { - # "steps.processed_data.seq2seq": False, - # } - # config = Params.from_file("test_config.jsonnet", params_overrides=overrides) - # with run_experiment(config, include_package=["snli_steps.py"]) as run_dir: - # assert (run_dir / "processed_data").is_dir() - # processed = ds.load_from_disk(run_dir / "processed_data" / "data") - # assert len(processed["train"][0].keys()) == 1 - # assert "source" in processed["train"][0].keys() - # # assert "target" in processed["train"][0].keys() - # assert processed["train"][0]["source"].startswith("nli premise:") diff --git a/tango/integrations/transformers/__init__.py b/tango/integrations/transformers/__init__.py index cb8cb69ef..e08d21d22 100644 --- a/tango/integrations/transformers/__init__.py +++ b/tango/integrations/transformers/__init__.py @@ -184,10 +184,19 @@ transformers::DefaultDataCollator """ -__all__ = ["RunGeneration", "RunGenerationDataset", "Tokenizer", "Config"] +__all__ = [ + "RunGeneration", + "RunGenerationDataset", + "Tokenizer", + "Config", + "FinetuneWrapper", + "FinetuneStep", + "TokenizeText2TextData", +] from .config import Config from .data import * # noqa: F403 +from .finetune import FinetuneStep, FinetuneWrapper, TokenizeText2TextData from .model import * # noqa: F403 from .optim import * # noqa: F403 from .run_generation import RunGeneration, RunGenerationDataset diff --git a/tango/integrations/transformers/finetune.py b/tango/integrations/transformers/finetune.py index 2d6dff86b..a7fe51794 100644 --- a/tango/integrations/transformers/finetune.py +++ b/tango/integrations/transformers/finetune.py @@ -1,10 +1,11 @@ import logging from os import PathLike -from typing import Any, List, Optional, Union, cast +from typing import List, Optional, Union, cast import datasets as ds import torch from transformers import ( + AutoConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, @@ -27,7 +28,7 @@ TrainingEngine, ) from tango.integrations.torch.train import _train -from tango.integrations.transformers.tokenizer import Tokenizer +from tango.integrations.transformers import Tokenizer from tango.step import Step logger = logging.getLogger(__name__) @@ -37,13 +38,23 @@ class FinetuneWrapper(PreTrainedModel): + """ + Wrapper `PreTrainedModel` class that returns either a `Seq2SeqLM` or `CausalLM` model. + """ + @classmethod def from_pretrained( # type: ignore cls, pretrained_model_name_or_path: Union[str, PathLike], - num_tokens: Optional[int] = None, + num_tokens: Optional[int] = None, # TODO: this seems to not be working correctly. **kwargs, ) -> PreTrainedModel: + """ + :param pretrained_model_name_or_path: + The name of the model to return. Any name that works in the transformers library works here. + :param num_tokens: + The number of token embeddings to have. + """ try: model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path, **kwargs) except ValueError: @@ -80,15 +91,44 @@ def tokenize_data( max_target_length: Optional[int] = 1024, pad_to_max_length: bool = False, ignore_pad_token_for_loss: bool = True, - seq2seq: bool = True, + concat_source_target: bool = False, ) -> ds.DatasetDict: """ - If it's seq2seq, we use `DataCollatorForSeq2Seq`, and take care of padding there. - If it's causal, we use the `DefaultDataCollator`, and take care of padding here. + Returns a `DatasetDict` with tokenized source and target fields. + + :param data: + The original dataset dict containing the source and target fields. + :param tokenizer: + The tokenizer to use. + :param num_workers: + The number of workers to use for processing the data. + :param source_field: + The string name of the field containing the source sequence. + :param target_field: + The string name of the field containing the target sequence. + :param max_source_length: + The maximum number of tokens in the source sequence. + :param max_target_length: + The maximum number of tokens in the target sequence. + :param pad_to_max_length: + Whether to pad to the maximum length when tokenizing. + :param ignore_pad_token_for_loss: + Whether to ignore the padded tokens for calculating loss. + If set to True, all the pad tokens in the labels are replaced + by -100, which is ignored by the loss function. + :param concat_source_target: + If the downstream model is decoder-only, like "gpt2", the source + and target sequences need to be concatenated and fed to the model + together. + + .. tip:: + If concat_source_target is set to True, we pad all sequences to max + length here. Otherwise, we leave it to the appropriate + :class:`~tango.integrations.torch.DataCollator` object. """ - if not seq2seq: - pad_to_max_length = True # TODO: address this. + if concat_source_target: + pad_to_max_length = True padding = "max_length" if pad_to_max_length else False _add_special_tokens(tokenizer) @@ -99,13 +139,11 @@ def preprocess_function(examples): input_lengths = [] for i in range(len(examples[source_field])): if examples[source_field][i] is not None and examples[target_field][i] is not None: - if seq2seq: + if not concat_source_target: inputs.append(examples[source_field][i]) targets.append(examples[target_field][i]) else: - text = ( - examples[source_field][i] + tokenizer.sep_token + examples[target_field][i] - ) + text = examples[source_field][i] + " " + examples[target_field][i] inputs.append(text) targets.append(text) input_lengths.append(len(examples[source_field][i])) @@ -145,6 +183,13 @@ def preprocess_function(examples): @Step.register("tokenize_text2text") class TokenizeText2TextData(Step): + """ + A step that tokenizes data containing source and target sequences. + + .. tip:: + Registered as a :class:`~tango.step.Step` under the name "tokenize_text2text". + """ + DETERMINISTIC = True CACHEABLE = True FORMAT = DatasetsFormat() @@ -160,9 +205,41 @@ def run( # type: ignore[override] max_target_length: Optional[int] = 1024, pad_to_max_length: bool = False, ignore_pad_token_for_loss: bool = True, - seq2seq: bool = True, + concat_source_target: bool = False, ) -> ds.DatasetDict: - + """ + Returns a `DatasetDict` with tokenized source and target fields. + + :param data: + The original dataset dict containing the source and target fields. + :param tokenizer: + The tokenizer to use. + :param num_workers: + The number of workers to use for processing the data. + :param source_field: + The string name of the field containing the source sequence. + :param target_field: + The string name of the field containing the target sequence. + :param max_source_length: + The maximum number of tokens in the source sequence. + :param max_target_length: + The maximum number of tokens in the target sequence. + :param pad_to_max_length: + Whether to pad to the maximum length when tokenizing. + :param ignore_pad_token_for_loss: + Whether to ignore the padded tokens for calculating loss. + If set to True, all the pad tokens in the labels are replaced + by -100, which is ignored by the loss function. + :param concat_source_target: + If the downstream model is decoder-only, like "gpt2", the source + and target sequences need to be concatenated and fed to the model + together. + + .. tip:: + If concat_source_target is set to True, we pad all sequences to max + length here. Otherwise, we leave it to the appropriate + :class:`~tango.integrations.torch.DataCollator` object. + """ return tokenize_data( data, tokenizer=tokenizer, @@ -173,7 +250,7 @@ def run( # type: ignore[override] max_target_length=max_target_length, pad_to_max_length=pad_to_max_length, ignore_pad_token_for_loss=ignore_pad_token_for_loss, - seq2seq=seq2seq, + concat_source_target=concat_source_target, ) @@ -223,7 +300,7 @@ class FinetuneStep(Step): def run( # type: ignore[override] self, model: Lazy[Model], - tokenizer: Tokenizer, # TODO: restrict the type + tokenizer: Tokenizer, training_engine: Lazy[TrainingEngine], dataset_dict: ds.DatasetDict, train_dataloader: Lazy[DataLoader], @@ -235,7 +312,6 @@ def run( # type: ignore[override] target_field: str = "target", max_source_length: Optional[int] = 1024, max_target_length: Optional[int] = 1024, - seq2seq: bool = True, seed: int = 42, train_steps: Optional[int] = None, train_epochs: Optional[int] = None, @@ -258,6 +334,8 @@ def run( # type: ignore[override] :param model: The model to train. It should return a ``dict`` that includes the ``loss`` during training and the ``val_metric_name`` during validation. + :param tokenizer: + The tokenizer to use for tokenizing source and target sequences. :param training_engine: The :class:`TrainingEngine` to use to train the model. :param dataset_dict: @@ -276,6 +354,14 @@ def run( # type: ignore[override] :class:`dict` objects. If not specified, but ``validation_split`` is given, the validation ``DataLoader`` will be constructed from the same parameters as the train ``DataLoader``. + :param source_field: + The string name of the field containing the source sequence. + :param target_field: + The string name of the field containing the target sequence. + :param max_source_length: + The maximum number of tokens in the source sequence. + :param max_target_length: + The maximum number of tokens in the target sequence. :param seed: Used to set the RNG states at the beginning of training. :param train_steps: @@ -367,7 +453,9 @@ def run( # type: ignore[override] num_tokens=len(tokenizer), # type: ignore ) - # seq2seq: bool = model.config_class in SEQ2SEQ # TODO: without model construction. + # Hacky way to get the config to check in order to check if the model is seq2seq or causal. + config = AutoConfig.from_pretrained(tokenizer.name_or_path) + seq2seq: bool = type(config) in SEQ2SEQ dataset_dict = tokenize_data( dataset_dict, @@ -376,7 +464,7 @@ def run( # type: ignore[override] target_field=target_field, max_source_length=max_source_length, max_target_length=max_target_length, - seq2seq=seq2seq, + concat_source_target=not seq2seq, ) if is_distributed: diff --git a/tests/integrations/transformers/finetune_test.py b/tests/integrations/transformers/finetune_test.py new file mode 100644 index 000000000..87c18be0d --- /dev/null +++ b/tests/integrations/transformers/finetune_test.py @@ -0,0 +1,46 @@ +from datasets import Dataset, DatasetDict +from transformers import AutoTokenizer + +from tango.common.testing import TangoTestCase +from tango.integrations.transformers import TokenizeText2TextData + + +class TestTokenizeText2TextData(TangoTestCase): + def test_tokenize_seq2seq(self): + dataset = Dataset.from_dict( + {"field1": ["hello", "hi"], "field2": ["world", "me"], "meta_field": [1, 0]} + ) + data_dict = DatasetDict({"train": dataset}) + tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random") + step = TokenizeText2TextData() + + tokenized = step.run( + data=data_dict, tokenizer=tokenizer, source_field="field1", target_field="field2" + ) + assert isinstance(tokenized, DatasetDict) + assert len(tokenized["train"]) == 2 + assert "input_ids" in tokenized["train"].column_names + assert "labels" in tokenized["train"].column_names + assert tokenized["train"][0]["input_ids"] == [21820, 1] + + def test_tokenize_concat(self): + dataset = Dataset.from_dict( + {"field1": ["hello", "hi"], "field2": ["world", "me"], "meta_field": [1, 0]} + ) + data_dict = DatasetDict({"train": dataset}) + tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random") + step = TokenizeText2TextData() + + tokenized = step.run( + data=data_dict, + tokenizer=tokenizer, + source_field="field1", + target_field="field2", + concat_source_target=True, + max_source_length=5, + ) + assert isinstance(tokenized, DatasetDict) + assert len(tokenized["train"]) == 2 + assert "input_ids" in tokenized["train"].column_names + assert "labels" in tokenized["train"].column_names + assert tokenized["train"][0]["input_ids"] == [21820, 296, 1, 0, 0] From bfa8b2430192babf500e231fc1056e2e8d8b6682 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 4 Apr 2022 01:55:31 -0700 Subject: [PATCH 10/16] fix bug with num tokens --- examples/finetune/config.jsonnet | 20 +++++++++--------- tango/common/lazy.py | 4 ++-- tango/integrations/transformers/finetune.py | 23 ++++++++++++++------- 3 files changed, 28 insertions(+), 19 deletions(-) diff --git a/examples/finetune/config.jsonnet b/examples/finetune/config.jsonnet index 8677d5ea5..485739742 100644 --- a/examples/finetune/config.jsonnet +++ b/examples/finetune/config.jsonnet @@ -119,15 +119,15 @@ local dataloader = if devices > 1 then distributed_dataloader else single_device device_count: devices, training_engine: training_engine, }, - "generations": { - "type": "transformers::run_generation_dataset", - "max_length": 5, - "input": {"type": "ref", "ref": "processed_data"}, - "batch_size": batch_size, - "model": {"type": "ref", "ref": "trained_model"}, - "prompt_field": "source", - "output_field": "generation", - "splits": ["validation"] - }, + generations: { + type: "transformers::run_generation_dataset", + max_length: 5, + input: {"type": "ref", "ref": "processed_data"}, + batch_size: batch_size, + model: {"type": "ref", "ref": "trained_model"}, + prompt_field: "source", + output_field: "generation", + splits: ["validation"] + } } } diff --git a/tango/common/lazy.py b/tango/common/lazy.py index 2b79b544d..e3640db8c 100644 --- a/tango/common/lazy.py +++ b/tango/common/lazy.py @@ -83,5 +83,5 @@ def construct(self, **kwargs) -> T: """ # If there are duplicate keys between self._constructor_extras and kwargs, # this will overwrite the ones in self._constructor_extras with what's in kwargs. - contructor_kwargs = {**self._constructor_extras, **kwargs} - return self.constructor(**contructor_kwargs) + constructor_kwargs = {**self._constructor_extras, **kwargs} + return self.constructor(**constructor_kwargs) diff --git a/tango/integrations/transformers/finetune.py b/tango/integrations/transformers/finetune.py index a7fe51794..b95868ce1 100644 --- a/tango/integrations/transformers/finetune.py +++ b/tango/integrations/transformers/finetune.py @@ -13,7 +13,7 @@ PreTrainedModel, ) -from tango.common import Lazy +from tango.common import Lazy, Params from tango.common.exceptions import ConfigurationError from tango.common.util import get_extra_imported_modules from tango.format import Format @@ -46,7 +46,7 @@ class FinetuneWrapper(PreTrainedModel): def from_pretrained( # type: ignore cls, pretrained_model_name_or_path: Union[str, PathLike], - num_tokens: Optional[int] = None, # TODO: this seems to not be working correctly. + num_tokens: Optional[int] = None, **kwargs, ) -> PreTrainedModel: """ @@ -60,7 +60,7 @@ def from_pretrained( # type: ignore except ValueError: model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **kwargs) - if num_tokens: + if num_tokens is not None: model.resize_token_embeddings(num_tokens) return model @@ -143,7 +143,9 @@ def preprocess_function(examples): inputs.append(examples[source_field][i]) targets.append(examples[target_field][i]) else: - text = examples[source_field][i] + " " + examples[target_field][i] + text = ( + examples[source_field][i] + tokenizer.sep_token + examples[target_field][i] + ) inputs.append(text) targets.append(text) input_lengths.append(len(examples[source_field][i])) @@ -446,14 +448,21 @@ def run( # type: ignore[override] # Setup the tokenizer _add_special_tokens(tokenizer) + + # Hacky way to deal with resizing the model embeddings. + model_params_dict = model._params.as_dict() + if "fairscale" in model_params_dict["type"]: + model_params_dict["model"]["num_tokens"] = len(tokenizer) # type: ignore + else: + model_params_dict["num_tokens"] = len(tokenizer) # type: ignore + model = Lazy( model._constructor, - model._params, + Params(model_params_dict), constructor_extras=model._constructor_extras, - num_tokens=len(tokenizer), # type: ignore ) - # Hacky way to get the config to check in order to check if the model is seq2seq or causal. + # Get the config to check in order to check if the model is seq2seq or causal. config = AutoConfig.from_pretrained(tokenizer.name_or_path) seq2seq: bool = type(config) in SEQ2SEQ From dbfe36e33e74fd769d6cad45c94113733cafd5f8 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 4 Apr 2022 01:58:32 -0700 Subject: [PATCH 11/16] update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c4d51002..477f7b96c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added an `ExecutorOutput` dataclass that is returned by `Executor.execute_step_graph()`. - `StepGraph` now prints itself in a readable way. - Tango now automatically detects when it's running under a debugger, and disables multicore support accordingly. Many debuggers can't properly follow sub-processes, so this is a convenience for people who love debuggers. +- Added new example for finetuning text-to-text models. ### Changed @@ -24,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Refactored `tango.step_graph.StepGraph` to allow initialization from a `Dict[str, Step]`. - `Executor.execute_step_graph()` now attempts to execute all steps and summarizes success/failures. - Upgraded PyTorch version in `tango` Docker image to latest `v1.11.0+cu113`. +- `RunGeneration` now allows model object as input. ### Fixed From 0fe64a812925acb796da8150b89cf2eed7163d1a Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 4 Apr 2022 02:06:22 -0700 Subject: [PATCH 12/16] fix test --- tests/integrations/transformers/finetune_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integrations/transformers/finetune_test.py b/tests/integrations/transformers/finetune_test.py index 87c18be0d..581a7ed81 100644 --- a/tests/integrations/transformers/finetune_test.py +++ b/tests/integrations/transformers/finetune_test.py @@ -43,4 +43,4 @@ def test_tokenize_concat(self): assert len(tokenized["train"]) == 2 assert "input_ids" in tokenized["train"].column_names assert "labels" in tokenized["train"].column_names - assert tokenized["train"][0]["input_ids"] == [21820, 296, 1, 0, 0] + assert tokenized["train"][0]["input_ids"] == [21820, 32100, 296, 1, 0] From b04c8ff6c3e7d7055291f2e6b7ffb3712f9a9484 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 4 Apr 2022 02:36:58 -0700 Subject: [PATCH 13/16] test with different model --- tango/integrations/transformers/finetune.py | 1 - tests/integrations/transformers/finetune_test.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tango/integrations/transformers/finetune.py b/tango/integrations/transformers/finetune.py index b95868ce1..d7f1660f1 100644 --- a/tango/integrations/transformers/finetune.py +++ b/tango/integrations/transformers/finetune.py @@ -156,7 +156,6 @@ def preprocess_function(examples): # Setup the tokenizer for targets with tokenizer.as_target_tokenizer(): - # TODO: remove source seq loss. labels = tokenizer( targets, max_length=max_target_length, padding=padding, truncation=True ) diff --git a/tests/integrations/transformers/finetune_test.py b/tests/integrations/transformers/finetune_test.py index 581a7ed81..85141c608 100644 --- a/tests/integrations/transformers/finetune_test.py +++ b/tests/integrations/transformers/finetune_test.py @@ -28,7 +28,7 @@ def test_tokenize_concat(self): {"field1": ["hello", "hi"], "field2": ["world", "me"], "meta_field": [1, 0]} ) data_dict = DatasetDict({"train": dataset}) - tokenizer = AutoTokenizer.from_pretrained("patrickvonplaten/t5-tiny-random") + tokenizer = AutoTokenizer.from_pretrained("sshleifer/tiny-gpt2") step = TokenizeText2TextData() tokenized = step.run( @@ -43,4 +43,4 @@ def test_tokenize_concat(self): assert len(tokenized["train"]) == 2 assert "input_ids" in tokenized["train"].column_names assert "labels" in tokenized["train"].column_names - assert tokenized["train"][0]["input_ids"] == [21820, 32100, 296, 1, 0] + assert tokenized["train"][0]["input_ids"] == [31373, 50257, 6894, 50256, 50256] From 2ace30663c9681c122ab2db3c281f51aedcafab5 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 4 Apr 2022 15:19:02 -0700 Subject: [PATCH 14/16] simplify --- tango/integrations/transformers/finetune.py | 19 +++++-------------- .../transformers/finetune_test.py | 3 +-- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/tango/integrations/transformers/finetune.py b/tango/integrations/transformers/finetune.py index d7f1660f1..1677b7ff2 100644 --- a/tango/integrations/transformers/finetune.py +++ b/tango/integrations/transformers/finetune.py @@ -9,7 +9,6 @@ AutoModelForCausalLM, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, - DefaultDataCollator, PreTrainedModel, ) @@ -120,15 +119,7 @@ def tokenize_data( If the downstream model is decoder-only, like "gpt2", the source and target sequences need to be concatenated and fed to the model together. - - .. tip:: - If concat_source_target is set to True, we pad all sequences to max - length here. Otherwise, we leave it to the appropriate - :class:`~tango.integrations.torch.DataCollator` object. """ - - if concat_source_target: - pad_to_max_length = True padding = "max_length" if pad_to_max_length else False _add_special_tokens(tokenizer) @@ -144,7 +135,10 @@ def preprocess_function(examples): targets.append(examples[target_field][i]) else: text = ( - examples[source_field][i] + tokenizer.sep_token + examples[target_field][i] + examples[source_field][i] + + tokenizer.sep_token + + examples[target_field][i] + + tokenizer.eos_token ) inputs.append(text) targets.append(text) @@ -487,10 +481,7 @@ def run( # type: ignore[override] ) collate_fn: DataCollator - if seq2seq: - collate_fn = cast(DataCollator, DataCollatorForSeq2Seq(tokenizer=tokenizer)) - else: - collate_fn = cast(DataCollator, DefaultDataCollator()) + collate_fn = cast(DataCollator, DataCollatorForSeq2Seq(tokenizer=tokenizer)) train_dataloader = Lazy( train_dataloader._constructor, diff --git a/tests/integrations/transformers/finetune_test.py b/tests/integrations/transformers/finetune_test.py index 85141c608..15678f89d 100644 --- a/tests/integrations/transformers/finetune_test.py +++ b/tests/integrations/transformers/finetune_test.py @@ -37,10 +37,9 @@ def test_tokenize_concat(self): source_field="field1", target_field="field2", concat_source_target=True, - max_source_length=5, ) assert isinstance(tokenized, DatasetDict) assert len(tokenized["train"]) == 2 assert "input_ids" in tokenized["train"].column_names assert "labels" in tokenized["train"].column_names - assert tokenized["train"][0]["input_ids"] == [31373, 50257, 6894, 50256, 50256] + assert tokenized["train"][0]["input_ids"] == [31373, 50257, 6894, 50256] From da291db42087260d1e9e60a7c56df2b2903fa884 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 4 Apr 2022 16:37:06 -0700 Subject: [PATCH 15/16] limit loss calculation to actual labels --- tango/integrations/transformers/finetune.py | 19 ++++++++++++++----- .../transformers/finetune_test.py | 1 + 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/tango/integrations/transformers/finetune.py b/tango/integrations/transformers/finetune.py index 1677b7ff2..b026459ee 100644 --- a/tango/integrations/transformers/finetune.py +++ b/tango/integrations/transformers/finetune.py @@ -148,11 +148,20 @@ def preprocess_function(examples): inputs, max_length=max_source_length, padding=padding, truncation=True ) - # Setup the tokenizer for targets - with tokenizer.as_target_tokenizer(): - labels = tokenizer( - targets, max_length=max_target_length, padding=padding, truncation=True - ) + if not concat_source_target: + # Setup the tokenizer for targets + with tokenizer.as_target_tokenizer(): + labels = tokenizer( + targets, max_length=max_target_length, padding=padding, truncation=True + ) + else: + labels = {"input_ids": []} + sep_token_idx = tokenizer.convert_tokens_to_ids([tokenizer.sep_token])[0] + for input_ids in model_inputs["input_ids"]: + label_start_idx = input_ids.index(sep_token_idx) + label_ids = [-100] * len(input_ids) + label_ids[label_start_idx + 1 :] = input_ids[label_start_idx + 1 :] + labels["input_ids"].append(label_ids) # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 # when we want to ignore padding in the loss. diff --git a/tests/integrations/transformers/finetune_test.py b/tests/integrations/transformers/finetune_test.py index 15678f89d..8ec447484 100644 --- a/tests/integrations/transformers/finetune_test.py +++ b/tests/integrations/transformers/finetune_test.py @@ -43,3 +43,4 @@ def test_tokenize_concat(self): assert "input_ids" in tokenized["train"].column_names assert "labels" in tokenized["train"].column_names assert tokenized["train"][0]["input_ids"] == [31373, 50257, 6894, 50256] + assert tokenized["train"][0]["labels"] == [-100, -100, 6894, 50256] From c247c486915078aff704a7489743697e3c779075 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Thu, 14 Apr 2022 23:22:35 -0700 Subject: [PATCH 16/16] address comments --- tango/integrations/transformers/finetune.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tango/integrations/transformers/finetune.py b/tango/integrations/transformers/finetune.py index b026459ee..421b2657f 100644 --- a/tango/integrations/transformers/finetune.py +++ b/tango/integrations/transformers/finetune.py @@ -156,9 +156,8 @@ def preprocess_function(examples): ) else: labels = {"input_ids": []} - sep_token_idx = tokenizer.convert_tokens_to_ids([tokenizer.sep_token])[0] for input_ids in model_inputs["input_ids"]: - label_start_idx = input_ids.index(sep_token_idx) + label_start_idx = input_ids.index(tokenizer.sep_token_id) label_ids = [-100] * len(input_ids) label_ids[label_start_idx + 1 :] = input_ids[label_start_idx + 1 :] labels["input_ids"].append(label_ids) @@ -185,13 +184,13 @@ def preprocess_function(examples): return data -@Step.register("tokenize_text2text") +@Step.register("transformers::tokenize_text2text") class TokenizeText2TextData(Step): """ A step that tokenizes data containing source and target sequences. .. tip:: - Registered as a :class:`~tango.step.Step` under the name "tokenize_text2text". + Registered as a :class:`~tango.step.Step` under the name "transformers::tokenize_text2text". """ DETERMINISTIC = True