diff --git a/setup.sh b/setup.sh index adcde953..366ccf70 100644 --- a/setup.sh +++ b/setup.sh @@ -10,6 +10,3 @@ pip install flash-attn --no-build-isolation # vLLM support pip install vllm==0.7.2 - -# fix transformers version -pip install git+https://github.com/huggingface/transformers.git@336dc69d63d56f232a183a3e7f52790429b871ef \ No newline at end of file diff --git a/src/r1-v/setup.py b/src/r1-v/setup.py index a847d9eb..6207002f 100644 --- a/src/r1-v/setup.py +++ b/src/r1-v/setup.py @@ -61,7 +61,7 @@ "safetensors>=0.3.3", "sentencepiece>=0.1.99", "torch>=2.5.1", - "transformers @ git+https://github.com/huggingface/transformers.git@336dc69d63d56f232a183a3e7f52790429b871ef", + "transformers @ git+https://github.com/huggingface/transformers.git@main", "trl==0.14.0", "vllm==0.6.6.post1", "wandb>=0.19.1", diff --git a/src/r1-v/src/open_r1/sft.py b/src/r1-v/src/open_r1/sft.py index 41011e4e..5150e9cb 100644 --- a/src/r1-v/src/open_r1/sft.py +++ b/src/r1-v/src/open_r1/sft.py @@ -243,10 +243,17 @@ def main(script_args, training_args, model_args): quantization_config=quantization_config, ) # training_args.model_init_kwargs = model_kwargs - from transformers import Qwen2VLForConditionalGeneration - model = Qwen2VLForConditionalGeneration.from_pretrained( - model_args.model_name_or_path, **model_kwargs - ) + from transformers import Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration + if "Qwen2-VL" in model_args.model_name_or_path: + model = Qwen2VLForConditionalGeneration.from_pretrained( + model_args.model_name_or_path, **model_kwargs + ) + elif "Qwen2.5-VL" in model_args.model_name_or_path: + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_args.model_name_or_path, **model_kwargs + ) + else: + assert False, f"Model {model_args.model_name_or_path} not supported" ############################ # Initialize the SFT Trainer ############################