-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[Bug] RuntimeError with use_remove_padding=True in GRPO training for Qwen2.5-VL due to tensor shape mismatch #2958
Description
I am trying to run the GRPO training example for the Qwen2.5-VL model using the provided script. Due to VRAM limitations on my machine, I have modified the launch script to use the Qwen2.5-VL-3B-Instruct model instead of the default 7B version.
However, the training process fails at step 0 with a RuntimeError related to a tensor dimension mismatch inside the model's forward pass.
Through further debugging, I've discovered a key finding: this error only occurs when use_remove_padding=True is set in the configuration. If I set use_remove_padding=False, the training proceeds without this specific error.
Steps to Reproduce
Use the official launch script: run_qwen2_5_vl-7b_lora.sh.
Modify the actor_rollout_ref.model.path argument in the script to point to the 3B model path, for example: /path/to/Qwen2.5-VL-3B-Instruct.
Run the script.
Observed Behavior
The training process crashes immediately. The full error log has been uploaded to Pastebin for clarity:
https://pastebin.com/r3Xp7Jjs
The core error message is:
RuntimeError: expand(torch.cuda.LongTensor{[3, 1, 12087]}, size=[1, -1]): the number of sizes provided (2) must be greater or equal to the number of dimensions in the tensor (3)
This error originates from transformers/masking_utils.py when creating a causal mask, suggesting an issue with the shape of position_ids being passed to the model.
Expected Behavior
The training process should start successfully without runtime errors, as the 3B model is a smaller variant of the officially supported 7B model.
Environment and Dependencies
OS & Kernel: Linux dev-7ed99d0c-dc3b-4cc6-97f4-03d3ae19b28e-rr4k4 4.18.0-372.9.1.el8.x86_64 #1 SMP Tue May 10 14:48:47 UTC 2022 x86_64 x86_64 x86_64 GNU/Linux
GPU & NVIDIA Driver:
GPU: 4x NVIDIA H100 80GB HBM3
Driver Version: 570.86.10
CUDA Version: 12.8
Python Version: Python 3.10.12
Key Libraries:
verl: git+https://github.com/volcengine/verl@8e1fc24
transformers: 4.53.2
torch: 2.7.0
peft: 0.16.0
accelerate: 1.9.0
ray: 2.47.1
flash_attn: 2.7.4