Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 63 additions & 2 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,16 @@


if is_peft_available():
from peft import LoraConfig, PeftModel, PromptEncoderConfig, TaskType, get_peft_model
from peft import (
LoraConfig,
PeftModel,
PrefixTuningConfig,
PromptEncoderConfig,
PromptTuningConfig,
PromptTuningInit,
TaskType,
get_peft_model,
)


class DFTLossTester(TrlTestCase):
Expand Down Expand Up @@ -451,7 +460,7 @@ def test_train_model_dtype(self):
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")

@require_peft
def test_train_dense_with_peft_config(self):
def test_train_dense_with_peft_config_lora(self):
# Get the base model parameter names
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id)
Expand Down Expand Up @@ -487,6 +496,58 @@ def test_train_dense_with_peft_config(self):
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")

@parameterized.expand(
[
(PromptEncoderConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4),),
(PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4),),
(
PromptTuningConfig(
task_type=TaskType.CAUSAL_LM,
prompt_tuning_init=PromptTuningInit.RANDOM,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prompt_tuning_init=PromptTuningInit.RANDOM,

nit: I think this is the default value

num_virtual_tokens=4,
tokenizer_name_or_path="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
),
),
]
)
@require_peft
def test_train_with_peft_config_prompt_tuning(self, peft_config):
# Get the base model parameter names
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
model = AutoModelForCausalLM.from_pretrained(model_id)
base_param_names = [f"base_model.{n}" for n, _ in model.named_parameters()]

# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")

# Initialize the trainer, p-tuning doesn't support gradient checkpointing
training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none", gradient_checkpointing=False)
peft_config.encoder_hidden_size = model.config.hidden_size

trainer = SFTTrainer(
model=model_id,
args=training_args,
train_dataset=dataset,
peft_config=peft_config,
)

# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
trainer.train()

# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've now drooped unittest, see #4188, you'll have to replace these statements when merging main to your branch


# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
else: # We expect the peft parameters to be different
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")

@require_peft
def test_train_moe_with_peft_config(self):
# Get the base model parameter names
Expand Down
23 changes: 14 additions & 9 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@


if is_peft_available():
from peft import PeftConfig, PeftModel
from peft import PeftConfig, PeftModel, PeftType


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -1083,13 +1083,15 @@ def compute_loss(
if not self.args.use_liger_kernel: # liger doesn't return logits
with torch.no_grad():
per_token_entropy = entropy_from_logits(outputs.logits)
# When using Prompt Tuning, skip the virtual tokens in logits before entropy computation, since they
# do not correspond to actual input tokens.
if (
self.num_virtual_tokens > 0
and model.peft_config[model.active_adapter].peft_type != PeftType.PREFIX_TUNING
):
per_token_entropy = per_token_entropy[:, self.num_virtual_tokens :]
if "attention_mask" in inputs:
attention_mask = inputs["attention_mask"]
# When using Prompt Tuning, we need to add attention for the virtual tokens (all set to 1).
virtual_attention_mask = torch.ones(
attention_mask.size(0), self.num_virtual_tokens, device=attention_mask.device
)
attention_mask = torch.cat((virtual_attention_mask, attention_mask), dim=1)
entropy = torch.sum(per_token_entropy * attention_mask) / attention_mask.sum()
elif "position_ids" in inputs:
entropy = torch.mean(per_token_entropy)
Expand Down Expand Up @@ -1124,9 +1126,12 @@ def compute_loss(
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# When using Prompt Tuning, skip the virtual tokens in logits before accuracy computation, since they do
# not correspond to actual input labels.
shift_logits = shift_logits[:, self.num_virtual_tokens :, :]
# Prompt Tuning and P-Tuning output logits for virtual tokens but Prefix-Tuning does not.
if (
self.num_virtual_tokens > 0
and model.peft_config[model.active_adapter].peft_type != PeftType.PREFIX_TUNING
):
shift_logits = shift_logits[:, self.num_virtual_tokens :, :]

# Get predictions
predictions = shift_logits.argmax(dim=-1)
Expand Down
Loading