Skip to content

Commit fd0c68a

Browse files
committed
feat: remove left-right padding mode in FSDP engine
1. add transformation layer in actor/critic worker 2. support nested tensor in megatron engine 3. pass test_engine.py test for megatron/fsdp engines
1 parent 69b0127 commit fd0c68a

File tree

17 files changed

+494
-273
lines changed

17 files changed

+494
-273
lines changed

tests/models/test_engine.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def test_actor_engine(strategy):
8484
# init model
8585
wg.init_model()
8686

87+
print(f"test_actor_engine strategy: {strategy}, config: {config} after init_model")
88+
8789
batch_size = 8
8890
seqlen = 32
8991

@@ -100,8 +102,6 @@ def test_actor_engine(strategy):
100102

101103
global_token_num = torch.sum(attention_mask, dim=-1).tolist()
102104

103-
print(input_ids.float().mean(), attention_mask.float().mean())
104-
105105
responses = input_ids[:, response_length:]
106106
response_mask = attention_mask[:, response_length:]
107107

@@ -129,6 +129,7 @@ def test_actor_engine(strategy):
129129
hf_logprobs = logprobs_from_logits_naive(
130130
hf_output.logits[:, -response_length - 1 : -1, :].float(), input_ids[:, -response_length:]
131131
)
132+
132133
hf_logprobs_mean = torch.mean(hf_logprobs * response_mask)
133134
mcore_logprobs_mean = torch.mean(output.batch["old_log_probs"] * response_mask)
134135

@@ -173,6 +174,8 @@ def create_model():
173174
def test_critic_engine(strategy):
174175
ray.init()
175176

177+
torch.autograd.set_detect_anomaly(True)
178+
176179
path = create_model()
177180
model_config = HFModelConfig(path=path, load_tokenizer=False)
178181

@@ -209,6 +212,8 @@ def test_critic_engine(strategy):
209212
# init model
210213
wg.init_model()
211214

215+
print(f"test_critic_engine strategy: {strategy}, config: {config}")
216+
212217
batch_size = 8
213218
seqlen = 32
214219

@@ -257,6 +262,9 @@ def test_critic_engine(strategy):
257262

258263
engine_values = torch.mean(output.batch["values"] * response_mask)
259264

265+
print(f"engine_values: {output.batch['values']}")
266+
print(f"hf_values_mean: {hf_values_mean}, engine_values: {engine_values}")
267+
260268
torch.testing.assert_close(hf_values_mean, engine_values, atol=1e-2, rtol=1e-2)
261269

262270
data = data.union(output)
@@ -265,6 +273,7 @@ def test_critic_engine(strategy):
265273
data.batch["values"] = torch.rand_like(responses, dtype=torch.float32)
266274
data.batch["returns"] = torch.rand_like(responses, dtype=torch.float32)
267275

276+
print(f"before update critic: {data}")
268277
# update again
269278
ppo_metrics = wg.update_critic(data)
270279
print(ppo_metrics)
@@ -354,6 +363,8 @@ def test_per_tensor_generator(world_size, tmp_path, config, strategy):
354363
os.makedirs(os.path.dirname(rendezvous_file), exist_ok=True)
355364
# create a model
356365
model_path = create_actor_model(tmp_path, config)
366+
367+
print(f"test_per_tensor_generator world_size: {world_size}, strategy: {strategy}, config: {config}")
357368
# spawn workers
358369
mp.spawn(
359370
fn=_worker,

tests/special_e2e/sft/run_sft_engine_gsm8k.sh

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ PP_SIZE=${PP_SIZE:-1}
2929
VPP_SIZE=${VPP_SIZE:-null}
3030
CP_SIZE=${CP_SIZE:-1}
3131

32-
PAD_MODE=${PAD_MODE:-left_right}
32+
PAD_MODE=${PAD_MODE:-no_padding}
3333

3434
USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True}
3535

