|
| 1 | +import contextlib |
| 2 | +import io |
| 3 | +import logging |
| 4 | +import re |
| 5 | +from dataclasses import dataclass, field |
| 6 | +from typing import List, Optional |
| 7 | + |
| 8 | +import torch |
| 9 | +import transformers |
| 10 | +from datasets import load_dataset |
| 11 | +from math_verify import LatexExtractionConfig, parse, verify |
| 12 | +from peft import LoraConfig |
| 13 | +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser |
| 14 | +from transformers.integrations.deepspeed import ( |
| 15 | + is_deepspeed_available, |
| 16 | +) |
| 17 | +from transformers.trainer_utils import is_main_process |
| 18 | + |
| 19 | +from optimum.habana import GaudiConfig |
| 20 | +from optimum.habana.trl import GaudiGRPOConfig, GaudiGRPOTrainer |
| 21 | +from optimum.habana.utils import set_seed |
| 22 | + |
| 23 | + |
| 24 | +logger = logging.getLogger(__name__) |
| 25 | +SYSTEM_PROMPT = ( |
| 26 | + "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " |
| 27 | + "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " |
| 28 | + "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., " |
| 29 | + "<think> reasoning process here </think><answer> answer here </answer>" |
| 30 | +) |
| 31 | + |
| 32 | + |
| 33 | +def make_conversation(example): |
| 34 | + return { |
| 35 | + "prompt": [ |
| 36 | + {"role": "system", "content": SYSTEM_PROMPT}, |
| 37 | + {"role": "user", "content": example["problem"]}, |
| 38 | + ], |
| 39 | + } |
| 40 | + |
| 41 | + |
| 42 | +ideal_length = 50 |
| 43 | + |
| 44 | + |
| 45 | +def reward_len(completions, **kwargs): |
| 46 | + return [-abs(ideal_length - len(completion)) for completion in completions] # penalize response when len!=50 |
| 47 | + |
| 48 | + |
| 49 | +def format_reward(completions, **kwargs): |
| 50 | + # Checks if the reasoning process is enclosed within <think> and </think> tags, |
| 51 | + # while the final answer is enclosed within <answer> and </answer> tags. |
| 52 | + pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$" |
| 53 | + completion_contents = [completion[0]["content"] for completion in completions] |
| 54 | + matches = [re.match(pattern, content) for content in completion_contents] |
| 55 | + return [1.0 if match else 0.0 for match in matches] |
| 56 | + |
| 57 | + |
| 58 | +def accuracy_reward(completions, **kwargs): |
| 59 | + # Checks if the completion is the same as the ground truth. |
| 60 | + solutions = kwargs["solution"] |
| 61 | + completion_contents = [completion[0]["content"] for completion in completions] |
| 62 | + rewards = [] |
| 63 | + for content, solution in zip(completion_contents, solutions): |
| 64 | + gold_parsed = parse(solution, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()]) |
| 65 | + answer_parsed = parse(content, extraction_mode="first_match", extraction_config=[LatexExtractionConfig()]) |
| 66 | + if len(gold_parsed) != 0: |
| 67 | + try: |
| 68 | + with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(io.StringIO()): |
| 69 | + rewards.append(float(verify(answer_parsed, gold_parsed))) |
| 70 | + except Exception: |
| 71 | + rewards.append(0.0) |
| 72 | + else: |
| 73 | + rewards.append(1.0) |
| 74 | + return rewards |
| 75 | + |
| 76 | + |
| 77 | +@dataclass |
| 78 | +class ScriptArguments: |
| 79 | + model_name_or_path: Optional[str] = field(default="Qwen/Qwen2-0.5B-Instruct", metadata={"help": "the model name"}) |
| 80 | + dataset_name: Optional[str] = field(default=None, metadata={"help": "the dataset name"}) |
| 81 | + use_peft: Optional[bool] = field(default=False, metadata={"help": "whether to use peft"}) |
| 82 | + num_workers: Optional[int] = field(default=1, metadata={"help": "the number of workers"}) |
| 83 | + subset: Optional[str] = field(default=None, metadata={"help": "the subset to use"}) |
| 84 | + streaming: Optional[bool] = field(default=False, metadata={"help": "whether to stream the dataset"}) |
| 85 | + dataset_train_split: str = field(default="train[:5%]", metadata={"help": "Dataset split to use for training."}) |
| 86 | + dataset_test_split: str = field(default="test[:5%]", metadata={"help": "Dataset split to use for evaluation."}) |
| 87 | + reward_model_name_or_path: Optional[str] = field( |
| 88 | + default=None, |
| 89 | + metadata={ |
| 90 | + "help": "Reward model id of a pretrained model hosted inside a model repo on huggingface.co or " |
| 91 | + "local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`." |
| 92 | + }, |
| 93 | + ) |
| 94 | + |
| 95 | + use_flash_attention: Optional[bool] = field( |
| 96 | + default=True, metadata={"help": "Whether to use Habana flash attention for fine-tuning."} |
| 97 | + ) |
| 98 | + flash_attention_recompute: Optional[bool] = field( |
| 99 | + default=False, metadata={"help": "Whether to enable recompute in Habana flash attention for fine-tuning."} |
| 100 | + ) |
| 101 | + flash_attention_causal_mask: Optional[bool] = field( |
| 102 | + default=False, metadata={"help": "Whether to enable causal mask in Habana flash attention for fine-tuning."} |
| 103 | + ) |
| 104 | + |
| 105 | + # LoraConfig |
| 106 | + lora_alpha: Optional[float] = field(default=32, metadata={"help": "the lora alpha parameter"}) |
| 107 | + lora_dropout: Optional[float] = field(default=0.1, metadata={"help": "the lora dropout parameter"}) |
| 108 | + lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"}) |
| 109 | + lora_target_modules: List[str] = field( |
| 110 | + default_factory=lambda: None, |
| 111 | + metadata={"help": "Target modules for the LoRA method."}, |
| 112 | + ) |
| 113 | + |
| 114 | + |
| 115 | +if __name__ == "__main__": |
| 116 | + parser = HfArgumentParser((GaudiGRPOConfig, ScriptArguments)) |
| 117 | + (training_args, script_args) = parser.parse_args_into_dataclasses() |
| 118 | + |
| 119 | + logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) |
| 120 | + logger.warning( |
| 121 | + f"Process rank: {training_args.local_rank}, device: {training_args.device}, " |
| 122 | + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.bf16}" |
| 123 | + ) |
| 124 | + # Set the verbosity to info of the Transformers logger (on main process only): |
| 125 | + if is_main_process(training_args.local_rank): |
| 126 | + transformers.utils.logging.set_verbosity_info() |
| 127 | + transformers.utils.logging.enable_default_handler() |
| 128 | + transformers.utils.logging.enable_explicit_format() |
| 129 | + logger.info(f"Training/evaluation parameters {training_args}") |
| 130 | + # Set seed before initializing model. |
| 131 | + set_seed(training_args.seed) |
| 132 | + |
| 133 | + use_deepspeed = training_args.world_size > 1 |
| 134 | + |
| 135 | + if script_args.use_peft: |
| 136 | + peft_config = LoraConfig( |
| 137 | + r=script_args.lora_r, |
| 138 | + lora_alpha=script_args.lora_alpha, |
| 139 | + lora_dropout=script_args.lora_dropout, |
| 140 | + target_modules=script_args.lora_target_modules, |
| 141 | + task_type="CAUSAL_LM", |
| 142 | + ) |
| 143 | + else: |
| 144 | + peft_config = None |
| 145 | + |
| 146 | + tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path, trust_remote_code=True) |
| 147 | + if training_args.chat_template is not None: |
| 148 | + tokenizer.chat_template = training_args.chat_template |
| 149 | + |
| 150 | + train_dataset, test_dataset = load_dataset( |
| 151 | + script_args.dataset_name, |
| 152 | + data_dir=None if script_args.subset == "None" else script_args.subset, |
| 153 | + num_proc=script_args.num_workers if not script_args.streaming else None, |
| 154 | + split=[script_args.dataset_train_split, script_args.dataset_test_split], |
| 155 | + ) |
| 156 | + |
| 157 | + train_dataset = train_dataset.map(make_conversation) |
| 158 | + test_dataset = test_dataset.map(make_conversation) |
| 159 | + train_dataset = train_dataset.remove_columns(["messages", "problem"]) |
| 160 | + |
| 161 | + low_cpu_mem_usage = True |
| 162 | + if is_deepspeed_available() and use_deepspeed: |
| 163 | + from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled |
| 164 | + |
| 165 | + if is_deepspeed_zero3_enabled(): |
| 166 | + low_cpu_mem_usage = False |
| 167 | + |
| 168 | + model = AutoModelForCausalLM.from_pretrained( |
| 169 | + script_args.model_name_or_path, |
| 170 | + low_cpu_mem_usage=low_cpu_mem_usage, |
| 171 | + torch_dtype=torch.bfloat16, |
| 172 | + ) |
| 173 | + |
| 174 | + model.config.use_cache = False |
| 175 | + if not script_args.use_flash_attention and ( |
| 176 | + script_args.flash_attention_recompute or script_args.flash_attention_recompute |
| 177 | + ): |
| 178 | + assert "Need to enable use_flash_attention" |
| 179 | + model.generation_config.use_flash_attention = script_args.use_flash_attention |
| 180 | + model.generation_config.flash_attention_recompute = script_args.flash_attention_recompute |
| 181 | + model.generation_config.flash_attention_causal_mask = script_args.flash_attention_causal_mask |
| 182 | + |
| 183 | + reward_funcs = [format_reward, accuracy_reward] |
| 184 | + if script_args.reward_model_name_or_path: |
| 185 | + reward_funcs = AutoModelForSequenceClassification.from_pretrained( |
| 186 | + script_args.reward_model_name_or_path, |
| 187 | + trust_remote_code=True, |
| 188 | + ) |
| 189 | + |
| 190 | + if getattr(tokenizer, "pad_token", None) is None: |
| 191 | + tokenizer.pad_token = tokenizer.eos_token |
| 192 | + |
| 193 | + gaudi_config = GaudiConfig() |
| 194 | + gaudi_config.use_fused_adam = True |
| 195 | + gaudi_config.use_fused_clip_norm = True |
| 196 | + |
| 197 | + trainer = GaudiGRPOTrainer( |
| 198 | + model=model, |
| 199 | + reward_funcs=reward_funcs, |
| 200 | + args=training_args, |
| 201 | + train_dataset=train_dataset, |
| 202 | + eval_dataset=test_dataset, |
| 203 | + processing_class=tokenizer, |
| 204 | + gaudi_config=gaudi_config, |
| 205 | + peft_config=peft_config, |
| 206 | + ) |
| 207 | + |
| 208 | + trainer.train() |
| 209 | + |
| 210 | + print("Done!") |
0 commit comments