diff --git a/Makefile b/Makefile
index e93f23296a..da051959ef 100644
--- a/Makefile
+++ b/Makefile
@@ -159,11 +159,15 @@ slow_tests_video_llava_example: test_installs
slow_tests_fsdp: test_installs
python -m pytest tests/test_fsdp_examples.py -v -s --token $(TOKEN)
-slow_tests_trl: test_installs
+slow_tests_trl_ddpo: test_installs
python -m pip install trl==0.9.6
python -m pip install peft==0.12.0
python -m pytest tests/test_trl.py -v -s -k "test_calculate_loss"
+slow_tests_trl_grpo: test_installs
+ python -m pip install -r examples/trl/requirements_grpo.txt
+ python -m pytest tests/test_trl.py -v -s -k "GaudiGRPOTrainerTester"
+
slow_tests_object_segmentation: test_installs
python -m pytest tests/test_object_segmentation.py
diff --git a/examples/trl/README.md b/examples/trl/README.md
index 286b51a8bf..bce2f83b78 100644
--- a/examples/trl/README.md
+++ b/examples/trl/README.md
@@ -4,10 +4,74 @@
## Requirements
First, you should install the requirements:
+
+- For **GRPO example**:
+```bash
+$ pip install -U -r requirements_grpo.txt
+```
+
+- For **all other examples**:
```bash
$ pip install -U -r requirements.txt
```
+## GRPO Training
+
+Installing DeepSpeed
+
+```sh
+pip install git+https://github.com/HabanaAI/DeepSpeed.git@1.21.0
+```
+
+Running single card training
+
+```sh
+PT_HPU_MAX_COMPOUND_OP_SIZE=10 PT_HPU_LAZY_MODE=1 python3 grpo.py \
+ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
+ --dataset_name AI-MO/NuminaMath-TIR \
+ --per_device_train_batch_size 8 \
+ --per_device_eval_batch_size 8 \
+ --do_train \
+ --do_eval \
+ --use_habana \
+ --use_lazy_mode \
+ --bf16 True \
+ --gradient_accumulation_steps=16 \
+ --max_prompt_length 512 \
+ --num_generations 4 \
+ --max_completion_length 64 \
+ --use_peft True \
+ --lora_target_modules q_proj k_proj \
+ --num_train_epochs 1 \
+ --save_strategy="epoch"
+```
+
+
+Runnig multi-card training
+
+```sh
+PT_HPU_MAX_COMPOUND_OP_SIZE=10 PT_HPU_LAZY_MODE=1 python3 ../gaudi_spawn.py --world_size 8 --use_deepspeed grpo.py \
+ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
+ --dataset_name AI-MO/NuminaMath-TIR \
+ --per_device_train_batch_size 8 \
+ --per_device_eval_batch_size 8 \
+ --do_train \
+ --do_eval \
+ --use_habana \
+ --use_lazy_mode \
+ --bf16 True \
+ --gradient_accumulation_steps=16 \
+ --gradient_checkpointing \
+ --max_prompt_length 512 \
+ --num_generations 4 \
+ --max_completion_length 64 \
+ --use_peft True \
+ --lora_target_modules q_proj k_proj \
+ --max_steps=500 \
+ --logging_steps=10 \
+ --save_steps=100
+```
+
## Supervised Finetuning
1. The following example is for the supervised Lora finetune with Qwen2 model for conversational format dataset.
diff --git a/examples/trl/grpo.py b/examples/trl/grpo.py
new file mode 100644
index 0000000000..89124c424f
--- /dev/null
+++ b/examples/trl/grpo.py
@@ -0,0 +1,210 @@
+import contextlib
+import io
+import logging
+import re
+from dataclasses import dataclass, field
+from typing import List, Optional
+
+import torch
+import transformers
+from datasets import load_dataset
+from math_verify import LatexExtractionConfig, parse, verify
+from peft import LoraConfig
+from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser
+from transformers.integrations.deepspeed import (
+ is_deepspeed_available,
+)
+from transformers.trainer_utils import is_main_process
+
+from optimum.habana import GaudiConfig
+from optimum.habana.trl import GaudiGRPOConfig, GaudiGRPOTrainer
+from optimum.habana.utils import set_seed
+
+
+logger = logging.getLogger(__name__)
+SYSTEM_PROMPT = (
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
+ "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
+ "process and answer are enclosed within and tags, respectively, i.e., "
+ " reasoning process here answer here "
+)
+
+
+def make_conversation(example):
+ return {
+ "prompt": [
+ {"role": "system", "content": SYSTEM_PROMPT},
+ {"role": "user", "content": example["problem"]},
+ ],
+ }
+
+
+ideal_length = 50
+
+
+def reward_len(completions, **kwargs):
+ return [-abs(ideal_length - len(completion)) for completion in completions] # penalize response when len!=50
+
+
+def format_reward(completions, **kwargs):
+ # Checks if the reasoning process is enclosed within and tags,
+ # while the final answer is enclosed within and tags.
+ pattern = r"^.*?\s*.*?$"
+ completion_contents = [completion[0]["content"] for completion in completions]
+ matches = [re.match(pattern, content) for content in completion_contents]
+ return [1.0 if match else 0.0 for match in matches]
+
+
+def accuracy_reward(completions, **kwargs):
+ # Checks if the completion is the same as the ground truth.
+ solutions = kwargs["solution"]
+ completion_contents = [completion[0]["content"] for completion in completions]
+ rewards = []
+ for content, solution in zip(completion_contents, solutions):
+ gold_parsed = parse(solution, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
+ answer_parsed = parse(content, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
+ if len(gold_parsed) != 0:
+ try:
+ with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
+ rewards.append(float(verify(answer_parsed, gold_parsed)))
+ except Exception:
+ rewards.append(0.0)
+ else:
+ rewards.append(1.0)
+ return rewards
+
+
+@dataclass
+class ScriptArguments:
+ model_name_or_path: Optional[str] = field(default="Qwen/Qwen2-0.5B-Instruct", metadata={"help": "the model name"})
+ dataset_name: Optional[str] = field(default=None, metadata={"help": "the dataset name"})
+ use_peft: Optional[bool] = field(default=False, metadata={"help": "whether to use peft"})
+ num_workers: Optional[int] = field(default=1, metadata={"help": "the number of workers"})
+ subset: Optional[str] = field(default=None, metadata={"help": "the subset to use"})
+ streaming: Optional[bool] = field(default=False, metadata={"help": "whether to stream the dataset"})
+ dataset_train_split: str = field(default="train[:5%]", metadata={"help": "Dataset split to use for training."})
+ dataset_test_split: str = field(default="test[:5%]", metadata={"help": "Dataset split to use for evaluation."})
+ reward_model_name_or_path: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "Reward model id of a pretrained model hosted inside a model repo on huggingface.co or "
+ "local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`."
+ },
+ )
+
+ use_flash_attention: Optional[bool] = field(
+ default=True, metadata={"help": "Whether to use Habana flash attention for fine-tuning."}
+ )
+ flash_attention_recompute: Optional[bool] = field(
+ default=False, metadata={"help": "Whether to enable recompute in Habana flash attention for fine-tuning."}
+ )
+ flash_attention_causal_mask: Optional[bool] = field(
+ default=False, metadata={"help": "Whether to enable causal mask in Habana flash attention for fine-tuning."}
+ )
+
+ # LoraConfig
+ lora_alpha: Optional[float] = field(default=32, metadata={"help": "the lora alpha parameter"})
+ lora_dropout: Optional[float] = field(default=0.1, metadata={"help": "the lora dropout parameter"})
+ lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
+ lora_target_modules: List[str] = field(
+ default_factory=lambda: None,
+ metadata={"help": "Target modules for the LoRA method."},
+ )
+
+
+if __name__ == "__main__":
+ parser = HfArgumentParser((GaudiGRPOConfig, ScriptArguments))
+ (training_args, script_args) = parser.parse_args_into_dataclasses()
+
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
+ logger.warning(
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.bf16}"
+ )
+ # Set the verbosity to info of the Transformers logger (on main process only):
+ if is_main_process(training_args.local_rank):
+ transformers.utils.logging.set_verbosity_info()
+ transformers.utils.logging.enable_default_handler()
+ transformers.utils.logging.enable_explicit_format()
+ logger.info(f"Training/evaluation parameters {training_args}")
+ # Set seed before initializing model.
+ set_seed(training_args.seed)
+
+ use_deepspeed = training_args.world_size > 1
+
+ if script_args.use_peft:
+ peft_config = LoraConfig(
+ r=script_args.lora_r,
+ lora_alpha=script_args.lora_alpha,
+ lora_dropout=script_args.lora_dropout,
+ target_modules=script_args.lora_target_modules,
+ task_type="CAUSAL_LM",
+ )
+ else:
+ peft_config = None
+
+ tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path, trust_remote_code=True)
+ if training_args.chat_template is not None:
+ tokenizer.chat_template = training_args.chat_template
+
+ train_dataset, test_dataset = load_dataset(
+ script_args.dataset_name,
+ data_dir=None if script_args.subset == "None" else script_args.subset,
+ num_proc=script_args.num_workers if not script_args.streaming else None,
+ split=[script_args.dataset_train_split, script_args.dataset_test_split],
+ )
+
+ train_dataset = train_dataset.map(make_conversation)
+ test_dataset = test_dataset.map(make_conversation)
+ train_dataset = train_dataset.remove_columns(["messages", "problem"])
+
+ low_cpu_mem_usage = True
+ if is_deepspeed_available() and use_deepspeed:
+ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
+
+ if is_deepspeed_zero3_enabled():
+ low_cpu_mem_usage = False
+
+ model = AutoModelForCausalLM.from_pretrained(
+ script_args.model_name_or_path,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ torch_dtype=torch.bfloat16,
+ )
+
+ model.config.use_cache = False
+ if not script_args.use_flash_attention and (
+ script_args.flash_attention_recompute or script_args.flash_attention_recompute
+ ):
+ assert "Need to enable use_flash_attention"
+ model.generation_config.use_flash_attention = script_args.use_flash_attention
+ model.generation_config.flash_attention_recompute = script_args.flash_attention_recompute
+ model.generation_config.flash_attention_causal_mask = script_args.flash_attention_causal_mask
+
+ reward_funcs = [format_reward, accuracy_reward]
+ if script_args.reward_model_name_or_path:
+ reward_funcs = AutoModelForSequenceClassification.from_pretrained(
+ script_args.reward_model_name_or_path,
+ trust_remote_code=True,
+ )
+
+ if getattr(tokenizer, "pad_token", None) is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ gaudi_config = GaudiConfig()
+ gaudi_config.use_fused_adam = True
+ gaudi_config.use_fused_clip_norm = True
+
+ trainer = GaudiGRPOTrainer(
+ model=model,
+ reward_funcs=reward_funcs,
+ args=training_args,
+ train_dataset=train_dataset,
+ eval_dataset=test_dataset,
+ processing_class=tokenizer,
+ gaudi_config=gaudi_config,
+ peft_config=peft_config,
+ )
+
+ trainer.train()
+
+ print("Done!")
diff --git a/examples/trl/requirements_grpo.txt b/examples/trl/requirements_grpo.txt
new file mode 100644
index 0000000000..e7475bbc91
--- /dev/null
+++ b/examples/trl/requirements_grpo.txt
@@ -0,0 +1,8 @@
+trl == 0.17.0
+peft == 0.12.0
+datasets
+tyro
+evaluate
+scikit-learn == 1.5.2
+accelerate
+math_verify
diff --git a/optimum/habana/trl/__init__.py b/optimum/habana/trl/__init__.py
index 37f4d1156f..060e0b1379 100644
--- a/optimum/habana/trl/__init__.py
+++ b/optimum/habana/trl/__init__.py
@@ -1,10 +1,21 @@
+import importlib.metadata
+
+from packaging import version
+
from .models.modeling_base import adapt_PreTrainedModelWrapper_to_gaudi
from .models.modeling_sd_base import GaudiDefaultDDPOStableDiffusionPipeline
from .trainer.ddpo_trainer import GaudiDDPOTrainer
from .trainer.dpo_config import GaudiDPOConfig
from .trainer.dpo_trainer import GaudiDPOTrainer
-from .trainer.ppo_config import GaudiPPOConfig
-from .trainer.ppo_trainer import GaudiPPOTrainer
+
+
+trl_version = importlib.metadata.version("trl")
+if version.parse(trl_version) < version.parse("0.17.0"):
+ from .trainer.ppo_config import GaudiPPOConfig
+ from .trainer.ppo_trainer import GaudiPPOTrainer
+else:
+ from .trainer.grpo_config import GaudiGRPOConfig
+ from .trainer.grpo_trainer import GaudiGRPOTrainer
from .trainer.reward_trainer import GaudiRewardTrainer, RewardDataCollatorWithPadding
from .trainer.sft_config import GaudiSFTConfig
from .trainer.sft_trainer import GaudiSFTTrainer
diff --git a/optimum/habana/trl/trainer/__init__.py b/optimum/habana/trl/trainer/__init__.py
index 6da9debbd8..f6de8bf253 100644
--- a/optimum/habana/trl/trainer/__init__.py
+++ b/optimum/habana/trl/trainer/__init__.py
@@ -16,13 +16,22 @@
# There is a circular import in the PPOTrainer if we let isort sort these
# isort: on
+import importlib.metadata
+from packaging import version
from .sft_trainer import GaudiSFTTrainer
from .dpo_trainer import GaudiDPOTrainer
-from .ppo_config import GaudiPPOConfig
-from .ppo_trainer import GaudiPPOTrainer
+
from .reward_trainer import GaudiRewardTrainer, RewardDataCollatorWithPadding
from .ddpo_trainer import GaudiDDPOTrainer
from .dpo_config import GaudiDPOConfig
from .sft_config import GaudiSFTConfig
+
+trl_version = importlib.metadata.version("trl")
+if version.parse(trl_version) < version.parse("0.17.0"):
+ from .ppo_config import GaudiPPOConfig
+ from .ppo_trainer import GaudiPPOTrainer
+else:
+ from .grpo_trainer import GaudiGRPOTrainer
+ from .grpo_config import GaudiGRPOConfig
diff --git a/optimum/habana/trl/trainer/dpo_trainer.py b/optimum/habana/trl/trainer/dpo_trainer.py
index d57a032983..64846da42f 100644
--- a/optimum/habana/trl/trainer/dpo_trainer.py
+++ b/optimum/habana/trl/trainer/dpo_trainer.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import importlib.metadata
import inspect
import warnings
from collections import defaultdict
@@ -24,6 +25,7 @@
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available
from datasets import Dataset
+from packaging import version
from transformers import (
AutoModelForCausalLM,
DataCollator,
@@ -34,12 +36,10 @@
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from trl import DPOTrainer, create_reference_model
-from trl.import_utils import is_peft_available, is_wandb_available
from trl.trainer.dpo_config import FDivergenceConstants
from trl.trainer.utils import (
DPODataCollatorWithPadding,
RunningMoments,
- SyncRefModelCallback,
disable_dropout_in_model,
pad_to_length,
)
@@ -48,6 +48,16 @@
from .dpo_config import GaudiDPOConfig
+trl_version = importlib.metadata.version("trl")
+if version.parse(trl_version) < version.parse("0.17.0"):
+ from trl.import_utils import is_peft_available, is_wandb_available
+ from trl.trainer.utils import SyncRefModelCallback
+else:
+ from transformers import is_wandb_available
+ from transformers.utils import is_peft_available
+ from trl.trainer.callbacks import SyncRefModelCallback
+
+
if is_peft_available():
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
diff --git a/optimum/habana/trl/trainer/grpo_config.py b/optimum/habana/trl/trainer/grpo_config.py
new file mode 100644
index 0000000000..62df6c2e07
--- /dev/null
+++ b/optimum/habana/trl/trainer/grpo_config.py
@@ -0,0 +1,316 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from dataclasses import dataclass, field
+from typing import Optional
+
+from ... import GaudiTrainingArguments
+
+
+@dataclass
+class GaudiGRPOConfig(GaudiTrainingArguments):
+ r"""
+ Initialize GaudiGRPOConfig.
+ Adapted from https://github.com/huggingface/trl/blob/v0.17.0/trl/trainer/grpo_config.py
+ - inherit from GaudiTrainingArguments
+ """
+
+ # Parameters that control the model and reference model
+ model_init_kwargs: Optional[dict] = field(
+ default=None,
+ metadata={
+ "help": "Keyword arguments for `transformers.AutoModelForCausalLM.from_pretrained`, used when the `model` "
+ "argument of the `GRPOTrainer` is provided as a string."
+ },
+ )
+
+ # Parameters that control the data preprocessing
+ # The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on
+ # additional columns to compute the reward
+ remove_unused_columns: Optional[bool] = field(
+ default=False,
+ metadata={
+ "help": "Whether to only keep the column 'prompt' in the dataset. If you use a custom reward function "
+ "that requires any column other than 'prompts' and 'completions', you should keep this to `False`."
+ },
+ )
+ max_prompt_length: Optional[int] = field(
+ default=512,
+ metadata={
+ "help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left."
+ },
+ )
+ num_generations: Optional[int] = field(
+ default=4,
+ metadata={
+ "help": "Number of generations to sample. The global batch size (num_processes * per_device_batch_size) "
+ "must be divisible by this value."
+ },
+ )
+ max_completion_length: Optional[int] = field(
+ default=64,
+ metadata={"help": "Maximum length of the generated completion."},
+ )
+ ds3_gather_for_generation: bool = field(
+ default=True,
+ metadata={
+ "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
+ "generation, improving generation speed. However, disabling this option allows training models that "
+ "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option "
+ "is not compatible with vLLM generation."
+ },
+ )
+ shuffle_dataset: Optional[bool] = field(
+ default=True,
+ metadata={"help": "Whether to shuffle the training dataset."},
+ )
+
+ # Parameters that control generation
+ temperature: float = field(
+ default=0.9,
+ metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
+ )
+ top_p: float = field(
+ default=1.0,
+ metadata={
+ "help": "Float that controls the cumulative probability of the top tokens to consider. Must be in (0, 1]. "
+ "Set to 1.0 to consider all tokens."
+ },
+ )
+ top_k: Optional[int] = field(
+ default=50,
+ metadata={
+ "help": "Number of highest probability vocabulary tokens to keep for top-k-filtering. If `None`, "
+ "top-k-filtering is disabled."
+ },
+ )
+ min_p: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": "Minimum token probability, which will be scaled by the probability of the most likely token. It "
+ "must be a value between 0.0 and 1.0. Typical values are in the 0.01-0.2 range."
+ },
+ )
+ repetition_penalty: float = field(
+ default=1.0,
+ metadata={
+ "help": "Float that penalizes new tokens based on whether they appear in the prompt and the generated "
+ "text so far. Values > 1.0 encourage the model to use new tokens, while values < 1.0 encourage the model "
+ "to repeat tokens."
+ },
+ )
+ cache_implementation: Optional[str] = field(
+ default=None,
+ metadata={"help": "Implementation of the cache method for faster generation when use_vllm is set to False."},
+ )
+
+ # Parameters that control generation acceleration powered by vLLM
+ use_vllm: bool = field(
+ default=False,
+ metadata={
+ "help": "Whether to use vLLM for generating completions. If set to `True`, ensure that a vLLM server is "
+ "running. To run the server, install vLLM (`pip install vllm`) and run `trl vllm-serve`."
+ },
+ )
+ vllm_server_host: str = field(
+ default="0.0.0.0",
+ metadata={"help": "Host of the vLLM server to connect to."},
+ )
+ vllm_server_port: int = field(
+ default=8000,
+ metadata={"help": "Port of the vLLM server to connect to."},
+ )
+ vllm_server_timeout: float = field(
+ default=120.0,
+ metadata={
+ "help": "Total timeout duration in seconds to wait for the vLLM server to be up. If the server is not up "
+ "after the timeout, a `ConnectionError` is raised."
+ },
+ )
+ vllm_guided_decoding_regex: Optional[str] = field(
+ default=None,
+ metadata={"help": "Regex for vLLM guided decoding. If `None` (default), guided decoding is disabled."},
+ )
+
+ # Parameters that control the training
+ learning_rate: float = field(
+ default=2e-5,
+ metadata={
+ "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
+ "`transformers.TrainingArguments`."
+ },
+ )
+ beta: float = field(
+ default=0.04,
+ metadata={
+ "help": "KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving "
+ "training speed, but may be numerically unstable for long training runs."
+ },
+ )
+ num_iterations: int = field(
+ default=1,
+ metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
+ )
+ epsilon: float = field(
+ default=0.2,
+ metadata={"help": "Epsilon value for clipping."},
+ )
+ epsilon_high: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": "Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the "
+ "lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`."
+ },
+ )
+ reward_weights: Optional[list[float]] = field(
+ default=None,
+ metadata={
+ "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all "
+ "rewards are weighted equally with weight `1.0`."
+ },
+ )
+ scale_rewards: bool = field(
+ default=True,
+ metadata={
+ "help": "Whether to scale the rewards by dividing them by their standard deviation. If `True` (default), "
+ "the rewards are normalized by the standard deviation, ensuring they have unit variance. If `False`, no "
+ "scaling is applied. The Dr. GRPO paper recommends not scaling the rewards, as scaling by the standard "
+ "deviation introduces a question-level difficulty bias."
+ },
+ )
+ mask_truncated_completions: bool = field(
+ default=False,
+ metadata={
+ "help": "When enabled, truncated completions are excluded from the loss calculation, preventing them from "
+ "being incorrectly penalized and introducing noise during training. According to the DAPO paper, this is "
+ "a good practice for training stability."
+ },
+ )
+ sync_ref_model: bool = field(
+ default=False,
+ metadata={
+ "help": "Whether to synchronize the reference model with the active model every `ref_model_sync_steps` "
+ "steps, using the `ref_model_mixup_alpha` parameter."
+ },
+ )
+ ref_model_mixup_alpha: float = field(
+ default=0.6,
+ metadata={
+ "help": "α parameter from the TR-DPO paper, which controls the mix between the current policy and the "
+ "previous reference policy during updates. The reference policy is updated according to the equation: "
+ "`π_ref = α * π_θ + (1 - α) * π_ref_prev`. To use this parameter, you must set `sync_ref_model=True`."
+ },
+ )
+ ref_model_sync_steps: int = field(
+ default=512,
+ metadata={
+ "help": "τ parameter from the TR-DPO paper, which determines how frequently the current policy is "
+ "synchronized with the reference policy. To use this parameter, you must set `sync_ref_model=True`."
+ },
+ )
+
+ # Parameters that control the logging
+ log_completions: bool = field(
+ default=False,
+ metadata={
+ "help": "Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is "
+ "installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
+ },
+ )
+ num_completions_to_print: Optional[int] = field(
+ default=None,
+ metadata={"help": "Number of completions to print with `rich`. If `None`, all completions are logged."},
+ )
+
+ # Deprecated parameters
+ vllm_device: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "This parameter is deprecated and will be removed in version 0.18.0. To use vLLM, start a vLLM "
+ "server with the `trl vllm-serve` command."
+ },
+ )
+ vllm_gpu_memory_utilization: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": "This parameter is deprecated and will be removed in version 0.18.0. To control the GPU memory "
+ "utilization for vLLM, you should now use the `gpu_memory_utilization` parameter in the vLLM server "
+ "configuration."
+ },
+ )
+ vllm_dtype: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "This parameter is deprecated and will be removed in version 0.18.0. To control the data type for "
+ "vLLM generation, you should now use the `dtype` parameter in the vLLM server configuration."
+ },
+ )
+ vllm_max_model_len: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": "This parameter is deprecated and will be removed in version 0.18.0. To control the "
+ "`max_model_len` for vLLM, you should now use the `max_model_len` parameter in the vLLM server "
+ "configuration."
+ },
+ )
+ vllm_enable_prefix_caching: Optional[bool] = field(
+ default=None,
+ metadata={
+ "help": "This parameter is deprecated and will be removed in version 0.18.0. To control prefix caching in "
+ "vLLM, you should now use the `enable_prefix_caching` parameter in the vLLM server configuration."
+ },
+ )
+ chat_template: Optional[str] = field(default=None, metadata={"help": "chat_template"})
+
+ def __post_init__(self):
+ super().__post_init__()
+
+ if self.vllm_device is not None:
+ warnings.warn(
+ "`vllm_device` is deprecated and will be removed in version 0.18.0. To use vLLM, start a vLLM server "
+ "with the `trl vllm-serve` command.",
+ DeprecationWarning,
+ )
+
+ if self.vllm_gpu_memory_utilization is not None:
+ warnings.warn(
+ "`vllm_gpu_memory_utilization` is deprecated and will be removed in v0.18. To control the GPU memory "
+ "utilization for vLLM, you should now use the `gpu_memory_utilization` parameter in the vLLM server "
+ "configuration.",
+ DeprecationWarning,
+ )
+
+ if self.vllm_dtype is not None:
+ warnings.warn(
+ "`vllm_dtype` is deprecated and will be removed in version 0.18.0. To control the data type for vLLM "
+ "generation, you should now use the `dtype` parameter in the vLLM server configuration.",
+ DeprecationWarning,
+ )
+
+ if self.vllm_max_model_len is not None:
+ warnings.warn(
+ "`vllm_max_model_len` is deprecated and will be removed in version 0.18.0. To control the "
+ "`max_model_len` for vLLM, you should now use the `max_model_len` parameter in the vLLM server "
+ "configuration.",
+ DeprecationWarning,
+ )
+
+ if self.vllm_enable_prefix_caching is not None:
+ warnings.warn(
+ "`vllm_enable_prefix_caching` is deprecated and will be removed in version 0.18.0. To control prefix "
+ "caching in vLLM, you should now use the `enable_prefix_caching` parameter in the vLLM server "
+ "configuration.",
+ DeprecationWarning,
+ )
diff --git a/optimum/habana/trl/trainer/grpo_trainer.py b/optimum/habana/trl/trainer/grpo_trainer.py
new file mode 100644
index 0000000000..5bdb460b6f
--- /dev/null
+++ b/optimum/habana/trl/trainer/grpo_trainer.py
@@ -0,0 +1,835 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import bisect
+import copy
+import warnings
+from collections import defaultdict, deque
+from typing import Any, Callable, Dict, Optional, Tuple, Union
+
+import pandas as pd
+import torch
+import torch.utils.data
+from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
+from datasets import Dataset, IterableDataset
+from torch import nn
+from transformers import (
+ AutoModelForCausalLM,
+ AutoModelForSequenceClassification,
+ AutoTokenizer,
+ PreTrainedModel,
+ PreTrainedTokenizerBase,
+ Trainer,
+ TrainerCallback,
+ is_wandb_available,
+)
+from transformers.integrations.deepspeed import is_deepspeed_available, is_deepspeed_zero3_enabled
+from transformers.utils import is_datasets_available, is_peft_available
+from trl import GRPOTrainer
+from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
+from trl.extras.profiling import profiling_context, profiling_decorator
+from trl.extras.vllm_client import VLLMClient
+from trl.import_utils import is_rich_available, is_vllm_available
+from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
+from trl.trainer.callbacks import SyncRefModelCallback
+from trl.trainer.utils import (
+ pad,
+ print_prompt_completions_sample,
+ selective_log_softmax,
+)
+
+from optimum.utils import logging
+
+from ... import GaudiConfig, GaudiTrainer
+from ...transformers import trainer as habana_trainer
+from ...transformers.trainer import _get_input_update_settings
+from .grpo_config import GaudiGRPOConfig
+
+
+logger = logging.get_logger(__name__)
+
+
+if is_deepspeed_available():
+ pass
+
+if is_peft_available():
+ from peft import PeftConfig, get_peft_model
+
+if is_datasets_available():
+ pass
+
+if is_wandb_available():
+ import wandb
+
+# What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
+# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
+RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
+
+
+def grpo_get_input_update_settings(model, lazy_mode: Optional[bool] = None) -> Tuple[bool, Dict]:
+ # For GRPOTrainer, skip input update in the _inner_training_loop()
+ # because it expects a dict type input, but the GRPOTrainer input is a list of dict.
+ # Instead, the update is done in _get_per_token_logps()
+ return False, {}
+
+
+class GaudiGRPOTrainer(GRPOTrainer, GaudiTrainer):
+ _tag_names = ["trl", "grpo"]
+
+ def __init__(
+ self,
+ model: Union[str, PreTrainedModel],
+ reward_funcs: Union[RewardFunc, list[RewardFunc]],
+ args: Optional[GaudiGRPOConfig] = None,
+ gaudi_config: GaudiConfig = None,
+ train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
+ eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
+ processing_class: Optional[PreTrainedTokenizerBase] = None,
+ reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
+ callbacks: Optional[list[TrainerCallback]] = None,
+ optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
+ peft_config: Optional["PeftConfig"] = None,
+ ):
+ """
+ Copied from GRPOTrainer.__init__: https://github.com/huggingface/trl/blob/v0.17.0/trl/trainer/grpo_trainer.py#L264
+ The only differences are:
+ - Add new args gaudi_config
+ - Use GaudiTrainer instead of Trainer
+ - Add bucketing to reduce dynamic input shape
+ - Toggle of use_cache and gradient_checkpointing for the rollout performance with gradient_checkpointing
+ """
+ # Args
+ if args is None:
+ model_name = model if isinstance(model, str) else model.config._name_or_path
+ model_name = model_name.split("/")[-1]
+ args = GaudiGRPOConfig(f"{model_name}-GRPO")
+ self.args = args
+
+ # Models
+ # Trained model
+ model_init_kwargs = args.model_init_kwargs or {}
+ if isinstance(model, str):
+ model_id = model
+ torch_dtype = model_init_kwargs.get("torch_dtype")
+ if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
+ pass # torch_dtype is already a torch.dtype or "auto" or None
+ elif isinstance(torch_dtype, str): # it's a str, but not "auto"
+ torch_dtype = getattr(torch, torch_dtype)
+ model_init_kwargs["torch_dtype"] = torch_dtype
+ else:
+ raise ValueError(
+ "Invalid `torch_dtype` passed to `GaudiGRPOConfig`. Expected either 'auto' or a string representing "
+ f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
+ )
+
+ # Disable caching if gradient checkpointing is enabled (not supported)
+ model_init_kwargs["use_cache"] = (
+ False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
+ )
+ model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
+ else:
+ model_id = model.config._name_or_path
+ if args.model_init_kwargs is not None:
+ raise ValueError(
+ "You passed `model_init_kwargs` to the `GaudiGRPOConfig`, but your model is already instantiated. "
+ "This argument can only be used when the `model` argument is a string."
+ )
+
+ if peft_config is not None:
+ if not is_peft_available():
+ raise ImportError("PEFT is required to use `peft_config`. Run `pip install peft`.")
+ model = get_peft_model(model, peft_config)
+
+ # Enable gradient checkpointing if requested
+ if args.gradient_checkpointing:
+ model = self._enable_gradient_checkpointing(model, args)
+
+ # Reference model
+ self.beta = args.beta
+ if self.beta == 0.0:
+ # If beta is 0.0, the reference model is not needed
+ self.ref_model = None
+ elif is_deepspeed_zero3_enabled():
+ self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
+ elif is_peft_model(model):
+ # If PEFT is used, the reference model is not needed since the adapter can be disabled
+ # to revert to the initial model.
+ self.ref_model = None
+ else:
+ # If PEFT configuration is not provided, create a reference model based on the initial model.
+ self.ref_model = create_reference_model(model)
+
+ # Processing class
+ if processing_class is None:
+ processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
+ if processing_class.pad_token is None:
+ processing_class.pad_token = processing_class.eos_token
+
+ # Reward functions
+ if not isinstance(reward_funcs, list):
+ reward_funcs = [reward_funcs]
+ self.reward_func_names = []
+ for i, reward_func in enumerate(reward_funcs):
+ if isinstance(reward_func, str):
+ reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
+ reward_func, num_labels=1, **model_init_kwargs
+ )
+ if isinstance(reward_funcs[i], nn.Module): # Use Module over PretrainedModel for compat w/ compiled models
+ self.reward_func_names.append(reward_funcs[i].config._name_or_path.split("/")[-1])
+ else:
+ self.reward_func_names.append(reward_funcs[i].__name__)
+ self.reward_funcs = reward_funcs
+
+ # Reward weights
+ if args.reward_weights is not None:
+ if len(args.reward_weights) != len(reward_funcs):
+ raise ValueError(
+ f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
+ f"functions ({len(reward_funcs)})"
+ )
+ self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
+ else:
+ self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
+
+ # Reward processing class
+ if reward_processing_classes is None:
+ reward_processing_classes = [None] * len(reward_funcs)
+ elif not isinstance(reward_processing_classes, list):
+ reward_processing_classes = [reward_processing_classes]
+ else:
+ if len(reward_processing_classes) != len(reward_funcs):
+ raise ValueError("The number of reward processing classes must match the number of reward functions.")
+
+ for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
+ if isinstance(reward_func, PreTrainedModel):
+ if reward_processing_class is None:
+ reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
+ if reward_processing_class.pad_token_id is None:
+ reward_processing_class.pad_token = reward_processing_class.eos_token
+ # The reward model computes the reward for the latest non-padded token in the input sequence.
+ # So it's important to set the pad token ID to the padding token ID of the processing class.
+ reward_func.config.pad_token_id = reward_processing_class.pad_token_id
+ reward_processing_classes[i] = reward_processing_class
+ self.reward_processing_classes = reward_processing_classes
+
+ def data_collator(features):
+ return features
+
+ # Training arguments
+ self.max_prompt_length = args.max_prompt_length
+ self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
+ self.num_generations = args.num_generations # = G in the GRPO paper
+ self.temperature = args.temperature
+ self.top_p = args.top_p
+ self.top_k = args.top_k
+ self.min_p = args.min_p
+ self.repetition_penalty = args.repetition_penalty
+ self.use_vllm = args.use_vllm
+
+ self.scale_rewards = args.scale_rewards
+ self.mask_truncated_completions = args.mask_truncated_completions
+
+ self.buckets = self._get_buckets(train_dataset, processing_class)
+
+ self.shuffle_dataset = args.shuffle_dataset
+
+ if (
+ isinstance(train_dataset, IterableDataset)
+ or isinstance(eval_dataset, IterableDataset)
+ or (
+ isinstance(eval_dataset, dict) and any(isinstance(ds, IterableDataset) for ds in eval_dataset.values())
+ )
+ ):
+ # See https://github.com/huggingface/trl/issues/3213
+ raise NotImplementedError(
+ "Iterable datasets are not yet supported in GRPOTrainer. Please use a standard dataset instead."
+ )
+
+ # Multi-step
+ self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
+ self.epsilon_low = args.epsilon
+ self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
+ # Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle
+ self._step = 0
+ # Buffer the batch to reuse generated outputs across multiple updates. For more details, see
+ # `_get_train_sampler` and `_prepare_inputs`.
+ self._buffered_inputs = None
+
+ # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
+ # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
+ # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
+ # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
+ # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
+ # This acts as a flag to indicate that the warning has already been issued.
+ model.warnings_issued["estimate_tokens"] = True
+
+ GaudiTrainer.__init__(
+ self,
+ model=model,
+ args=args,
+ gaudi_config=gaudi_config,
+ data_collator=data_collator,
+ train_dataset=train_dataset,
+ eval_dataset=eval_dataset,
+ processing_class=processing_class,
+ callbacks=callbacks,
+ optimizers=optimizers,
+ )
+ habana_trainer._get_input_update_settings = grpo_get_input_update_settings
+
+ # Initialize the metrics
+ self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
+ self._total_train_tokens = 0
+ self.log_completions = args.log_completions
+ self.num_completions_to_print = args.num_completions_to_print
+ # maxlen is set to the total number of forward passes per step. This value of `maxlen` ensures we log only the
+ # final optimization step.
+ maxlen = self.accelerator.num_processes * args.per_device_train_batch_size * args.gradient_accumulation_steps
+ self._textual_logs = {
+ "prompt": deque(maxlen=maxlen),
+ "completion": deque(maxlen=maxlen),
+ "rewards": defaultdict(lambda: deque(maxlen=maxlen)),
+ }
+ # Check if the effective batch size can be divided by the number of generations
+ if self.num_generations < 2:
+ raise ValueError(
+ "GRPO requires at least 2 generations per prompt to calculate the advantages. You provided "
+ f"{self.num_generations}, which is less than the minimum required."
+ )
+
+ num_processes = self.accelerator.num_processes
+ global_batch_size = args.per_device_train_batch_size * num_processes * args.gradient_accumulation_steps
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
+ if self.num_generations not in possible_values:
+ raise ValueError(
+ f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
+ f"batch size, the valid values for the number of generations are: {possible_values}."
+ )
+ if self.args.eval_strategy != "no":
+ global_batch_size = args.per_device_eval_batch_size * num_processes
+ possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
+ if self.num_generations not in possible_values:
+ raise ValueError(
+ f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
+ f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
+ f"eval batch size, the valid values for the number of generations are: {possible_values}."
+ )
+
+ # Ensure each process receives a unique seed to prevent duplicate completions when generating with
+ # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
+ # it's safer to set it in all cases.
+ set_seed(args.seed, device_specific=True)
+
+ if self.use_vllm:
+ if not is_vllm_available():
+ raise ImportError(
+ "vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
+ "`pip install vllm` to use it."
+ )
+
+ if self.accelerator.is_main_process:
+ self.vllm_client = VLLMClient(
+ args.vllm_server_host, args.vllm_server_port, connection_timeout=args.vllm_server_timeout
+ )
+ self.vllm_client.init_communicator()
+
+ # vLLM specific sampling arguments
+ self.guided_decoding_regex = args.vllm_guided_decoding_regex
+
+ self._last_loaded_step = -1 # tag to avoid useless loading during grad accumulation
+
+ # When using vLLM, the main process is responsible for loading the model weights. This can cause process
+ # desynchronization and seems to lead to DeepSpeed hanging during initialization. To prevent this, we
+ # synchronize all processes after vLLM has been fully initialized.
+ self.accelerator.wait_for_everyone()
+ else:
+ self.generation_config = copy.deepcopy(model.generation_config)
+ self.generation_config.max_new_tokens = self.max_completion_length
+ self.generation_config.do_sample = True
+ self.generation_config.pad_token_id = processing_class.pad_token_id
+ self.generation_config.bos_token_id = processing_class.bos_token_id
+ self.generation_config.eos_token_id = processing_class.eos_token_id
+ self.generation_config.temperature = self.temperature
+ self.generation_config.top_p = self.top_p
+ self.generation_config.top_k = self.top_k
+ self.generation_config.min_p = self.min_p
+ self.generation_config.repetition_penalty = self.repetition_penalty
+ self.generation_config.cache_implementation = args.cache_implementation
+ self.generation_config.use_cache = True
+ self.generation_config.static_shapes = True
+ self.generation_config.reuse_cache = True
+
+ # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
+ # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
+ # self.model_accepts_loss_kwargs to False to enable scaling.
+ self.model_accepts_loss_kwargs = False
+
+ # Add tags to the model
+ self.model.add_model_tags(self._tag_names)
+
+ if self.ref_model is not None:
+ if self.is_deepspeed_enabled:
+ self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
+ else:
+ self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
+
+ if args.sync_ref_model:
+ self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
+
+ for i, reward_func in enumerate(self.reward_funcs):
+ if isinstance(reward_func, PreTrainedModel):
+ if self.is_deepspeed_enabled:
+ self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator)
+ else:
+ self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
+
+ def _get_buckets(self, train_dataset, tokenizer, num_buckets=5):
+ # Collect all seq lens
+ sentence_lengths = []
+ for batch in train_dataset:
+ formatted_prompt = maybe_apply_chat_template(batch, tokenizer)["prompt"]
+ formatted_prompt_len = len(tokenizer(formatted_prompt)["input_ids"])
+ sentence_lengths.append(formatted_prompt_len)
+
+ # Assign bucket labels to each sentence
+ bucket_label_per_sentence = pd.qcut(sentence_lengths, q=num_buckets, labels=False, duplicates="drop")
+
+ # Get max len per bucket
+ df = pd.DataFrame({"value": sentence_lengths, "bucket": bucket_label_per_sentence})
+ buckets = df.groupby("bucket")["value"].max().tolist()
+ # Make sure that no bucket exceeds self.max_prompt_length
+ buckets = [min(b, self.max_prompt_length) for b in buckets]
+ return buckets
+
+ # Get the per-token log probabilities for the completions for the model and the reference model
+ @profiling_decorator
+ def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
+ # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
+ inputs = {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "logits_to_keep": logits_to_keep + 1,
+ }
+
+ if hasattr(model, "module"):
+ # For distributed
+ should_update_inputs, input_updates = _get_input_update_settings(model.module)
+ inputs.update(input_updates)
+ else:
+ # For non distributed
+ should_update_inputs, input_updates = _get_input_update_settings(model)
+ inputs.update(input_updates)
+
+ logits = model(**inputs).logits
+
+ logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
+
+ input_ids = input_ids[:, -logits_to_keep:]
+ # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
+ # See https://github.com/huggingface/trl/issues/2770
+ logits = logits[:, -logits_to_keep:]
+ # Divide logits by sampling temperature.
+ # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
+ logits = logits / self.temperature
+
+ return selective_log_softmax(logits, input_ids)
+
+ def _generate_and_score_completions(
+ self, inputs: dict[str, Union[torch.Tensor, Any]]
+ ) -> dict[str, Union[torch.Tensor, Any]]:
+ device = self.accelerator.device
+
+ prompts = [x["prompt"] for x in inputs]
+ prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
+
+ # Get unique seq len within a batch
+ max_prompt_len_per_batch = 0
+ for prompt_idx in range(
+ 0, len(prompts_text), self.num_generations
+ ): # Prompts are repeated self.num_generations times
+ prompt_len = len(
+ self.processing_class(
+ text=prompts_text[prompt_idx], return_tensors="pt", padding=False, add_special_tokens=False
+ )["input_ids"][0]
+ )
+ max_prompt_len_per_batch = max(max_prompt_len_per_batch, prompt_len)
+
+ # Search bucket and the tokenize prompts with padding
+ bucket_indices = bisect.bisect_left(self.buckets, max_prompt_len_per_batch)
+ bucket_indices = min(bucket_indices, len(self.buckets) - 1)
+ prompt_inputs = self.processing_class(
+ text=prompts_text,
+ return_tensors="pt",
+ padding="max_length",
+ padding_side="left",
+ max_length=self.buckets[bucket_indices],
+ truncation=True,
+ add_special_tokens=False,
+ )
+
+ prompt_inputs = Trainer._prepare_inputs(self, prompt_inputs)
+ prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
+
+ if self.max_prompt_length is not None:
+ prompt_ids = prompt_ids[:, -self.max_prompt_length :]
+ prompt_mask = prompt_mask[:, -self.max_prompt_length :]
+
+ # Generate completions using either vLLM or regular generation
+ if self.args.use_vllm:
+ # First, have main process load weights if needed
+ if self.state.global_step != self._last_loaded_step:
+ self._move_model_to_vllm()
+ self._last_loaded_step = self.state.global_step
+
+ # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
+ all_prompts_text = gather_object(prompts_text)
+ if self.accelerator.is_main_process:
+ # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
+ # num_generations outputs for each one. This is faster than generating outputs for each duplicate
+ # prompt individually.
+ ordered_set_of_prompts = all_prompts_text[:: self.num_generations]
+ with profiling_context(self, "vLLM.generate"):
+ completion_ids = self.vllm_client.generate(
+ prompts=ordered_set_of_prompts,
+ n=self.num_generations,
+ repetition_penalty=self.repetition_penalty,
+ temperature=self.temperature,
+ top_p=self.top_p,
+ top_k=-1 if self.top_k is None else self.top_k,
+ min_p=0.0 if self.min_p is None else self.min_p,
+ max_tokens=self.max_completion_length,
+ guided_decoding_regex=self.guided_decoding_regex,
+ )
+ else:
+ completion_ids = [None] * len(all_prompts_text)
+ # Broadcast the completions from the main process to all processes, ensuring each process receives its
+ # corresponding slice.
+ completion_ids = broadcast_object_list(completion_ids, from_process=0)
+ process_slice = slice(
+ self.accelerator.process_index * len(prompts),
+ (self.accelerator.process_index + 1) * len(prompts),
+ )
+ completion_ids = completion_ids[process_slice]
+
+ # Pad the completions, and concatenate them with the prompts
+ completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
+ completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
+ prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
+ else:
+ with unwrap_model_for_generation(
+ self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
+ ) as unwrapped_model:
+ if self.args.gradient_checkpointing:
+ unwrapped_model.gradient_checkpointing_disable()
+ unwrapped_model.config.use_cache = True
+
+ unwrapped_model.eval()
+
+ with torch.no_grad():
+ prompt_completion_ids = unwrapped_model.generate(
+ prompt_ids,
+ attention_mask=prompt_mask,
+ hpu_graphs=True,
+ generation_config=self.generation_config,
+ lazy_mode=True,
+ )
+
+ unwrapped_model.train()
+
+ # KV cache is not used during training. Delete KV cache to save memory.
+ if is_peft_model(unwrapped_model):
+ for layer in unwrapped_model.base_model.model.model.layers:
+ layer.self_attn.k_cache.cache = None
+ layer.self_attn.v_cache.cache = None
+ else:
+ for layer in unwrapped_model.model.layers:
+ layer.self_attn.k_cache.cache = None
+ layer.self_attn.v_cache.cache = None
+
+ # Compute prompt length and extract completion ids
+ prompt_length = prompt_ids.size(1)
+ prompt_ids = prompt_completion_ids[:, :prompt_length]
+ completion_ids = prompt_completion_ids[:, prompt_length:]
+
+ # Mask everything after the first EOS token
+ is_eos = completion_ids == self.processing_class.eos_token_id
+ eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
+ eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
+ sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
+ completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
+
+ # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
+ if self.mask_truncated_completions:
+ truncated_completions = ~is_eos.any(dim=1)
+ completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int()
+
+ # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need
+ # to re-tokenize completions if the reward is computed from tokens.
+ completion_ids_list = [
+ [id.item() for id, m in zip(row, mask_row) if m] for row, mask_row in zip(completion_ids, completion_mask)
+ ]
+
+ # Concatenate prompt_mask with completion_mask for logit computation
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
+
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
+
+ with torch.no_grad():
+ # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
+ # computation here, and use per_token_logps.detach() instead.
+ if self.num_iterations > 1:
+ old_per_token_logps = self._get_per_token_logps(
+ self.model, prompt_completion_ids, attention_mask, logits_to_keep
+ )
+ else:
+ old_per_token_logps = None
+
+ if self.beta == 0.0:
+ ref_per_token_logps = None
+ elif self.ref_model is not None:
+ ref_per_token_logps = self._get_per_token_logps(
+ self.ref_model, prompt_completion_ids, attention_mask, logits_to_keep
+ )
+ else:
+ with self.accelerator.unwrap_model(self.model).disable_adapter():
+ ref_per_token_logps = self._get_per_token_logps(
+ self.model, prompt_completion_ids, attention_mask, logits_to_keep
+ )
+
+ # Decode the generated completions
+ completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
+
+ if is_conversational(inputs[0]):
+ completions = []
+ for prompt, completion in zip(prompts, completions_text):
+ bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
+ completions.append([{"role": "assistant", "content": bootstrap + completion}])
+ else:
+ completions = completions_text
+
+ rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
+ for i, (reward_func, reward_processing_class) in enumerate(
+ zip(self.reward_funcs, self.reward_processing_classes)
+ ):
+ if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
+ reward_func_name = f"reward {reward_func.config._name_or_path.split('/')[-1]}"
+ else:
+ reward_func_name = reward_func.__name__
+ with profiling_context(self, reward_func_name):
+ if isinstance(
+ reward_func, nn.Module
+ ): # Module instead of PretrainedModel for compat with compiled models
+ if is_conversational(inputs[0]):
+ messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
+ texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
+ else:
+ texts = [p + c for p, c in zip(prompts, completions)]
+ reward_inputs = reward_processing_class(
+ text=texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
+ )
+ reward_inputs = Trainer._prepare_inputs(self, reward_inputs)
+ with torch.inference_mode():
+ rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
+ else:
+ # Repeat all input columns (but "prompt" and "completion") to match the number of generations
+ keys = [
+ key
+ for key in inputs[0]
+ if key
+ not in [
+ "prompt",
+ "completion",
+ "completion_ids",
+ "use_flash_attention",
+ "flash_attention_fast_softmax",
+ "lazy_mode",
+ ]
+ ]
+
+ reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
+ output_reward_func = reward_func(
+ prompts=prompts, completions=completions, completion_ids=completion_ids_list, **reward_kwargs
+ )
+ # Convert None values to NaN
+ output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
+
+ rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
+
+ # If all reward functions return None for a given row, issue a detailed warning
+ if torch.isnan(rewards_per_func).all(dim=1).any():
+ nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
+ row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()}
+ row_reward_kwargs["prompt"] = prompts[nan_row_idx]
+ row_reward_kwargs["completion"] = completions[nan_row_idx]
+ warnings.warn(
+ f"All reward functions returned None for the following kwargs: {row_reward_kwargs}. "
+ "Please ensure that at least one reward function returns a valid reward."
+ )
+
+ # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
+ # completions may be distributed across processes
+ rewards_per_func = gather(rewards_per_func)
+
+ # Apply weights to each reward function's output and sum
+ rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
+
+ # Compute grouped-wise rewards
+ mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
+ std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
+
+ # Normalize the rewards to compute the advantages
+ mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
+ std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
+ advantages = rewards - mean_grouped_rewards
+
+ if self.args.scale_rewards:
+ advantages = advantages / (std_grouped_rewards + 1e-4)
+
+ # Slice to keep only the local part of the data
+ process_slice = slice(
+ self.accelerator.process_index * len(prompts),
+ (self.accelerator.process_index + 1) * len(prompts),
+ )
+ advantages = advantages[process_slice]
+
+ # Log the metrics
+ mode = "eval" if self.control.should_evaluate else "train"
+
+ if mode == "train":
+ self._total_train_tokens += self.accelerator.gather_for_metrics(attention_mask.sum()).sum().item()
+ self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
+
+ completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
+ self._metrics[mode]["completion_length"].append(completion_length)
+
+ # Calculate mean reward per function, but only for samples where the function was applied
+ for i, reward_func in enumerate(self.reward_funcs):
+ if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
+ reward_func_name = reward_func.config._name_or_path.split("/")[-1]
+ else:
+ reward_func_name = reward_func.__name__
+ # Only calculate mean for samples where this reward function was applied (non-NaN values)
+ mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
+ self._metrics[mode][f"rewards/{reward_func_name}"].append(mean_rewards)
+ self._metrics[mode]["reward"].append(rewards.mean().item())
+ self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
+
+ if self.log_completions and self.state.global_step % self.args.logging_steps == 0:
+ prompts_to_log = gather_object(prompts_text)
+ completions_to_log = gather_object(completions_text)
+ rewards_to_log = rewards.tolist()
+
+ if self.accelerator.is_main_process:
+ if is_rich_available():
+ print_prompt_completions_sample(
+ prompts_to_log,
+ completions_to_log,
+ rewards_to_log,
+ self.state.global_step,
+ )
+ if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
+ import pandas as pd
+
+ # For logging
+ table = {
+ "step": [str(self.state.global_step)] * len(rewards),
+ "prompt": prompts_to_log,
+ "completion": completions_to_log,
+ "reward": rewards.tolist(),
+ }
+ df = pd.DataFrame(table)
+ wandb.log({"completions": wandb.Table(dataframe=df)})
+
+ return {
+ "prompt_ids": prompt_ids,
+ "prompt_mask": prompt_mask,
+ "completion_ids": completion_ids,
+ "completion_mask": completion_mask,
+ "old_per_token_logps": old_per_token_logps,
+ "ref_per_token_logps": ref_per_token_logps,
+ "advantages": advantages,
+ }
+
+ @profiling_decorator
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
+ if return_outputs:
+ raise ValueError("The GRPOTrainer does not support returning outputs")
+
+ if self.args.gradient_checkpointing:
+ if hasattr(model, "module"):
+ # Distributed
+ model.module.config.use_cache = False
+ if is_peft_model(model.module):
+ model.module.base_model.gradient_checkpointing_enable()
+ model.module.base_model.enable_input_require_grads()
+ else:
+ model.module.gradient_checkpointing_enable()
+ model.module.enable_input_require_grads()
+
+ else:
+ # Single card
+ model.config.use_cache = False
+ if is_peft_model(model):
+ model.base_model.gradient_checkpointing_enable()
+ model.base_model.enable_input_require_grads()
+ else:
+ # Enable gradient checkpointing for non-PEFT models
+ model.gradient_checkpointing_enable()
+ model.enable_input_require_grads()
+
+ # Compute the per-token log probabilities for the model
+ prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
+ completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
+ input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
+ attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
+ logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
+
+ per_token_logps = self._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
+
+ # Compute the KL divergence between the model and the reference model
+ if self.beta != 0.0:
+ ref_per_token_logps = inputs["ref_per_token_logps"]
+ per_token_kl = (
+ torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
+ )
+
+ # Compute the loss
+ advantages = inputs["advantages"]
+ # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's computation (see
+ # _generate_and_score_completions) and use per_token_logps.detach() instead.
+
+ old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
+ coef_1 = torch.exp(per_token_logps - old_per_token_logps)
+ coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
+ per_token_loss1 = coef_1 * advantages.unsqueeze(1)
+ per_token_loss2 = coef_2 * advantages.unsqueeze(1)
+ per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
+
+ if self.beta != 0.0:
+ per_token_loss = per_token_loss + self.beta * per_token_kl
+
+ loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
+
+ # Log the metrics
+ mode = "eval" if self.control.should_evaluate else "train"
+
+ if self.beta != 0.0:
+ mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
+ self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
+
+ is_clipped = (per_token_loss1 < per_token_loss2).float()
+ clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
+ self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
+ return loss
diff --git a/optimum/habana/trl/trainer/sft_trainer.py b/optimum/habana/trl/trainer/sft_trainer.py
index 6fb6365655..281b5015f7 100644
--- a/optimum/habana/trl/trainer/sft_trainer.py
+++ b/optimum/habana/trl/trainer/sft_trainer.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
+import importlib.metadata
import inspect
import warnings
from collections.abc import Mapping
@@ -23,6 +24,7 @@
import torch.nn as nn
from accelerate import PartialState
from datasets import Dataset
+from packaging import version
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
@@ -36,20 +38,27 @@
from transformers.trainer_utils import EvalPrediction
from trl import SFTTrainer
from trl.extras.dataset_formatting import get_formatting_func_from_dataset
-from trl.import_utils import is_peft_available
-from trl.trainer.utils import (
- ConstantLengthDataset,
- DataCollatorForCompletionOnlyLM,
- RichProgressCallback,
-)
+
+from ... import GaudiConfig, GaudiTrainer
+from .sft_config import GaudiSFTConfig
+trl_version = importlib.metadata.version("trl")
+if version.parse(trl_version) < version.parse("0.17.0"):
+ from trl.import_utils import is_peft_available
+ from trl.trainer.utils import (
+ ConstantLengthDataset,
+ DataCollatorForCompletionOnlyLM,
+ RichProgressCallback,
+ )
+else:
+ from transformers.utils import is_peft_available
+ from trl.trainer.callbacks import RichProgressCallback
+ from trl.trainer.utils import ConstantLengthDataset, DataCollatorForCompletionOnlyLM
+
if is_peft_available():
from peft import PeftConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
-from ... import GaudiConfig, GaudiTrainer
-from .sft_config import GaudiSFTConfig
-
class BucketedDataCollatorForLanguageModeling(DataCollatorForLanguageModeling):
def _get_bucketed_len(self, examples):
diff --git a/tests/ci/slow_tests_trl.sh b/tests/ci/slow_tests_trl.sh
index 90a81ec892..d64d001833 100644
--- a/tests/ci/slow_tests_trl.sh
+++ b/tests/ci/slow_tests_trl.sh
@@ -2,4 +2,4 @@
python -m pip install --upgrade pip
export RUN_SLOW=true
-make slow_tests_trl
+make slow_tests_trl_ddpo && make slow_tests_trl_grpo
diff --git a/tests/test_trl.py b/tests/test_trl.py
index ebb64edf73..f723361532 100644
--- a/tests/test_trl.py
+++ b/tests/test_trl.py
@@ -13,14 +13,35 @@
# limitations under the License.
import gc
+import importlib.metadata
+import tempfile
import unittest
import torch
-from transformers.testing_utils import slow
+from datasets import load_dataset
+from packaging import version
+from parameterized import parameterized
+from transformers.testing_utils import require_peft, slow
+from transformers.utils import is_peft_available
from trl import DDPOConfig
from optimum.habana import GaudiConfig
-from optimum.habana.trl import GaudiDDPOTrainer, GaudiDefaultDDPOStableDiffusionPipeline
+
+
+trl_version = importlib.metadata.version("trl")
+if version.parse(trl_version) < version.parse("0.17.0"):
+ from optimum.habana.trl import (
+ GaudiDDPOTrainer,
+ GaudiDefaultDDPOStableDiffusionPipeline,
+ )
+else:
+ from optimum.habana.trl import (
+ GaudiGRPOConfig,
+ GaudiGRPOTrainer,
+ )
+
+if is_peft_available():
+ from peft import LoraConfig, PeftModel
def scorer_function(images, prompts, metadata):
@@ -154,3 +175,156 @@ def setUp(self):
)
return super().setUp()
+
+
+class GaudiGRPOTrainerTester(unittest.TestCase):
+ """
+ Test the GaudiGRPOTrainer class.
+
+ Adapted from https://github.com/huggingface/trl/blob/main/tests/test_grpo_trainer.py#L216
+ The main changes are:
+ - use GaudiGRPOConfig and GaudiGRPOTrainer instead of GRPOConfig and GRPOTrainer
+ - add GaudiConfig
+ """
+
+ def test_init_minimal(self):
+ # Test that GRPOTrainer can be instantiated with only model, reward_model and train_dataset
+ dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
+
+ training_args = GaudiGRPOConfig(
+ use_habana=True,
+ use_lazy_mode=True,
+ )
+ gaudi_config = GaudiConfig()
+
+ GaudiGRPOTrainer(
+ model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
+ reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ args=training_args,
+ train_dataset=dataset,
+ gaudi_config=gaudi_config,
+ )
+
+ @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)])
+ def test_training(self, config_name):
+ dataset = load_dataset("trl-internal-testing/zen", config_name, split="train")
+
+ gaudi_config = GaudiConfig()
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ training_args = GaudiGRPOConfig(
+ output_dir=tmp_dir,
+ learning_rate=0.1, # increase the learning rate to speed up the test
+ per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
+ num_generations=3, # reduce the number of generations to reduce memory usage
+ max_completion_length=8, # reduce the completion length to reduce memory usage
+ report_to="none",
+ use_habana=True,
+ use_lazy_mode=True,
+ )
+ trainer = GaudiGRPOTrainer(
+ model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
+ reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ args=training_args,
+ train_dataset=dataset,
+ gaudi_config=gaudi_config,
+ )
+
+ previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+ trainer.train()
+
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+ # Check that the params have changed
+ for n, param in previous_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
+
+ @require_peft
+ def test_training_peft(self):
+ dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
+ gaudi_config = GaudiConfig()
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ training_args = GaudiGRPOConfig(
+ output_dir=tmp_dir,
+ learning_rate=0.1, # increase the learning rate to speed up the test
+ per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
+ num_generations=3, # reduce the number of generations to reduce memory usage
+ max_completion_length=8, # reduce the completion length to reduce memory usage
+ report_to="none",
+ use_habana=True,
+ use_lazy_mode=True,
+ )
+ trainer = GaudiGRPOTrainer(
+ model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
+ reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ args=training_args,
+ train_dataset=dataset,
+ peft_config=LoraConfig(),
+ gaudi_config=gaudi_config,
+ )
+
+ previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+ trainer.train()
+
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+ # Check that the peft params have changed and the base model params have not changed
+ for n, param in previous_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ if "lora" in n.lower(): # We expect the lora params to be different
+ self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.")
+ else: # We expect the rest of params to be the same
+ self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.")
+
+ @require_peft
+ def test_training_peft_with_gradient_checkpointing(self):
+ """Test that training works with PEFT and gradient checkpointing enabled."""
+ dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
+ gaudi_config = GaudiConfig()
+
+ lora_config = LoraConfig(
+ r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none"
+ )
+
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ training_args = GaudiGRPOConfig(
+ output_dir=tmp_dir,
+ learning_rate=0.1,
+ per_device_train_batch_size=3,
+ num_generations=3,
+ max_completion_length=8,
+ gradient_checkpointing=True, # Enable gradient checkpointing
+ report_to="none",
+ use_habana=True,
+ use_lazy_mode=True,
+ )
+ trainer = GaudiGRPOTrainer(
+ model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
+ reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
+ args=training_args,
+ train_dataset=dataset,
+ peft_config=lora_config,
+ gaudi_config=gaudi_config,
+ )
+
+ # Verify gradient checkpointing is enabled
+ self.assertIsInstance(trainer.model, PeftModel)
+
+ # Store initial parameters to check which ones change
+ previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
+
+ trainer.train()
+
+ self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
+
+ # Check that only LoRA parameters have changed, base model parameters remain unchanged
+ for n, param in previous_trainable_params.items():
+ new_param = trainer.model.get_parameter(n)
+ if "lora" in n.lower(): # LoRA parameters should change
+ self.assertFalse(torch.equal(param, new_param), f"LoRA parameter {n} has not changed.")
+ else: # Base model parameters should not change
+ self.assertTrue(torch.equal(param, new_param), f"Base parameter {n} has changed.")