diff --git a/unsloth/models/rl.py b/unsloth/models/rl.py index ae01469ac..45b8ca633 100644 --- a/unsloth/models/rl.py +++ b/unsloth/models/rl.py @@ -168,7 +168,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): trainer = eval(f"trl.trainer.{trainer_file}") except Exception as error: return - + # Get SFTTrainer and SFTConfig names name = [x for x in dir(trainer) if x.endswith("Trainer") and x != "Trainer" and trainer_file.split("_")[0] in x.lower()] config = [x for x in dir(trainer) if x.endswith("Config") and x != "Config" and trainer_file.split("_")[0] in x.lower()] @@ -484,7 +484,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): "dataloader_persistent_workers" : True, # Keeps dataloader in RAM "dataloader_prefetch_factor" : 2, "dataloader_pin_memory" : True, - "dataloader_num_workers" : 0, # Default is 0 means 1 + "dataloader_num_workers" : 1, } for k, v in replacements.items(): x = f"{k}( = [^,\n]{{1,}})?,\n" @@ -565,7 +565,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): pass # Check GRPO num_generations mismatch - if "per_device_train_batch_size" in call_args and "num_generations" in call_args: + if "per_device_train_batch_size" in call_args and "num_generations" in call_args: check_num_generations = \ "if (per_device_train_batch_size // num_generations) * num_generations != per_device_train_batch_size:\n"\ " print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"\ @@ -576,7 +576,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): pass # Check temperature must not be <= 0. Also stop if >= 10 - if "temperature" in call_args: + if "temperature" in call_args: check_temperature = \ "if temperature <= 0:\n"\ " raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')\n"\ @@ -625,7 +625,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): if "SamplingParams" in old_RLTrainer_source: RL_pre = RL_pre + "\n" + inspect.getsource(vLLMSamplingParams) pass - + # Selective log softmax selective_log_softmax_code = inspect.getsource(selective_log_softmax) @@ -651,12 +651,12 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): selective_log_softmax_code = selective_log_softmax_code, ) - + if RLTrainer_name == "SFTTrainer": original_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask"]' new_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask","labels"]' RLTrainer_source = RLTrainer_source.replace(original_text, new_text) - + # Remove multiple doc strings if __RLConfig_doc__ != "" and RLTrainer_source.count(__RLTrainer_doc__) == 2: RLTrainer_source = RLTrainer_source.replace(__RLTrainer_doc__, "", 1) @@ -673,12 +673,12 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"): imports, overwrite = False, ) - + # Patch Trainer exec(f"trl.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals()) exec(f"trl.trainer.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals()) exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals()) - + # Patch Config exec(f"trl.{RLConfig_name} = created_module.Unsloth{RLConfig_name}", locals(), globals()) exec(f"trl.trainer.{RLConfig_name} = created_module.Unsloth{RLConfig_name}", locals(), globals()) @@ -754,7 +754,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import new_vllm_part, flags = re.MULTILINE | re.DOTALL, ) - + if len(sampling_params) == 1: sampling_params = sampling_params[0] # Fix guided_decoding @@ -768,7 +768,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import sampling_params = \ " "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \ sampling_params # Add spaces - + # count the indentation of last line of sampling_params. last_line = sampling_params.split("\n")[-1] last_prev_line = sampling_params.split("\n")[-2] @@ -844,7 +844,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import r"", source, ) - + # Replace self.llm.generate and self.llm.chat lora_name = trainer_file + "_lora_model" source = re.sub(