Skip to content

Commit ab3fa33

Browse files
schoi-habanaregisssjiminhaskaulintelyeonsily
authored andcommitted
Enable trl GRPO trainer (#2088)
Signed-off-by: Wang, Yi A <[email protected]> Signed-off-by: Daniel Socek <[email protected]> Signed-off-by: U. Artie Eoff <[email protected]> Signed-off-by: Vivek Kumar <[email protected]> Signed-off-by: Urszula <[email protected]> Co-authored-by: regisss <[email protected]> Co-authored-by: Jimin Ha <[email protected]> Co-authored-by: Shiv Kaul <[email protected]> Co-authored-by: Yeonsil Yoon <[email protected]> Co-authored-by: Harish Subramony <[email protected]> Co-authored-by: Vidya Galli <[email protected]> Co-authored-by: Iman Gohari <[email protected]> Co-authored-by: Bhargav <[email protected]> Co-authored-by: Chetan Kumar Verma <[email protected]> Co-authored-by: Sheng Yang <[email protected]> Co-authored-by: Yaser Afshar <[email protected]> Co-authored-by: Wang, Yi <[email protected]> Co-authored-by: Akihiro Takahashi <[email protected]> Co-authored-by: Edward Mascarenhas <[email protected]> Co-authored-by: Libin Tang <[email protected]> Co-authored-by: Sayantan Sarkar <[email protected]> Co-authored-by: Daniel Socek <[email protected]> Co-authored-by: Silvia Colabrese <[email protected]> Co-authored-by: Mieszko Dziadowiec <[email protected]> Co-authored-by: Dmitry <[email protected]> Co-authored-by: Urszula Golowicz <[email protected]> Co-authored-by: Nikolay Protasov <[email protected]> Co-authored-by: U. Artie Eoff <[email protected]> Co-authored-by: Luca Calabria <[email protected]> Co-authored-by: Harshvardhan Chauhan <[email protected]> Co-authored-by: Shifani Rajabose <[email protected]> Co-authored-by: Dmitry <[email protected]> Co-authored-by: Mounika Mandava <[email protected]> Co-authored-by: ZhengHongming888 <[email protected]> Co-authored-by: Alexey Fadeev <[email protected]> Co-authored-by: Vivek Kumar <[email protected]> Co-authored-by: Vivek Kumar <[email protected]> Co-authored-by: Ilyas Moutawwakil <[email protected]> Co-authored-by: Rafal Bogdanowicz <[email protected]> Co-authored-by: Rafal <[email protected]> Co-authored-by: Jan Kamiński <[email protected]> Co-authored-by: Karol Brejna <[email protected]> Co-authored-by: Piotr Bielak <[email protected]> Co-authored-by: Piotr Bielak <[email protected]> Co-authored-by: karol-brejna-i <[email protected]> Co-authored-by: IlyasMoutawwakil <[email protected]>
1 parent 31f428b commit ab3fa33

File tree

12 files changed

+1670
-20
lines changed

12 files changed

+1670
-20
lines changed

Makefile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,15 @@ slow_tests_video_llava_example: test_installs
162162
slow_tests_fsdp: test_installs
163163
python -m pytest tests/test_fsdp_examples.py -v -s --token $(TOKEN)
164164

165-
slow_tests_trl: test_installs
165+
slow_tests_trl_ddpo: test_installs
166166
python -m pip install trl==0.9.6
167167
python -m pip install peft==0.15.0
168168
python -m pytest tests/test_trl.py -v -s -k "test_calculate_loss"
169169

170+
slow_tests_trl_grpo: test_installs
171+
python -m pip install -r examples/trl/requirements_grpo.txt
172+
python -m pytest tests/test_trl.py -v -s -k "GaudiGRPOTrainerTester"
173+
170174
slow_tests_object_segmentation: test_installs
171175
python -m pytest tests/test_object_segmentation.py
172176

examples/trl/README.md

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,74 @@
44
## Requirements
55

66
First, you should install the requirements:
7+
8+
- For **GRPO example**:
9+
```bash
10+
$ pip install -U -r requirements_grpo.txt
11+
```
12+
13+
- For **all other examples**:
714
```bash
815
$ pip install -U -r requirements.txt
916
```
1017

18+
## GRPO Training
19+
20+
Installing DeepSpeed
21+
22+
```sh
23+
pip install git+https://github.com/HabanaAI/[email protected]
24+
```
25+
26+
Running single card training
27+
28+
```sh
29+
PT_HPU_MAX_COMPOUND_OP_SIZE=10 PT_HPU_LAZY_MODE=1 python3 grpo.py \
30+
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
31+
--dataset_name AI-MO/NuminaMath-TIR \
32+
--per_device_train_batch_size 8 \
33+
--per_device_eval_batch_size 8 \
34+
--do_train \
35+
--do_eval \
36+
--use_habana \
37+
--use_lazy_mode \
38+
--bf16 True \
39+
--gradient_accumulation_steps=16 \
40+
--max_prompt_length 512 \
41+
--num_generations 4 \
42+
--max_completion_length 64 \
43+
--use_peft True \
44+
--lora_target_modules q_proj k_proj \
45+
--num_train_epochs 1 \
46+
--save_strategy="epoch"
47+
```
48+
49+
50+
Runnig multi-card training
51+
52+
```sh
53+
PT_HPU_MAX_COMPOUND_OP_SIZE=10 PT_HPU_LAZY_MODE=1 python3 ../gaudi_spawn.py --world_size 8 --use_deepspeed grpo.py \
54+
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
55+
--dataset_name AI-MO/NuminaMath-TIR \
56+
--per_device_train_batch_size 8 \
57+
--per_device_eval_batch_size 8 \
58+
--do_train \
59+
--do_eval \
60+
--use_habana \
61+
--use_lazy_mode \
62+
--bf16 True \
63+
--gradient_accumulation_steps=16 \
64+
--gradient_checkpointing \
65+
--max_prompt_length 512 \
66+
--num_generations 4 \
67+
--max_completion_length 64 \
68+
--use_peft True \
69+
--lora_target_modules q_proj k_proj \
70+
--max_steps=500 \
71+
--logging_steps=10 \
72+
--save_steps=100
73+
```
74+
1175
## Supervised Finetuning
1276

1377
1. The following example is for the supervised Lora finetune with Qwen2 model for conversational format dataset.

examples/trl/grpo.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import contextlib
2+
import io
3+
import logging
4+
import re
5+
from dataclasses import dataclass, field
6+
from typing import List, Optional
7+
8+
import torch
9+
import transformers
10+
from datasets import load_dataset
11+
from math_verify import LatexExtractionConfig, parse, verify
12+
from peft import LoraConfig
13+
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser
14+
from transformers.integrations.deepspeed import (
15+
is_deepspeed_available,
16+
)
17+
from transformers.trainer_utils import is_main_process
18+
19+
from optimum.habana import GaudiConfig
20+
from optimum.habana.trl import GaudiGRPOConfig, GaudiGRPOTrainer
21+
from optimum.habana.utils import set_seed
22+
23+
24+
logger = logging.getLogger(__name__)
25+
SYSTEM_PROMPT = (
26+
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
27+
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
28+
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
29+
"<think> reasoning process here </think><answer> answer here </answer>"
30+
)
31+
32+
33+
def make_conversation(example):
34+
return {
35+
"prompt": [
36+
{"role": "system", "content": SYSTEM_PROMPT},
37+
{"role": "user", "content": example["problem"]},
38+
],
39+
}
40+
41+
42+
ideal_length = 50
43+
44+
45+
def reward_len(completions, **kwargs):
46+
return [-abs(ideal_length - len(completion)) for completion in completions] # penalize response when len!=50
47+
48+
49+
def format_reward(completions, **kwargs):
50+
# Checks if the reasoning process is enclosed within <think> and </think> tags,
51+
# while the final answer is enclosed within <answer> and </answer> tags.
52+
pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
53+
completion_contents = [completion[0]["content"] for completion in completions]
54+
matches = [re.match(pattern, content) for content in completion_contents]
55+
return [1.0 if match else 0.0 for match in matches]
56+
57+
58+
def accuracy_reward(completions, **kwargs):
59+
# Checks if the completion is the same as the ground truth.
60+
solutions = kwargs["solution"]
61+
completion_contents = [completion[0]["content"] for completion in completions]
62+
rewards = []
63+
for content, solution in zip(completion_contents, solutions):
64+
gold_parsed = parse(solution, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
65+
answer_parsed = parse(content, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()])
66+
if len(gold_parsed) != 0:
67+
try:
68+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()):
69+
rewards.append(float(verify(answer_parsed, gold_parsed)))
70+
except Exception:
71+
rewards.append(0.0)
72+
else:
73+
rewards.append(1.0)
74+
return rewards
75+
76+
77+
@dataclass
78+
class ScriptArguments:
79+
model_name_or_path: Optional[str] = field(default="Qwen/Qwen2-0.5B-Instruct", metadata={"help": "the model name"})
80+
dataset_name: Optional[str] = field(default=None, metadata={"help": "the dataset name"})
81+
use_peft: Optional[bool] = field(default=False, metadata={"help": "whether to use peft"})
82+
num_workers: Optional[int] = field(default=1, metadata={"help": "the number of workers"})
83+
subset: Optional[str] = field(default=None, metadata={"help": "the subset to use"})
84+
streaming: Optional[bool] = field(default=False, metadata={"help": "whether to stream the dataset"})
85+
dataset_train_split: str = field(default="train[:5%]", metadata={"help": "Dataset split to use for training."})
86+
dataset_test_split: str = field(default="test[:5%]", metadata={"help": "Dataset split to use for evaluation."})
87+
reward_model_name_or_path: Optional[str] = field(
88+
default=None,
89+
metadata={
90+
"help": "Reward model id of a pretrained model hosted inside a model repo on huggingface.co or "
91+
"local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`."
92+
},
93+
)
94+
95+
use_flash_attention: Optional[bool] = field(
96+
default=True, metadata={"help": "Whether to use Habana flash attention for fine-tuning."}
97+
)
98+
flash_attention_recompute: Optional[bool] = field(
99+
default=False, metadata={"help": "Whether to enable recompute in Habana flash attention for fine-tuning."}
100+
)
101+
flash_attention_causal_mask: Optional[bool] = field(
102+
default=False, metadata={"help": "Whether to enable causal mask in Habana flash attention for fine-tuning."}
103+
)
104+
105+
# LoraConfig
106+
lora_alpha: Optional[float] = field(default=32, metadata={"help": "the lora alpha parameter"})
107+
lora_dropout: Optional[float] = field(default=0.1, metadata={"help": "the lora dropout parameter"})
108+
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
109+
lora_target_modules: List[str] = field(
110+
default_factory=lambda: None,
111+
metadata={"help": "Target modules for the LoRA method."},
112+
)
113+
114+
115+
if __name__ == "__main__":
116+
parser = HfArgumentParser((GaudiGRPOConfig, ScriptArguments))
117+
(training_args, script_args) = parser.parse_args_into_dataclasses()
118+
119+
logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
120+
logger.warning(
121+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
122+
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.bf16}"
123+
)
124+
# Set the verbosity to info of the Transformers logger (on main process only):
125+
if is_main_process(training_args.local_rank):
126+
transformers.utils.logging.set_verbosity_info()
127+
transformers.utils.logging.enable_default_handler()
128+
transformers.utils.logging.enable_explicit_format()
129+
logger.info(f"Training/evaluation parameters {training_args}")
130+
# Set seed before initializing model.
131+
set_seed(training_args.seed)
132+
133+
use_deepspeed = training_args.world_size > 1
134+
135+
if script_args.use_peft:
136+
peft_config = LoraConfig(
137+
r=script_args.lora_r,
138+
lora_alpha=script_args.lora_alpha,
139+
lora_dropout=script_args.lora_dropout,
140+
target_modules=script_args.lora_target_modules,
141+
task_type="CAUSAL_LM",
142+
)
143+
else:
144+
peft_config = None
145+
146+
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path, trust_remote_code=True)
147+
if training_args.chat_template is not None:
148+
tokenizer.chat_template = training_args.chat_template
149+
150+
train_dataset, test_dataset = load_dataset(
151+
script_args.dataset_name,
152+
data_dir=None if script_args.subset == "None" else script_args.subset,
153+
num_proc=script_args.num_workers if not script_args.streaming else None,
154+
split=[script_args.dataset_train_split, script_args.dataset_test_split],
155+
)
156+
157+
train_dataset = train_dataset.map(make_conversation)
158+
test_dataset = test_dataset.map(make_conversation)
159+
train_dataset = train_dataset.remove_columns(["messages", "problem"])
160+
161+
low_cpu_mem_usage = True
162+
if is_deepspeed_available() and use_deepspeed:
163+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
164+
165+
if is_deepspeed_zero3_enabled():
166+
low_cpu_mem_usage = False
167+
168+
model = AutoModelForCausalLM.from_pretrained(
169+
script_args.model_name_or_path,
170+
low_cpu_mem_usage=low_cpu_mem_usage,
171+
torch_dtype=torch.bfloat16,
172+
)
173+
174+
model.config.use_cache = False
175+
if not script_args.use_flash_attention and (
176+
script_args.flash_attention_recompute or script_args.flash_attention_recompute
177+
):
178+
assert "Need to enable use_flash_attention"
179+
model.generation_config.use_flash_attention = script_args.use_flash_attention
180+
model.generation_config.flash_attention_recompute = script_args.flash_attention_recompute
181+
model.generation_config.flash_attention_causal_mask = script_args.flash_attention_causal_mask
182+
183+
reward_funcs = [format_reward, accuracy_reward]
184+
if script_args.reward_model_name_or_path:
185+
reward_funcs = AutoModelForSequenceClassification.from_pretrained(
186+
script_args.reward_model_name_or_path,
187+
trust_remote_code=True,
188+
)
189+
190+
if getattr(tokenizer, "pad_token", None) is None:
191+
tokenizer.pad_token = tokenizer.eos_token
192+
193+
gaudi_config = GaudiConfig()
194+
gaudi_config.use_fused_adam = True
195+
gaudi_config.use_fused_clip_norm = True
196+
197+
trainer = GaudiGRPOTrainer(
198+
model=model,
199+
reward_funcs=reward_funcs,
200+
args=training_args,
201+
train_dataset=train_dataset,
202+
eval_dataset=test_dataset,
203+
processing_class=tokenizer,
204+
gaudi_config=gaudi_config,
205+
peft_config=peft_config,
206+
)
207+
208+
trainer.train()
209+
210+
print("Done!")

