Skip to content

Conversation

@SP1029
Copy link
Contributor

@SP1029 SP1029 commented Sep 3, 2025

Issue #2757 is about replacing the from_legacy_cache method of transformers Cache classes in peft due to their removal from transformers in v4.58.0. The from_legacy_cache method is used to convert data in the form of iterables to Cache objects.

The issue is addressed by replacing the method with the corresponding constructors of the Cache classes since they also deal with the data in a similar way.

@SP1029 SP1029 marked this pull request as draft September 3, 2025 13:54
@SP1029
Copy link
Contributor Author

SP1029 commented Sep 3, 2025

The testing script below passes locally for me. @BenjaminBossan, would this script be fine, or would we want to add anything else?

import pickle
import torch
from peft import PrefixTuningConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM

torch.manual_seed(0)

# Constants
CODEBASE = "old" # new/old

# Test definitions

TEST = [
    "Decoder",
    "Encoder-Decoder",
    "Gemma"
]

MODEL = [
    AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM"),
    AutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration"),
    AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-Gemma3ForCausalLM"),
]

PEFT_CONFIG = [
    PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM"),
    PrefixTuningConfig(num_virtual_tokens=10, task_type="SEQ_2_SEQ_LM"),
    PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM")
]

# Supporting functions

def save_cache(model, test):
    batch_size = 1
    
    with open(f"{test}_{CODEBASE}.pkl", "wb") as f:
        pickle.dump(model.get_prompt(batch_size), f, pickle.HIGHEST_PROTOCOL)    
       
 # for cache data of a single layer
def is_equal_layer(layer1, layer2):
    
    for attr in layer1.__dict__:
        check = bool(float(torch.tensor(layer1.__dict__[attr] == layer2.__dict__[attr]).min()))
        if check == False:
            return False
        
    return True

# For DynamicCache objects
def is_equal_DynamicCache(cache1, cache2):
    
    if len(cache1.layers)!=len(cache2.layers):
        return False
    
    for i in range(len(cache1.layers)):
        if not is_equal_layer(cache1.layers[i], cache2.layers[i]):
            return False
        
    return True

# For DynamicCache objects of the EncoderDecoderCache
def is_equal_EncoderDecoderCache(cache1, cache2):
    
    if not is_equal_DynamicCache(cache1.self_attention_cache, cache2.self_attention_cache):
        return False
    
    if not is_equal_DynamicCache(cache1.cross_attention_cache, cache2.cross_attention_cache):
        return False
    
    return True

def is_equal_cache(cache1, cache2):
    
    cache_type = type(cache1).__name__
    
    if cache_type=="DynamicCache":
        return is_equal_DynamicCache(cache1, cache2)
    elif cache_type=="EncoderDecoderCache":
        return is_equal_EncoderDecoderCache(cache1, cache2)

# Executing tests

for i in range(3):
    model = get_peft_model(MODEL[i], PEFT_CONFIG[i])
    test = TEST[i]
    save_cache(model, test)

# Comparing

for i in range(3):
    test = TEST[i]
    with open(f"{test}_old.pkl", "rb") as f:
        cache1 = pickle.load(f)
    with open(f"{test}_new.pkl", "rb") as f:
        cache2 = pickle.load(f)
        
    print(test, is_equal_cache(cache1, cache2))
    

@SP1029 SP1029 marked this pull request as ready for review September 3, 2025 14:57
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this PR and especially for checking the results. This check is more in-depth than what I had in mind (I would have simply compared model outputs) but it works too.

I have one comment concerning backwards compatibility, otherwise the PR looks good.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this task, LGTM.

Just to confirm 4.58.0 for TODO was a typo, right, since check is for 4.56.0, or was it intended?

Indeed. The deprecation is schedules for 4.58.0 but the support for from_legacy_cache can still be dropped once we move beyond 4.56.0, good catch.

@BenjaminBossan BenjaminBossan merged commit ed5c6ea into huggingface:main Sep 8, 2025
26 of 27 checks passed
@SP1029 SP1029 deleted the replace-from-legacy-cache branch September 9, 2025 06:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants