Skip to content

Conversation

@Datta0
Copy link
Collaborator

@Datta0 Datta0 commented May 15, 2025

Depends on : unslothai/unsloth-zoo#140

Tested: GRPO with qwen-3-4B, ORPO and DPO (to make sure nothing breaks there)

TODO:

[ ] Compare performance

pluesclues and others added 9 commits May 18, 2025 11:30
For TRL 0.18.0 (Main branch of TRL at the time because its on 0.17.0) , the SFT trainer for some reason deletes the labels column and unsloth internal loss funcitons need that column for hte claculations so I add it back in like this.
Trl update, particualry small fixes on SFT trainer
@Datta0 Datta0 force-pushed the trl_upgrade_fix branch from 8b89b09 to 5096de8 Compare May 24, 2025 07:54
@Datta0 Datta0 force-pushed the trl_upgrade_fix branch from 5096de8 to dd5128d Compare May 24, 2025 07:54
if trl_version >= "0.18":
# Replace LLM init with already existing vLLM engine for colocate mode
vllm_llm_init_pattern = r"self\.llm\s*=\s*LLM\([^)]*\)*\)"
vllm_llm_repalcement = "self.llm = model.vllm_engine\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spelling error :)

vllm_llm_repalcement = "self.llm = model.vllm_engine\n"
new_vllm_part = re.sub(
vllm_llm_init_pattern,
vllm_llm_repalcement,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

if not hasattr(self, '_autocast_dtype'):
self._autocast_dtype = torch.float16 if os.environ.get('ACCELERATE_MIXED_PRECISION', 'fp16') == 'fp16' else torch.bfloat16
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1': self._autocast_dtype = torch.float16
os.environ["UNSLOTH_RETURN_HIDDEN_STATES"] = "0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait this forcefully returns hidden states right - doesn't this make the GRPO loss use more memory?

logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

logits = logits.to(torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is correct since we auto upcast inside the torch.compile function, so this uses 2x more memory

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think so.
will remove this

logits = logits[:, -logits_to_keep:]
return logits
# return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
#return logits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is correct for non GRPO, but reminder in Unsloth GRPO, we specifically calculate the logits on the fly to save VRAM

@Datta0 Datta0 force-pushed the trl_upgrade_fix branch from a7eec30 to 1aa7aa1 Compare May 25, 2025 13:11
@Datta0 Datta0 force-pushed the trl_upgrade_fix branch from 2961e2d to 1aa2e37 Compare May 25, 2025 18:31
@Datta0 Datta0 marked this pull request as ready for review May 26, 2025 04:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants