Skip to content

Commit 16a13d8

Browse files
authored
[misc] feat: support logging rollout prob vs. actor probs for debugging purpose (#1712)
### Checklist Before Starting - [X] Search for similar PR(s). ### What does this PR do? - Support logging rollout probs vs. actor probs for debugging purpose - Support both vllm and sglang async ### High-Level Design > Demonstrate the high-level design if this PR is complex. ### Specific Changes > List the specific changes. ### API > Demonstrate how the API changes if any. ### Usage Example > Provide usage example(s) for easier usage. ```python # Add code snippet or script demonstrating how to use this ``` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluatuion results, etc. ### Additional Info. - **Issue Number**: Fixes issue # or discussion # if any. - **Training**: [Note which backend this PR will affect: FSDP, Megatron, both, or none] - **Inference**: [Note which backend this PR will affect: vLLM, SGLang, both, or none] ### Checklist Before Submitting - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [ ] Add `[BREAKING]` to the PR title if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add CI test(s) if necessary.
1 parent 34e409b commit 16a13d8

File tree

7 files changed

+180
-17
lines changed

7 files changed

+180
-17
lines changed

recipe/dapo/test_dapo_7b_math.sh

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
set -xeuo pipefail
33

44
project_name='DAPO'
5-
exp_name='DAPO-Qwen2.5-7b-MATH-0519a1'
5+
exp_name='DAPO-Qwen2.5-7b-MATH-0527a1'
66

77
adv_estimator=grpo
88

@@ -27,10 +27,11 @@ n_resp_per_prompt=16
2727
train_prompt_mini_bsz=32
2828

2929
# Ray
30-
RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
31-
WORKING_DIR=${WORKING_DIR:-"${PWD}"}
32-
RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
33-
NNODES=${NNODES:-4}
30+
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
31+
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
32+
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
33+
NNODES=${NNODES:-8}
34+
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
3435
# Paths
3536
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
3637
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
@@ -53,6 +54,8 @@ offload=True
5354
gen_tp=4
5455
fsdp_size=32
5556

57+
# remember to set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 for this model
58+
5659
python3 -m verl.trainer.main_ppo \
5760
data.train_files="${TRAIN_FILE}" \
5861
data.val_files="${TEST_FILE}" \
@@ -71,6 +74,7 @@ python3 -m verl.trainer.main_ppo \
7174
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
7275
actor_rollout_ref.actor.clip_ratio_c=10.0 \
7376
actor_rollout_ref.model.use_remove_padding=True \
77+
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
7478
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
7579
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
7680
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
@@ -113,12 +117,13 @@ python3 -m verl.trainer.main_ppo \
113117
trainer.logger=['console','wandb'] \
114118
trainer.project_name="${project_name}" \
115119
trainer.experiment_name="${exp_name}" \
116-
trainer.n_gpus_per_node=8 \
120+
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
117121
trainer.nnodes="${NNODES}" \
118-
trainer.val_before_train=False \
122+
trainer.val_before_train=True \
119123
trainer.test_freq=10 \
120124
trainer.save_freq=10 \
121125
trainer.total_epochs=10 \
126+
trainer.total_training_steps=200 \
122127
trainer.default_local_dir="${CKPTS_DIR}" \
123128
trainer.resume_mode=auto \
124129
trainer.log_val_generations=10
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#!/usr/bin/env bash
2+
set -xeuo pipefail
3+
4+
project_name='DAPO'
5+
exp_name='DAPO-Qwen3-30B-A3B-Base-MATH-0527a1'
6+
7+
adv_estimator=grpo
8+
9+
use_kl_in_reward=False
10+
kl_coef=0.0
11+
use_kl_loss=False
12+
kl_loss_coef=0.0
13+
14+
clip_ratio_low=0.2
15+
clip_ratio_high=0.28
16+
17+
max_prompt_length=$((1024 * 2))
18+
max_response_length=$((1024 * 8))
19+
enable_overlong_buffer=True
20+
overlong_buffer_len=$((1024 * 4))
21+
overlong_penalty_factor=1.0
22+
23+
loss_agg_mode="token-mean"
24+
25+
train_prompt_bsz=512
26+
n_resp_per_prompt=16
27+
train_prompt_mini_bsz=32
28+
29+
# Ray
30+
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
31+
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
32+
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
33+
NNODES=${NNODES:-8}
34+
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
35+
# Paths
36+
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
37+
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base"}
38+
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
39+
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
40+
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
41+
42+
# Algorithm
43+
temperature=1.0
44+
top_p=1.0
45+
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
46+
val_top_p=0.7
47+
48+
# Performance Related Parameter
49+
sp_size=4
50+
use_dynamic_bsz=True
51+
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
52+
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
53+
offload=True
54+
gen_tp=4
55+
fsdp_size=32
56+
57+
python3 -m verl.trainer.main_ppo \
58+
data.train_files="${TRAIN_FILE}" \
59+
data.val_files="${TEST_FILE}" \
60+
data.prompt_key=prompt \
61+
data.truncation='left' \
62+
data.max_prompt_length=${max_prompt_length} \
63+
data.max_response_length=${max_response_length} \
64+
data.train_batch_size=${train_prompt_bsz} \
65+
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
66+
algorithm.adv_estimator=${adv_estimator} \
67+
algorithm.use_kl_in_reward=${use_kl_in_reward} \
68+
algorithm.kl_ctrl.kl_coef=${kl_coef} \
69+
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
70+
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
71+
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
72+
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
73+
actor_rollout_ref.actor.clip_ratio_c=10.0 \
74+
actor_rollout_ref.model.use_remove_padding=True \
75+
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
76+
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
77+
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
78+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
79+
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
80+
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
81+
actor_rollout_ref.model.path="${MODEL_PATH}" \
82+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
83+
actor_rollout_ref.actor.optim.lr=1e-6 \
84+
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
85+
actor_rollout_ref.actor.optim.weight_decay=0.1 \
86+
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
87+
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
88+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
89+
actor_rollout_ref.actor.entropy_coeff=0 \
90+
actor_rollout_ref.actor.grad_clip=1.0 \
91+
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
92+
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
93+
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
94+
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
95+
actor_rollout_ref.rollout.enable_chunked_prefill=True \
96+
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
97+
actor_rollout_ref.rollout.temperature=${temperature} \
98+
actor_rollout_ref.rollout.top_p=${top_p} \
99+
actor_rollout_ref.rollout.top_k=${top_k} \
100+
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
101+
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
102+
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
103+
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
104+
actor_rollout_ref.rollout.val_kwargs.n=1 \
105+
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
106+
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
107+
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
108+
reward_model.reward_manager=dapo \
109+
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
110+
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
111+
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
112+
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
113+
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
114+
trainer.logger=['console','wandb'] \
115+
trainer.project_name="${project_name}" \
116+
trainer.experiment_name="${exp_name}" \
117+
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
118+
trainer.nnodes="${NNODES}" \
119+
trainer.val_before_train=True \
120+
trainer.test_freq=10 \
121+
trainer.save_freq=10 \
122+
trainer.total_epochs=10 \
123+
trainer.total_training_steps=300 \
124+
trainer.default_local_dir="${CKPTS_DIR}" \
125+
trainer.resume_mode=auto \
126+
trainer.log_val_generations=10

verl/trainer/ppo/ray_trainer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,30 @@ def fit(self):
10411041
old_log_prob.batch.pop("entropys")
10421042
batch = batch.union(old_log_prob)
10431043

1044+
if "rollout_log_probs" in batch.batch.keys():
1045+
# TODO: we may want to add diff of probs too.
1046+
rollout_old_log_probs = batch.batch["rollout_log_probs"]
1047+
actor_old_log_probs = batch.batch["old_log_probs"]
1048+
attention_mask = batch.batch["attention_mask"]
1049+
responses = batch.batch["responses"]
1050+
response_length = responses.size(1)
1051+
response_mask = attention_mask[:, -response_length:]
1052+
1053+
rollout_probs = torch.exp(rollout_old_log_probs)
1054+
actor_probs = torch.exp(actor_old_log_probs)
1055+
rollout_probs_diff = torch.abs(rollout_probs - actor_probs)
1056+
rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool())
1057+
rollout_probs_diff_max = torch.max(rollout_probs_diff)
1058+
rollout_probs_diff_mean = torch.mean(rollout_probs_diff)
1059+
rollout_probs_diff_std = torch.std(rollout_probs_diff)
1060+
metrics.update(
1061+
{
1062+
"training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(),
1063+
"training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(),
1064+
"training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(),
1065+
}
1066+
)
1067+
10441068
if self.use_reference_policy:
10451069
# compute reference log_prob
10461070
with _timer("ref", timing_raw):

