Skip to content

Bug: AutoModelForVision2Seq.from_config() Compatibility Issue #3267

@Maxwell-Jia

Description

@Maxwell-Jia

Description

When using FSDPCheckpointManager to save checkpoints for Vision Language Models (e.g., Qwen2.5-VL-3B-Instruct) with the hf_model save option enabled, a TypeError occurs due to incompatible parameters in the from_config() method.

Error Message

TypeError: AutoModelForVision2Seq.from_config() got an unexpected keyword argument 'torch_dtype'

Error Location

File: verl/utils/checkpoint/fsdp_checkpoint_manager.py
Line: 331

save_model = auto_model_cls.from_config(model_config, torch_dtype=torch.bfloat16)

Environment

  • transformers version: 4.55.4
  • Python version: 3.10
  • PyTorch version: 2.7.1
  • CUDA version: 12.6

Steps to Reproduce

  1. Load a Vision Language Model (e.g., Qwen2.5-VL-3B-Instruct)
  2. Wrap the model with FSDP
  3. Create FSDPCheckpointManager with hf_model in save configuration:
    checkpoint_config = DictConfig({
        'save_contents': ['model', 'hf_model'],
        'load_contents': ['model']
    })
  4. Call checkpoint_manager.save_checkpoint()

Minimal Reproduction Script

#!/usr/bin/env python3
import os
import tempfile
import torch
import torch.distributed as dist
import warnings
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from transformers import AutoConfig, AutoModelForVision2Seq, AutoProcessor
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from omegaconf import DictConfig

def setup_distributed():
    if not dist.is_initialized():
        os.environ['CUDA_VISIBLE_DEVICES'] = '0'
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'
        os.environ['RANK'] = '0'
        os.environ['WORLD_SIZE'] = '1'
        dist.init_process_group(backend='nccl' if torch.cuda.is_available() else 'gloo')

def main():
    setup_distributed()
    
    model_name = "Qwen/Qwen2.5-VL-3B-Instruct"  # or local path
    
    config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
    processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        model = AutoModelForVision2Seq.from_pretrained(
            model_name, 
            torch_dtype=torch.bfloat16,
            trust_remote_code=True,
            device_map={"": device}
        )
    
    auto_wrap_policy = size_based_auto_wrap_policy
    fsdp_model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        mixed_precision=None,
        device_id=0 if torch.cuda.is_available() else None,
        use_orig_params=True,
    )
    
    checkpoint_config = DictConfig({
        'save_contents': ['model', 'hf_model'],
        'load_contents': ['model']
    })
    
    checkpoint_manager = FSDPCheckpointManager(
        model=fsdp_model,
        optimizer=None,
        lr_scheduler=None,
        processing_class=processor,
        checkpoint_config=checkpoint_config
    )
    
    import shutil
    temp_dir = tempfile.mkdtemp()
    try:
        checkpoint_path = os.path.join(temp_dir, "test_checkpoint")
        checkpoint_manager.save_checkpoint(local_path=checkpoint_path, global_step=1)
    finally:
        if os.path.exists(temp_dir):
            shutil.rmtree(temp_dir, ignore_errors=True)
    
    if dist.is_initialized():
        dist.destroy_process_group()

if __name__ == "__main__":
    main()

Root Cause

The issue occurs at fsdp_checkpoint_manager.py:331:

save_model = auto_model_cls.from_config(model_config, torch_dtype=torch.bfloat16)

In transformers>=4.54.0, the AutoModelForVision2Seq.from_config() method does not accept the torch_dtype parameter, and AutoModelForVision2Seq is deprecated and will be removed in v5.0. We should use "AutoModelForImageTextToText instead."

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions