Skip to content

Commit 08eb12b

Browse files
houminzvermouth1992
authored andcommitted
[misc] feat: prototype deprecate DataProto and replace with Tensordict: part 2 (volcengine#3567)
### What does this PR do? This PR continues the work started in PR volcengine#2733, it adds support for variable sequence lengths in MultiTurnSFTDataset by introducing a `no_padding` option for the pad_mode. When this mode is active, sequences are not padded to a fixed length. - Implement no-padding mode for FSDP engine using nested tensors in sft trainer - Add test for no-padding mode both enable/disable use_remove_padding - Fix FSDP2 gradnorm issue ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: zhangchi.usc1992 <[email protected]>
1 parent a756e67 commit 08eb12b

File tree

12 files changed

+449
-97
lines changed

12 files changed

+449
-97
lines changed

tests/special_e2e/sft/run_sft_engine_gsm8k.sh

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ PP_SIZE=${PP_SIZE:-1}
2929
VPP_SIZE=${VPP_SIZE:-null}
3030
CP_SIZE=${CP_SIZE:-1}
3131

32+
PAD_MODE=${PAD_MODE:-left_right}
33+
34+
USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True}
35+
3236
FSDP_ENGINE_CONFIG="\
3337
engine=${backend} \
3438
optim=${backend} \
@@ -63,11 +67,11 @@ MEGATRON_ENGINE_CONFIG="\
6367
if [ "$backend" = "fsdp" ]; then
6468
ENGINE_CONFIG="$FSDP_ENGINE_CONFIG"
6569
echo "Using fsdp engine"
66-
exp_name=gsm8k-${backend}-${FSDP_STRATEGY}-sp${SP_SIZE}-fsdp${FSDP_SIZE}
70+
exp_name=gsm8k-${backend}-${FSDP_STRATEGY}-sp${SP_SIZE}-fsdp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}
6771
else
6872
ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG"
6973
echo "Using megatron engine"
70-
exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}
74+
exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}
7175
fi
7276

7377
mkdir -p "${ckpts_home}"
@@ -78,12 +82,13 @@ torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} ${ENTRYPOINT} \
7882
data.train_batch_size=256 \
7983
data.max_prompt_length=1024 \
8084
data.max_response_length=1024 \
81-
data.pad_mode=left_right \
85+
data.pad_mode=${PAD_MODE} \
8286
data.truncation=error \
8387
data.use_dynamic_bsz=True \
8488
data.max_token_len_per_gpu=8192 \
8589
data.messages_key=messages \
8690
model.path=$MODEL_PATH \
91+
model.use_remove_padding=${USE_REMOVE_PADDING} \
8792
${ENGINE_CONFIG} \
8893
trainer.test_freq=after_each_epoch \
8994
trainer.save_freq=-1 \

tests/special_e2e/sft/test_sft_engine_all.sh

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,43 @@ echo "run with single gpu as golden"
99
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp VERL_FILE_LOGGER_PATH=~/verl/test/log/golden.jsonl bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
1010

1111
# test with fsdp 1
12-
echo "run with sp1 fsdp_size2 num_gpus8 fsdp_strategy fsdp"
13-
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
14-
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp"
15-
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
16-
echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp"
17-
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
18-
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp"
19-
BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
12+
echo "run with sp1 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
13+
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
14+
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
15+
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
16+
echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
17+
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
18+
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
19+
BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
20+
21+
echo "run with sp1 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
22+
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
23+
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
24+
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
25+
echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
26+
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
27+
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
28+
BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
29+
30+
# test use_remove_padding and pad_mode left_right/no_padding
31+
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode left_right use_remove_padding False"
32+
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
33+
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding use_remove_padding False"
34+
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
35+
2036

