Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,10 @@ def test_training_beta_non_zero(self):
new_param = trainer.model.get_parameter(n)
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

def test_training_with_cast_lm_head_to_fp32(self):
@pytest.mark.parametrize(
"model_name", ["trl-internal-testing/tiny-Qwen3ForCausalLM", "trl-internal-testing/tiny-Gemma2ForCausalLM"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Qwen3 has tied word embedding and Gemma 2 no, correct? If so, I'd just add a small comment so that we remember why we test these two cases

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the other way around Qwen3 has untied and Gemma 2 has tied.

)
def test_training_with_cast_lm_head_to_fp32(self, model_name):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
training_args = GRPOConfig(
output_dir=self.tmp_dir,
Expand All @@ -715,7 +718,7 @@ def test_training_with_cast_lm_head_to_fp32(self):
cast_lm_head_to_fp32=True,
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
model=model_name,
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
Expand Down
36 changes: 23 additions & 13 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,21 +479,31 @@ def __init__(

# Cast LM Head To FP32
if args.cast_lm_head_to_fp32:
if not model.config.tie_word_embeddings:

def cast_inputs_to_fp32(module, input):
return (input[0].float(),)
def _cast_lm_head_to_fp32(target_model: PreTrainedModel):
"""Cast lm_head to fp32 while preserving embedding output dtype if tied."""

model.lm_head = model.lm_head.float()
model.lm_head.register_forward_pre_hook(cast_inputs_to_fp32)
if self.ref_model is not None:
self.ref_model.lm_head = self.ref_model.lm_head.float()
self.ref_model.lm_head.register_forward_pre_hook(cast_inputs_to_fp32)
else:
raise NotImplementedError(
"`cast_lm_head_to_fp32=True` is only supported when the model has untied word embedding and language modeling head layers"
"i.e. `tie_word_embeddings` in the model config is False."
)
def cast_inputs_to_fp32(module, inputs):
# Preserve other positional args and kwargs untouched
if not inputs:
return inputs
return (inputs[0].to(torch.float32),) + inputs[1:]

original_dtype_local = target_model.lm_head.weight.dtype
target_model.lm_head = target_model.lm_head.float()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the record, float() is inlace, so in theory, you could just have

target_model.lm_head.float()

it happens that .float() returns self, so target_model.lm_head = target_model.lm_head.float() works as well. I personally prefer the current way, assignment makes it more explicit.

target_model.lm_head.register_forward_pre_hook(cast_inputs_to_fp32)

if target_model.config.tie_word_embeddings:

def cast_outputs_to_original_dtype(module, args, output):
return output.to(original_dtype_local)

# Only cast activations; weights are now fp32 (intentional for numerical stability of logits)
target_model.model.embed_tokens.register_forward_hook(cast_outputs_to_original_dtype)

_cast_lm_head_to_fp32(model)
if self.ref_model is not None:
_cast_lm_head_to_fp32(self.ref_model)

# Liger loss
if self.use_liger_kernel:
Expand Down
Loading