-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Closed
Description
Description
When using FSDPCheckpointManager to save checkpoints for Vision Language Models (e.g., Qwen2.5-VL-3B-Instruct) with the hf_model save option enabled, a TypeError occurs due to incompatible parameters in the from_config() method.
Error Message
TypeError: AutoModelForVision2Seq.from_config() got an unexpected keyword argument 'torch_dtype'
Error Location
File: verl/utils/checkpoint/fsdp_checkpoint_manager.py
Line: 331
save_model = auto_model_cls.from_config(model_config, torch_dtype=torch.bfloat16)Environment
- transformers version: 4.55.4
- Python version: 3.10
- PyTorch version: 2.7.1
- CUDA version: 12.6
Steps to Reproduce
- Load a Vision Language Model (e.g., Qwen2.5-VL-3B-Instruct)
- Wrap the model with FSDP
- Create
FSDPCheckpointManagerwithhf_modelin save configuration:checkpoint_config = DictConfig({ 'save_contents': ['model', 'hf_model'], 'load_contents': ['model'] })
- Call
checkpoint_manager.save_checkpoint()
Minimal Reproduction Script
#!/usr/bin/env python3
import os
import tempfile
import torch
import torch.distributed as dist
import warnings
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from transformers import AutoConfig, AutoModelForVision2Seq, AutoProcessor
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from omegaconf import DictConfig
def setup_distributed():
if not dist.is_initialized():
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
dist.init_process_group(backend='nccl' if torch.cuda.is_available() else 'gloo')
def main():
setup_distributed()
model_name = "Qwen/Qwen2.5-VL-3B-Instruct" # or local path
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model = AutoModelForVision2Seq.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map={"": device}
)
auto_wrap_policy = size_based_auto_wrap_policy
fsdp_model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=None,
device_id=0 if torch.cuda.is_available() else None,
use_orig_params=True,
)
checkpoint_config = DictConfig({
'save_contents': ['model', 'hf_model'],
'load_contents': ['model']
})
checkpoint_manager = FSDPCheckpointManager(
model=fsdp_model,
optimizer=None,
lr_scheduler=None,
processing_class=processor,
checkpoint_config=checkpoint_config
)
import shutil
temp_dir = tempfile.mkdtemp()
try:
checkpoint_path = os.path.join(temp_dir, "test_checkpoint")
checkpoint_manager.save_checkpoint(local_path=checkpoint_path, global_step=1)
finally:
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
if dist.is_initialized():
dist.destroy_process_group()
if __name__ == "__main__":
main()Root Cause
The issue occurs at fsdp_checkpoint_manager.py:331:
save_model = auto_model_cls.from_config(model_config, torch_dtype=torch.bfloat16)In transformers>=4.54.0, the AutoModelForVision2Seq.from_config() method does not accept the torch_dtype parameter, and AutoModelForVision2Seq is deprecated and will be removed in v5.0. We should use "AutoModelForImageTextToText instead."
Metadata
Metadata
Assignees
Labels
No labels