2137
# test with fsdp 2
22-
echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2"
23-
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
24-
25-
# TODO: toggle the follow tests when the grad norm of fsdp is fixed
26-
# echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp2"
27-
# BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
28-
# echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp2"
29-
# BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
30-
# BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
31-
# BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
38+
echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2 pad_mode left_right"
39+
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
40+
echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2 pad_mode no_padding"
41+
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
42+
43+
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp2"
44+
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
45+
echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp2"
46+
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
47+
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
48+
BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
3249

3350
# test with megatron
3451
echo "run with tp1 pp1 cp1 num_gpus1"

tests/test_protocol_v2_on_cpu.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,45 @@ def test_tensor_dict_constructor():
8888
assert data["name"] == "abdce"
8989

9090

91+
def test_index_select_tensor_dict():
92+
vocab_size = 128
93+
a = torch.randint(low=0, high=vocab_size, size=(11,))
94+
b = torch.randint(low=0, high=vocab_size, size=(13,))
95+
c = torch.randint(low=0, high=vocab_size, size=(12,))
96+
d = torch.randint(low=0, high=vocab_size, size=(15,))
97+
input_ids = [a, b, c, d]
98+
input_ids = torch.nested.as_nested_tensor(input_ids, layout=torch.jagged)
99+
100+
padded_tensor = torch.randn(4, 10)
101+
non_tensor_dict = {"global_batch_size": "4"}
102+
103+
data = tu.get_tensordict(
104+
tensor_dict={
105+
"input_ids": input_ids,
106+
"padded_tensor": padded_tensor,
107+
},
108+
non_tensor_dict=non_tensor_dict,
109+
)
110+
111+
assert data.batch_size == torch.Size([4])
112+
113+
# test index select
114+
indices = torch.tensor([1, 3])
115+
selected_data = tu.index_select_tensor_dict(data, indices)
116+
117+
assert selected_data.batch_size == torch.Size([2])
118+
119+
target_input_ids = torch.nested.as_nested_tensor([input_ids[idx] for idx in indices], layout=torch.jagged)
120+
target_select_data = tu.get_tensordict(
121+
tensor_dict={
122+
"input_ids": target_input_ids,
123+
"padded_tensor": padded_tensor[indices],
124+
},
125+
non_tensor_dict=non_tensor_dict,
126+
)
127+
tu.assert_tensordict_eq(selected_data, target_select_data)
128+
129+
91130
def test_tensordict_with_images():
92131
# each sample contains a sequence with multiple images of different sizes
93132
vocab_size = 128
@@ -173,6 +212,37 @@ def test_tensordict_eq():
173212
with pytest.raises(AssertionError):
174213
tu.assert_tensordict_eq(data, data2)
175214

215+
tensor_list = [
216+
torch.tensor([1, 2, 3, 3, 2]),
217+
torch.tensor([4, 5]),
218+
torch.tensor([7, 8, 10, 14]),
219+
torch.tensor([10, 11, 12]),
220+
torch.tensor([13, 14, 15, 18]),
221+
torch.tensor([16, 17]),
222+
]
223+
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
224+
data_sources = ["abc", "def", "abc", "def", "pol", "klj"]
225+
non_tensor_dict = {"train_sample_kwargs": {"top_p": 1.0}, "val_sample_kwargs": {"top_p": 0.7}}
226+
data3 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
227+
228+
tensor_list[0] = torch.tensor([1, 2, 3, 3, 2])
229+
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
230+
data4 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
231+
tu.assert_tensordict_eq(data3, data4)
232+
233+
tensor_list[0] = torch.tensor([1, 2, 4])
234+
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
235+
data5 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
236+
with pytest.raises(AssertionError):
237+
tu.assert_tensordict_eq(data3, data5)
238+
239+
tensor_list[0] = torch.tensor([4, 5])
240+
tensor_list[1] = torch.tensor([1, 2, 3, 3, 2])
241+
obs = torch.nested.as_nested_tensor(tensor_list, layout=torch.jagged)
242+
data6 = tu.get_tensordict({"obs": obs, "data_sources": data_sources}, non_tensor_dict=non_tensor_dict)
243+
with pytest.raises(AssertionError):
244+
tu.assert_tensordict_eq(data3, data6)
245+
176246