@@ -80,8 +80,6 @@ torchrun --standalone --nnodes=1 --nproc_per_node=${NUM_GPUS} ${ENTRYPOINT} \
8080
data.train_files="${TRAIN_FILES}" \
8181
data.val_files="${VAL_FILES}" \
8282
data.train_batch_size=256 \
83-
data.max_prompt_length=1024 \
84-
data.max_response_length=1024 \
8583
data.pad_mode=${PAD_MODE} \
8684
data.truncation=error \
8785
data.use_dynamic_bsz=True \

tests/special_e2e/sft/test_sft_engine_all.sh

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,6 @@ echo "run with single gpu as golden"
99
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp VERL_FILE_LOGGER_PATH=~/verl/test/log/golden.jsonl bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
1010

1111
# test with fsdp 1
12-
echo "run with sp1 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
13-
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
14-
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
15-
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
16-
echo "run with sp2 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
17-
BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
18-
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode left_right"
19-
BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
20-
2112
echo "run with sp1 fsdp_size2 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
2213
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
2314
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
@@ -27,18 +18,14 @@ BACKEND=fsdp SP_SIZE=2 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_pa
2718
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding"
2819
BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
2920

30-
# test use_remove_padding and pad_mode left_right/no_padding
31-
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode left_right use_remove_padding False"
32-
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=left_right USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
21+
# test use_remove_padding and pad_mode no_padding
3322
echo "run with sp4 fsdp_size4 num_gpus8 fsdp_strategy fsdp pad_mode no_padding use_remove_padding False"
3423
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp PAD_MODE=no_padding USE_REMOVE_PADDING=False bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
3524

3625

3726
# test with fsdp 2
38-
echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2 pad_mode left_right"
39-
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 PAD_MODE=left_right bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
40-
echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2 pad_mode no_padding"
41-
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 PAD_MODE=no_padding bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
27+
echo "run with sp1 fsdp_size1 num_gpus1 fsdp_strategy fsdp2"
28+
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=1 NUM_GPUS=1 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
4229

4330
echo "run with sp1 fsdp_size-1 num_gpus8 fsdp_strategy fsdp2"
4431
BACKEND=fsdp SP_SIZE=1 FSDP_SIZE=-1 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
@@ -50,6 +37,7 @@ BACKEND=fsdp SP_SIZE=4 FSDP_SIZE=4 NUM_GPUS=8 FSDP_STRATEGY=fsdp2 bash tests/spe
5037
# test with megatron
5138
echo "run with tp1 pp1 cp1 num_gpus1"
5239
BACKEND=megatron TP_SIZE=1 PP_SIZE=1 CP_SIZE=1 NUM_GPUS=1 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
40+
5341
echo "run with tp2 pp2 vpp2 cp1 num_gpus8"
5442
BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=2 CP_SIZE=1 NUM_GPUS=8 bash tests/special_e2e/sft/run_sft_engine_gsm8k.sh
5543

tests/utils/dataset/test_multiturn_sft_dataset_on_cpu.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -177,28 +177,26 @@ def test_multiturn_sft_dataset():
177177
assert torch.all(padded_item["attention_mask"][actual_length:] == 0), "Attention mask not set correctly for padding"
178178
assert torch.all(padded_item["loss_mask"][actual_length:] == 0), "Loss mask not set correctly for padding"
179179

180-
# test left right padding
180+
# test no-padding
181181
config = {
182182
"max_length": 512,
183183
"truncation": "error",
184184
"multiturn": {"messages_key": "messages"},
185-
"pad_mode": "left_right",
186-
"max_prompt_length": 64,
187-
"max_response_length": 64,
185+
"pad_mode": "no_padding",
188186
}
189187
dataset = MultiTurnSFTDataset(parquet_files=test_file, tokenizer=tokenizer, config=config)
190188

191189
item0 = dataset[0]
192190

