generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Fix entropy and accuracy calculation for prompt_tuning techniques. #4196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
pramodith
merged 7 commits into
huggingface:main
from
pramodith:pramodith/sft_for_prompt_tuning
Oct 8, 2025
Merged
Changes from 2 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
a67d1fd
Fix entropy and accuracy calculation for prompt_tuning techniques.
pramodith 5da98d1
Merge branch 'main' into pramodith/sft_for_prompt_tuning
pramodith 302e560
Merge branch 'huggingface:main' into pramodith/sft_for_prompt_tuning
pramodith dd2b6b6
Address comments
pramodith 4adf41b
revert
pramodith da6bb8b
Comments
pramodith 4dcc0a5
Merge branch 'main' into pramodith/sft_for_prompt_tuning
pramodith File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -30,7 +30,16 @@ | |||
|
|
||||
|
|
||||
| if is_peft_available(): | ||||
| from peft import LoraConfig, PeftModel, PromptEncoderConfig, TaskType, get_peft_model | ||||
| from peft import ( | ||||
| LoraConfig, | ||||
| PeftModel, | ||||
| PrefixTuningConfig, | ||||
| PromptEncoderConfig, | ||||
| PromptTuningConfig, | ||||
| PromptTuningInit, | ||||
| TaskType, | ||||
| get_peft_model, | ||||
| ) | ||||
|
|
||||
|
|
||||
| class DFTLossTester(TrlTestCase): | ||||
|
|
@@ -451,7 +460,7 @@ def test_train_model_dtype(self): | |||
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | ||||
|
|
||||
| @require_peft | ||||
| def test_train_dense_with_peft_config(self): | ||||
| def test_train_dense_with_peft_config_lora(self): | ||||
| # Get the base model parameter names | ||||
| model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" | ||||
| model = AutoModelForCausalLM.from_pretrained(model_id) | ||||
|
|
@@ -487,6 +496,58 @@ def test_train_dense_with_peft_config(self): | |||
| elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) | ||||
| self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") | ||||
|
|
||||
| @parameterized.expand( | ||||
| [ | ||||
| (PromptEncoderConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4),), | ||||
| (PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=4),), | ||||
| ( | ||||
| PromptTuningConfig( | ||||
| task_type=TaskType.CAUSAL_LM, | ||||
| prompt_tuning_init=PromptTuningInit.RANDOM, | ||||
|
||||
| prompt_tuning_init=PromptTuningInit.RANDOM, |
nit: I think this is the default value
Outdated
Member
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We've now drooped unittest, see #4188, you'll have to replace these statements when merging main to your branch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.