Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions unsloth/models/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
trainer = eval(f"trl.trainer.{trainer_file}")
except Exception as error:
return

# Get SFTTrainer and SFTConfig names
name = [x for x in dir(trainer) if x.endswith("Trainer") and x != "Trainer" and trainer_file.split("_")[0] in x.lower()]
config = [x for x in dir(trainer) if x.endswith("Config") and x != "Config" and trainer_file.split("_")[0] in x.lower()]
Expand Down Expand Up @@ -484,7 +484,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
"dataloader_persistent_workers" : True, # Keeps dataloader in RAM
"dataloader_prefetch_factor" : 2,
"dataloader_pin_memory" : True,
"dataloader_num_workers" : 0, # Default is 0 means 1
"dataloader_num_workers" : 1,
}
for k, v in replacements.items():
x = f"{k}( = [^,\n]{{1,}})?,\n"
Expand Down Expand Up @@ -565,7 +565,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
pass

# Check GRPO num_generations mismatch
if "per_device_train_batch_size" in call_args and "num_generations" in call_args:
if "per_device_train_batch_size" in call_args and "num_generations" in call_args:
check_num_generations = \
"if (per_device_train_batch_size // num_generations) * num_generations != per_device_train_batch_size:\n"\
" print('Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.\\n"\
Expand All @@ -576,7 +576,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
pass

# Check temperature must not be <= 0. Also stop if >= 10
if "temperature" in call_args:
if "temperature" in call_args:
check_temperature = \
"if temperature <= 0:\n"\
" raise MathError('Unsloth: Please set a positive non-zero temperature since your results will be wrong.')\n"\
Expand Down Expand Up @@ -625,7 +625,7 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
if "SamplingParams" in old_RLTrainer_source:
RL_pre = RL_pre + "\n" + inspect.getsource(vLLMSamplingParams)
pass

# Selective log softmax
selective_log_softmax_code = inspect.getsource(selective_log_softmax)

Expand All @@ -651,12 +651,12 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):

selective_log_softmax_code = selective_log_softmax_code,
)

if RLTrainer_name == "SFTTrainer":
original_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask"]'
new_text = 'self._signature_columns = ["input_ids", "attention_mask", "completion_mask","labels"]'
RLTrainer_source = RLTrainer_source.replace(original_text, new_text)

# Remove multiple doc strings
if __RLConfig_doc__ != "" and RLTrainer_source.count(__RLTrainer_doc__) == 2:
RLTrainer_source = RLTrainer_source.replace(__RLTrainer_doc__, "", 1)
Expand All @@ -673,12 +673,12 @@ def _patch_trl_rl_trainers(trainer_file = "grpo_trainer"):
imports,
overwrite = False,
)

# Patch Trainer
exec(f"trl.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals())
exec(f"trl.trainer.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals())
exec(f"trl.trainer.{trainer_file}.{RLTrainer_name} = created_module.Unsloth{RLTrainer_name}", locals(), globals())

# Patch Config
exec(f"trl.{RLConfig_name} = created_module.Unsloth{RLConfig_name}", locals(), globals())
exec(f"trl.trainer.{RLConfig_name} = created_module.Unsloth{RLConfig_name}", locals(), globals())
Expand Down Expand Up @@ -754,7 +754,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import
new_vllm_part,
flags = re.MULTILINE | re.DOTALL,
)

if len(sampling_params) == 1:
sampling_params = sampling_params[0]
# Fix guided_decoding
Expand All @@ -768,7 +768,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import
sampling_params = \
" "*12 + "self.llm = model.vllm_engine; self._last_loaded_step = 0; " + \
sampling_params # Add spaces

# count the indentation of last line of sampling_params.
last_line = sampling_params.split("\n")[-1]
last_prev_line = sampling_params.split("\n")[-2]
Expand Down Expand Up @@ -844,7 +844,7 @@ def patch_functions(RLTrainer, trainer_file, RLTrainer_name, all_imports, import
r"",
source,
)

# Replace self.llm.generate and self.llm.chat
lora_name = trainer_file + "_lora_model"
source = re.sub(
Expand Down