193-
# make sure all the input_ids with attention_mask == 0 are all padding
194-
assert torch.all(item0["input_ids"][item0["attention_mask"] == 0] == tokenizer.pad_token_id)
191+
# Verify that the output contains expected keys for no-padding mode
192+
required_keys = ["input_ids", "position_ids", "loss_mask"]
193+
for key in required_keys:
194+
assert key in item0, f"Missing key {key} in no-padding mode dataset item"
195+
assert isinstance(item0[key], torch.Tensor), f"Expected torch.Tensor for {key} in no-padding mode"
195196

196197
# make sure assistant_text matches with expected
197-
assistant_text = tokenizer.decode(item0["responses"][item0["response_mask"] == 1])
198+
assistant_text = tokenizer.decode(item0["input_ids"][item0["loss_mask"] == 1])
198199
assert assistant_text == "2+2 equals 4.<|im_end|>\n4+4 equals 8.<|im_end|>\n"
199200

200-
# make sure responses are part of input_ids
201-
assert torch.all(item0["input_ids"][-item0["responses"].shape[0] :] == item0["responses"])
202-
203201
print("All tests passed!")
204202
print("Starting test...")

verl/models/mcore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .registry import (
1717
get_mcore_forward_fn,
1818
get_mcore_forward_fused_fn,
19+
get_mcore_forward_no_padding_fn,
1920
get_mcore_weight_converter,
2021
hf_to_mcore_config,
2122
init_mcore_model,
@@ -27,4 +28,5 @@
2728
"get_mcore_forward_fn",
2829
"get_mcore_weight_converter",
2930
"get_mcore_forward_fused_fn",
31+
"get_mcore_forward_no_padding_fn",
3032
]

verl/models/mcore/model_forward.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@
1616

1717
from verl.utils.megatron_utils import unwrap_model
1818

19-
from .util import postprocess_packed_seqs, preprocess_packed_seqs, recover_left_padding, remove_left_padding
19+
from .util import (
20+
postprocess_packed_seqs,
21+
postprocess_packed_seqs_no_padding,
22+
preprocess_packed_seqs,
23+
preprocess_packed_seqs_no_padding,
24+
recover_left_padding,
25+
remove_left_padding,
26+
)
2027

2128

