From 04002ae4826060286e2ce68032306b1b24a9663d Mon Sep 17 00:00:00 2001 From: "zhangchi.usc1992" Date: Sun, 25 May 2025 15:14:26 +0800 Subject: [PATCH 1/8] return rollout log probs --- verl/trainer/ppo/ray_trainer.py | 23 ++++++++++++++++++- .../rollout/vllm_rollout/vllm_rollout.py | 8 ++++--- .../rollout/vllm_rollout/vllm_rollout_spmd.py | 12 ++++++++-- 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 5ebd8df7619..d8bd529267e 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -57,7 +57,7 @@ reduce_metrics, ) from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance -from verl.utils.torch_functional import masked_mean +from verl.utils.torch_functional import masked_mean, masked_var from verl.utils.tracking import ValidationGenerationsLogger from verl.workers.rollout.async_server import AsyncLLMServerManager @@ -978,6 +978,27 @@ def fit(self): old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) + if 'rollout_log_probs' in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + rollout_old_log_probs = batch.batch['rollout_log_probs'] + actor_old_log_probs = batch.batch['old_log_probs'] + attention_mask = batch.batch["attention_mask"] + responses = batch.batch["responses"] + response_length = responses.size(1) + response_mask = attention_mask[:, -response_length:] + log_probs_diff = torch.abs(rollout_old_log_probs - actor_old_log_probs) + log_probs_diff = torch.masked_select(log_probs_diff, response_mask.bool()) + + log_probs_diff_max = torch.max(log_probs_diff) + log_probs_diff_mean = torch.mean(log_probs_diff) + log_probs_diff_std = torch.std(log_probs_diff) + + metrics.update({ + "training/log_probs_diff_max": log_probs_diff_max.detach().item(), + "training/log_probs_diff_mean": log_probs_diff_mean.detach().item(), + "training/log_probs_diff_std": log_probs_diff_std.detach().item(), + }) + if self.use_reference_policy: # compute reference log_prob with _timer("ref", timing_raw): diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 37a39a5ee82..8b6a91ed937 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -229,11 +229,11 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # TODO(sgm): disable logprob when recompute_log_prob is enable # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length) response = output[0].to(idx.device) - # log_probs = output[1].to(idx.device) + log_probs = output[1].to(idx.device) if response.shape[1] < self.config.response_length: response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id) - # log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id) + log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id) # utilize current sampling params if self.sampling_params.n > 1 and do_sample: @@ -256,13 +256,15 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: response_attention_mask = get_response_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype) attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) + breakpoint() + # all the tp ranks should contain the same data here. data in all ranks are valid batch = TensorDict( { "prompts": idx, "responses": response, "input_ids": seq, # here input_ids become the whole sentences - # 'old_log_probs': log_probs, # we will recompute old log prob with actor + 'rollout_log_probs': log_probs, # we will recompute old log prob with actor "attention_mask": attention_mask, "position_ids": position_ids, }, diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py index e8ae44437dd..e6a162ef791 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py @@ -282,11 +282,19 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # if n = 1: (bs, response_length) ; if n > 1: (bs * n, response_length) response = [] + rollout_log_probs = [] for output in outputs: for sample_id in range(len(output.outputs)): - response.append(output.outputs[sample_id].token_ids) + response_ids = output.outputs[sample_id].token_ids + response.append(response_ids) + curr_log_prob = [] + for i, logprob in enumerate(output.outputs[sample_id].logprobs): + curr_log_prob.append(logprob[response_ids[i]].logprob) + rollout_log_probs.append(curr_log_prob) response = pad_2d_list_to_length(response, self.pad_token_id, max_length=self.config.response_length).to(idx.device) + rollout_log_probs = pad_2d_list_to_length(rollout_log_probs, -1, max_length=self.config.response_length).to(idx.device) + rollout_log_probs = rollout_log_probs.to(torch.float32) if self.sampling_params.n > 1 and do_sample: idx = _repeat_interleave(idx, self.sampling_params.n) @@ -322,7 +330,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: "prompts": idx, "responses": response, "input_ids": seq, # here input_ids become the whole sentences - # 'old_log_probs': log_probs, # we will recompute old log prob with actor + 'rollout_log_probs': rollout_log_probs, # we will recompute old log prob with actor "attention_mask": attention_mask, "position_ids": position_ids, }, From 1a300ea0140c7e728334cb285498704bba7c60e6 Mon Sep 17 00:00:00 2001 From: "zhangchi.usc1992" Date: Tue, 27 May 2025 08:27:07 +0800 Subject: [PATCH 2/8] remove breakpoint --- verl/workers/rollout/vllm_rollout/vllm_rollout.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 8b6a91ed937..06817b5d50f 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -256,8 +256,6 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: response_attention_mask = get_response_mask(response_id=response, eos_token=eos_token_id, dtype=attention_mask.dtype) attention_mask = torch.cat((attention_mask, response_attention_mask), dim=-1) - breakpoint() - # all the tp ranks should contain the same data here. data in all ranks are valid batch = TensorDict( { From 3f4d9bb16e5fad245b8681b3afd0cda37cbb612c Mon Sep 17 00:00:00 2001 From: "zhangchi.usc1992" Date: Tue, 27 May 2025 09:37:38 +0800 Subject: [PATCH 3/8] add sglang rollout logprobs --- verl/trainer/ppo/ray_trainer.py | 19 ++++++++++--------- .../sglang_rollout/async_sglang_rollout.py | 6 +++--- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index d8bd529267e..df45c4dfd6d 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -986,17 +986,18 @@ def fit(self): responses = batch.batch["responses"] response_length = responses.size(1) response_mask = attention_mask[:, -response_length:] - log_probs_diff = torch.abs(rollout_old_log_probs - actor_old_log_probs) - log_probs_diff = torch.masked_select(log_probs_diff, response_mask.bool()) - - log_probs_diff_max = torch.max(log_probs_diff) - log_probs_diff_mean = torch.mean(log_probs_diff) - log_probs_diff_std = torch.std(log_probs_diff) + rollout_probs = torch.exp(rollout_old_log_probs) + actor_probs = torch.exp(actor_old_log_probs) + rollout_probs_diff = torch.abs(rollout_probs - actor_probs) + rollout_probs_diff = torch.masked_select(rollout_probs_diff, response_mask.bool()) + rollout_probs_diff_max = torch.max(rollout_probs_diff) + rollout_probs_diff_mean = torch.mean(rollout_probs_diff) + rollout_probs_diff_std = torch.std(rollout_probs_diff) metrics.update({ - "training/log_probs_diff_max": log_probs_diff_max.detach().item(), - "training/log_probs_diff_mean": log_probs_diff_mean.detach().item(), - "training/log_probs_diff_std": log_probs_diff_std.detach().item(), + "training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(), + "training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(), + "training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(), }) if self.use_reference_policy: diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py b/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py index 3e8102483b8..46fdb15dff9 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py @@ -390,11 +390,11 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: out = _post_process_outputs(self.tokenizer, output) response = out[0].to(idx.device) - # log_probs = out[1].to(idx.device) + rollout_log_probs = out[1].to(idx.device) if response.shape[1] < self.config.response_length: response = pad_sequence_to_length(response, self.config.response_length, self.pad_token_id) - # log_probs = pad_sequence_to_length(log_probs, self.config.response_length, self.pad_token_id) + rollout_log_probs = pad_sequence_to_length(rollout_log_probs, self.config.response_length, self.pad_token_id) # utilize current sampling params if self.sampling_params.get("n", 1) > 1 and do_sample: @@ -428,7 +428,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: "prompts": idx, "responses": response, "input_ids": seq, # here input_ids become the whole sentences - # 'old_log_probs': log_probs, # we will recompute old log prob with actor + 'rollout_log_probs': rollout_log_probs, # we will recompute old log prob with actor "attention_mask": attention_mask, "position_ids": position_ids, }, From 663031dee61d57712d2e68b98ac2b5e724bd29db Mon Sep 17 00:00:00 2001 From: "zhangchi.usc1992" Date: Tue, 27 May 2025 09:53:23 +0800 Subject: [PATCH 4/8] remove verbose --- verl/workers/rollout/sglang_rollout/async_sglang_rollout.py | 2 +- verl/workers/rollout/sglang_rollout/sglang_rollout.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py b/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py index 46fdb15dff9..b501305e1a0 100644 --- a/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/async_sglang_rollout.py @@ -365,7 +365,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # users can customize different sampling_params at different run with self.update_sampling_params(**kwargs): - print(f"{self.sampling_params=}") + # print(f"{self.sampling_params=}") if self._tp_rank == 0: loop = asyncio.get_event_loop() output = loop.run_until_complete( diff --git a/verl/workers/rollout/sglang_rollout/sglang_rollout.py b/verl/workers/rollout/sglang_rollout/sglang_rollout.py index ed852f769f0..af30a568df5 100644 --- a/verl/workers/rollout/sglang_rollout/sglang_rollout.py +++ b/verl/workers/rollout/sglang_rollout/sglang_rollout.py @@ -307,7 +307,7 @@ def generate_sequences(self, prompts: DataProto, **kwargs) -> DataProto: # users can customize different sampling_params at different run with self.update_sampling_params(**kwargs): - print(f"{self.sampling_params=}") + # print(f"{self.sampling_params=}") output = self.inference_engine.generate( prompt=None, # because we have already convert it to prompt token id sampling_params=self.sampling_params, From 0623c381f2821519878ffffcd397d64233a0487c Mon Sep 17 00:00:00 2001 From: Chi Zhang Date: Tue, 27 May 2025 11:16:02 +0800 Subject: [PATCH 5/8] update --- verl/trainer/ppo/ray_trainer.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index df45c4dfd6d..fd298f95da7 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -57,7 +57,7 @@ reduce_metrics, ) from verl.utils.seqlen_balancing import get_seqlen_balanced_partitions, log_seqlen_unbalance -from verl.utils.torch_functional import masked_mean, masked_var +from verl.utils.torch_functional import masked_mean from verl.utils.tracking import ValidationGenerationsLogger from verl.workers.rollout.async_server import AsyncLLMServerManager @@ -978,10 +978,10 @@ def fit(self): old_log_prob.batch.pop("entropys") batch = batch.union(old_log_prob) - if 'rollout_log_probs' in batch.batch.keys(): + if "rollout_log_probs" in batch.batch.keys(): # TODO: we may want to add diff of probs too. - rollout_old_log_probs = batch.batch['rollout_log_probs'] - actor_old_log_probs = batch.batch['old_log_probs'] + rollout_old_log_probs = batch.batch["rollout_log_probs"] + actor_old_log_probs = batch.batch["old_log_probs"] attention_mask = batch.batch["attention_mask"] responses = batch.batch["responses"] response_length = responses.size(1) @@ -994,11 +994,13 @@ def fit(self): rollout_probs_diff_max = torch.max(rollout_probs_diff) rollout_probs_diff_mean = torch.mean(rollout_probs_diff) rollout_probs_diff_std = torch.std(rollout_probs_diff) - metrics.update({ - "training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(), - "training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(), - "training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(), - }) + metrics.update( + { + "training/rollout_probs_diff_max": rollout_probs_diff_max.detach().item(), + "training/rollout_probs_diff_mean": rollout_probs_diff_mean.detach().item(), + "training/rollout_probs_diff_std": rollout_probs_diff_std.detach().item(), + } + ) if self.use_reference_policy: # compute reference log_prob From aa70dac5b1867293bb245ff60b6525f4b4516065 Mon Sep 17 00:00:00 2001 From: "zhangchi.usc1992" Date: Tue, 27 May 2025 13:41:09 +0800 Subject: [PATCH 6/8] update dapo scripts --- recipe/dapo/test_dapo_7b_math.sh | 16 +-- recipe/dapo/test_dapo_qwen3_30b_math.sh | 126 ++++++++++++++++++++++++ 2 files changed, 135 insertions(+), 7 deletions(-) create mode 100644 recipe/dapo/test_dapo_qwen3_30b_math.sh diff --git a/recipe/dapo/test_dapo_7b_math.sh b/recipe/dapo/test_dapo_7b_math.sh index 824cdad566f..051799aa95a 100644 --- a/recipe/dapo/test_dapo_7b_math.sh +++ b/recipe/dapo/test_dapo_7b_math.sh @@ -2,7 +2,7 @@ set -xeuo pipefail project_name='DAPO' -exp_name='DAPO-Qwen2.5-7b-MATH-0519a1' +exp_name='DAPO-Qwen2.5-7b-MATH-0527a1' adv_estimator=grpo @@ -27,10 +27,11 @@ n_resp_per_prompt=16 train_prompt_mini_bsz=32 # Ray -RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} -WORKING_DIR=${WORKING_DIR:-"${PWD}"} -RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} -NNODES=${NNODES:-4} +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-8} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} # Paths RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} @@ -113,12 +114,13 @@ python3 -m verl.trainer.main_ppo \ trainer.logger=['console','wandb'] \ trainer.project_name="${project_name}" \ trainer.experiment_name="${exp_name}" \ - trainer.n_gpus_per_node=8 \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ trainer.nnodes="${NNODES}" \ - trainer.val_before_train=False \ + trainer.val_before_train=True \ trainer.test_freq=10 \ trainer.save_freq=10 \ trainer.total_epochs=10 \ + trainer.total_training_steps=200 \ trainer.default_local_dir="${CKPTS_DIR}" \ trainer.resume_mode=auto \ trainer.log_val_generations=10 diff --git a/recipe/dapo/test_dapo_qwen3_30b_math.sh b/recipe/dapo/test_dapo_qwen3_30b_math.sh new file mode 100644 index 00000000000..56ebd0397ef --- /dev/null +++ b/recipe/dapo/test_dapo_qwen3_30b_math.sh @@ -0,0 +1,126 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen3-30B-A3B-Base-MATH-0527a1' + +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +loss_agg_mode="token-mean" + +train_prompt_bsz=512 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +NNODES=${NNODES:-8} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +sp_size=4 +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +offload=True +gen_tp=4 +fsdp_size=32 + +python3 -m verl.trainer.main_ppo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','wandb'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + trainer.nnodes="${NNODES}" \ + trainer.val_before_train=True \ + trainer.test_freq=10 \ + trainer.save_freq=10 \ + trainer.total_epochs=10 \ + trainer.total_training_steps=300 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.log_val_generations=10 From 3592251ee1a5c2fae6bddcf3e51440cab63c07bd Mon Sep 17 00:00:00 2001 From: "zhangchi.usc1992" Date: Tue, 27 May 2025 13:49:39 +0800 Subject: [PATCH 7/8] fix max_position_embeddings --- recipe/dapo/test_dapo_7b_math.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/recipe/dapo/test_dapo_7b_math.sh b/recipe/dapo/test_dapo_7b_math.sh index 051799aa95a..cf1e5b96ac9 100644 --- a/recipe/dapo/test_dapo_7b_math.sh +++ b/recipe/dapo/test_dapo_7b_math.sh @@ -72,6 +72,7 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ actor_rollout_ref.actor.clip_ratio_c=10.0 \ actor_rollout_ref.model.use_remove_padding=True \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ From 3dcb6d578b5d71a712134a99666e1f7446b8f465 Mon Sep 17 00:00:00 2001 From: "zhangchi.usc1992" Date: Tue, 27 May 2025 13:55:00 +0800 Subject: [PATCH 8/8] add comments --- recipe/dapo/test_dapo_7b_math.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/recipe/dapo/test_dapo_7b_math.sh b/recipe/dapo/test_dapo_7b_math.sh index cf1e5b96ac9..39918ac2d4b 100644 --- a/recipe/dapo/test_dapo_7b_math.sh +++ b/recipe/dapo/test_dapo_7b_math.sh @@ -54,6 +54,8 @@ offload=True gen_tp=4 fsdp_size=32 +# remember to set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 for this model + python3 -m verl.trainer.main_ppo \ data.train_files="${TRAIN_FILE}" \ data.val_files="${TEST_FILE}" \