examples/trl/requirements_grpo.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
trl == 0.17.0
2+
peft == 0.12.0
3+
datasets
4+
tyro
5+
evaluate
6+
scikit-learn == 1.5.2
7+
accelerate
8+
math_verify

optimum/habana/trl/__init__.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
1+
import importlib.metadata
2+
3+
from packaging import version
4+
15
from .models.modeling_base import adapt_PreTrainedModelWrapper_to_gaudi
26
from .models.modeling_sd_base import GaudiDefaultDDPOStableDiffusionPipeline
37
from .trainer.ddpo_trainer import GaudiDDPOTrainer
48
from .trainer.dpo_config import GaudiDPOConfig
59
from .trainer.dpo_trainer import GaudiDPOTrainer
6-
from .trainer.ppo_config import GaudiPPOConfig
7-
from .trainer.ppo_trainer import GaudiPPOTrainer
10+
11+
12+
trl_version = importlib.metadata.version("trl")
13+
if version.parse(trl_version) < version.parse("0.17.0"):
14+
from .trainer.ppo_config import GaudiPPOConfig
15+
from .trainer.ppo_trainer import GaudiPPOTrainer
16+
else:
17+
from .trainer.grpo_config import GaudiGRPOConfig
18+
from .trainer.grpo_trainer import GaudiGRPOTrainer
819
from .trainer.reward_trainer import GaudiRewardTrainer, RewardDataCollatorWithPadding
920
from .trainer.sft_config import GaudiSFTConfig
1021
from .trainer.sft_trainer import GaudiSFTTrainer

optimum/habana/trl/trainer/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,22 @@
1616

1717
# There is a circular import in the PPOTrainer if we let isort sort these
1818
# isort: on
19+
import importlib.metadata
20+
from packaging import version
1921

2022
from .sft_trainer import GaudiSFTTrainer
2123
from .dpo_trainer import GaudiDPOTrainer
22-
from .ppo_config import GaudiPPOConfig
23-
from .ppo_trainer import GaudiPPOTrainer
24+
2425
from .reward_trainer import GaudiRewardTrainer, RewardDataCollatorWithPadding
2526

2627
from .ddpo_trainer import GaudiDDPOTrainer
2728
from .dpo_config import GaudiDPOConfig
2829
from .sft_config import GaudiSFTConfig
30+
31+
trl_version = importlib.metadata.version("trl")
32+
if version.parse(trl_version) < version.parse("0.17.0"):
33+
from .ppo_config import GaudiPPOConfig
34+
from .ppo_trainer import GaudiPPOTrainer
35+
else:
36+
from .grpo_trainer import GaudiGRPOTrainer
37+
from .grpo_config import GaudiGRPOConfig

0 commit comments

Comments
 (0)