Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions .github/workflows/e2e_ppo_trainer_megatron_sglang.yml
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,16 @@ jobs:
exp_name="deepseek-coder-1.3b-instruct-megatron-gsm8k-minimal"
python -m verl.model_merger test --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface
python -m verl.model_merger test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface
- name: Running GRPO GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Deepseek)
- name: Profiling GRPO GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Deepseek)
run: |
ray stop --force
ENGINE=sglang ADV_ESTIMATOR=grpo USE_DYNAMIC_BSZ=False MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/special_e2e/run_ppo_trainer_megatron.sh
PROFILE_ENABLE=True ENGINE=sglang ADV_ESTIMATOR=grpo USE_DYNAMIC_BSZ=False MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/special_e2e/run_ppo_trainer_megatron.sh
if [ -z "$( ls -A '/tmp/ray/session_latest/logs/nsight/' )" ]; then
echo "[ERROR] not found any profiling files"
exit 1
else
echo "[SUCCESS] profile success"
fi
- name: clean up
run: |
rm -rf checkpoints
Expand Down
10 changes: 5 additions & 5 deletions docs/ascend_tutorial/ascend_profiling.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ Last updated: 07/24/2025.

通过 ppo_trainer.yaml 中的参数控制采集步数和模式:

- profiler: 控制采集的rank和模式
- global_profiler: 控制采集的rank和模式

- tool: 使用的采集工具,选项有 nsys、npu、torch、torch_memory。
- steps: 此参数可以设置为包含采集步数的列表,例如 [2, 4],表示将采集第2步和第4步。如果设置为 null,则不进行采集。
- save_path: 保存采集数据的路径。默认值为 "outputs/profile"。

通过 ``profiler.tool_config.npu`` 中的参数控制具体采集行为:
通过 ``global_profiler.global_tool_config.npu`` 中的参数控制具体采集行为:

- level: 采集级别—选项有 level_none、level0、level1 和 level2

Expand Down Expand Up @@ -63,15 +63,15 @@ Last updated: 07/24/2025.

.. code:: yaml

profiler:
global_profiler:
steps: null # disable profile

端到端采集
~~~~~~~~~~~~~~~~~~~~~

.. code:: yaml

profiler:
global_profiler:
steps: [1, 2, 5]
discrete: False
actor_rollout_ref:
Expand All @@ -87,7 +87,7 @@ Last updated: 07/24/2025.

.. code:: yaml

profiler:
global_profiler:
discrete: True


Expand Down
10 changes: 5 additions & 5 deletions docs/ascend_tutorial/ascend_profiling_en.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Global collection control
Use parameters in ppo_trainer.yaml to control the collection mode
and steps.

- profiler: Control the ranks and mode of profiling
- global_profiler: Control the ranks and mode of profiling

- tool: The profiling tool to use, options are nsys, npu, torch,
torch_memory.
Expand All @@ -30,7 +30,7 @@ and steps.
- save_path: The path to save the collected data. Default is
"outputs/profile".

Use parameters in ``profiler.tool_config.npu`` to control npu profiler behavior:
Use parameters in ``global_profiler.global_tool_config.npu`` to control npu profiler behavior:

- level: Collection level—options are level_none, level0, level1, and
level2
Expand Down Expand Up @@ -77,15 +77,15 @@ Disabling collection

.. code:: yaml

profiler:
global_profiler:
steps: null # disable profile

End-to-End collection
~~~~~~~~~~~~~~~~~~~~~

.. code:: yaml

profiler:
global_profiler:
steps: [1, 2, 5]
discrete: False
actor_rollout_ref:
Expand All @@ -100,7 +100,7 @@ Discrete Mode Collection

.. code:: yaml

profiler:
global_profiler:
discrete: True