verl/workers/rollout/sglang_rollout/async_sglang_rollout.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
365365

366366
# users can customize different sampling_params at different run
367367
with self.update_sampling_params(**kwargs):
368-
print(f"{self.sampling_params=}")
368+
# print(f"{self.sampling_params=}")
369369
if self._tp_rank == 0:
370370
loop = asyncio.get_event_loop()
371371
output = loop.run_until_complete(
@@ -390,11 +390,11 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
390390
out = _post_process_outputs(self.tokenizer, output)
391391

392392
response = out[0].to(idx.device)
393-
# log_probs = out[1].to(idx.device)
393+
rollout_log_probs = out[1].to(idx.device)
394394

395395
if response.shape[1] < self.config.response_length:
396396
response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)
397-
# log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id)
397+
rollout_log_probs = pad_sequence_to_length(rollout_log_probs, self.config.response_length, self.pad_token_id)
398398

399399
# utilize current sampling params
400400
if self.sampling_params.get("n", 1) > 1 and do_sample:
@@ -428,7 +428,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
428428
"prompts": idx,
429429
"responses": response,
430430
"input_ids": seq, # here input_ids become the whole sentences
431-
# 'old_log_probs': log_probs, # we will recompute old log prob with actor
431+
'rollout_log_probs': rollout_log_probs, # we will recompute old log prob with actor
432432
"attention_mask": attention_mask,
433433
"position_ids": position_ids,
434434
},

