-
-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Upgrade trl fix #2544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Upgrade trl fix #2544
Conversation
…se autoSequenceClassification
Trl upgrade fix
For TRL 0.18.0 (Main branch of TRL at the time because its on 0.17.0) , the SFT trainer for some reason deletes the labels column and unsloth internal loss funcitons need that column for hte claculations so I add it back in like this.
Trl update, particualry small fixes on SFT trainer
Signed-off-by: datta0 <[email protected]>
unsloth/models/rl.py
Outdated
| if trl_version >= "0.18": | ||
| # Replace LLM init with already existing vLLM engine for colocate mode | ||
| vllm_llm_init_pattern = r"self\.llm\s*=\s*LLM\([^)]*\)*\)" | ||
| vllm_llm_repalcement = "self.llm = model.vllm_engine\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spelling error :)
unsloth/models/rl.py
Outdated
| vllm_llm_repalcement = "self.llm = model.vllm_engine\n" | ||
| new_vllm_part = re.sub( | ||
| vllm_llm_init_pattern, | ||
| vllm_llm_repalcement, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same
unsloth/models/rl_replacements.py
Outdated
| if not hasattr(self, '_autocast_dtype'): | ||
| self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16 | ||
| if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16 | ||
| os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait this forcefully returns hidden states right - doesn't this make the GRPO loss use more memory?
unsloth/models/rl_replacements.py
Outdated
| logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits | ||
| logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred | ||
|
|
||
| logits = logits.to(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is correct since we auto upcast inside the torch.compile function, so this uses 2x more memory
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I think so.
will remove this
unsloth/models/rl_replacements.py
Outdated
| logits = logits[:, -logits_to_keep:] | ||
| return logits | ||
| # return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens | ||
| #return logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes this is correct for non GRPO, but reminder in Unsloth GRPO, we specifically calculate the logits on the fly to save VRAM
Signed-off-by: datta0 <[email protected]>
Depends on : unslothai/unsloth-zoo#140
Tested: GRPO with qwen-3-4B, ORPO and DPO (to make sure nothing breaks there)
TODO:
[ ] Compare performance