Expand Down
15 changes: 7 additions & 8 deletions docs/perf/nsight_profiling.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,23 @@ Nsight Systems version is important, please reference `docker/Dockerfile.vllm.sg

verl has one single controller process and multiple worker processes. Both controller and worker processes can be profiled. Since the controller process can be executed in any nodes in the cluster, there is a message printed in the logging to indicate the controller process node hostname and process id.

In `profiler`, three new config entries control the profiler behaviors:
In `global_profiler`, three new config entries control the profiler behaviors:

* **`profiler.steps`**. List of step numbers at which profiling should be performed. For example: [1, 2, 5] will profile steps 1, 2, and 5. And ``null`` means no profiling.
* **`global_profiler.steps`**. List of step numbers at which profiling should be performed. For example: [1, 2, 5] will profile steps 1, 2, and 5. And ``null`` means no profiling.

* **`profiler.profile_continuous_steps`**. If true, and the following `profiler.discrete==False`, then the continuous steps in `profiler.steps` will be combined into one database. For example the above step 1 and 2 are in one database, and 5 in another. If false, every step occupies at least one database. The reason for this config is to observe the program behaviors between steps.
* **`global_profiler.profile_continuous_steps`**. If true, and the following `global_profiler.discrete==False`, then the continuous steps in `global_profiler.steps` will be combined into one database. For example the above step 1 and 2 are in one database, and 5 in another. If false, every step occupies at least one database. The reason for this config is to observe the program behaviors between steps.

Nsys options in controller nodes and worker nodes are configured in `trainer`:
Nsys options in controller nodes and worker nodes are configured in `global_profiler.global_tool_config.nsys`:

* **`trainer.controller_nsight_options`**. This config group is for the single controller. All fields in this config group will be just sent to Nsight Systems when Ray starts the controller process. `ppo_trainer.yaml` provides a workable example. Users can reference [Nsight Systems manual](https://docs.nvidia.com/nsight-systems/UserGuide/index.html) and [Ray user guide](https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html) for more details.
* **`trainer.worker_nsight_options`**. This config group is for the worker processes. Similarly all fields in this config group will be just sent to Nsight Systems when Ray starts the controller process. Capture range is used to control the profiler when to start and stop. So `capture-range: "cudaProfilerApi"` is fixed and does not change it. Users can change `capture-range-end` with some accurate calculation or just leave it `null`.
* **`global_profiler.global_tool_config.nsys.controller_nsight_options`**. This config group is for the single controller. All fields in this config group will be just sent to Nsight Systems when Ray starts the controller process. `ppo_trainer.yaml` provides a workable example. Users can reference [Nsight Systems manual](https://docs.nvidia.com/nsight-systems/UserGuide/index.html) and [Ray user guide](https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html) for more details.
* **`global_profiler.global_tool_config.nsys.worker_nsight_options`**. This config group is for the worker processes. Similarly all fields in this config group will be just sent to Nsight Systems when Ray starts the controller process. Capture range is used to control the profiler when to start and stop. So `capture-range: "cudaProfilerApi"` is fixed and does not change it. Users can change `capture-range-end` with some accurate calculation or just leave it `null`.

### Worker process profiling

Verl manages mulitiple RL roles, _Actor_, _Ref_, _Rollout_, _Critic_, _Reward_, which are implemented in different Worker classes. And these workers can be combined into one Ray Actor, running in a process group. Each RL role has its own profiling config group, `profiler`, which consists of three fields:

* **`all_ranks` and `ranks`**. When `all_ranks` is set `True` then all ranks will be profiled; when set `False`, `ranks` will be profiled. By default, verl profiles the whole training process in a series ` worker_process_<PID>.<RID>.nsys-rep` files for each process rank. PID is the process ID; RID is the capture range ID.
* **`discrete`**. When set `False`, all the roles actions in one training step will be dumped in one database. When set `True`, the actions annotated by `DistProfiler.annotate` will be dumped into a discrete database. In this case, each role's action occupies one `<RID>`.
* **`actor_rollout_ref`**. This Worker can be configured to contain at most 3 roles and executes together. So `actor_rollout_ref` has a `profiler` config and all the inside roles inherit it.
* **Verl collocate mode**. Verl can combine two Worker sub classes to one Worker Actor. In this case, the user should take care that the combined Workers have consistent `discrete`. The Nsight Systems profiler uses a `torch.cuda.profiler.start()` and `stop()` pair to dump a `<step>` database anyway.

### where to find the profiling data
Expand All @@ -56,7 +55,7 @@ To enable profiling for specific components and steps, modify your ppo_trainer.y
### Enable profiler and one database for one training step

```yaml
profiler:
global_profiler:
steps: [1, 2, 5]
discrete: False
actor_rollout_ref:
Expand Down
14 changes: 7 additions & 7 deletions examples/grpo_trainer/run_qwen2_5_7b_grpo_discrete_prof_npu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,11 @@ python3 -m verl.trainer.main_ppo \
trainer.test_freq=5 \
trainer.total_epochs=5 \
trainer.device=npu \
profiler.tool=npu \
profiler.steps=$PROFILE_STEPS \
profiler.save_path=$SAVE_PATH \
profiler.tool_config.npu.discrete=$DISCRETE \
profiler.tool_config.npu.contents=$CONTENTS \
profiler.tool_config.npu.level=$LEVEL \
profiler.tool_config.npu.analysis=$ANALYSIS
global_profiler.tool=npu \
global_profiler.steps=$PROFILE_STEPS \
global_profiler.save_path=$SAVE_PATH \
global_profiler.global_tool_config.npu.discrete=$DISCRETE \
global_profiler.global_tool_config.npu.contents=$CONTENTS \
global_profiler.global_tool_config.npu.level=$LEVEL \
global_profiler.global_tool_config.npu.analysis=$ANALYSIS
$@
14 changes: 7 additions & 7 deletions examples/grpo_trainer/run_qwen2_5_7b_grpo_e2e_prof_npu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ python3 -m verl.trainer.main_ppo \
trainer.test_freq=5 \
trainer.total_epochs=5 \
trainer.device=npu \
profiler.tool=npu \
profiler.steps=$PROFILE_STEPS \
profiler.save_path=$SAVE_PATH \
profiler.tool_config.npu.discrete=$DISCRETE \
profiler.tool_config.npu.contents=$CONTENTS \
profiler.tool_config.npu.level=$LEVEL \
profiler.tool_config.npu.analysis=$ANALYSIS \
global_profiler.tool=npu \
global_profiler.steps=$PROFILE_STEPS \
global_profiler.save_path=$SAVE_PATH \
global_profiler.global_tool_config.npu.discrete=$DISCRETE \
global_profiler.global_tool_config.npu.contents=$CONTENTS \
global_profiler.global_tool_config.npu.level=$LEVEL \
global_profiler.global_tool_config.npu.analysis=$ANALYSIS
$@
6 changes: 3 additions & 3 deletions examples/ppo_trainer/run_deepseek_math_gsm8k_megatron_nsys.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,6 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat
trainer.test_freq=-1 \
trainer.total_epochs=100 \
trainer.total_training_steps=1 \
profiler.tool=nsys \
profiler.steps=$PROFILE_STEPS \
profiler.tool_config.nsys.discrete=$DISCRETE $@
global_profiler.tool=nsys \
global_profiler.steps=$PROFILE_STEPS \
global_profiler.global_tool_config.nsys.discrete=$DISCRETE $@
8 changes: 4 additions & 4 deletions examples/ppo_trainer/run_qwen2-7b_rm_seq_balance_nsys.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ python3 -m verl.trainer.main_ppo \
trainer.test_freq=-1 \
trainer.total_epochs=15 \
trainer.total_training_steps=6 \
profiler.profile_continuous_steps=True \
profiler.tool=nsys \
profiler.steps=$PROFILE_STEPS \
profiler.tool_config.nsys.discrete=$DISCRETE $@
global_profiler.profile_continuous_steps=True \
global_profiler.tool=nsys \
global_profiler.steps=$PROFILE_STEPS \
global_profiler.global_tool_config.nsys.discrete=$DISCRETE $@
9 changes: 5 additions & 4 deletions recipe/one_step_off_policy/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,12 @@ def init_workers(self):
wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
if OmegaConf.select(self.config.global_profiler, "steps") is not None:
wg_kwargs["profile_steps"] = OmegaConf.select(self.config.trainer, "steps")
assert OmegaConf.select(self.config.global_profiler, "worker_nsight_options") is not None, (
"worker_nsight_options must be set when profile_steps is set"
)
assert (
OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
is not None
), "worker_nsight_options must be set when profile_steps is set"
wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(
OmegaConf.select(self.config.global_profiler, "worker_nsight_options")
OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
)

