Skip to content

Commit f50e5c2

Browse files
[sglang] feat: add preparation for sglang+verl (#3506)
### What does this PR do? support npu for verl + sglang ```python bash examples/grpo_trainer/run_qwen3_8b_grpo_sglang_1k_npu.sh ``` ### Accuracy test 8b: <img width="747" height="842" alt="8b" src="https://github.com/user-attachments/assets/f36ef25a-b32f-4c76-97d0-2e5fe53ff183" /> 30b: <img width="759" height="850" alt="30b" src="https://github.com/user-attachments/assets/97979002-7ebf-47fa-ae57-3e9b6637f12c" /> ### Test ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Signed-off-by: lbk-sys <[email protected]> Co-authored-by: 1StepForever <[email protected]>
1 parent aa19c1a commit f50e5c2

File tree

8 files changed

+395
-32
lines changed

8 files changed

+395
-32
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
verl x Ascend
2+
===================================
3+
4+
Last updated: 09/25/2025.
5+
6+
我们在 verl 上增加对华为昇腾设备的支持。
7+
8+
硬件支持
9+
-----------------------------------
10+
11+
Atlas 200T A2 Box16
12+
13+
Atlas 900 A2 PODc
14+
15+
Atlas 800T A3
16+
17+
18+
安装
19+
-----------------------------------
20+
21+
基础环境准备
22+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
23+
24+
+-----------+-------------+
25+
| software | version |
26+
+-----------+-------------+
27+
| Python | == 3.11 |
28+
+-----------+-------------+
29+
| CANN | == 8.3.RC1 |
30+
+-----------+-------------+
31+
| HDK | == 25.3.RC1 |
32+
+-----------+-------------+
33+
| torch | == 2.6.0 |
34+
+-----------+-------------+
35+
| torch_npu | == 2.6.0 |
36+
+-----------+-------------+
37+
38+
**目前verl框架中sglang npu后端仅支持上述HDK、CANN和PTA版本, 商发可用版本预计2025年10月发布**
39+
40+
为了能够在 verl 中正常使用 sglang,需使用以下命令安装sglang、torch_memory_saver和verl。
41+
42+
sglang
43+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
44+
.. code-block:: bash
45+
46+
# sglang
47+
git clone https://github.com/sgl-project/sglang.git
48+
cd sglang
49+
mv python/pyproject.toml python/pyproject.toml.backup
50+
mv python/pyproject_other.toml python/pyproject.toml
51+
pip install -e "python[srt_npu]"
52+
53+
安装torch_memory_saver
54+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
55+
.. code-block:: bash
56+
57+
# torch_memory_saver
58+
git clone https://github.com/sgl-project/sgl-kernel-npu.git
59+
cd sgl-kernel-npu
60+
bash build.sh -a memory-saver
61+
pip install output/torch_memory_saver*.whl
62+
63+
安装verl
64+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
65+
66+
.. code-block:: bash
67+
68+
git clone https://github.com/volcengine/verl.git
69+
cd verl
70+
pip install --no-deps -e .
71+
pip install -r requirements-npu.txt
72+
73+
74+
其他三方库说明
75+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
76+
77+
+--------------+---------------+
78+
| software | description |
79+
+--------------+---------------+
80+
| transformers | v4.56.1 |
81+
+--------------+---------------+
82+
| triton_ascend| v3.2.0 |
83+
+--------------+---------------+
84+
85+
1. sglang依赖 transformers v4.56.1
86+
2. sglang依赖triton_ascend v3.2.0
87+
3. 暂不支持多模态模型,卸载相关安装包torchvision、timm
88+
89+
.. code-block:: bash
90+
91+
pip uninstall torchvision
92+
pip uninstall timm
93+
pip uninstall triton
94+
95+
pip install transformers==4.56.1
96+
pip install -i https://test.pypi.org/simple/ triton-ascend==3.2.0.dev20250925
97+
98+
99+
快速开始
100+
-----------------------------------
101+
正式使用前,建议您通过对Qwen3-8B GRPO的训练尝试以检验环境准备和安装的正确性。
102+
103+
1.下载数据集并将数据集预处理为parquet格式,以便包含计算RL奖励所需的必要字段
104+
105+
.. code-block:: bash
106+
107+
python3 examples/data_preprocess/gsm8k.py --local_save_dir ~/data/gsm8k
108+
109+
2.执行训练
110+
111+
.. code-block:: bash
112+
113+
bash verl/examples/grpo_trainer/run_qwen3_8b_grpo_sglang_1k_npu.sh

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ verl is fast with:
133133
ascend_tutorial/ascend_quick_start.rst
134134
ascend_tutorial/ascend_profiling_zh.rst
135135
ascend_tutorial/ascend_profiling_en.rst
136+
ascend_tutorial/ascend_sglang_quick_start.rst
136137

137138
.. toctree::
138139
:maxdepth: 1
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
set -x
2+
export HCCL_CONNECT_TIMEOUT=1500
3+
export HCCL_HOST_SOCKET_PORT_RANGE=60000-60050
4+
export HCCL_NPU_SOCKET_PORT_RANGE=61000-61050
5+
6+
# WORKSPACE_HOME and DATA_HOME support custom path configuration.
7+
WORKSPACE_HOME=$pwd
8+
DATA_HOME=$pwd
9+
10+
sp_size=4
11+
num_npu=4
12+
tp_size=4
13+
train_prompt_bsz=16
14+
train_prompt_mini_bsz=16
15+
16+
max_prompt_length=512
17+
max_response_length=1024
18+
19+
CKPTS_DIR=$WORKSPACE_HOME/logs/ckpt/qwen3_8b
20+
model_path=$DATA_HOME/models/Qwen3-8B
21+
train_data=$DATA_HOME/datasets/processed_gsm8k/train.parquet
22+
valid_data=$DATA_HOME/datasets/processed_gsm8k/test.parquet
23+
24+
python3 -m verl.trainer.main_ppo \
25+
algorithm.adv_estimator=grpo \
26+
data.train_files=$train_data \
27+
data.val_files=$valid_data \
28+
data.train_batch_size=$train_prompt_bsz \
29+
data.max_prompt_length=$max_prompt_length \
30+
data.max_response_length=$max_response_length \
31+
data.filter_overlong_prompts=True \
32+
data.truncation='error' \
33+
actor_rollout_ref.model.path=$model_path \
34+
actor_rollout_ref.actor.optim.lr=1e-6 \
35+
actor_rollout_ref.model.use_remove_padding=True \
36+
actor_rollout_ref.actor.ppo_mini_batch_size=$train_prompt_mini_bsz \
37+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
38+
actor_rollout_ref.actor.use_kl_loss=True \
39+
actor_rollout_ref.actor.entropy_coeff=0 \
40+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
41+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
42+
actor_rollout_ref.actor.use_torch_compile=False \
43+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
44+
actor_rollout_ref.actor.fsdp_config.param_offload=True \
45+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
46+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \
47+
actor_rollout_ref.rollout.tensor_model_parallel_size=$tp_size \
48+
actor_rollout_ref.rollout.name=sglang \
49+
actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \
50+
actor_rollout_ref.rollout.n=5 \
51+
+actor_rollout_ref.rollout.engine_kwargs.sglang.attention_backend="ascend" \
52+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
53+
actor_rollout_ref.rollout.enable_chunked_prefill=False \
54+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 \
55+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
56+
actor_rollout_ref.nccl_timeout=1800 \
57+
algorithm.use_kl_in_reward=False \
58+
trainer.critic_warmup=0 \
59+
trainer.logger=console \
60+
trainer.val_before_train=False \
61+
trainer.project_name='verl_grpo_example_512_1024_gsm8k' \
62+
trainer.experiment_name='qwen3_8b_function_rm' \
63+
trainer.n_gpus_per_node=$num_npu \
64+
trainer.nnodes=1 \
65+
trainer.save_freq=1000 \
66+
trainer.test_freq=10000 \
67+
trainer.total_epochs=5 \
68+
trainer.default_local_dir="${CKPTS_DIR}" \
69+
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
70+
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
71+
trainer.device=npu $@
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
set -x
2+
export HCCL_CONNECT_TIMEOUT=1500
3+
export HCCL_HOST_SOCKET_PORT_RANGE=60000-60050
4+
export HCCL_NPU_SOCKET_PORT_RANGE=61000-61050
5+
6+
# WORKSPACE_HOME and DATA_HOME support custom path configuration.
7+
WORKSPACE_HOME=$pwd
8+
DATA_HOME=$pwd
9+
10+
sp_size=4
11+
num_gpu=8
12+
tp_size=4
13+
train_prompt_bsz=16
14+
train_prompt_mini_bsz=16
15+
16+
max_prompt_length=$((1024 * 2))
17+
max_response_length=$((1024 * 32))
18+
19+
CKPTS_DIR=$WORKSPACE_HOME/logs/ckpt/qwen3_8b
20+
model_path=$DATA_HOME/models/Qwen3-8B
21+
train_data=$DATA_HOME/datasets/dapo/dapo-math-17k.parquet
22+
valid_data=$DATA_HOME/datasets/dapo/aime-2024.parquet
23+
24+
python3 -m verl.trainer.main_ppo \
25+
algorithm.adv_estimator=grpo \
26+
data.train_files=$train_data \
27+
data.val_files=$valid_data \
28+
data.train_batch_size=$train_prompt_bsz \
29+
data.max_prompt_length=$max_prompt_length \
30+
data.max_response_length=$max_response_length \
31+
data.filter_overlong_prompts=False \
32+
data.truncation='error' \
33+
actor_rollout_ref.model.path=$model_path \
34+
actor_rollout_ref.actor.optim.lr=1e-6 \
35+
actor_rollout_ref.model.use_remove_padding=True \
36+
actor_rollout_ref.actor.ppo_mini_batch_size=$train_prompt_mini_bsz \
37+
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
38+
actor_rollout_ref.actor.use_kl_loss=True \
39+
actor_rollout_ref.actor.entropy_coeff=0 \
40+
actor_rollout_ref.actor.kl_loss_coef=0.001 \
41+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
42+
actor_rollout_ref.actor.use_torch_compile=False \
43+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
44+
actor_rollout_ref.actor.fsdp_config.param_offload=True \
45+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
46+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
47+
actor_rollout_ref.rollout.tensor_model_parallel_size=$tp_size \
48+
actor_rollout_ref.rollout.name=sglang \
49+
actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \
50+
actor_rollout_ref.rollout.n=5 \
51+
+actor_rollout_ref.rollout.engine_kwargs.sglang.attention_backend="ascend" \
52+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
53+
actor_rollout_ref.rollout.enable_chunked_prefill=False \
54+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
55+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
56+
actor_rollout_ref.nccl_timeout=3600 \
57+
algorithm.use_kl_in_reward=False \
58+
trainer.critic_warmup=0 \
59+
trainer.logger=console \
60+
trainer.val_before_train=False \
61+
trainer.project_name='verl_grpo_example_2k_32k' \
62+
trainer.experiment_name='qwen3_8b_function_rm' \
63+
trainer.n_gpus_per_node=$num_gpu \
64+
trainer.nnodes=1 \
65+
trainer.save_freq=1000 \
66+
trainer.test_freq=10000 \
67+
trainer.total_epochs=5 \
68+
trainer.default_local_dir="${CKPTS_DIR}" \
69+
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
70+
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
71+
trainer.device=npu $@

verl/utils/attention_utils.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,7 @@ def _get_attention_functions() -> tuple[Callable, Callable, Callable, Callable]:
2727
if is_cuda_available:
2828
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
2929
elif is_npu_available:
30-
try:
31-
from transformers.integrations.npu_flash_attention import (
32-
index_first_axis,
33-
pad_input,
34-
rearrange,
35-
unpad_input,
36-
)
37-
except ImportError:
38-
# Since transformers v4.55.1, index_first_axis, pad_input, and unpad_input
39-
# have been consolidated into `transformers.modeling_flash_attention_utils`.
40-
from einops import rearrange
41-
from transformers.modeling_flash_attention_utils import _index_first_axis as index_first_axis
42-
from transformers.modeling_flash_attention_utils import _pad_input as pad_input
43-
from transformers.modeling_flash_attention_utils import _unpad_input as unpad_input
30+
from verl.utils.npu_utils import index_first_axis, pad_input, rearrange, unpad_input
4431

4532
_index_first_axis, _pad_input, _rearrange, _unpad_input = index_first_axis, pad_input, rearrange, unpad_input
4633

0 commit comments

Comments
 (0)