Skip to content

Commit e9065e7

Browse files
committed
fix position ids for latest transformers
1 parent e302711 commit e9065e7

File tree

5 files changed

+28
-11
lines changed

5 files changed

+28
-11
lines changed

tests/test_dataset.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import pytest
1516
import torch
1617
from PIL.Image import Image
1718

1819
from verl.utils.dataset import RLHFDataset
1920
from verl.utils.tokenizer import get_processor, get_tokenizer
2021

2122

22-
def test_image_dataset():
23-
tokenizer = get_tokenizer("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=True)
24-
processor = get_processor("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=True)
23+
@pytest.mark.parametrize("use_fast", [True, False])
24+
def test_image_dataset(use_fast: bool):
25+
tokenizer = get_tokenizer("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=use_fast)
26+
processor = get_processor("Qwen/Qwen2.5-VL-7B-Instruct", use_fast=use_fast)
2527
dataset = RLHFDataset(
2628
data_path="hiyouga/geometry3k@test",
2729
tokenizer=tokenizer,
@@ -44,8 +46,8 @@ def test_image_dataset():
4446
}
4547
assert torch.all(dataset[0]["input_ids"] == torch.tensor(token_ids))
4648
assert torch.all(dataset[0]["attention_mask"] == torch.ones(16))
47-
assert torch.all(dataset[0]["position_ids"] == torch.arange(16).unsqueeze(0).expand(3, -1))
48-
assert list(dataset[0]["position_ids"].size()) == [3, 16] # avoid fake positive caused by broadcasting
49+
assert torch.all(dataset[0]["position_ids"] == torch.arange(16).unsqueeze(0).expand(4, -1))
50+
assert list(dataset[0]["position_ids"].size()) == [4, 16] # avoid fake positive caused by broadcasting
4951
assert dataset[0]["raw_prompt_ids"] == token_ids
5052
assert dataset[0]["ground_truth"] == "48"
5153
assert isinstance(dataset[0]["multi_modal_data"]["images"][0], Image)

verl/models/transformers/qwen2_vl.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,19 @@ def get_rope_index(
147147
return position_ids
148148

149149

150+
def process_position_ids(position_ids: torch.Tensor) -> torch.Tensor:
151+
if position_ids.dim() != 3 or position_ids.size(0) != 4:
152+
# we concat the text position ids with the 3D vision position ids by default
153+
# see https://github.com/huggingface/transformers/pull/39447
154+
raise ValueError("position_ids should be a 3D tensor of shape (4, batch_size, seq_length).")
155+
156+
if not is_transformers_version_greater_than("4.54.0"):
157+
# transformers < 4.54.0 only accepts vision position ids, so we discard the text position ids here
158+
position_ids = position_ids[1:]
159+
160+
return position_ids
161+
162+
150163
def qwen2_vl_attn_forward(
151164
self: "Qwen2VLAttention",
152165
hidden_states: torch.Tensor,
@@ -272,7 +285,7 @@ def qwen2_vl_forward_old(
272285
outputs = self.model(
273286
input_ids=None,
274287
attention_mask=attention_mask,
275-
position_ids=position_ids,
288+
position_ids=process_position_ids(position_ids),
276289
inputs_embeds=inputs_embeds,
277290
**kwargs,
278291
)
@@ -306,7 +319,7 @@ def qwen2_vl_base_forward_new(
306319
)
307320
outputs = self.language_model(
308321
input_ids=None,
309-
position_ids=position_ids,
322+
position_ids=process_position_ids(position_ids),
310323
attention_mask=attention_mask,
311324
inputs_embeds=inputs_embeds,
312325
**kwargs,

verl/utils/dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,14 +266,16 @@ def __getitem__(self, index):
266266

267267
if self.processor is not None and "Qwen2VLImageProcessor" in self.processor.image_processor.__class__.__name__:
268268
# qwen2vl mrope
269-
position_ids = get_rope_index(
269+
vision_position_ids = get_rope_index(
270270
self.processor,
271271
input_ids=input_ids,
272272
image_grid_thw=model_inputs.get("image_grid_thw", None),
273273
video_grid_thw=model_inputs.get("video_grid_thw", None),
274274
second_per_grid_ts=model_inputs.get("second_per_grid_ts", None),
275275
attention_mask=attention_mask,
276276
) # (3, seq_length)
277+
text_position_ids = torch.arange(len(input_ids)).unsqueeze(0) # (1, seq_length)
278+
position_ids = torch.cat((text_position_ids, vision_position_ids), dim=0) # (4, seq_length)
277279
else:
278280
position_ids = torch.clip(attention_mask.cumsum(dim=0) - 1, min=0, max=None) # (seq_length,)
279281

verl/workers/actor/dp_actor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _forward_micro_batch(self, micro_batch: dict[str, torch.Tensor], temperature
7777
responses = micro_batch["responses"]
7878
response_length = responses.size(-1)
7979
if position_ids.dim() == 3: # qwen2vl mrope
80-
position_ids = position_ids.transpose(0, 1) # (bsz, 3, seqlen) -> (3, bsz, seqlen)
80+
position_ids = position_ids.transpose(0, 1) # (bsz, 4, seqlen) -> (4, bsz, seqlen)
8181

8282
multi_modal_inputs = defaultdict(list)
8383
if "multi_modal_inputs" in micro_batch:
@@ -96,7 +96,7 @@ def _forward_micro_batch(self, micro_batch: dict[str, torch.Tensor], temperature
9696
index_first_axis(rearrange(position_ids, "c b s ... -> (b s) c ..."), indices)
9797
.transpose(0, 1)
9898
.unsqueeze(1)
99-
) # (3, bsz, seqlen) -> (3, 1, bsz * seqlen)
99+
) # (4, bsz, seqlen) -> (4, 1, bsz * seqlen)
100100
else:
101101
position_ids_rmpad = index_first_axis(
102102
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices

verl/workers/rollout/vllm_rollout_spmd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
210210
delta_position_id = torch.arange(1, response_length + 1, device=position_ids.device)
211211
delta_position_id = delta_position_id.view(1, -1).expand(batch_size, -1)
212212
if position_ids.dim() == 3: # qwen2vl mrope
213-
delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 3, -1)
213+
delta_position_id = delta_position_id.view(batch_size, 1, -1).expand(batch_size, 4, -1)
214214

215215
# prompt: left pad + response: right pad
216216
# attention_mask: [0,0,0,0,1,1,1,1 | 1,1,1,0,0,0,0,0]

0 commit comments

Comments
 (0)