for resource_pool, class_dict in self.resource_pool_to_cls.items():
Expand Down
21 changes: 20 additions & 1 deletion tests/special_e2e/run_ppo_trainer_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,12 @@ fi

OPTIM_MEMORY_EFFICIENT=${OPTIM_MEMORY_EFFICIENT:-False}

PROFILE_ENABLE=${PROFILE_ENABLE:-False}
PROFILE_STEPS=${PROFILE_STEPS:-[1]}
PROFILE_RANKS_ALL=${PROFILE_RANKS_ALL:-True}
PROFILE_RANKS=${PROFILE_RANKS:-[0,1,2,3]}
DISCRETE=${DISCRETE:-True} # or True

python3 -m verl.trainer.main_ppo --config-path=config \
--config-name='ppo_megatron_trainer.yaml'\
algorithm.adv_estimator="${ADV_ESTIMATOR}" \
Expand Down Expand Up @@ -176,6 +182,9 @@ python3 -m verl.trainer.main_ppo --config-path=config \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.checkpoint.save_contents=$CHECKPOINT_CONTENTS \
actor_rollout_ref.actor.profiler.enable=$PROFILE_ENABLE \
actor_rollout_ref.actor.profiler.ranks=$PROFILE_RANKS \
actor_rollout_ref.actor.profiler.all_ranks=$PROFILE_RANKS_ALL \
actor_rollout_ref.rollout.name="${ENGINE}" ${ROLLOUT_MODE_ARG}\
actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
Expand Down Expand Up @@ -214,6 +223,9 @@ python3 -m verl.trainer.main_ppo --config-path=config \
critic.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \
critic.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \
critic.checkpoint.save_contents=$CHECKPOINT_CONTENTS \
critic.profiler.enable=$PROFILE_ENABLE \
critic.profiler.ranks=$PROFILE_RANKS \
critic.profiler.all_ranks=$PROFILE_RANKS_ALL \
reward_model.enable=True \
reward_model.model.path="${MODEL_PATH}" \
reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
Expand All @@ -227,6 +239,9 @@ python3 -m verl.trainer.main_ppo --config-path=config \
reward_model.megatron.param_offload=${RM_PARAM_OFFLOAD} \
reward_model.megatron.use_dist_checkpointing=${USE_DIST_CKPT} \
reward_model.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \
reward_model.profiler.enable=$PROFILE_ENABLE \
reward_model.profiler.ranks=$PROFILE_RANKS \
reward_model.profiler.all_ranks=$PROFILE_RANKS_ALL \
algorithm.use_kl_in_reward=False \
algorithm.kl_penalty=kl \
algorithm.kl_ctrl.kl_coef=0.001 \
Expand All @@ -241,4 +256,8 @@ python3 -m verl.trainer.main_ppo --config-path=config \
trainer.save_freq="${SAVE_FREQ}" \
trainer.resume_mode="${RESUME_MODE}" \
trainer.total_epochs=2 \
trainer.total_training_steps="${TOTAL_TRAIN_STEPS}" $@
trainer.total_training_steps="${TOTAL_TRAIN_STEPS}" \
global_profiler.profile_continuous_steps=True \
global_profiler.tool=nsys \
global_profiler.steps=$PROFILE_STEPS \
global_profiler.global_tool_config.nsys.discrete=$DISCRETE $@
9 changes: 5 additions & 4 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ actor_rollout_ref:
save_path: ${oc.select:global_profiler.save_path,null}
tool_config:
nsys:
_target_: verl.utils.profiler.config.NsightToolConfig
discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}
npu:
_target_: verl.utils.profiler.config.NPUToolConfig
Expand Down Expand Up @@ -118,7 +119,7 @@ actor_rollout_ref:
all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false}
ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]}
save_path: ${oc.select:global_profiler.save_path,null}
tool_config: ${oc.select:actor_rollout_ref.actor.tool_config,null}
tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}
megatron:
_target_: verl.workers.config.MegatronEngineConfig
param_offload: false
Expand Down Expand Up @@ -206,7 +207,7 @@ actor_rollout_ref:
all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false}
ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]}
save_path: ${oc.select:global_profiler.save_path,null}
tool_config: ${oc.select:actor_rollout_ref.actor.tool_config,null}
tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}
enable_chunked_prefill: false
load_format: dummy_megatron
layer_name_map:
Expand Down Expand Up @@ -315,7 +316,7 @@ critic:
all_ranks: false
ranks: []
save_path: ${oc.select:global_profiler.save_path,null}
tool_config: ${oc.select:actor_rollout_ref.actor.tool_config,null}
tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}
nccl_timeout: 600
megatron:
_target_: verl.workers.config.McoreEngineConfig
Expand Down Expand Up @@ -364,7 +365,7 @@ reward_model:
all_ranks: false
ranks: []
save_path: ${oc.select:global_profiler.save_path,null}
tool_config: ${oc.select:actor_rollout_ref.actor.tool_config,null}
tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}
nccl_timeout: 600
megatron:
_target_: verl.workers.config.MegatronEngineConfig
Expand Down
9 changes: 5 additions & 4 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ actor_rollout_ref:
save_path: ${oc.select:global_profiler.save_path,null}
tool_config:
nsys:
_target_: verl.utils.profiler.config.NsightToolConfig
discrete: ${oc.select:global_profiler.global_tool_config.nsys.discrete}
npu:
_target_: verl.utils.profiler.config.NPUToolConfig
Expand Down Expand Up @@ -99,7 +100,7 @@ actor_rollout_ref:
all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false}
ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]}
save_path: ${oc.select:global_profiler.save_path,null}
tool_config: ${oc.select:actor_rollout_ref.actor.tool_config,null}
tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}
model: null
fsdp_config:
_target_: verl.workers.config.FSDPEngineConfig
Expand Down Expand Up @@ -181,7 +182,7 @@ actor_rollout_ref:
all_ranks: ${oc.select:actor_rollout_ref.actor.profiler.all_ranks,false}
ranks: ${oc.select:actor_rollout_ref.actor.profiler.ranks,[]}
save_path: ${oc.select:global_profiler.save_path,null}
tool_config: ${oc.select:actor_rollout_ref.actor.tool_config,null}
tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}
enable_chunked_prefill: true
load_format: dummy_dtensor
layered_summon: false
Expand Down Expand Up @@ -301,7 +302,7 @@ critic:
all_ranks: false
ranks: []
save_path: ${oc.select:global_profiler.save_path,null}
tool_config: ${oc.select:actor_rollout_ref.actor.tool_config,null}
tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}
forward_micro_batch_size: ${oc.select:.ppo_micro_batch_size,null}
forward_micro_batch_size_per_gpu: ${oc.select:.ppo_micro_batch_size_per_gpu,null}
ulysses_sequence_parallel_size: 1
Expand Down Expand Up @@ -343,7 +344,7 @@ reward_model:
all_ranks: false
ranks: []
save_path: ${oc.select:global_profiler.save_path,null}
tool_config: ${oc.select:actor_rollout_ref.actor.tool_config,null}
tool_config: ${oc.select:actor_rollout_ref.actor.profiler.tool_config,null}
ulysses_sequence_parallel_size: 1
custom_reward_function:
path: null
Expand Down
Loading