Skip to content

Commit 4480199

Browse files
chenyingshuzhtmike
andcommitted
[data] feat: Add dataset for Qwen-Image (verl-project#6)
* add entroypoint (verl-project#1) * add training engine (verl-project#2) * add training engine * fix init * fix typs * move folders & make for two-forward pass in training loop (verl-project#4) * Add diffusion reward loop (verl-project#3) * init reward; add ocr reward * update disrm input * add unit test * pass ut * fix typos/bugs * update copyright * [fix] update customized reward func in UT (verl-project#5) * init reward; add ocr reward * update disrm input * add unit test * pass ut * fix typos/bugs * update copyright * update customized reward_fn * init dataset for Qwen-Image * pass UT * update return, update UT * pass UT * align with rl_dataset * pass UT * update filter long prompts * debug * clean code --------- Co-authored-by: Cheung Ka Wai <zhtmike@gmail.com>
1 parent 4d0a8d8 commit 4480199

5 files changed

Lines changed: 439 additions & 3 deletions

File tree

tests/experimental/reward_loop/test_diffusion_reward_model_genrm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def create_data_samples(tokenizer) -> DataProto:
4747

4848
data = DataProto.from_dict(
4949
tensors={
50-
"prompts": prompt_ids,
50+
"input_ids": prompt_ids,
5151
"responses": responses,
5252
},
5353
non_tensors={
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
import torch
17+
from omegaconf import OmegaConf
18+
from torch.utils.data import DataLoader
19+
20+
from verl import DataProto
21+
from verl.utils import hf_tokenizer
22+
from verl.utils.dataset import QwenDataset
23+
from verl.utils.dataset.rl_dataset import collate_fn
24+
25+
26+
def get_ocr_data():
27+
# prepare test dataset
28+
local_folder = os.path.expanduser("~/data/ocr/")
29+
local_path = os.path.join(local_folder, "train.txt")
30+
os.makedirs(local_folder, exist_ok=True)
31+
return local_path
32+
33+
34+
def test_qwen_dataset():
35+
tokenizer = hf_tokenizer(os.path.expanduser("~/models/Qwen/Qwen-Image"), trust_remote_code=True)
36+
local_path = get_ocr_data()
37+
config = OmegaConf.create(
38+
{
39+
"max_prompt_length": 1024,
40+
"filter_overlong_prompts": True,
41+
"data_source": "ocr",
42+
}
43+
)
44+
dataset = QwenDataset(data_files=local_path, tokenizer=tokenizer, config=config)
45+
46+
dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True, drop_last=True, collate_fn=collate_fn)
47+
48+
a = next(iter(dataloader))
49+
50+
tensors = {}
51+
non_tensors = {}
52+
53+
for key, val in a.items():
54+
if isinstance(val, torch.Tensor):
55+
tensors[key] = val
56+
else:
57+
non_tensors[key] = val
58+
59+
data_proto = DataProto.from_dict(tensors=tensors, non_tensors=non_tensors)
60+
assert len(data_proto) == 16
61+
assert "input_ids" in data_proto.batch
62+
assert "attention_mask" in data_proto.batch
63+
64+
65+
def test_qwen_dataset_with_max_samples():
66+
tokenizer = hf_tokenizer(os.path.expanduser("~/models/Qwen/Qwen-Image"), trust_remote_code=True)
67+
local_path = get_ocr_data()
68+
config = OmegaConf.create(
69+
{
70+
"max_prompt_length": 1024,
71+
"filter_overlong_prompts": True,
72+
"data_source": "ocr",
73+
}
74+
)
75+
dataset = QwenDataset(data_files=local_path, tokenizer=tokenizer, config=config, max_samples=5)
76+
assert len(dataset) == 5
77+
78+
# test split
79+
dataset_split = dataset.split(5)
80+
assert len(dataset_split) == 5

verl/experimental/reward_loop/reward_manager/diffusion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ def __init__(self, config, tokenizer, compute_score=None, reward_router_address=
4545
async def run_single(self, data: DataProto) -> dict:
4646
assert len(data) == 1, "Only support single data item"
4747
data_item = data[0]
48-
# prompt_str = self.tokenizer.decode(data_item.batch["prompts"], skip_special_tokens=True)
4948
response_image = data_item.batch["responses"]
5049
data_source = data_item.non_tensor_batch["data_source"]
5150
ground_truth = data_item.non_tensor_batch["reward_model"]["ground_truth"]

verl/utils/dataset/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from .qwen_dataset import QwenDataset
1516
from .rl_dataset import RLHFDataset
1617
from .rm_dataset import RMDataset
1718
from .sft_dataset import SFTDataset
1819

19-
__all__ = ["RLHFDataset", "RMDataset", "SFTDataset"]
20+
__all__ = ["RLHFDataset", "RMDataset", "SFTDataset", "QwenDataset"]

0 commit comments

Comments
 (0)