verl/workers/rollout/sglang_rollout/sglang_rollout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
307307

308308
# users can customize different sampling_params at different run
309309
with self.update_sampling_params(**kwargs):
310-
print(f"{self.sampling_params=}")
310+
# print(f"{self.sampling_params=}")
311311
output = self.inference_engine.generate(
312312
prompt=None, # because we have already convert it to prompt token id
313313
sampling_params=self.sampling_params,

verl/workers/rollout/vllm_rollout/vllm_rollout.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,11 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
229229
# TODO(sgm): disable logprob when recompute_log_prob is enable
230230
# if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)
231231
response = output[0].to(idx.device)
232-
# log_probs = output[1].to(idx.device)
232+
log_probs = output[1].to(idx.device)
233233

234234
if response.shape[1] < self.config.response_length:
235235
response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id)
236-
# log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id)
236+
log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id)
237237

238238
# utilize current sampling params
239239
if self.sampling_params.n > 1 and do_sample:
@@ -262,7 +262,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
262262
"prompts": idx,
263263
"responses": response,
264264
"input_ids": seq, # here input_ids become the whole sentences
265-
# 'old_log_probs': log_probs, # we will recompute old log prob with actor
265+
'rollout_log_probs': log_probs, # we will recompute old log prob with actor
266266
"attention_mask": attention_mask,
267267
"position_ids": position_ids,
268268
},

verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,11 +282,19 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
282282
# if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length)
283283

284284
response = []
285+
rollout_log_probs = []
285286
for output in outputs:
286287
for sample_id in range(len(output.outputs)):
287-
response.append(output.outputs[sample_id].token_ids)
288+
response_ids = output.outputs[sample_id].token_ids
289+
response.append(response_ids)
290+
curr_log_prob = []
291+
for i, logprob in enumerate(output.outputs[sample_id].logprobs):
292+
curr_log_prob.append(logprob[response_ids[i]].logprob)
293+
rollout_log_probs.append(curr_log_prob)
288294

289295
response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to(idx.device)
296+
rollout_log_probs = pad_2d_list_to_length(rollout_log_probs, -1, max_length=self.config.response_length).to(idx.device)
297+
rollout_log_probs = rollout_log_probs.to(torch.float32)
290298

291299
if self.sampling_params.n > 1 and do_sample:
292300
idx = _repeat_interleave(idx, self.sampling_params.n)
@@ -322,7 +330,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto:
322330
"prompts": idx,
323331
"responses": response,
324332
"input_ids": seq, # here input_ids become the whole sentences
325-
# 'old_log_probs': log_probs, # we will recompute old log prob with actor
333+
'rollout_log_probs': rollout_log_probs, # we will recompute old log prob with actor
326334
"attention_mask": attention_mask,
327335
"position_ids": position_ids,
328336
},

0 commit comments

Comments
 (0)