177247
def test_tensor_dict_make_iterator():
178248
obs = torch.tensor([1, 2, 3, 4, 5, 6])

verl/trainer/sft_trainer.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
from verl.utils import tensordict_utils as tu
3434
from verl.utils.checkpoint import CheckpointHandler
35+
from verl.utils.dataset.dataset_utils import SFTTensorCollator
3536
from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset
3637
from verl.utils.device import get_device_name, is_cuda_available, is_npu_available
3738
from verl.utils.distributed import destroy_global_process_group
@@ -167,11 +168,13 @@ def _build_dataloader(self):
167168

168169
self.global_batch_size = config.data.train_batch_size
169170
self.train_batch_size_per_dp = self.global_batch_size // dp_size
171+
self.collate_fn = SFTTensorCollator(config.data.pad_mode)
170172

171173
self.train_dataloader = StatefulDataLoader(
172174
dataset=self.train_dataset,
173175
batch_size=self.train_batch_size_per_dp,
174176
sampler=self.train_sampler,
177+
collate_fn=self.collate_fn,
175178
num_workers=8,
176179
pin_memory=True,
177180
drop_last=True,
@@ -185,6 +188,7 @@ def _build_dataloader(self):
185188
dataset=self.val_dataset,
186189
batch_size=self.train_batch_size_per_dp,
187190
sampler=self.val_sampler,
191+
collate_fn=self.collate_fn,
188192
num_workers=8,
189193
pin_memory=True,
190194
drop_last=True,
@@ -227,11 +231,14 @@ def fit(self):
227231
start_epoch = global_step // self.steps_per_epoch
228232

229233
meta_info = {
234+
"use_remove_padding": self.config.model.use_remove_padding,
230235
"use_dynamic_bsz": self.config.data.use_dynamic_bsz,
231236
"max_token_len_per_gpu": self.config.data.max_token_len_per_gpu,
232237
"micro_batch_size_per_gpu": self.config.data.micro_batch_size_per_gpu,
233238
"temperature": 1.0,
234239
"global_batch_size": self.global_batch_size,
240+
"pad_mode": self.config.data.pad_mode,
241+
"pad_token_id": self.model_config.tokenizer.pad_token_id,
235242
}
236243

237244
train_time = 0
@@ -263,7 +270,12 @@ def fit(self):
263270
loss = torch.mean(torch.tensor(metrics["loss"], device=self.device_name))
264271

265272
# mean over dp group
266-
batch_seqlens = data["attention_mask"].sum(dim=-1).to(self.device_name) # (global_bsz // dp)
273+
is_nested = data["input_ids"].is_nested
274+
if is_nested:
275+
batch_seqlens: torch.Tensor = data["input_ids"].offsets().diff()
276+
else:
277+
batch_seqlens: torch.Tensor = data["attention_mask"].sum(dim=-1)
278+
batch_seqlens = batch_seqlens.to(self.device_name) # (global_bsz // dp)
267279

268280
output_tensor = torch.randint(
269281
0,
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from enum import Enum
17+
18+
import torch
19+
20+
21+
class DatasetPadMode(str, Enum):
22+
"""Padding mode for dataset"""
23+
24+
RIGHT = "right"
25+
LEFT_RIGHT = "left_right"
26+
NO_PADDING = "no_padding"
27+
28+
29+
class SFTTensorCollator:
30+
"""
31+
A custom collate_fn that handles batching of sequences.
32+
1. for variable-length sequences, convert them into NestedTensors.
33+
2. for fixed-length sequences, use default_collate.
34+
"""
35+
36+
def __init__(self, pad_mode: DatasetPadMode = DatasetPadMode.LEFT_RIGHT):
37+
self.pad_mode = pad_mode
38+
39+
def __call__(self, batch: list[dict[str, any]]) -> dict[str, any]:
40+
if self.pad_mode == DatasetPadMode.NO_PADDING:
41+
return self.collate_variable_batch(batch)
42+
elif self.pad_mode in [DatasetPadMode.RIGHT, DatasetPadMode.LEFT_RIGHT]:
43+
from torch.utils.data import default_collate
44+
45+
return default_collate(batch)
46+
else:
47+
raise NotImplementedError(f"pad_mode {self.pad_mode} not implemented")
48+
49+
def collate_variable_batch(self, batch: list[dict[str, any]]) -> dict[str, any]:
50+
"""
51+
Collates a list of samples into a single batch.
52+
53+
Args:
54+
batch: A list of dictionary samples from the dataset.
55+
56+
Returns:
57+
A dictionary representing the batched data, with variable-length
58+
sequences converted to NestedTensors.
59+
"""
60+
61+
final_batch = {}
62+
63+
tensor_keys = [key for key in batch[0].keys() if isinstance(batch[0][key], torch.Tensor)]
64+
65+
# Handle tensor values by creating a NestedTensor.
66+
for key in tensor_keys:
67+
tensors = [item[key] for item in batch]
68+
final_batch[key] = torch.nested.as_nested_tensor(tensors, layout=torch.jagged)
69+
70+
return final_batch

verl/utils/dataset/multiturn_sft_dataset.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from transformers import PreTrainedTokenizer
2828

2929
from verl.utils import hf_tokenizer
30+
from verl.utils.dataset.dataset_utils import DatasetPadMode
3031
from verl.utils.fs import copy_local_path_from_hdfs
3132
from verl.utils.model import compute_position_id_with_mask
3233
from verl.utils.torch_functional import pad_sequence_to_length, postprocess_data
@@ -54,8 +55,8 @@ def __init__(self, parquet_files: str | list[str], tokenizer, config=None):
5455
# Set defaults and extract parameters from config if provided
5556
config = config or {}
5657
self.pad_mode = config.get("pad_mode", "right")
57-
assert self.pad_mode in ["right", "left_right"], (
58-
f"Expect pad_mode to be 'right' or 'left_right'. Got {self.pad_mode}"
58+
assert self.pad_mode in ["right", "left_right", "no_padding"], (
59+
f"Expect pad_mode to be 'right', 'left_right' or 'no_padding'. Got {self.pad_mode}"
5960
)
6061
self.truncation = config.get("truncation", "error")
6162
# for right padding
@@ -328,7 +329,7 @@ def __getitem__(self, item):
328329

329330
sequence_length = input_ids.shape[0]
330331
# Handle sequence length
331-
if self.pad_mode == "right":
332+
if self.pad_mode == DatasetPadMode.RIGHT:
332333
if sequence_length < self.max_length:
333334
# Pad sequences
334335
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0
@@ -364,7 +365,7 @@ def __getitem__(self, item):
364365
"position_ids": position_ids,
365366
"loss_mask": loss_mask,
366367
}
367-
elif self.pad_mode == "left_right":
368+
elif self.pad_mode == DatasetPadMode.LEFT_RIGHT:
368369
assert self.truncation == "error", "Only support error truncation for left_right pad mode"
369370
prompt_str = self.tokenizer.apply_chat_template(
370371
messages[:prompt_message_length],
@@ -426,3 +427,16 @@ def __getitem__(self, item):
426427
"responses": response_ids,
427428
"response_mask": response_loss_mask,
428429
}
430+
elif self.pad_mode == DatasetPadMode.NO_PADDING:
431+
# truncate input_ids if it is longer than max_length
432+
if len(input_ids) > self.max_length:
433+
input_ids = input_ids[: self.max_length]
434+
loss_mask = loss_mask[: self.max_length]
435+
# create position IDs
436+
position_ids = torch.arange(len(input_ids), dtype=torch.long)
437+
# return nested tensor with out padding
438+
return {
439+
"input_ids": input_ids,
440+
"position_ids": position_ids,
441+
"loss_mask": loss_mask,
442+
}

0 commit comments

Comments
 (0)