diff --git a/Makefile b/Makefile index e93f23296a..da051959ef 100644 --- a/Makefile +++ b/Makefile @@ -159,11 +159,15 @@ slow_tests_video_llava_example: test_installs slow_tests_fsdp: test_installs python -m pytest tests/test_fsdp_examples.py -v -s --token $(TOKEN) -slow_tests_trl: test_installs +slow_tests_trl_ddpo: test_installs python -m pip install trl==0.9.6 python -m pip install peft==0.12.0 python -m pytest tests/test_trl.py -v -s -k "test_calculate_loss" +slow_tests_trl_grpo: test_installs + python -m pip install -r examples/trl/requirements_grpo.txt + python -m pytest tests/test_trl.py -v -s -k "GaudiGRPOTrainerTester" + slow_tests_object_segmentation: test_installs python -m pytest tests/test_object_segmentation.py diff --git a/examples/trl/README.md b/examples/trl/README.md index 286b51a8bf..bce2f83b78 100644 --- a/examples/trl/README.md +++ b/examples/trl/README.md @@ -4,10 +4,74 @@ ## Requirements First, you should install the requirements: + +- For **GRPO example**: +```bash +$ pip install -U -r requirements_grpo.txt +``` + +- For **all other examples**: ```bash $ pip install -U -r requirements.txt ``` +## GRPO Training + +Installing DeepSpeed + +```sh +pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.21.0 +``` + +Running single card training + +```sh +PT_HPU_MAX_COMPOUND_OP_SIZE=10 PT_HPU_LAZY_MODE=1 python3 grpo.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name AI-MO/NuminaMath-TIR \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 8 \ + --do_train \ + --do_eval \ + --use_habana \ + --use_lazy_mode \ + --bf16 True \ + --gradient_accumulation_steps=16 \ + --max_prompt_length 512 \ + --num_generations 4 \ + --max_completion_length 64 \ + --use_peft True \ + --lora_target_modules q_proj k_proj \ + --num_train_epochs 1 \ + --save_strategy="epoch" +``` + + +Runnig multi-card training + +```sh +PT_HPU_MAX_COMPOUND_OP_SIZE=10 PT_HPU_LAZY_MODE=1 python3 ../gaudi_spawn.py --world_size 8 --use_deepspeed grpo.py \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --dataset_name AI-MO/NuminaMath-TIR \ + --per_device_train_batch_size 8 \ + --per_device_eval_batch_size 8 \ + --do_train \ + --do_eval \ + --use_habana \ + --use_lazy_mode \ + --bf16 True \ + --gradient_accumulation_steps=16 \ + --gradient_checkpointing \ + --max_prompt_length 512 \ + --num_generations 4 \ + --max_completion_length 64 \ + --use_peft True \ + --lora_target_modules q_proj k_proj \ + --max_steps=500 \ + --logging_steps=10 \ + --save_steps=100 +``` + ## Supervised Finetuning 1. The following example is for the supervised Lora finetune with Qwen2 model for conversational format dataset. diff --git a/examples/trl/grpo.py b/examples/trl/grpo.py new file mode 100644 index 0000000000..89124c424f --- /dev/null +++ b/examples/trl/grpo.py @@ -0,0 +1,210 @@ +import contextlib +import io +import logging +import re +from dataclasses import dataclass, field +from typing import List, Optional + +import torch +import transformers +from datasets import load_dataset +from math_verify import LatexExtractionConfig, parse, verify +from peft import LoraConfig +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser +from transformers.integrations.deepspeed import ( + is_deepspeed_available, +) +from transformers.trainer_utils import is_main_process + +from optimum.habana import GaudiConfig +from optimum.habana.trl import GaudiGRPOConfig, GaudiGRPOTrainer +from optimum.habana.utils import set_seed + + +logger = logging.getLogger(__name__) +SYSTEM_PROMPT = ( + "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " + "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " + "process and answer are enclosed within and tags, respectively, i.e., " + " reasoning process here answer here " +) + + +def make_conversation(example): + return { + "prompt": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": example["problem"]}, + ], + } + + +ideal_length = 50 + + +def reward_len(completions, **kwargs): + return [-abs(ideal_length - len(completion)) for completion in completions] # penalize response when len!=50 + + +def format_reward(completions, **kwargs): + # Checks if the reasoning process is enclosed within and tags, + # while the final answer is enclosed within and tags. + pattern = r"^.*?\s*.*?$" + completion_contents = [completion[0]["content"] for completion in completions] + matches = [re.match(pattern, content) for content in completion_contents] + return [1.0 if match else 0.0 for match in matches] + + +def accuracy_reward(completions, **kwargs): + # Checks if the completion is the same as the ground truth. + solutions = kwargs["solution"] + completion_contents = [completion[0]["content"] for completion in completions] + rewards = [] + for content, solution in zip(completion_contents, solutions): + gold_parsed = parse(solution, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()]) + answer_parsed = parse(content, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()]) + if len(gold_parsed) != 0: + try: + with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()): + rewards.append(float(verify(answer_parsed, gold_parsed))) + except Exception: + rewards.append(0.0) + else: + rewards.append(1.0) + return rewards + + +@dataclass +class ScriptArguments: + model_name_or_path: Optional[str] = field(default="Qwen/Qwen2-0.5B-Instruct", metadata={"help": "the model name"}) + dataset_name: Optional[str] = field(default=None, metadata={"help": "the dataset name"}) + use_peft: Optional[bool] = field(default=False, metadata={"help": "whether to use peft"}) + num_workers: Optional[int] = field(default=1, metadata={"help": "the number of workers"}) + subset: Optional[str] = field(default=None, metadata={"help": "the subset to use"}) + streaming: Optional[bool] = field(default=False, metadata={"help": "whether to stream the dataset"}) + dataset_train_split: str = field(default="train[:5%]", metadata={"help": "Dataset split to use for training."}) + dataset_test_split: str = field(default="test[:5%]", metadata={"help": "Dataset split to use for evaluation."}) + reward_model_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": "Reward model id of a pretrained model hosted inside a model repo on huggingface.co or " + "local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`." + }, + ) + + use_flash_attention: Optional[bool] = field( + default=True, metadata={"help": "Whether to use Habana flash attention for fine-tuning."} + ) + flash_attention_recompute: Optional[bool] = field( + default=False, metadata={"help": "Whether to enable recompute in Habana flash attention for fine-tuning."} + ) + flash_attention_causal_mask: Optional[bool] = field( + default=False, metadata={"help": "Whether to enable causal mask in Habana flash attention for fine-tuning."} + ) + + # LoraConfig + lora_alpha: Optional[float] = field(default=32, metadata={"help": "the lora alpha parameter"}) + lora_dropout: Optional[float] = field(default=0.1, metadata={"help": "the lora dropout parameter"}) + lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) + lora_target_modules: List[str] = field( + default_factory=lambda: None, + metadata={"help": "Target modules for the LoRA method."}, + ) + + +if __name__ == "__main__": + parser = HfArgumentParser((GaudiGRPOConfig, ScriptArguments)) + (training_args, script_args) = parser.parse_args_into_dataclasses() + + logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, " + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.bf16}" + ) + # Set the verbosity to info of the Transformers logger (on main process only): + if is_main_process(training_args.local_rank): + transformers.utils.logging.set_verbosity_info() + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + logger.info(f"Training/evaluation parameters {training_args}") + # Set seed before initializing model. + set_seed(training_args.seed) + + use_deepspeed = training_args.world_size > 1 + + if script_args.use_peft: + peft_config = LoraConfig( + r=script_args.lora_r, + lora_alpha=script_args.lora_alpha, + lora_dropout=script_args.lora_dropout, + target_modules=script_args.lora_target_modules, + task_type="CAUSAL_LM", + ) + else: + peft_config = None + + tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path, trust_remote_code=True) + if training_args.chat_template is not None: + tokenizer.chat_template = training_args.chat_template + + train_dataset, test_dataset = load_dataset( + script_args.dataset_name, + data_dir=None if script_args.subset == "None" else script_args.subset, + num_proc=script_args.num_workers if not script_args.streaming else None, + split=[script_args.dataset_train_split, script_args.dataset_test_split], + ) + + train_dataset = train_dataset.map(make_conversation) + test_dataset = test_dataset.map(make_conversation) + train_dataset = train_dataset.remove_columns(["messages", "problem"]) + + low_cpu_mem_usage = True + if is_deepspeed_available() and use_deepspeed: + from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled + + if is_deepspeed_zero3_enabled(): + low_cpu_mem_usage = False + + model = AutoModelForCausalLM.from_pretrained( + script_args.model_name_or_path, + low_cpu_mem_usage=low_cpu_mem_usage, + torch_dtype=torch.bfloat16, + ) + + model.config.use_cache = False + if not script_args.use_flash_attention and ( + script_args.flash_attention_recompute or script_args.flash_attention_recompute + ): + assert "Need to enable use_flash_attention" + model.generation_config.use_flash_attention = script_args.use_flash_attention + model.generation_config.flash_attention_recompute = script_args.flash_attention_recompute + model.generation_config.flash_attention_causal_mask = script_args.flash_attention_causal_mask + + reward_funcs = [format_reward, accuracy_reward] + if script_args.reward_model_name_or_path: + reward_funcs = AutoModelForSequenceClassification.from_pretrained( + script_args.reward_model_name_or_path, + trust_remote_code=True, + ) + + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + gaudi_config = GaudiConfig() + gaudi_config.use_fused_adam = True + gaudi_config.use_fused_clip_norm = True + + trainer = GaudiGRPOTrainer( + model=model, + reward_funcs=reward_funcs, + args=training_args, + train_dataset=train_dataset, + eval_dataset=test_dataset, + processing_class=tokenizer, + gaudi_config=gaudi_config, + peft_config=peft_config, + ) + + trainer.train() + + print("Done!") diff --git a/examples/trl/requirements_grpo.txt b/examples/trl/requirements_grpo.txt new file mode 100644 index 0000000000..e7475bbc91 --- /dev/null +++ b/examples/trl/requirements_grpo.txt @@ -0,0 +1,8 @@ +trl == 0.17.0 +peft == 0.12.0 +datasets +tyro +evaluate +scikit-learn == 1.5.2 +accelerate +math_verify diff --git a/optimum/habana/trl/__init__.py b/optimum/habana/trl/__init__.py index 37f4d1156f..060e0b1379 100644 --- a/optimum/habana/trl/__init__.py +++ b/optimum/habana/trl/__init__.py @@ -1,10 +1,21 @@ +import importlib.metadata + +from packaging import version + from .models.modeling_base import adapt_PreTrainedModelWrapper_to_gaudi from .models.modeling_sd_base import GaudiDefaultDDPOStableDiffusionPipeline from .trainer.ddpo_trainer import GaudiDDPOTrainer from .trainer.dpo_config import GaudiDPOConfig from .trainer.dpo_trainer import GaudiDPOTrainer -from .trainer.ppo_config import GaudiPPOConfig -from .trainer.ppo_trainer import GaudiPPOTrainer + + +trl_version = importlib.metadata.version("trl") +if version.parse(trl_version) < version.parse("0.17.0"): + from .trainer.ppo_config import GaudiPPOConfig + from .trainer.ppo_trainer import GaudiPPOTrainer +else: + from .trainer.grpo_config import GaudiGRPOConfig + from .trainer.grpo_trainer import GaudiGRPOTrainer from .trainer.reward_trainer import GaudiRewardTrainer, RewardDataCollatorWithPadding from .trainer.sft_config import GaudiSFTConfig from .trainer.sft_trainer import GaudiSFTTrainer diff --git a/optimum/habana/trl/trainer/__init__.py b/optimum/habana/trl/trainer/__init__.py index 6da9debbd8..f6de8bf253 100644 --- a/optimum/habana/trl/trainer/__init__.py +++ b/optimum/habana/trl/trainer/__init__.py @@ -16,13 +16,22 @@ # There is a circular import in the PPOTrainer if we let isort sort these # isort: on +import importlib.metadata +from packaging import version from .sft_trainer import GaudiSFTTrainer from .dpo_trainer import GaudiDPOTrainer -from .ppo_config import GaudiPPOConfig -from .ppo_trainer import GaudiPPOTrainer + from .reward_trainer import GaudiRewardTrainer, RewardDataCollatorWithPadding from .ddpo_trainer import GaudiDDPOTrainer from .dpo_config import GaudiDPOConfig from .sft_config import GaudiSFTConfig + +trl_version = importlib.metadata.version("trl") +if version.parse(trl_version) < version.parse("0.17.0"): + from .ppo_config import GaudiPPOConfig + from .ppo_trainer import GaudiPPOTrainer +else: + from .grpo_trainer import GaudiGRPOTrainer + from .grpo_config import GaudiGRPOConfig diff --git a/optimum/habana/trl/trainer/dpo_trainer.py b/optimum/habana/trl/trainer/dpo_trainer.py index d57a032983..64846da42f 100644 --- a/optimum/habana/trl/trainer/dpo_trainer.py +++ b/optimum/habana/trl/trainer/dpo_trainer.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib.metadata import inspect import warnings from collections import defaultdict @@ -24,6 +25,7 @@ from accelerate import PartialState from accelerate.utils import is_deepspeed_available from datasets import Dataset +from packaging import version from transformers import ( AutoModelForCausalLM, DataCollator, @@ -34,12 +36,10 @@ from transformers.trainer_callback import TrainerCallback from transformers.trainer_utils import EvalLoopOutput from trl import DPOTrainer, create_reference_model -from trl.import_utils import is_peft_available, is_wandb_available from trl.trainer.dpo_config import FDivergenceConstants from trl.trainer.utils import ( DPODataCollatorWithPadding, RunningMoments, - SyncRefModelCallback, disable_dropout_in_model, pad_to_length, ) @@ -48,6 +48,16 @@ from .dpo_config import GaudiDPOConfig +trl_version = importlib.metadata.version("trl") +if version.parse(trl_version) < version.parse("0.17.0"): + from trl.import_utils import is_peft_available, is_wandb_available + from trl.trainer.utils import SyncRefModelCallback +else: + from transformers import is_wandb_available + from transformers.utils import is_peft_available + from trl.trainer.callbacks import SyncRefModelCallback + + if is_peft_available(): from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training diff --git a/optimum/habana/trl/trainer/grpo_config.py b/optimum/habana/trl/trainer/grpo_config.py new file mode 100644 index 0000000000..62df6c2e07 --- /dev/null +++ b/optimum/habana/trl/trainer/grpo_config.py @@ -0,0 +1,316 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass, field +from typing import Optional + +from ... import GaudiTrainingArguments + + +@dataclass +class GaudiGRPOConfig(GaudiTrainingArguments): + r""" + Initialize GaudiGRPOConfig. + Adapted from https://github.com/huggingface/trl/blob/v0.17.0/trl/trainer/grpo_config.py + - inherit from GaudiTrainingArguments + """ + + # Parameters that control the model and reference model + model_init_kwargs: Optional[dict] = field( + default=None, + metadata={ + "help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` " + "argument of the `GRPOTrainer` is provided as a string." + }, + ) + + # Parameters that control the data preprocessing + # The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on + # additional columns to compute the reward + remove_unused_columns: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function " + "that requires any column other than 'prompts' and 'completions', you should keep this to `False`." + }, + ) + max_prompt_length: Optional[int] = field( + default=512, + metadata={ + "help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left." + }, + ) + num_generations: Optional[int] = field( + default=4, + metadata={ + "help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) " + "must be divisible by this value." + }, + ) + max_completion_length: Optional[int] = field( + default=64, + metadata={"help": "Maximum length of the generated completion."}, + ) + ds3_gather_for_generation: bool = field( + default=True, + metadata={ + "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " + "generation, improving generation speed. However, disabling this option allows training models that " + "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option " + "is not compatible with vLLM generation." + }, + ) + shuffle_dataset: Optional[bool] = field( + default=True, + metadata={"help": "Whether to shuffle the training dataset."}, + ) + + # Parameters that control generation + temperature: float = field( + default=0.9, + metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, + ) + top_p: float = field( + default=1.0, + metadata={ + "help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. " + "Set to 1.0 to consider all tokens." + }, + ) + top_k: Optional[int] = field( + default=50, + metadata={ + "help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, " + "top-k-filtering is disabled." + }, + ) + min_p: Optional[float] = field( + default=None, + metadata={ + "help": "Minimum token probability, which will be scaled by the probability of the most likely token. It " + "must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range." + }, + ) + repetition_penalty: float = field( + default=1.0, + metadata={ + "help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated " + "text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model " + "to repeat tokens." + }, + ) + cache_implementation: Optional[str] = field( + default=None, + metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."}, + ) + + # Parameters that control generation acceleration powered by vLLM + use_vllm: bool = field( + default=False, + metadata={ + "help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a vLLM server is " + "running. To run the server, install vLLM (`pip install vllm`) and run `trl vllm-serve`." + }, + ) + vllm_server_host: str = field( + default="0.0.0.0", + metadata={"help": "Host of the vLLM server to connect to."}, + ) + vllm_server_port: int = field( + default=8000, + metadata={"help": "Port of the vLLM server to connect to."}, + ) + vllm_server_timeout: float = field( + default=120.0, + metadata={ + "help": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up " + "after the timeout, a `ConnectionError` is raised." + }, + ) + vllm_guided_decoding_regex: Optional[str] = field( + default=None, + metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."}, + ) + + # Parameters that control the training + learning_rate: float = field( + default=2e-5, + metadata={ + "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " + "`transformers.TrainingArguments`." + }, + ) + beta: float = field( + default=0.04, + metadata={ + "help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving " + "training speed, but may be numerically unstable for long training runs." + }, + ) + num_iterations: int = field( + default=1, + metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."}, + ) + epsilon: float = field( + default=0.2, + metadata={"help": "Epsilon value for clipping."}, + ) + epsilon_high: Optional[float] = field( + default=None, + metadata={ + "help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the " + "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`." + }, + ) + reward_weights: Optional[list[float]] = field( + default=None, + metadata={ + "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all " + "rewards are weighted equally with weight `1.0`." + }, + ) + scale_rewards: bool = field( + default=True, + metadata={ + "help": "Whether to scale the rewards by dividing them by their standard deviation. If `True` (default), " + "the rewards are normalized by the standard deviation, ensuring they have unit variance. If `False`, no " + "scaling is applied. The Dr. GRPO paper recommends not scaling the rewards, as scaling by the standard " + "deviation introduces a question-level difficulty bias." + }, + ) + mask_truncated_completions: bool = field( + default=False, + metadata={ + "help": "When enabled, truncated completions are excluded from the loss calculation, preventing them from " + "being incorrectly penalized and introducing noise during training. According to the DAPO paper, this is " + "a good practice for training stability." + }, + ) + sync_ref_model: bool = field( + default=False, + metadata={ + "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` " + "steps, using the `ref_model_mixup_alpha` parameter." + }, + ) + ref_model_mixup_alpha: float = field( + default=0.6, + metadata={ + "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the " + "previous reference policy during updates. The reference policy is updated according to the equation: " + "`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`." + }, + ) + ref_model_sync_steps: int = field( + default=512, + metadata={ + "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is " + "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`." + }, + ) + + # Parameters that control the logging + log_completions: bool = field( + default=False, + metadata={ + "help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is " + "installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`." + }, + ) + num_completions_to_print: Optional[int] = field( + default=None, + metadata={"help": "Number of completions to print with `rich`. If `None`, all completions are logged."}, + ) + + # Deprecated parameters + vllm_device: Optional[str] = field( + default=None, + metadata={ + "help": "This parameter is deprecated and will be removed in version 0.18.0. To use vLLM, start a vLLM " + "server with the `trl vllm-serve` command." + }, + ) + vllm_gpu_memory_utilization: Optional[float] = field( + default=None, + metadata={ + "help": "This parameter is deprecated and will be removed in version 0.18.0. To control the GPU memory " + "utilization for vLLM, you should now use the `gpu_memory_utilization` parameter in the vLLM server " + "configuration." + }, + ) + vllm_dtype: Optional[str] = field( + default=None, + metadata={ + "help": "This parameter is deprecated and will be removed in version 0.18.0. To control the data type for " + "vLLM generation, you should now use the `dtype` parameter in the vLLM server configuration." + }, + ) + vllm_max_model_len: Optional[int] = field( + default=None, + metadata={ + "help": "This parameter is deprecated and will be removed in version 0.18.0. To control the " + "`max_model_len` for vLLM, you should now use the `max_model_len` parameter in the vLLM server " + "configuration." + }, + ) + vllm_enable_prefix_caching: Optional[bool] = field( + default=None, + metadata={ + "help": "This parameter is deprecated and will be removed in version 0.18.0. To control prefix caching in " + "vLLM, you should now use the `enable_prefix_caching` parameter in the vLLM server configuration." + }, + ) + chat_template: Optional[str] = field(default=None, metadata={"help": "chat_template"}) + + def __post_init__(self): + super().__post_init__() + + if self.vllm_device is not None: + warnings.warn( + "`vllm_device` is deprecated and will be removed in version 0.18.0. To use vLLM, start a vLLM server " + "with the `trl vllm-serve` command.", + DeprecationWarning, + ) + + if self.vllm_gpu_memory_utilization is not None: + warnings.warn( + "`vllm_gpu_memory_utilization` is deprecated and will be removed in v0.18. To control the GPU memory " + "utilization for vLLM, you should now use the `gpu_memory_utilization` parameter in the vLLM server " + "configuration.", + DeprecationWarning, + ) + + if self.vllm_dtype is not None: + warnings.warn( + "`vllm_dtype` is deprecated and will be removed in version 0.18.0. To control the data type for vLLM " + "generation, you should now use the `dtype` parameter in the vLLM server configuration.", + DeprecationWarning, + ) + + if self.vllm_max_model_len is not None: + warnings.warn( + "`vllm_max_model_len` is deprecated and will be removed in version 0.18.0. To control the " + "`max_model_len` for vLLM, you should now use the `max_model_len` parameter in the vLLM server " + "configuration.", + DeprecationWarning, + ) + + if self.vllm_enable_prefix_caching is not None: + warnings.warn( + "`vllm_enable_prefix_caching` is deprecated and will be removed in version 0.18.0. To control prefix " + "caching in vLLM, you should now use the `enable_prefix_caching` parameter in the vLLM server " + "configuration.", + DeprecationWarning, + ) diff --git a/optimum/habana/trl/trainer/grpo_trainer.py b/optimum/habana/trl/trainer/grpo_trainer.py new file mode 100644 index 0000000000..5bdb460b6f --- /dev/null +++ b/optimum/habana/trl/trainer/grpo_trainer.py @@ -0,0 +1,835 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import bisect +import copy +import warnings +from collections import defaultdict, deque +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import pandas as pd +import torch +import torch.utils.data +from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed +from datasets import Dataset, IterableDataset +from torch import nn +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizerBase, + Trainer, + TrainerCallback, + is_wandb_available, +) +from transformers.integrations.deepspeed import is_deepspeed_available, is_deepspeed_zero3_enabled +from transformers.utils import is_datasets_available, is_peft_available +from trl import GRPOTrainer +from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template +from trl.extras.profiling import profiling_context, profiling_decorator +from trl.extras.vllm_client import VLLMClient +from trl.import_utils import is_rich_available, is_vllm_available +from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation +from trl.trainer.callbacks import SyncRefModelCallback +from trl.trainer.utils import ( + pad, + print_prompt_completions_sample, + selective_log_softmax, +) + +from optimum.utils import logging + +from ... import GaudiConfig, GaudiTrainer +from ...transformers import trainer as habana_trainer +from ...transformers.trainer import _get_input_update_settings +from .grpo_config import GaudiGRPOConfig + + +logger = logging.get_logger(__name__) + + +if is_deepspeed_available(): + pass + +if is_peft_available(): + from peft import PeftConfig, get_peft_model + +if is_datasets_available(): + pass + +if is_wandb_available(): + import wandb + +# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of +# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. +RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] + + +def grpo_get_input_update_settings(model, lazy_mode: Optional[bool] = None) -> Tuple[bool, Dict]: + # For GRPOTrainer, skip input update in the _inner_training_loop() + # because it expects a dict type input, but the GRPOTrainer input is a list of dict. + # Instead, the update is done in _get_per_token_logps() + return False, {} + + +class GaudiGRPOTrainer(GRPOTrainer, GaudiTrainer): + _tag_names = ["trl", "grpo"] + + def __init__( + self, + model: Union[str, PreTrainedModel], + reward_funcs: Union[RewardFunc, list[RewardFunc]], + args: Optional[GaudiGRPOConfig] = None, + gaudi_config: GaudiConfig = None, + train_dataset: Optional[Union[Dataset, IterableDataset]] = None, + eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, + processing_class: Optional[PreTrainedTokenizerBase] = None, + reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None, + callbacks: Optional[list[TrainerCallback]] = None, + optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), + peft_config: Optional["PeftConfig"] = None, + ): + """ + Copied from GRPOTrainer.__init__: https://github.com/huggingface/trl/blob/v0.17.0/trl/trainer/grpo_trainer.py#L264 + The only differences are: + - Add new args gaudi_config + - Use GaudiTrainer instead of Trainer + - Add bucketing to reduce dynamic input shape + - Toggle of use_cache and gradient_checkpointing for the rollout performance with gradient_checkpointing + """ + # Args + if args is None: + model_name = model if isinstance(model, str) else model.config._name_or_path + model_name = model_name.split("/")[-1] + args = GaudiGRPOConfig(f"{model_name}-GRPO") + self.args = args + + # Models + # Trained model + model_init_kwargs = args.model_init_kwargs or {} + if isinstance(model, str): + model_id = model + torch_dtype = model_init_kwargs.get("torch_dtype") + if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: + pass # torch_dtype is already a torch.dtype or "auto" or None + elif isinstance(torch_dtype, str): # it's a str, but not "auto" + torch_dtype = getattr(torch, torch_dtype) + model_init_kwargs["torch_dtype"] = torch_dtype + else: + raise ValueError( + "Invalid `torch_dtype` passed to `GaudiGRPOConfig`. Expected either 'auto' or a string representing " + f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." + ) + + # Disable caching if gradient checkpointing is enabled (not supported) + model_init_kwargs["use_cache"] = ( + False if args.gradient_checkpointing else model_init_kwargs.get("use_cache") + ) + model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) + else: + model_id = model.config._name_or_path + if args.model_init_kwargs is not None: + raise ValueError( + "You passed `model_init_kwargs` to the `GaudiGRPOConfig`, but your model is already instantiated. " + "This argument can only be used when the `model` argument is a string." + ) + + if peft_config is not None: + if not is_peft_available(): + raise ImportError("PEFT is required to use `peft_config`. Run `pip install peft`.") + model = get_peft_model(model, peft_config) + + # Enable gradient checkpointing if requested + if args.gradient_checkpointing: + model = self._enable_gradient_checkpointing(model, args) + + # Reference model + self.beta = args.beta + if self.beta == 0.0: + # If beta is 0.0, the reference model is not needed + self.ref_model = None + elif is_deepspeed_zero3_enabled(): + self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) + elif is_peft_model(model): + # If PEFT is used, the reference model is not needed since the adapter can be disabled + # to revert to the initial model. + self.ref_model = None + else: + # If PEFT configuration is not provided, create a reference model based on the initial model. + self.ref_model = create_reference_model(model) + + # Processing class + if processing_class is None: + processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left") + if processing_class.pad_token is None: + processing_class.pad_token = processing_class.eos_token + + # Reward functions + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + self.reward_func_names = [] + for i, reward_func in enumerate(reward_funcs): + if isinstance(reward_func, str): + reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( + reward_func, num_labels=1, **model_init_kwargs + ) + if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models + self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1]) + else: + self.reward_func_names.append(reward_funcs[i].__name__) + self.reward_funcs = reward_funcs + + # Reward weights + if args.reward_weights is not None: + if len(args.reward_weights) != len(reward_funcs): + raise ValueError( + f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " + f"functions ({len(reward_funcs)})" + ) + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) + else: + self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) + + # Reward processing class + if reward_processing_classes is None: + reward_processing_classes = [None] * len(reward_funcs) + elif not isinstance(reward_processing_classes, list): + reward_processing_classes = [reward_processing_classes] + else: + if len(reward_processing_classes) != len(reward_funcs): + raise ValueError("The number of reward processing classes must match the number of reward functions.") + + for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)): + if isinstance(reward_func, PreTrainedModel): + if reward_processing_class is None: + reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path) + if reward_processing_class.pad_token_id is None: + reward_processing_class.pad_token = reward_processing_class.eos_token + # The reward model computes the reward for the latest non-padded token in the input sequence. + # So it's important to set the pad token ID to the padding token ID of the processing class. + reward_func.config.pad_token_id = reward_processing_class.pad_token_id + reward_processing_classes[i] = reward_processing_class + self.reward_processing_classes = reward_processing_classes + + def data_collator(features): + return features + + # Training arguments + self.max_prompt_length = args.max_prompt_length + self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper + self.num_generations = args.num_generations # = G in the GRPO paper + self.temperature = args.temperature + self.top_p = args.top_p + self.top_k = args.top_k + self.min_p = args.min_p + self.repetition_penalty = args.repetition_penalty + self.use_vllm = args.use_vllm + + self.scale_rewards = args.scale_rewards + self.mask_truncated_completions = args.mask_truncated_completions + + self.buckets = self._get_buckets(train_dataset, processing_class) + + self.shuffle_dataset = args.shuffle_dataset + + if ( + isinstance(train_dataset, IterableDataset) + or isinstance(eval_dataset, IterableDataset) + or ( + isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values()) + ) + ): + # See https://github.com/huggingface/trl/issues/3213 + raise NotImplementedError( + "Iterable datasets are not yet supported in GRPOTrainer. Please use a standard dataset instead." + ) + + # Multi-step + self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon + # Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle + self._step = 0 + # Buffer the batch to reuse generated outputs across multiple updates. For more details, see + # `_get_train_sampler` and `_prepare_inputs`. + self._buffered_inputs = None + + # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the + # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the + # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning: + # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To + # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True. + # This acts as a flag to indicate that the warning has already been issued. + model.warnings_issued["estimate_tokens"] = True + + GaudiTrainer.__init__( + self, + model=model, + args=args, + gaudi_config=gaudi_config, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + processing_class=processing_class, + callbacks=callbacks, + optimizers=optimizers, + ) + habana_trainer._get_input_update_settings = grpo_get_input_update_settings + + # Initialize the metrics + self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} + self._total_train_tokens = 0 + self.log_completions = args.log_completions + self.num_completions_to_print = args.num_completions_to_print + # maxlen is set to the total number of forward passes per step. This value of `maxlen` ensures we log only the + # final optimization step. + maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.gradient_accumulation_steps + self._textual_logs = { + "prompt": deque(maxlen=maxlen), + "completion": deque(maxlen=maxlen), + "rewards": defaultdict(lambda: deque(maxlen=maxlen)), + } + # Check if the effective batch size can be divided by the number of generations + if self.num_generations < 2: + raise ValueError( + "GRPO requires at least 2 generations per prompt to calculate the advantages. You provided " + f"{self.num_generations}, which is less than the minimum required." + ) + + num_processes = self.accelerator.num_processes + global_batch_size = args.per_device_train_batch_size * num_processes * args.gradient_accumulation_steps + possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0] + if self.num_generations not in possible_values: + raise ValueError( + f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly " + f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train " + f"batch size, the valid values for the number of generations are: {possible_values}." + ) + if self.args.eval_strategy != "no": + global_batch_size = args.per_device_eval_batch_size * num_processes + possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0] + if self.num_generations not in possible_values: + raise ValueError( + f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly " + f"divisible by the number of generations per prompt ({self.num_generations}). Given the current " + f"eval batch size, the valid values for the number of generations are: {possible_values}." + ) + + # Ensure each process receives a unique seed to prevent duplicate completions when generating with + # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but + # it's safer to set it in all cases. + set_seed(args.seed, device_specific=True) + + if self.use_vllm: + if not is_vllm_available(): + raise ImportError( + "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " + "`pip install vllm` to use it." + ) + + if self.accelerator.is_main_process: + self.vllm_client = VLLMClient( + args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout + ) + self.vllm_client.init_communicator() + + # vLLM specific sampling arguments + self.guided_decoding_regex = args.vllm_guided_decoding_regex + + self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation + + # When using vLLM, the main process is responsible for loading the model weights. This can cause process + # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we + # synchronize all processes after vLLM has been fully initialized. + self.accelerator.wait_for_everyone() + else: + self.generation_config = copy.deepcopy(model.generation_config) + self.generation_config.max_new_tokens = self.max_completion_length + self.generation_config.do_sample = True + self.generation_config.pad_token_id = processing_class.pad_token_id + self.generation_config.bos_token_id = processing_class.bos_token_id + self.generation_config.eos_token_id = processing_class.eos_token_id + self.generation_config.temperature = self.temperature + self.generation_config.top_p = self.top_p + self.generation_config.top_k = self.top_k + self.generation_config.min_p = self.min_p + self.generation_config.repetition_penalty = self.repetition_penalty + self.generation_config.cache_implementation = args.cache_implementation + self.generation_config.use_cache = True + self.generation_config.static_shapes = True + self.generation_config.reuse_cache = True + + # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the + # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set + # self.model_accepts_loss_kwargs to False to enable scaling. + self.model_accepts_loss_kwargs = False + + # Add tags to the model + self.model.add_model_tags(self._tag_names) + + if self.ref_model is not None: + if self.is_deepspeed_enabled: + self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) + else: + self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) + + if args.sync_ref_model: + self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) + + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, PreTrainedModel): + if self.is_deepspeed_enabled: + self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator) + else: + self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) + + def _get_buckets(self, train_dataset, tokenizer, num_buckets=5): + # Collect all seq lens + sentence_lengths = [] + for batch in train_dataset: + formatted_prompt = maybe_apply_chat_template(batch, tokenizer)["prompt"] + formatted_prompt_len = len(tokenizer(formatted_prompt)["input_ids"]) + sentence_lengths.append(formatted_prompt_len) + + # Assign bucket labels to each sentence + bucket_label_per_sentence = pd.qcut(sentence_lengths, q=num_buckets, labels=False, duplicates="drop") + + # Get max len per bucket + df = pd.DataFrame({"value": sentence_lengths, "bucket": bucket_label_per_sentence}) + buckets = df.groupby("bucket")["value"].max().tolist() + # Make sure that no bucket exceeds self.max_prompt_length + buckets = [min(b, self.max_prompt_length) for b in buckets] + return buckets + + # Get the per-token log probabilities for the completions for the model and the reference model + @profiling_decorator + def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "logits_to_keep": logits_to_keep + 1, + } + + if hasattr(model, "module"): + # For distributed + should_update_inputs, input_updates = _get_input_update_settings(model.module) + inputs.update(input_updates) + else: + # For non distributed + should_update_inputs, input_updates = _get_input_update_settings(model) + inputs.update(input_updates) + + logits = model(**inputs).logits + + logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + + input_ids = input_ids[:, -logits_to_keep:] + # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. + # See https://github.com/huggingface/trl/issues/2770 + logits = logits[:, -logits_to_keep:] + # Divide logits by sampling temperature. + # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details + logits = logits / self.temperature + + return selective_log_softmax(logits, input_ids) + + def _generate_and_score_completions( + self, inputs: dict[str, Union[torch.Tensor, Any]] + ) -> dict[str, Union[torch.Tensor, Any]]: + device = self.accelerator.device + + prompts = [x["prompt"] for x in inputs] + prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] + + # Get unique seq len within a batch + max_prompt_len_per_batch = 0 + for prompt_idx in range( + 0, len(prompts_text), self.num_generations + ): # Prompts are repeated self.num_generations times + prompt_len = len( + self.processing_class( + text=prompts_text[prompt_idx], return_tensors="pt", padding=False, add_special_tokens=False + )["input_ids"][0] + ) + max_prompt_len_per_batch = max(max_prompt_len_per_batch, prompt_len) + + # Search bucket and the tokenize prompts with padding + bucket_indices = bisect.bisect_left(self.buckets, max_prompt_len_per_batch) + bucket_indices = min(bucket_indices, len(self.buckets) - 1) + prompt_inputs = self.processing_class( + text=prompts_text, + return_tensors="pt", + padding="max_length", + padding_side="left", + max_length=self.buckets[bucket_indices], + truncation=True, + add_special_tokens=False, + ) + + prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs) + prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] + + if self.max_prompt_length is not None: + prompt_ids = prompt_ids[:, -self.max_prompt_length :] + prompt_mask = prompt_mask[:, -self.max_prompt_length :] + + # Generate completions using either vLLM or regular generation + if self.args.use_vllm: + # First, have main process load weights if needed + if self.state.global_step != self._last_loaded_step: + self._move_model_to_vllm() + self._last_loaded_step = self.state.global_step + + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + all_prompts_text = gather_object(prompts_text) + if self.accelerator.is_main_process: + # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate + # num_generations outputs for each one. This is faster than generating outputs for each duplicate + # prompt individually. + ordered_set_of_prompts = all_prompts_text[:: self.num_generations] + with profiling_context(self, "vLLM.generate"): + completion_ids = self.vllm_client.generate( + prompts=ordered_set_of_prompts, + n=self.num_generations, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.max_completion_length, + guided_decoding_regex=self.guided_decoding_regex, + ) + else: + completion_ids = [None] * len(all_prompts_text) + # Broadcast the completions from the main process to all processes, ensuring each process receives its + # corresponding slice. + completion_ids = broadcast_object_list(completion_ids, from_process=0) + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + completion_ids = completion_ids[process_slice] + + # Pad the completions, and concatenate them with the prompts + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + else: + with unwrap_model_for_generation( + self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model: + if self.args.gradient_checkpointing: + unwrapped_model.gradient_checkpointing_disable() + unwrapped_model.config.use_cache = True + + unwrapped_model.eval() + + with torch.no_grad(): + prompt_completion_ids = unwrapped_model.generate( + prompt_ids, + attention_mask=prompt_mask, + hpu_graphs=True, + generation_config=self.generation_config, + lazy_mode=True, + ) + + unwrapped_model.train() + + # KV cache is not used during training. Delete KV cache to save memory. + if is_peft_model(unwrapped_model): + for layer in unwrapped_model.base_model.model.model.layers: + layer.self_attn.k_cache.cache = None + layer.self_attn.v_cache.cache = None + else: + for layer in unwrapped_model.model.layers: + layer.self_attn.k_cache.cache = None + layer.self_attn.v_cache.cache = None + + # Compute prompt length and extract completion ids + prompt_length = prompt_ids.size(1) + prompt_ids = prompt_completion_ids[:, :prompt_length] + completion_ids = prompt_completion_ids[:, prompt_length:] + + # Mask everything after the first EOS token + is_eos = completion_ids == self.processing_class.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask + if self.mask_truncated_completions: + truncated_completions = ~is_eos.any(dim=1) + completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() + + # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need + # to re-tokenize completions if the reward is computed from tokens. + completion_ids_list = [ + [id.item() for id, m in zip(row, mask_row) if m] for row, mask_row in zip(completion_ids, completion_mask) + ] + + # Concatenate prompt_mask with completion_mask for logit computation + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + + with torch.no_grad(): + # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's + # computation here, and use per_token_logps.detach() instead. + if self.num_iterations > 1: + old_per_token_logps = self._get_per_token_logps( + self.model, prompt_completion_ids, attention_mask, logits_to_keep + ) + else: + old_per_token_logps = None + + if self.beta == 0.0: + ref_per_token_logps = None + elif self.ref_model is not None: + ref_per_token_logps = self._get_per_token_logps( + self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep + ) + else: + with self.accelerator.unwrap_model(self.model).disable_adapter(): + ref_per_token_logps = self._get_per_token_logps( + self.model, prompt_completion_ids, attention_mask, logits_to_keep + ) + + # Decode the generated completions + completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + + if is_conversational(inputs[0]): + completions = [] + for prompt, completion in zip(prompts, completions_text): + bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" + completions.append([{"role": "assistant", "content": bootstrap + completion}]) + else: + completions = completions_text + + rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) + for i, (reward_func, reward_processing_class) in enumerate( + zip(self.reward_funcs, self.reward_processing_classes) + ): + if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models + reward_func_name = f"reward {reward_func.config._name_or_path.split('/')[-1]}" + else: + reward_func_name = reward_func.__name__ + with profiling_context(self, reward_func_name): + if isinstance( + reward_func, nn.Module + ): # Module instead of PretrainedModel for compat with compiled models + if is_conversational(inputs[0]): + messages = [{"messages": p + c} for p, c in zip(prompts, completions)] + texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] + else: + texts = [p + c for p, c in zip(prompts, completions)] + reward_inputs = reward_processing_class( + text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False + ) + reward_inputs = Trainer._prepare_inputs(self, reward_inputs) + with torch.inference_mode(): + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) + else: + # Repeat all input columns (but "prompt" and "completion") to match the number of generations + keys = [ + key + for key in inputs[0] + if key + not in [ + "prompt", + "completion", + "completion_ids", + "use_flash_attention", + "flash_attention_fast_softmax", + "lazy_mode", + ] + ] + + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} + output_reward_func = reward_func( + prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs + ) + # Convert None values to NaN + output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] + + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + # If all reward functions return None for a given row, issue a detailed warning + if torch.isnan(rewards_per_func).all(dim=1).any(): + nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] + row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()} + row_reward_kwargs["prompt"] = prompts[nan_row_idx] + row_reward_kwargs["completion"] = completions[nan_row_idx] + warnings.warn( + f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. " + "Please ensure that at least one reward function returns a valid reward." + ) + + # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the + # completions may be distributed across processes + rewards_per_func = gather(rewards_per_func) + + # Apply weights to each reward function's output and sum + rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) + + # Compute grouped-wise rewards + mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) + std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) + + # Normalize the rewards to compute the advantages + mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) + advantages = rewards - mean_grouped_rewards + + if self.args.scale_rewards: + advantages = advantages / (std_grouped_rewards + 1e-4) + + # Slice to keep only the local part of the data + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + advantages = advantages[process_slice] + + # Log the metrics + mode = "eval" if self.control.should_evaluate else "train" + + if mode == "train": + self._total_train_tokens += self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item() + self._metrics[mode]["num_tokens"] = [self._total_train_tokens] + + completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() + self._metrics[mode]["completion_length"].append(completion_length) + + # Calculate mean reward per function, but only for samples where the function was applied + for i, reward_func in enumerate(self.reward_funcs): + if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models + reward_func_name = reward_func.config._name_or_path.split("/")[-1] + else: + reward_func_name = reward_func.__name__ + # Only calculate mean for samples where this reward function was applied (non-NaN values) + mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() + self._metrics[mode][f"rewards/{reward_func_name}"].append(mean_rewards) + self._metrics[mode]["reward"].append(rewards.mean().item()) + self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item()) + + if self.log_completions and self.state.global_step % self.args.logging_steps == 0: + prompts_to_log = gather_object(prompts_text) + completions_to_log = gather_object(completions_text) + rewards_to_log = rewards.tolist() + + if self.accelerator.is_main_process: + if is_rich_available(): + print_prompt_completions_sample( + prompts_to_log, + completions_to_log, + rewards_to_log, + self.state.global_step, + ) + if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: + import pandas as pd + + # For logging + table = { + "step": [str(self.state.global_step)] * len(rewards), + "prompt": prompts_to_log, + "completion": completions_to_log, + "reward": rewards.tolist(), + } + df = pd.DataFrame(table) + wandb.log({"completions": wandb.Table(dataframe=df)}) + + return { + "prompt_ids": prompt_ids, + "prompt_mask": prompt_mask, + "completion_ids": completion_ids, + "completion_mask": completion_mask, + "old_per_token_logps": old_per_token_logps, + "ref_per_token_logps": ref_per_token_logps, + "advantages": advantages, + } + + @profiling_decorator + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + if return_outputs: + raise ValueError("The GRPOTrainer does not support returning outputs") + + if self.args.gradient_checkpointing: + if hasattr(model, "module"): + # Distributed + model.module.config.use_cache = False + if is_peft_model(model.module): + model.module.base_model.gradient_checkpointing_enable() + model.module.base_model.enable_input_require_grads() + else: + model.module.gradient_checkpointing_enable() + model.module.enable_input_require_grads() + + else: + # Single card + model.config.use_cache = False + if is_peft_model(model): + model.base_model.gradient_checkpointing_enable() + model.base_model.enable_input_require_grads() + else: + # Enable gradient checkpointing for non-PEFT models + model.gradient_checkpointing_enable() + model.enable_input_require_grads() + + # Compute the per-token log probabilities for the model + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + + per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + + # Compute the KL divergence between the model and the reference model + if self.beta != 0.0: + ref_per_token_logps = inputs["ref_per_token_logps"] + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 + ) + + # Compute the loss + advantages = inputs["advantages"] + # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's computation (see + # _generate_and_score_completions) and use per_token_logps.detach() instead. + + old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach() + coef_1 = torch.exp(per_token_logps - old_per_token_logps) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + + loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() + + # Log the metrics + mode = "eval" if self.control.should_evaluate else "train" + + if self.beta != 0.0: + mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum() + self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) + + is_clipped = (per_token_loss1 < per_token_loss2).float() + clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum() + self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item()) + return loss diff --git a/optimum/habana/trl/trainer/sft_trainer.py b/optimum/habana/trl/trainer/sft_trainer.py index 6fb6365655..281b5015f7 100644 --- a/optimum/habana/trl/trainer/sft_trainer.py +++ b/optimum/habana/trl/trainer/sft_trainer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import dataclasses +import importlib.metadata import inspect import warnings from collections.abc import Mapping @@ -23,6 +24,7 @@ import torch.nn as nn from accelerate import PartialState from datasets import Dataset +from packaging import version from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -36,20 +38,27 @@ from transformers.trainer_utils import EvalPrediction from trl import SFTTrainer from trl.extras.dataset_formatting import get_formatting_func_from_dataset -from trl.import_utils import is_peft_available -from trl.trainer.utils import ( - ConstantLengthDataset, - DataCollatorForCompletionOnlyLM, - RichProgressCallback, -) + +from ... import GaudiConfig, GaudiTrainer +from .sft_config import GaudiSFTConfig +trl_version = importlib.metadata.version("trl") +if version.parse(trl_version) < version.parse("0.17.0"): + from trl.import_utils import is_peft_available + from trl.trainer.utils import ( + ConstantLengthDataset, + DataCollatorForCompletionOnlyLM, + RichProgressCallback, + ) +else: + from transformers.utils import is_peft_available + from trl.trainer.callbacks import RichProgressCallback + from trl.trainer.utils import ConstantLengthDataset, DataCollatorForCompletionOnlyLM + if is_peft_available(): from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training -from ... import GaudiConfig, GaudiTrainer -from .sft_config import GaudiSFTConfig - class BucketedDataCollatorForLanguageModeling(DataCollatorForLanguageModeling): def _get_bucketed_len(self, examples): diff --git a/tests/ci/slow_tests_trl.sh b/tests/ci/slow_tests_trl.sh index 90a81ec892..d64d001833 100644 --- a/tests/ci/slow_tests_trl.sh +++ b/tests/ci/slow_tests_trl.sh @@ -2,4 +2,4 @@ python -m pip install --upgrade pip export RUN_SLOW=true -make slow_tests_trl +make slow_tests_trl_ddpo && make slow_tests_trl_grpo diff --git a/tests/test_trl.py b/tests/test_trl.py index ebb64edf73..f723361532 100644 --- a/tests/test_trl.py +++ b/tests/test_trl.py @@ -13,14 +13,35 @@ # limitations under the License. import gc +import importlib.metadata +import tempfile import unittest import torch -from transformers.testing_utils import slow +from datasets import load_dataset +from packaging import version +from parameterized import parameterized +from transformers.testing_utils import require_peft, slow +from transformers.utils import is_peft_available from trl import DDPOConfig from optimum.habana import GaudiConfig -from optimum.habana.trl import GaudiDDPOTrainer, GaudiDefaultDDPOStableDiffusionPipeline + + +trl_version = importlib.metadata.version("trl") +if version.parse(trl_version) < version.parse("0.17.0"): + from optimum.habana.trl import ( + GaudiDDPOTrainer, + GaudiDefaultDDPOStableDiffusionPipeline, + ) +else: + from optimum.habana.trl import ( + GaudiGRPOConfig, + GaudiGRPOTrainer, + ) + +if is_peft_available(): + from peft import LoraConfig, PeftModel def scorer_function(images, prompts, metadata): @@ -154,3 +175,156 @@ def setUp(self): ) return super().setUp() + + +class GaudiGRPOTrainerTester(unittest.TestCase): + """ + Test the GaudiGRPOTrainer class. + + Adapted from https://github.com/huggingface/trl/blob/main/tests/test_grpo_trainer.py#L216 + The main changes are: + - use GaudiGRPOConfig and GaudiGRPOTrainer instead of GRPOConfig and GRPOTrainer + - add GaudiConfig + """ + + def test_init_minimal(self): + # Test that GRPOTrainer can be instantiated with only model, reward_model and train_dataset + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GaudiGRPOConfig( + use_habana=True, + use_lazy_mode=True, + ) + gaudi_config = GaudiConfig() + + GaudiGRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + gaudi_config=gaudi_config, + ) + + @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + def test_training(self, config_name): + dataset = load_dataset("trl-internal-testing/zen", config_name, split="train") + + gaudi_config = GaudiConfig() + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GaudiGRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + use_habana=True, + use_lazy_mode=True, + ) + trainer = GaudiGRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + gaudi_config=gaudi_config, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + @require_peft + def test_training_peft(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + gaudi_config = GaudiConfig() + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GaudiGRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + report_to="none", + use_habana=True, + use_lazy_mode=True, + ) + trainer = GaudiGRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + peft_config=LoraConfig(), + gaudi_config=gaudi_config, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that the peft params have changed and the base model params have not changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if "lora" in n.lower(): # We expect the lora params to be different + self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") + else: # We expect the rest of params to be the same + self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") + + @require_peft + def test_training_peft_with_gradient_checkpointing(self): + """Test that training works with PEFT and gradient checkpointing enabled.""" + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + gaudi_config = GaudiConfig() + + lora_config = LoraConfig( + r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none" + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = GaudiGRPOConfig( + output_dir=tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=8, + gradient_checkpointing=True, # Enable gradient checkpointing + report_to="none", + use_habana=True, + use_lazy_mode=True, + ) + trainer = GaudiGRPOTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + peft_config=lora_config, + gaudi_config=gaudi_config, + ) + + # Verify gradient checkpointing is enabled + self.assertIsInstance(trainer.model, PeftModel) + + # Store initial parameters to check which ones change + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + # Check that only LoRA parameters have changed, base model parameters remain unchanged + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + if "lora" in n.lower(): # LoRA parameters should change + self.assertFalse(torch.equal(param, new_param), f"LoRA parameter {n} has not changed.") + else: # Base model parameters should not change + self.assertTrue(torch.equal(param, new_param), f"Base parameter {n} has changed.")