Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
164 commits
Select commit Hold shift + click to select a range
5e88ce7
Upgrade to commit 74e19e81e2a23809af192532b9b0e7ea202be6f2
regisss Aug 28, 2024
43bc4eb
Merge branch 'main' into transformers_future
regisss Sep 2, 2024
8eea643
Add specific commit in setup.py
regisss Sep 3, 2024
a7be363
Upgrade to commit e48e5f1f13e05380e24f4f31f5fee07aa6f959eb
regisss Sep 6, 2024
f0b909a
Merge branch 'main' into transformers_future
regisss Sep 6, 2024
d99f18f
Fix default cache
regisss Sep 9, 2024
39b7a76
Merge branch 'main' into transformers_future
regisss Sep 9, 2024
da66ecf
Merge branch 'main' into transformers_future
regisss Sep 9, 2024
5547767
Merge branch 'main' into transformers_future
regisss Sep 10, 2024
bf89e41
Merge branch 'main' into transformers_future
regisss Sep 10, 2024
98b0da5
Merge branch 'main' into transformers_future
regisss Sep 24, 2024
47ad03c
Upgrade to commit 238b13478df209ab534f2195a397dc64a3930883
regisss Sep 24, 2024
94c23ba
Fix
regisss Sep 24, 2024
c19dedd
Upgrade to v4.45.0
regisss Sep 25, 2024
c12fd7e
Merge branch 'main' into transformers_future
regisss Sep 25, 2024
fc399fa
Fix
regisss Sep 25, 2024
9216159
Add bias to gptj (#1363)
jiminha Sep 26, 2024
679365a
Switch roberta from sdpa to eager attn (#1361)
skaulintel Sep 26, 2024
1abd6ee
Update bloom attention forward reshape follwing the transformer chang…
yeonsily Sep 26, 2024
8043d2c
Workaround for Llava/Llava-next
regisss Sep 26, 2024
047e7ff
Fix reshape error in mamba (#1369)
hsubramony Sep 28, 2024
f89e03b
Merge branch 'main' into transformers_future
regisss Sep 30, 2024
2ae546a
Merge branch 'main' into transformers_future
regisss Sep 30, 2024
1b8a3f7
Fix contrastive search
regisss Oct 1, 2024
2332afb
Fix local variable 'image_features' referenced before assignment (#1383)
vidyasiv Oct 1, 2024
f62ecde
Use model.generation_config instead of model.config (#1384)
hsubramony Oct 2, 2024
a8fb8ac
Make style
regisss Oct 2, 2024
c7ff331
Merge branch 'main' into transformers_future
regisss Jan 3, 2025
dd07c16
Upgrade to Transformers v4.47.1
regisss Jan 7, 2025
1924c89
Fix Transformers version to install
regisss Jan 8, 2025
f0926ae
Temporary workaround for GaudiTrainer
regisss Jan 9, 2025
e50e179
Fixes for text generation
regisss Jan 10, 2025
c804270
Set eager attention for distilbert, gpt_neox
regisss Jan 10, 2025
0000de5
Upgrade to Transformers v4.48
regisss Jan 15, 2025
2a3affa
Small fixes
regisss Jan 15, 2025
064f4c1
Fix integration tests
regisss Jan 15, 2025
21714f7
Fixes for text-generation
regisss Jan 16, 2025
1cfd53b
Fixes
regisss Jan 16, 2025
573cc57
Style
regisss Jan 16, 2025
a7bc517
Again
regisss Jan 16, 2025
3d402ea
Merge branch 'main' into transformers_future
regisss Jan 16, 2025
f69e957
Fix for image2text lora llama test (#1731)
vidyasiv Jan 28, 2025
265e6a1
Cherry-pick https://github.com/huggingface/transformers/pull/35651
regisss Jan 17, 2025
6dad1c4
Merge branch 'main' into transformers_future
regisss Jan 30, 2025
32478f5
Upgrade to Transformers v4.48.2
regisss Jan 31, 2025
1b79cf3
Fix deprecated imports following merged changes for DETR and Qwen2-VL
regisss Jan 31, 2025
c1f30d8
Workaround for textual inversion
regisss Jan 31, 2025
1460856
Merge branch 'main' into transformers_future
regisss Jan 31, 2025
7eadac6
Fixes for v4.48 pytest (#1699)
imangohari1 Jan 31, 2025
5cee218
fea(): Applied changes in HF #35235 (#1738)
imangohari1 Jan 31, 2025
417cbee
Merge branch 'main' into transformers_future
regisss Jan 31, 2025
c1e3232
Merge branch 'main' into transformers_future
regisss Feb 3, 2025
0f68bbb
Merge branch 'main' into transformers_future
regisss Feb 4, 2025
17943de
Removing HL_DS_DISTRIBUTED_ATTENTION_SEQ_DIM as it's not needed from …
bhargaveede Feb 5, 2025
d214819
Update DS config to align with recommended settings (#1730)
ckvermaAI Feb 5, 2025
6a520ff
Fix graph breaks in Mixtral (#65) (#1705)
Solaryee Feb 5, 2025
58de6b6
Merge branch 'main' into synapse_1_20
regisss Feb 5, 2025
bedc041
Add batch dim idx to support latest deepspeed DistributedAttention (…
bhargaveede Feb 6, 2025
ce57e40
Add _prepare_inputs_for_generation (#1743)
yafshar Feb 7, 2025
ef77fac
Merge branch 'main' into synapse_1_20
regisss Feb 7, 2025
be34027
Upgrade to v4.48.3
regisss Feb 7, 2025
bd9a60e
Fix the issue with --load_quantized_model_with_autoawq (#1747)
schoi-habana Feb 7, 2025
fc6a92b
Merge branch 'main' into synapse_1_20
regisss Feb 7, 2025
01bb4af
Merge branch 'synapse_1_20' into transformers_future
regisss Feb 7, 2025
2f665e8
Fix dpo crash in transformers 4.48 (#1750)
sywangyi Feb 12, 2025
595b816
Fix for Falcon image-to-text crash (#1760)
schoi-habana Feb 12, 2025
f3729a4
Fix llama attr (#1771)
atakaha Feb 12, 2025
bcb0778
Update llama scaling (#1775)
atakaha Feb 12, 2025
bd87113
Merge branch 'main' into synapse_1_20
regisss Feb 12, 2025
a13b5d2
Merge branch 'synapse_1_20' into transformers_future
regisss Feb 12, 2025
ce1bf08
Merge branch 'main' into synapse_1_20
regisss Feb 14, 2025
9b8bb2e
Merge branch 'synapse_1_20' into transformers_future
regisss Feb 14, 2025
d053218
Fix loss calculation (Workaround), final fix TBD (#1784)
emascarenhas Feb 14, 2025
8044aa4
Merge branch 'main' into synapse_1_20
regisss Feb 17, 2025
fe01ca2
Merge branch 'synapse_1_20' into transformers_future
regisss Feb 17, 2025
8b006c4
Simplify text-gen readme (#1780)
libinta Feb 18, 2025
0eb5d79
Merge branch 'main' into synapse_1_20
regisss Feb 18, 2025
a03f1d0
Merge branch 'synapse_1_20' into transformers_future
regisss Feb 18, 2025
6772b4f
Diffusers: Simplified the README files. Updated CI tests. (#1718)
imangohari1 Feb 20, 2025
06644af
Merge branch 'main' into synapse_1_20
regisss Feb 20, 2025
244b19e
Merge branch 'synapse_1_20' into transformers_future
regisss Feb 20, 2025
9279ab1
Merge branch 'main' into synapse_1_20
regisss Feb 21, 2025
5c7fea2
Merge branch 'synapse_1_20' into transformers_future
regisss Feb 21, 2025
fe65b05
Switch version number
regisss Feb 21, 2025
523370d
Merge branch 'main' into synapse_1_20
regisss Feb 21, 2025
3dfceb9
Merge branch 'main' into synapse_1_20
regisss Feb 26, 2025
836961e
Merge branch 'main' into synapse_1_20
regisss Feb 26, 2025
ffda2a0
Temporary WA for get_type error (#1806)
12010486 Feb 28, 2025
2688527
Merge branch 'main' into transformers_future
regisss Feb 28, 2025
167a218
Loss Computation for Compatibility with Transformers 4.48.3 (#1794)
yafshar Mar 5, 2025
379524c
Move model to device before wrapping with FSDP (#1801)
mieshkiwrk Mar 5, 2025
3197dd8
Merge branch 'main' into synapse_1_20
regisss Mar 5, 2025
46bad3b
v1.16 Llama3-405B text-generation. Added DEEPSPEED_USE_HABANA_FRAMEWO…
dsmertin Mar 5, 2025
0078227
Make style
regisss Mar 6, 2025
9b7ca11
Merge branch 'synapse_1_20' into transformers_future
regisss Mar 6, 2025
6d575e8
Merge branch 'main' into synapse_1_20
regisss Mar 6, 2025
73be6a2
Merge branch 'synapse_1_20' into transformers_future
regisss Mar 6, 2025
81f33ed
Revert placing llama on cpu (#1827)
ugolowic Mar 6, 2025
b46ed25
Merge branch 'synapse_1_20' into transformers_future
regisss Mar 7, 2025
73dd3ed
Merge branch 'main' into synapse_1_20
regisss Mar 7, 2025
4527647
Merge branch 'synapse_1_20' into transformers_future
regisss Mar 7, 2025
d0b54b8
Merge branch 'main' into transformers_future
regisss Mar 10, 2025
195fdf8
Fix contrastive search
regisss Mar 11, 2025
38f59eb
Merge branch 'main' into transformers_future
regisss Mar 12, 2025
b6602f7
Merge branch 'main' into transformers_future
regisss Mar 13, 2025
2f98dec
Merge branch 'main' into transformers_future
regisss Mar 13, 2025
45dc3aa
Merge branch 'main' into transformers_future
regisss Mar 14, 2025
a22b821
Upgrade to Transformers v4.49 (#1810)
regisss Mar 14, 2025
dd42c92
Fix `get_num_items_in_batches` for iterable datasets and when resumin…
regisss Mar 14, 2025
69f7e6d
Fixes pytest runtime error - Incompatible input shapes, broadcast not…
srajabos Mar 14, 2025
50d1f2e
Merge branch 'main' into transformers_future
regisss Mar 14, 2025
d0d0172
Fix for AutoModelForCausalLM.from_pretrained() (#1844)
dsmertin Mar 14, 2025
adbaa23
Fix unexpected 'num_items_in_batch' argument in GPT-NeoX forward (#1850)
mounikamandava Mar 14, 2025
e802f5f
Make style
regisss Mar 14, 2025
f461199
Fix for `GaudiLlamaAttention` object has no attribute 'max_position_e…
12010486 Mar 17, 2025
9cf57be
Fix error with TRL examples
regisss Mar 17, 2025
b780d70
[skip ci] Merge branch 'main' into transformers_future
regisss Mar 18, 2025
dbd987b
Adjust precision of eval_accuracy to avoid random failure in pytest f…
hchauhan123 Mar 19, 2025
78e50b9
Missing num_key_value_heads attribute in GaudiGemmaAttention (#1861)
hsubramony Mar 19, 2025
bff3803
Update Sentence Transformer CI/Ref (#1862)
ZhengHongming888 Mar 19, 2025
5d2fbde
Fix typo in modeling llama (#1864)
hsubramony Mar 20, 2025
0ec8b04
fea(): Added the updated skip list for mistral/mixtral tests (#1863)
imangohari1 Mar 20, 2025
639f96d
Fix llama internal bucketing issue (#1871)
dsocek Mar 21, 2025
f3124e7
Fix regression for test_run_image2text_lora_finetune_idefics2-8b_mult…
srajabos Mar 21, 2025
bbada81
Revert "Move model to device before wrapping with FSDP (#1801)" (#1865)
12010486 Mar 24, 2025
0732389
added GRPO Trainer and config / resolved import issues
alekseyfa Mar 26, 2025
2366a00
Resolved import issues
alekseyfa Mar 26, 2025
ee669dc
Updated requirements
alekseyfa Mar 28, 2025
4126553
GRPO simple training script
alekseyfa Mar 28, 2025
45fb347
Updated README
alekseyfa Mar 28, 2025
b6af175
Updated data collator
alekseyfa Apr 3, 2025
247c590
updated sample
alekseyfa Apr 3, 2025
942bd01
Updated README
alekseyfa Apr 4, 2025
6c78be8
Added LORA config
alekseyfa Apr 7, 2025
83cd501
Checking pad_token
alekseyfa Apr 9, 2025
35ac2c6
enable flash attn and pad inputs to the max seq len
schoi-habana May 6, 2025
3df10a5
Merge branch 'main' into schoi/grpo_from_pr1898
schoi-habana May 6, 2025
46fc724
README changes for Llama3.1 8B Finetuning with LoRA (#1947)
bhargaveede May 14, 2025
493bfd2
pt2e quant changes into the main script (#191) (#1875)
vivek5-ai May 14, 2025
cf889b5
working, convergence not sure
schoi-habana May 9, 2025
736b1df
Merge branch 'main' into v1.18-release
regisss May 15, 2025
3cbc2f1
added bucketing
schoi-habana May 17, 2025
8c061df
Merge branch 'main' into v1.18-release
regisss May 20, 2025
c55388b
multicard breaks with sync error even when gradient checkpointing is …
schoi-habana May 28, 2025
890ceb6
Merge branch 'main' into schoi/grpo_from_pr1898_gradient_ckpt
schoi-habana May 28, 2025
62a6ed5
multi card works without --gradient_checkpointing
schoi-habana May 28, 2025
15acb06
bump trl version to 0.17.0
schoi-habana May 28, 2025
42218b2
multicard works w/ gradient_checkpointing
schoi-habana May 29, 2025
e397864
Hot fix regional compilation (#2005)
IlyasMoutawwakil Jun 2, 2025
5f2bb76
Enable mixtral 8x7b accuracy evaluation (#1986)
rbogdano Jun 3, 2025
2188aaa
Update readme files for explicit lazy mode (#1921)
jasi306 Jun 3, 2025
167e07a
[llama-vision] Remove token_idx_cpu parameter (#2018)
ugolowic Jun 5, 2025
822f4b2
Update README examples (#2020)
pbielak Jun 9, 2025
d3ef327
Pin latest optimum to force mutual updates (#2016)
IlyasMoutawwakil Jun 6, 2025
c0856d5
Fix FP8 support and address related issues (#2010)
IlyasMoutawwakil Jun 10, 2025
e72327d
trl==0.17.0 working version for trl example 6/11
schoi-habana Jun 11, 2025
ea00dc2
Release: v1.18.0
IlyasMoutawwakil Jun 12, 2025
47ae40b
cleaned and formatted, upto 4x tested with and without gradient check…
schoi-habana Jun 26, 2025
b14dc6f
Merge tag 'v1.18.0' into schoi/grpo_from_pr1898_gradient_ckpt
schoi-habana Jun 26, 2025
d4231b2
Merge branch 'main' into schoi/grpo_from_pr1898_gradient_ckpt
schoi-habana Jun 26, 2025
a65a9d6
resolve trl version mismatch with other trl trainers in OH
schoi-habana Jun 27, 2025
86dcc6a
incorporating the review
schoi-habana Jul 8, 2025
3c0a6a6
add tests for grpo
schoi-habana Jul 9, 2025
7962be0
update tests in Makefile
schoi-habana Jul 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,15 @@ slow_tests_video_llava_example: test_installs
slow_tests_fsdp: test_installs
python -m pytest tests/test_fsdp_examples.py -v -s --token $(TOKEN)

slow_tests_trl: test_installs
slow_tests_trl_ddpo: test_installs
python -m pip install trl==0.9.6
python -m pip install peft==0.12.0
python -m pytest tests/test_trl.py -v -s -k "test_calculate_loss"

slow_tests_trl_grpo: test_installs
python -m pip install -r examples/trl/requirements_grpo.txt
python -m pytest tests/test_trl.py -v -s -k "GaudiGRPOTrainerTester"

slow_tests_object_segmentation: test_installs
python -m pytest tests/test_object_segmentation.py

Expand Down
64 changes: 64 additions & 0 deletions examples/trl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,74 @@
## Requirements

First, you should install the requirements:

- For **GRPO example**:
```bash
$ pip install -U -r requirements_grpo.txt
```

- For **all other examples**:
```bash
$ pip install -U -r requirements.txt
```

## GRPO Training

Installing DeepSpeed

```sh
pip install git+https://github.com/HabanaAI/[email protected]
```

Running single card training

```sh
PT_HPU_MAX_COMPOUND_OP_SIZE=10 PT_HPU_LAZY_MODE=1 python3 grpo.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name AI-MO/NuminaMath-TIR \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--do_train \
--do_eval \
--use_habana \
--use_lazy_mode \
--bf16 True \
--gradient_accumulation_steps=16 \
--max_prompt_length 512 \
--num_generations 4 \
--max_completion_length 64 \
--use_peft True \
--lora_target_modules q_proj k_proj \
--num_train_epochs 1 \
--save_strategy="epoch"
```


Runnig multi-card training

```sh
PT_HPU_MAX_COMPOUND_OP_SIZE=10 PT_HPU_LAZY_MODE=1 python3 ../gaudi_spawn.py --world_size 8 --use_deepspeed grpo.py \
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
--dataset_name AI-MO/NuminaMath-TIR \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--do_train \
--do_eval \
--use_habana \
--use_lazy_mode \
--bf16 True \
--gradient_accumulation_steps=16 \
--gradient_checkpointing \
--max_prompt_length 512 \
--num_generations 4 \
--max_completion_length 64 \
--use_peft True \
--lora_target_modules q_proj k_proj \
--max_steps=500 \
--logging_steps=10 \
--save_steps=100
```

## Supervised Finetuning

1. The following example is for the supervised Lora finetune with Qwen2 model for conversational format dataset.
Expand Down
210 changes: 210 additions & 0 deletions examples/trl/grpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import contextlib
import io
import logging
import re
from dataclasses import dataclass, field
from typing import List, Optional

import torch
import transformers
from datasets import load_dataset
from math_verify import LatexExtractionConfig, parse, verify
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser
from transformers.integrations.deepspeed import (
is_deepspeed_available,
)
from transformers.trainer_utils import is_main_process

from optimum.habana import GaudiConfig
from optimum.habana.trl import GaudiGRPOConfig, GaudiGRPOTrainer
from optimum.habana.utils import set_seed


logger = logging.getLogger(__name__)
SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"
)


def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
}


ideal_length = 50


def reward_len(completions, **kwargs):
return [-abs(ideal_length - len(completion)) for completion in completions] # penalize response when len!=50


def format_reward(completions, **kwargs):
# Checks if the reasoning process is enclosed within <think> and </think> tags,
# while the final answer is enclosed within <answer> and </answer> tags.
pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
return [1.0 if match else 0.0 for match in matches]


def accuracy_reward(completions, **kwargs):
# Checks if the completion is the same as the ground truth.
solutions = kwargs["solution"]
completion_contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, solution in zip(completion_contents, solutions):
gold_parsed = parse(solution, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
answer_parsed = parse(content, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
if len(gold_parsed) != 0:
try:
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
rewards.append(float(verify(answer_parsed, gold_parsed)))
except Exception:
rewards.append(0.0)
else:
rewards.append(1.0)
return rewards


@dataclass
class ScriptArguments:
model_name_or_path: Optional[str] = field(default="Qwen/Qwen2-0.5B-Instruct", metadata={"help": "the model name"})
dataset_name: Optional[str] = field(default=None, metadata={"help": "the dataset name"})
use_peft: Optional[bool] = field(default=False, metadata={"help": "whether to use peft"})
num_workers: Optional[int] = field(default=1, metadata={"help": "the number of workers"})
subset: Optional[str] = field(default=None, metadata={"help": "the subset to use"})
streaming: Optional[bool] = field(default=False, metadata={"help": "whether to stream the dataset"})
dataset_train_split: str = field(default="train[:5%]", metadata={"help": "Dataset split to use for training."})
dataset_test_split: str = field(default="test[:5%]", metadata={"help": "Dataset split to use for evaluation."})
reward_model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "Reward model id of a pretrained model hosted inside a model repo on huggingface.co or "
"local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`."
},
)

use_flash_attention: Optional[bool] = field(
default=True, metadata={"help": "Whether to use Habana flash attention for fine-tuning."}
)
flash_attention_recompute: Optional[bool] = field(
default=False, metadata={"help": "Whether to enable recompute in Habana flash attention for fine-tuning."}
)
flash_attention_causal_mask: Optional[bool] = field(
default=False, metadata={"help": "Whether to enable causal mask in Habana flash attention for fine-tuning."}
)

# LoraConfig
lora_alpha: Optional[float] = field(default=32, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.1, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
lora_target_modules: List[str] = field(
default_factory=lambda: None,
metadata={"help": "Target modules for the LoRA method."},
)


if __name__ == "__main__":
parser = HfArgumentParser((GaudiGRPOConfig, ScriptArguments))
(training_args, script_args) = parser.parse_args_into_dataclasses()

logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.bf16}"
)
# Set the verbosity to info of the Transformers logger (on main process only):
if is_main_process(training_args.local_rank):
transformers.utils.logging.set_verbosity_info()
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model.
set_seed(training_args.seed)

use_deepspeed = training_args.world_size > 1

if script_args.use_peft:
peft_config = LoraConfig(
r=script_args.lora_r,
lora_alpha=script_args.lora_alpha,
lora_dropout=script_args.lora_dropout,
target_modules=script_args.lora_target_modules,
task_type="CAUSAL_LM",
)
else:
peft_config = None

tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path, trust_remote_code=True)
if training_args.chat_template is not None:
tokenizer.chat_template = training_args.chat_template

train_dataset, test_dataset = load_dataset(
script_args.dataset_name,
data_dir=None if script_args.subset == "None" else script_args.subset,
num_proc=script_args.num_workers if not script_args.streaming else None,
split=[script_args.dataset_train_split, script_args.dataset_test_split],
)

train_dataset = train_dataset.map(make_conversation)
test_dataset = test_dataset.map(make_conversation)
train_dataset = train_dataset.remove_columns(["messages", "problem"])

low_cpu_mem_usage = True
if is_deepspeed_available() and use_deepspeed:
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled

if is_deepspeed_zero3_enabled():
low_cpu_mem_usage = False

model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=low_cpu_mem_usage,
torch_dtype=torch.bfloat16,
)

model.config.use_cache = False
if not script_args.use_flash_attention and (
script_args.flash_attention_recompute or script_args.flash_attention_recompute
):
assert "Need to enable use_flash_attention"
model.generation_config.use_flash_attention = script_args.use_flash_attention
model.generation_config.flash_attention_recompute = script_args.flash_attention_recompute
model.generation_config.flash_attention_causal_mask = script_args.flash_attention_causal_mask

reward_funcs = [format_reward, accuracy_reward]
if script_args.reward_model_name_or_path:
reward_funcs = AutoModelForSequenceClassification.from_pretrained(
script_args.reward_model_name_or_path,
trust_remote_code=True,
)

if getattr(tokenizer, "pad_token", None) is None:
tokenizer.pad_token = tokenizer.eos_token

gaudi_config = GaudiConfig()
gaudi_config.use_fused_adam = True
gaudi_config.use_fused_clip_norm = True

trainer = GaudiGRPOTrainer(
model=model,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
processing_class=tokenizer,
gaudi_config=gaudi_config,
peft_config=peft_config,
)

trainer.train()

print("Done!")
8 changes: 8 additions & 0 deletions examples/trl/requirements_grpo.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
trl == 0.17.0
peft == 0.12.0
datasets
tyro
evaluate
scikit-learn == 1.5.2
accelerate
math_verify
15 changes: 13 additions & 2 deletions optimum/habana/trl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
import importlib.metadata

from packaging import version

from .models.modeling_base import adapt_PreTrainedModelWrapper_to_gaudi
from .models.modeling_sd_base import GaudiDefaultDDPOStableDiffusionPipeline
from .trainer.ddpo_trainer import GaudiDDPOTrainer
from .trainer.dpo_config import GaudiDPOConfig
from .trainer.dpo_trainer import GaudiDPOTrainer
from .trainer.ppo_config import GaudiPPOConfig
from .trainer.ppo_trainer import GaudiPPOTrainer


trl_version = importlib.metadata.version("trl")
if version.parse(trl_version) < version.parse("0.17.0"):
from .trainer.ppo_config import GaudiPPOConfig
from .trainer.ppo_trainer import GaudiPPOTrainer
else:
from .trainer.grpo_config import GaudiGRPOConfig
from .trainer.grpo_trainer import GaudiGRPOTrainer
from .trainer.reward_trainer import GaudiRewardTrainer, RewardDataCollatorWithPadding
from .trainer.sft_config import GaudiSFTConfig
from .trainer.sft_trainer import GaudiSFTTrainer
13 changes: 11 additions & 2 deletions optimum/habana/trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,22 @@

# There is a circular import in the PPOTrainer if we let isort sort these
# isort: on
import importlib.metadata
from packaging import version

from .sft_trainer import GaudiSFTTrainer
from .dpo_trainer import GaudiDPOTrainer
from .ppo_config import GaudiPPOConfig
from .ppo_trainer import GaudiPPOTrainer

from .reward_trainer import GaudiRewardTrainer, RewardDataCollatorWithPadding

from .ddpo_trainer import GaudiDDPOTrainer
from .dpo_config import GaudiDPOConfig
from .sft_config import GaudiSFTConfig

trl_version = importlib.metadata.version("trl")
if version.parse(trl_version) < version.parse("0.17.0"):
from .ppo_config import GaudiPPOConfig
from .ppo_trainer import GaudiPPOTrainer
else:
from .grpo_trainer import GaudiGRPOTrainer
from .grpo_config import GaudiGRPOConfig
Loading
Loading