-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Replace from_legacy_cache method with constructors #2767
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace from_legacy_cache method with constructors #2767
Conversation
|
The testing script below passes locally for me. @BenjaminBossan, would this script be fine, or would we want to add anything else? import pickle
import torch
from peft import PrefixTuningConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
torch.manual_seed(0)
# Constants
CODEBASE = "old" # new/old
# Test definitions
TEST = [
"Decoder",
"Encoder-Decoder",
"Gemma"
]
MODEL = [
AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM"),
AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration"),
AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-Gemma3ForCausalLM"),
]
PEFT_CONFIG = [
PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM"),
PrefixTuningConfig(num_virtual_tokens=10, task_type="SEQ_2_SEQ_LM"),
PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM")
]
# Supporting functions
def save_cache(model, test):
batch_size = 1
with open(f"{test}_{CODEBASE}.pkl", "wb") as f:
pickle.dump(model.get_prompt(batch_size), f, pickle.HIGHEST_PROTOCOL)
# for cache data of a single layer
def is_equal_layer(layer1, layer2):
for attr in layer1.__dict__:
check = bool(float(torch.tensor(layer1.__dict__[attr] == layer2.__dict__[attr]).min()))
if check == False:
return False
return True
# For DynamicCache objects
def is_equal_DynamicCache(cache1, cache2):
if len(cache1.layers)!=len(cache2.layers):
return False
for i in range(len(cache1.layers)):
if not is_equal_layer(cache1.layers[i], cache2.layers[i]):
return False
return True
# For DynamicCache objects of the EncoderDecoderCache
def is_equal_EncoderDecoderCache(cache1, cache2):
if not is_equal_DynamicCache(cache1.self_attention_cache, cache2.self_attention_cache):
return False
if not is_equal_DynamicCache(cache1.cross_attention_cache, cache2.cross_attention_cache):
return False
return True
def is_equal_cache(cache1, cache2):
cache_type = type(cache1).__name__
if cache_type=="DynamicCache":
return is_equal_DynamicCache(cache1, cache2)
elif cache_type=="EncoderDecoderCache":
return is_equal_EncoderDecoderCache(cache1, cache2)
# Executing tests
for i in range(3):
model = get_peft_model(MODEL[i], PEFT_CONFIG[i])
test = TEST[i]
save_cache(model, test)
# Comparing
for i in range(3):
test = TEST[i]
with open(f"{test}_old.pkl", "rb") as f:
cache1 = pickle.load(f)
with open(f"{test}_new.pkl", "rb") as f:
cache2 = pickle.load(f)
print(test, is_equal_cache(cache1, cache2))
|
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this PR and especially for checking the results. This check is more in-depth than what I had in mind (I would have simply compared model outputs) but it works too.
I have one comment concerning backwards compatibility, otherwise the PR looks good.
BenjaminBossan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this task, LGTM.
Just to confirm 4.58.0 for TODO was a typo, right, since check is for 4.56.0, or was it intended?
Indeed. The deprecation is schedules for 4.58.0 but the support for from_legacy_cache can still be dropped once we move beyond 4.56.0, good catch.
Issue #2757 is about replacing the
from_legacy_cachemethod oftransformersCacheclasses inpeftdue to their removal from transformers in v4.58.0. Thefrom_legacy_cachemethod is used to convert data in the form of iterables to Cache objects.The issue is addressed by replacing the method with the corresponding constructors of the Cache classes since they also deal with the data in a similar way.