-
Notifications
You must be signed in to change notification settings - Fork 267
Enable trl GRPO trainer #2088
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Enable trl GRPO trainer #2088
Changes from all commits
Commits
Show all changes
164 commits
Select commit
Hold shift + click to select a range
5e88ce7
Upgrade to commit 74e19e81e2a23809af192532b9b0e7ea202be6f2
regisss 43bc4eb
Merge branch 'main' into transformers_future
regisss 8eea643
Add specific commit in setup.py
regisss a7be363
Upgrade to commit e48e5f1f13e05380e24f4f31f5fee07aa6f959eb
regisss f0b909a
Merge branch 'main' into transformers_future
regisss d99f18f
Fix default cache
regisss 39b7a76
Merge branch 'main' into transformers_future
regisss da66ecf
Merge branch 'main' into transformers_future
regisss 5547767
Merge branch 'main' into transformers_future
regisss bf89e41
Merge branch 'main' into transformers_future
regisss 98b0da5
Merge branch 'main' into transformers_future
regisss 47ad03c
Upgrade to commit 238b13478df209ab534f2195a397dc64a3930883
regisss 94c23ba
Fix
regisss c19dedd
Upgrade to v4.45.0
regisss c12fd7e
Merge branch 'main' into transformers_future
regisss fc399fa
Fix
regisss 9216159
Add bias to gptj (#1363)
jiminha 679365a
Switch roberta from sdpa to eager attn (#1361)
skaulintel 1abd6ee
Update bloom attention forward reshape follwing the transformer chang…
yeonsily 8043d2c
Workaround for Llava/Llava-next
regisss 047e7ff
Fix reshape error in mamba (#1369)
hsubramony f89e03b
Merge branch 'main' into transformers_future
regisss 2ae546a
Merge branch 'main' into transformers_future
regisss 1b8a3f7
Fix contrastive search
regisss 2332afb
Fix local variable 'image_features' referenced before assignment (#1383)
vidyasiv f62ecde
Use model.generation_config instead of model.config (#1384)
hsubramony a8fb8ac
Make style
regisss c7ff331
Merge branch 'main' into transformers_future
regisss dd07c16
Upgrade to Transformers v4.47.1
regisss 1924c89
Fix Transformers version to install
regisss f0926ae
Temporary workaround for GaudiTrainer
regisss e50e179
Fixes for text generation
regisss c804270
Set eager attention for distilbert, gpt_neox
regisss 0000de5
Upgrade to Transformers v4.48
regisss 2a3affa
Small fixes
regisss 064f4c1
Fix integration tests
regisss 21714f7
Fixes for text-generation
regisss 1cfd53b
Fixes
regisss 573cc57
Style
regisss a7bc517
Again
regisss 3d402ea
Merge branch 'main' into transformers_future
regisss f69e957
Fix for image2text lora llama test (#1731)
vidyasiv 265e6a1
Cherry-pick https://github.com/huggingface/transformers/pull/35651
regisss 6dad1c4
Merge branch 'main' into transformers_future
regisss 32478f5
Upgrade to Transformers v4.48.2
regisss 1b79cf3
Fix deprecated imports following merged changes for DETR and Qwen2-VL
regisss c1f30d8
Workaround for textual inversion
regisss 1460856
Merge branch 'main' into transformers_future
regisss 7eadac6
Fixes for v4.48 pytest (#1699)
imangohari1 5cee218
fea(): Applied changes in HF #35235 (#1738)
imangohari1 417cbee
Merge branch 'main' into transformers_future
regisss c1e3232
Merge branch 'main' into transformers_future
regisss 0f68bbb
Merge branch 'main' into transformers_future
regisss 17943de
Removing HL_DS_DISTRIBUTED_ATTENTION_SEQ_DIM as it's not needed from …
bhargaveede d214819
Update DS config to align with recommended settings (#1730)
ckvermaAI 6a520ff
Fix graph breaks in Mixtral (#65) (#1705)
Solaryee 58de6b6
Merge branch 'main' into synapse_1_20
regisss bedc041
Add batch dim idx to support latest deepspeed DistributedAttention (…
bhargaveede ce57e40
Add _prepare_inputs_for_generation (#1743)
yafshar ef77fac
Merge branch 'main' into synapse_1_20
regisss be34027
Upgrade to v4.48.3
regisss bd9a60e
Fix the issue with --load_quantized_model_with_autoawq (#1747)
schoi-habana fc6a92b
Merge branch 'main' into synapse_1_20
regisss 01bb4af
Merge branch 'synapse_1_20' into transformers_future
regisss 2f665e8
Fix dpo crash in transformers 4.48 (#1750)
sywangyi 595b816
Fix for Falcon image-to-text crash (#1760)
schoi-habana f3729a4
Fix llama attr (#1771)
atakaha bcb0778
Update llama scaling (#1775)
atakaha bd87113
Merge branch 'main' into synapse_1_20
regisss a13b5d2
Merge branch 'synapse_1_20' into transformers_future
regisss ce1bf08
Merge branch 'main' into synapse_1_20
regisss 9b8bb2e
Merge branch 'synapse_1_20' into transformers_future
regisss d053218
Fix loss calculation (Workaround), final fix TBD (#1784)
emascarenhas 8044aa4
Merge branch 'main' into synapse_1_20
regisss fe01ca2
Merge branch 'synapse_1_20' into transformers_future
regisss 8b006c4
Simplify text-gen readme (#1780)
libinta 0eb5d79
Merge branch 'main' into synapse_1_20
regisss a03f1d0
Merge branch 'synapse_1_20' into transformers_future
regisss 6772b4f
Diffusers: Simplified the README files. Updated CI tests. (#1718)
imangohari1 06644af
Merge branch 'main' into synapse_1_20
regisss 244b19e
Merge branch 'synapse_1_20' into transformers_future
regisss 9279ab1
Merge branch 'main' into synapse_1_20
regisss 5c7fea2
Merge branch 'synapse_1_20' into transformers_future
regisss fe65b05
Switch version number
regisss 523370d
Merge branch 'main' into synapse_1_20
regisss 3dfceb9
Merge branch 'main' into synapse_1_20
regisss 836961e
Merge branch 'main' into synapse_1_20
regisss ffda2a0
Temporary WA for get_type error (#1806)
12010486 2688527
Merge branch 'main' into transformers_future
regisss 167a218
Loss Computation for Compatibility with Transformers 4.48.3 (#1794)
yafshar 379524c
Move model to device before wrapping with FSDP (#1801)
mieshkiwrk 3197dd8
Merge branch 'main' into synapse_1_20
regisss 46bad3b
v1.16 Llama3-405B text-generation. Added DEEPSPEED_USE_HABANA_FRAMEWO…
dsmertin 0078227
Make style
regisss 9b7ca11
Merge branch 'synapse_1_20' into transformers_future
regisss 6d575e8
Merge branch 'main' into synapse_1_20
regisss 73be6a2
Merge branch 'synapse_1_20' into transformers_future
regisss 81f33ed
Revert placing llama on cpu (#1827)
ugolowic b46ed25
Merge branch 'synapse_1_20' into transformers_future
regisss 73dd3ed
Merge branch 'main' into synapse_1_20
regisss 4527647
Merge branch 'synapse_1_20' into transformers_future
regisss d0b54b8
Merge branch 'main' into transformers_future
regisss 195fdf8
Fix contrastive search
regisss 38f59eb
Merge branch 'main' into transformers_future
regisss b6602f7
Merge branch 'main' into transformers_future
regisss 2f98dec
Merge branch 'main' into transformers_future
regisss 45dc3aa
Merge branch 'main' into transformers_future
regisss a22b821
Upgrade to Transformers v4.49 (#1810)
regisss dd42c92
Fix `get_num_items_in_batches` for iterable datasets and when resumin…
regisss 69f7e6d
Fixes pytest runtime error - Incompatible input shapes, broadcast not…
srajabos 50d1f2e
Merge branch 'main' into transformers_future
regisss d0d0172
Fix for AutoModelForCausalLM.from_pretrained() (#1844)
dsmertin adbaa23
Fix unexpected 'num_items_in_batch' argument in GPT-NeoX forward (#1850)
mounikamandava e802f5f
Make style
regisss f461199
Fix for `GaudiLlamaAttention` object has no attribute 'max_position_e…
12010486 9cf57be
Fix error with TRL examples
regisss b780d70
[skip ci] Merge branch 'main' into transformers_future
regisss dbd987b
Adjust precision of eval_accuracy to avoid random failure in pytest f…
hchauhan123 78e50b9
Missing num_key_value_heads attribute in GaudiGemmaAttention (#1861)
hsubramony bff3803
Update Sentence Transformer CI/Ref (#1862)
ZhengHongming888 5d2fbde
Fix typo in modeling llama (#1864)
hsubramony 0ec8b04
fea(): Added the updated skip list for mistral/mixtral tests (#1863)
imangohari1 639f96d
Fix llama internal bucketing issue (#1871)
dsocek f3124e7
Fix regression for test_run_image2text_lora_finetune_idefics2-8b_mult…
srajabos bbada81
Revert "Move model to device before wrapping with FSDP (#1801)" (#1865)
12010486 0732389
added GRPO Trainer and config / resolved import issues
alekseyfa 2366a00
Resolved import issues
alekseyfa ee669dc
Updated requirements
alekseyfa 4126553
GRPO simple training script
alekseyfa 45fb347
Updated README
alekseyfa b6af175
Updated data collator
alekseyfa 247c590
updated sample
alekseyfa 942bd01
Updated README
alekseyfa 6c78be8
Added LORA config
alekseyfa 83cd501
Checking pad_token
alekseyfa 35ac2c6
enable flash attn and pad inputs to the max seq len
schoi-habana 3df10a5
Merge branch 'main' into schoi/grpo_from_pr1898
schoi-habana 46fc724
README changes for Llama3.1 8B Finetuning with LoRA (#1947)
bhargaveede 493bfd2
pt2e quant changes into the main script (#191) (#1875)
vivek5-ai cf889b5
working, convergence not sure
schoi-habana 736b1df
Merge branch 'main' into v1.18-release
regisss 3cbc2f1
added bucketing
schoi-habana 8c061df
Merge branch 'main' into v1.18-release
regisss c55388b
multicard breaks with sync error even when gradient checkpointing is …
schoi-habana 890ceb6
Merge branch 'main' into schoi/grpo_from_pr1898_gradient_ckpt
schoi-habana 62a6ed5
multi card works without --gradient_checkpointing
schoi-habana 15acb06
bump trl version to 0.17.0
schoi-habana 42218b2
multicard works w/ gradient_checkpointing
schoi-habana e397864
Hot fix regional compilation (#2005)
IlyasMoutawwakil 5f2bb76
Enable mixtral 8x7b accuracy evaluation (#1986)
rbogdano 2188aaa
Update readme files for explicit lazy mode (#1921)
jasi306 167e07a
[llama-vision] Remove token_idx_cpu parameter (#2018)
ugolowic 822f4b2
Update README examples (#2020)
pbielak d3ef327
Pin latest optimum to force mutual updates (#2016)
IlyasMoutawwakil c0856d5
Fix FP8 support and address related issues (#2010)
IlyasMoutawwakil e72327d
trl==0.17.0 working version for trl example 6/11
schoi-habana ea00dc2
Release: v1.18.0
IlyasMoutawwakil 47ae40b
cleaned and formatted, upto 4x tested with and without gradient check…
schoi-habana b14dc6f
Merge tag 'v1.18.0' into schoi/grpo_from_pr1898_gradient_ckpt
schoi-habana d4231b2
Merge branch 'main' into schoi/grpo_from_pr1898_gradient_ckpt
schoi-habana a65a9d6
resolve trl version mismatch with other trl trainers in OH
schoi-habana 86dcc6a
incorporating the review
schoi-habana 3c0a6a6
add tests for grpo
schoi-habana 7962be0
update tests in Makefile
schoi-habana File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,10 +4,74 @@ | |
| ## Requirements | ||
|
|
||
| First, you should install the requirements: | ||
|
|
||
| - For **GRPO example**: | ||
| ```bash | ||
| $ pip install -U -r requirements_grpo.txt | ||
| ``` | ||
|
|
||
| - For **all other examples**: | ||
| ```bash | ||
| $ pip install -U -r requirements.txt | ||
| ``` | ||
|
|
||
| ## GRPO Training | ||
|
|
||
| Installing DeepSpeed | ||
|
|
||
| ```sh | ||
| pip install git+https://github.com/HabanaAI/[email protected] | ||
| ``` | ||
|
|
||
| Running single card training | ||
|
|
||
| ```sh | ||
| PT_HPU_MAX_COMPOUND_OP_SIZE=10 PT_HPU_LAZY_MODE=1 python3 grpo.py \ | ||
| --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ | ||
| --dataset_name AI-MO/NuminaMath-TIR \ | ||
| --per_device_train_batch_size 8 \ | ||
| --per_device_eval_batch_size 8 \ | ||
| --do_train \ | ||
| --do_eval \ | ||
| --use_habana \ | ||
| --use_lazy_mode \ | ||
| --bf16 True \ | ||
| --gradient_accumulation_steps=16 \ | ||
| --max_prompt_length 512 \ | ||
| --num_generations 4 \ | ||
| --max_completion_length 64 \ | ||
| --use_peft True \ | ||
| --lora_target_modules q_proj k_proj \ | ||
| --num_train_epochs 1 \ | ||
| --save_strategy="epoch" | ||
| ``` | ||
|
|
||
|
|
||
| Runnig multi-card training | ||
|
|
||
| ```sh | ||
| PT_HPU_MAX_COMPOUND_OP_SIZE=10 PT_HPU_LAZY_MODE=1 python3 ../gaudi_spawn.py --world_size 8 --use_deepspeed grpo.py \ | ||
| --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ | ||
| --dataset_name AI-MO/NuminaMath-TIR \ | ||
| --per_device_train_batch_size 8 \ | ||
| --per_device_eval_batch_size 8 \ | ||
| --do_train \ | ||
| --do_eval \ | ||
| --use_habana \ | ||
| --use_lazy_mode \ | ||
| --bf16 True \ | ||
| --gradient_accumulation_steps=16 \ | ||
| --gradient_checkpointing \ | ||
| --max_prompt_length 512 \ | ||
| --num_generations 4 \ | ||
| --max_completion_length 64 \ | ||
| --use_peft True \ | ||
| --lora_target_modules q_proj k_proj \ | ||
| --max_steps=500 \ | ||
| --logging_steps=10 \ | ||
| --save_steps=100 | ||
| ``` | ||
|
|
||
| ## Supervised Finetuning | ||
|
|
||
| 1. The following example is for the supervised Lora finetune with Qwen2 model for conversational format dataset. | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,210 @@ | ||
| import contextlib | ||
| import io | ||
| import logging | ||
| import re | ||
| from dataclasses import dataclass, field | ||
| from typing import List, Optional | ||
|
|
||
| import torch | ||
| import transformers | ||
| from datasets import load_dataset | ||
| from math_verify import LatexExtractionConfig, parse, verify | ||
| from peft import LoraConfig | ||
| from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser | ||
| from transformers.integrations.deepspeed import ( | ||
| is_deepspeed_available, | ||
| ) | ||
| from transformers.trainer_utils import is_main_process | ||
|
|
||
| from optimum.habana import GaudiConfig | ||
| from optimum.habana.trl import GaudiGRPOConfig, GaudiGRPOTrainer | ||
| from optimum.habana.utils import set_seed | ||
|
|
||
|
|
||
| logger = logging.getLogger(__name__) | ||
| SYSTEM_PROMPT = ( | ||
| "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " | ||
| "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " | ||
| "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., " | ||
| "<think> reasoning process here </think><answer> answer here </answer>" | ||
| ) | ||
|
|
||
|
|
||
| def make_conversation(example): | ||
| return { | ||
| "prompt": [ | ||
| {"role": "system", "content": SYSTEM_PROMPT}, | ||
| {"role": "user", "content": example["problem"]}, | ||
| ], | ||
| } | ||
|
|
||
|
|
||
| ideal_length = 50 | ||
|
|
||
|
|
||
| def reward_len(completions, **kwargs): | ||
| return [-abs(ideal_length - len(completion)) for completion in completions] # penalize response when len!=50 | ||
|
|
||
|
|
||
| def format_reward(completions, **kwargs): | ||
| # Checks if the reasoning process is enclosed within <think> and </think> tags, | ||
| # while the final answer is enclosed within <answer> and </answer> tags. | ||
| pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$" | ||
| completion_contents = [completion[0]["content"] for completion in completions] | ||
| matches = [re.match(pattern, content) for content in completion_contents] | ||
| return [1.0 if match else 0.0 for match in matches] | ||
|
|
||
|
|
||
| def accuracy_reward(completions, **kwargs): | ||
| # Checks if the completion is the same as the ground truth. | ||
| solutions = kwargs["solution"] | ||
| completion_contents = [completion[0]["content"] for completion in completions] | ||
| rewards = [] | ||
| for content, solution in zip(completion_contents, solutions): | ||
| gold_parsed = parse(solution, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()]) | ||
| answer_parsed = parse(content, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()]) | ||
| if len(gold_parsed) != 0: | ||
| try: | ||
| with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()): | ||
| rewards.append(float(verify(answer_parsed, gold_parsed))) | ||
| except Exception: | ||
| rewards.append(0.0) | ||
| else: | ||
| rewards.append(1.0) | ||
| return rewards | ||
|
|
||
|
|
||
| @dataclass | ||
| class ScriptArguments: | ||
| model_name_or_path: Optional[str] = field(default="Qwen/Qwen2-0.5B-Instruct", metadata={"help": "the model name"}) | ||
| dataset_name: Optional[str] = field(default=None, metadata={"help": "the dataset name"}) | ||
| use_peft: Optional[bool] = field(default=False, metadata={"help": "whether to use peft"}) | ||
| num_workers: Optional[int] = field(default=1, metadata={"help": "the number of workers"}) | ||
| subset: Optional[str] = field(default=None, metadata={"help": "the subset to use"}) | ||
| streaming: Optional[bool] = field(default=False, metadata={"help": "whether to stream the dataset"}) | ||
| dataset_train_split: str = field(default="train[:5%]", metadata={"help": "Dataset split to use for training."}) | ||
| dataset_test_split: str = field(default="test[:5%]", metadata={"help": "Dataset split to use for evaluation."}) | ||
| reward_model_name_or_path: Optional[str] = field( | ||
| default=None, | ||
| metadata={ | ||
| "help": "Reward model id of a pretrained model hosted inside a model repo on huggingface.co or " | ||
| "local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`." | ||
| }, | ||
| ) | ||
|
|
||
| use_flash_attention: Optional[bool] = field( | ||
| default=True, metadata={"help": "Whether to use Habana flash attention for fine-tuning."} | ||
| ) | ||
| flash_attention_recompute: Optional[bool] = field( | ||
| default=False, metadata={"help": "Whether to enable recompute in Habana flash attention for fine-tuning."} | ||
| ) | ||
| flash_attention_causal_mask: Optional[bool] = field( | ||
| default=False, metadata={"help": "Whether to enable causal mask in Habana flash attention for fine-tuning."} | ||
| ) | ||
|
|
||
| # LoraConfig | ||
| lora_alpha: Optional[float] = field(default=32, metadata={"help": "the lora alpha parameter"}) | ||
| lora_dropout: Optional[float] = field(default=0.1, metadata={"help": "the lora dropout parameter"}) | ||
| lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) | ||
| lora_target_modules: List[str] = field( | ||
| default_factory=lambda: None, | ||
| metadata={"help": "Target modules for the LoRA method."}, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| parser = HfArgumentParser((GaudiGRPOConfig, ScriptArguments)) | ||
| (training_args, script_args) = parser.parse_args_into_dataclasses() | ||
|
|
||
| logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) | ||
| logger.warning( | ||
| f"Process rank: {training_args.local_rank}, device: {training_args.device}, " | ||
| + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.bf16}" | ||
| ) | ||
| # Set the verbosity to info of the Transformers logger (on main process only): | ||
| if is_main_process(training_args.local_rank): | ||
| transformers.utils.logging.set_verbosity_info() | ||
| transformers.utils.logging.enable_default_handler() | ||
| transformers.utils.logging.enable_explicit_format() | ||
| logger.info(f"Training/evaluation parameters {training_args}") | ||
| # Set seed before initializing model. | ||
| set_seed(training_args.seed) | ||
|
|
||
| use_deepspeed = training_args.world_size > 1 | ||
|
|
||
| if script_args.use_peft: | ||
| peft_config = LoraConfig( | ||
| r=script_args.lora_r, | ||
| lora_alpha=script_args.lora_alpha, | ||
| lora_dropout=script_args.lora_dropout, | ||
| target_modules=script_args.lora_target_modules, | ||
| task_type="CAUSAL_LM", | ||
| ) | ||
| else: | ||
| peft_config = None | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path, trust_remote_code=True) | ||
| if training_args.chat_template is not None: | ||
| tokenizer.chat_template = training_args.chat_template | ||
|
|
||
| train_dataset, test_dataset = load_dataset( | ||
| script_args.dataset_name, | ||
| data_dir=None if script_args.subset == "None" else script_args.subset, | ||
| num_proc=script_args.num_workers if not script_args.streaming else None, | ||
| split=[script_args.dataset_train_split, script_args.dataset_test_split], | ||
| ) | ||
|
|
||
| train_dataset = train_dataset.map(make_conversation) | ||
| test_dataset = test_dataset.map(make_conversation) | ||
| train_dataset = train_dataset.remove_columns(["messages", "problem"]) | ||
|
|
||
| low_cpu_mem_usage = True | ||
| if is_deepspeed_available() and use_deepspeed: | ||
| from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled | ||
|
|
||
| if is_deepspeed_zero3_enabled(): | ||
| low_cpu_mem_usage = False | ||
|
|
||
| model = AutoModelForCausalLM.from_pretrained( | ||
| script_args.model_name_or_path, | ||
| low_cpu_mem_usage=low_cpu_mem_usage, | ||
| torch_dtype=torch.bfloat16, | ||
| ) | ||
|
|
||
| model.config.use_cache = False | ||
| if not script_args.use_flash_attention and ( | ||
| script_args.flash_attention_recompute or script_args.flash_attention_recompute | ||
| ): | ||
| assert "Need to enable use_flash_attention" | ||
| model.generation_config.use_flash_attention = script_args.use_flash_attention | ||
| model.generation_config.flash_attention_recompute = script_args.flash_attention_recompute | ||
| model.generation_config.flash_attention_causal_mask = script_args.flash_attention_causal_mask | ||
|
|
||
| reward_funcs = [format_reward, accuracy_reward] | ||
| if script_args.reward_model_name_or_path: | ||
| reward_funcs = AutoModelForSequenceClassification.from_pretrained( | ||
| script_args.reward_model_name_or_path, | ||
| trust_remote_code=True, | ||
| ) | ||
|
|
||
| if getattr(tokenizer, "pad_token", None) is None: | ||
| tokenizer.pad_token = tokenizer.eos_token | ||
|
|
||
| gaudi_config = GaudiConfig() | ||
| gaudi_config.use_fused_adam = True | ||
| gaudi_config.use_fused_clip_norm = True | ||
|
|
||
| trainer = GaudiGRPOTrainer( | ||
| model=model, | ||
| reward_funcs=reward_funcs, | ||
| args=training_args, | ||
| train_dataset=train_dataset, | ||
| eval_dataset=test_dataset, | ||
| processing_class=tokenizer, | ||
| gaudi_config=gaudi_config, | ||
| peft_config=peft_config, | ||
| ) | ||
|
|
||
| trainer.train() | ||
|
|
||
| print("Done!") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| trl == 0.17.0 | ||
| peft == 0.12.0 | ||
| datasets | ||
| tyro | ||
| evaluate | ||
| scikit-learn == 1.5.2 | ||
| accelerate | ||
| math_verify |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,21 @@ | ||
| import importlib.metadata | ||
|
|
||
| from packaging import version | ||
|
|
||
| from .models.modeling_base import adapt_PreTrainedModelWrapper_to_gaudi | ||
| from .models.modeling_sd_base import GaudiDefaultDDPOStableDiffusionPipeline | ||
| from .trainer.ddpo_trainer import GaudiDDPOTrainer | ||
| from .trainer.dpo_config import GaudiDPOConfig | ||
| from .trainer.dpo_trainer import GaudiDPOTrainer | ||
| from .trainer.ppo_config import GaudiPPOConfig | ||
| from .trainer.ppo_trainer import GaudiPPOTrainer | ||
|
|
||
|
|
||
| trl_version = importlib.metadata.version("trl") | ||
| if version.parse(trl_version) < version.parse("0.17.0"): | ||
| from .trainer.ppo_config import GaudiPPOConfig | ||
| from .trainer.ppo_trainer import GaudiPPOTrainer | ||
| else: | ||
| from .trainer.grpo_config import GaudiGRPOConfig | ||
| from .trainer.grpo_trainer import GaudiGRPOTrainer | ||
| from .trainer.reward_trainer import GaudiRewardTrainer, RewardDataCollatorWithPadding | ||
| from .trainer.sft_config import GaudiSFTConfig | ||
| from .trainer.sft_trainer import GaudiSFTTrainer |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.