2229
def gptmodel_forward(
@@ -37,13 +44,16 @@ def gptmodel_forward(
3744
if pack_seqs:
3845
batch_size, seq_len = attention_mask.shape[:2]
3946
input_ids_rmpad, packed_seq_params = preprocess_packed_seqs(input_ids, attention_mask, pre_process=pre_process)
47+
print(f"input_ids_rmpad shape: {input_ids_rmpad.shape}, packed_seq_params: {packed_seq_params}")
4048
input_ids_rmpad = input_ids_rmpad.contiguous()
4149
output_orig = model(
4250
input_ids=input_ids_rmpad,
4351
attention_mask=None,
4452
position_ids=position_ids,
4553
packed_seq_params=packed_seq_params,
4654
)
55+
print(f"output_orig: {output_orig}")
56+
4757
if post_process and logits_processor is not None:
4858
args = {
4959
k: preprocess_packed_seqs(v, attention_mask, pre_process=True)[0]
@@ -146,3 +156,55 @@ def gptmodel_forward_qwen2_5_vl(
146156
if value_model and post_process:
147157
output = output[..., 0]
148158
return output
159+
160+
161+
def gptmodel_forward_no_padding(
162+
model,
163+
input_ids,
164+
value_model=False,
165+
pack_seqs=True,
166+
logits_processor=None,
167+
logits_processor_args: dict = None,
168+
**kwargs,
169+
):
170+
"""Default forward pass for GPT models with optional sequence packing."""
171+
pre_process = unwrap_model(model).pre_process
172+
post_process = unwrap_model(model).post_process
173+
if pack_seqs:
174+
batch_size = input_ids.shape[0]
175+
input_ids_rmpad, packed_seq_params = preprocess_packed_seqs_no_padding(input_ids, pre_process=pre_process)
176+
input_ids_rmpad = input_ids_rmpad.contiguous()
177+
output_orig = model(
178+
input_ids=input_ids_rmpad,
179+
attention_mask=None,
180+
position_ids=None,
181+
packed_seq_params=packed_seq_params,
182+
)
183+
184+
if post_process and logits_processor is not None:
185+
args = {
186+
k: preprocess_packed_seqs_no_padding(v, pre_process=True)[0] for k, v in logits_processor_args.items()
187+
}
188+
output_dict = logits_processor(output_orig, **args)
189+
# print(f'gptmodel_forward_no_padding: {output_dict=}')
190+
output = {
191+
k: postprocess_packed_seqs_no_padding(
192+
v, packed_seq_params, input_ids, batch_size, post_process=post_process
193+
)
194+
for k, v in output_dict.items()
195+
}
196+
else:
197+
output = postprocess_packed_seqs_no_padding(
198+
output_orig, packed_seq_params, input_ids, batch_size, post_process=post_process
199+
)
200+
else:
201+
raise NotImplementedError("gptmodel_forward_no_padding only supports packed sequences")
202+
203+
if value_model and post_process:
204+
# output = output[..., 0]
205+
# while using nested tensor, the advanced indexing operation above will result in an error at backward, i.e.
206+
# ValueError: NestedTensor _nested_select_backward_default(grad_output: t, self: jt_all, dim: any, index: any)
207+
# so we use `squeeze` to remove the last dimension
208+
output = output.squeeze(-1)
209+
210+
return output

verl/models/mcore/registry.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
)
3636
from .model_forward import (
3737
gptmodel_forward,
38+
gptmodel_forward_no_padding,
3839
gptmodel_forward_qwen2_5_vl,
3940
)
4041
from .model_forward_fused import (
@@ -122,6 +123,23 @@ class SupportedModel(Enum):
122123
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward,
123124
}
124125

126+
# Registry for model forward functions
127+
MODEL_FORWARD_NOPAD_REGISTRY: dict[SupportedModel, Callable] = {
128+
SupportedModel.LLAMA: gptmodel_forward_no_padding,
129+
SupportedModel.QWEN2: gptmodel_forward_no_padding,
130+
SupportedModel.QWEN2_MOE: gptmodel_forward_no_padding,
131+
SupportedModel.MIXTRAL: gptmodel_forward_no_padding,
132+
SupportedModel.DEEPSEEK_V3: gptmodel_forward_no_padding,
133+
SupportedModel.QWEN2_5_VL: gptmodel_forward_no_padding,
134+
SupportedModel.LLAMA4: gptmodel_forward_no_padding,
135+
SupportedModel.QWEN3: gptmodel_forward_no_padding,
136+
SupportedModel.QWEN3_MOE: gptmodel_forward_no_padding,
137+
# SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl,
138+
SupportedModel.DEEPSEEK_V3: gptmodel_forward_no_padding,
139+
SupportedModel.GLM4_MOE: gptmodel_forward_no_padding,
140+
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward_no_padding,
141+
}
142+
125143
# Registry for model forward functions
126144
MODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = {
127145
SupportedModel.LLAMA: fused_forward_gptmodel,
@@ -227,6 +245,15 @@ def get_mcore_forward_fn(hf_config: PretrainedConfig) -> Callable:
227245
return MODEL_FORWARD_REGISTRY[model]
228246

229247

248+
def get_mcore_forward_no_padding_fn(hf_config: PretrainedConfig) -> Callable:
249+
"""
250+
Get the forward function for given model architecture.
251+
"""
252+
assert len(hf_config.architectures) == 1, "Only one architecture is supported for now"
253+
model = get_supported_model(hf_config.architectures[0])
254+
return MODEL_FORWARD_NOPAD_REGISTRY[model]
255+
256+
230257
def get_mcore_forward_fused_fn(hf_config: PretrainedConfig) -> Callable:
231258
"""
232259
Get the forward function for given model architecture.

0 commit comments

Comments
 (0)