File tree Expand file tree Collapse file tree 7 files changed +23
-60
lines changed Expand file tree Collapse file tree 7 files changed +23
-60
lines changed Original file line number Diff line number Diff line change 5757 - " examples/data_preprocess/gsm8k.py"
5858 - " examples/data_preprocess/geo3k.py"
5959 - " tests/special_e2e/ppo_trainer"
60+ - " tests/special_npu"
6061 - " verl/trainer/main_ppo.py"
6162 - " verl/trainer/config/ppo_trainer.yaml"
6263
@@ -123,6 +124,11 @@ jobs:
123124 run : |
124125 ray stop --force
125126 python3 examples/data_preprocess/geo3k.py
127+ - name : Running gsm8k e2e qwen3 training tests with PPO on ASCEND NPU
128+ run : |
129+ ray stop --force
130+ bash tests/special_npu/run_qwen3_06b_ppo.sh
131+ rm -rf $HOME/ckpts
126132 - name : Running gsm8k e2e training tests with peft sft on ASCEND NPU
127133 run : |
128134 ray stop --force
@@ -143,16 +149,6 @@ jobs:
143149 ray stop --force
144150 bash tests/special_npu/run_qwen2_5_05b_dapo.sh
145151 rm -rf $HOME/ckpts
146- - name : Running gsm8k e2e qwen3 training tests with GRPO on ASCEND NPU
147- run : |
148- ray stop --force
149- bash tests/special_npu/run_qwen3_06b_grpo.sh
150- rm -rf $HOME/ckpts
151- - name : Running gsm8k e2e qwen3 training tests with PPO on ASCEND NPU
152- run : |
153- ray stop --force
154- bash tests/special_npu/run_qwen3_06b_ppo.sh
155- rm -rf $HOME/ckpts
156152 - name : Running gsm8k e2e training tests with GRPO MindSpeed on ASCEND NPU
157153 run : |
158154 ray stop --force
Original file line number Diff line number Diff line change 11#! /usr/bin/env bash
22set -xeuo pipefail
33
4- NUM_GPUS=${NUM_GPUS:- 8 }
4+ NUM_GPUS=${NUM_GPUS:- 16 }
55
66MODEL_ID=${MODEL_ID:- Qwen/ Qwen2.5-0.5B-Instruct}
77MODEL_PATH=${MODEL_PATH:- ${HOME} / models/ ${MODEL_ID} }
Original file line number Diff line number Diff line change @@ -36,7 +36,7 @@ python3 -m verl.trainer.main_ppo \
3636 trainer.logger=console \
3737 trainer.project_name=' verl_grpo_example_gsm8k' \
3838 trainer.experiment_name=' qwen2_7b_function_rm' \
39- trainer.n_gpus_per_node=8 \
39+ trainer.n_gpus_per_node=16 \
4040 trainer.nnodes=1 \
4141 trainer.save_freq=-1 \
4242 trainer.test_freq=5 \
Original file line number Diff line number Diff line change @@ -59,7 +59,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \
5959 trainer.logger=console \
6060 trainer.project_name=' verl_grpo_example_gsm8k' \
6161 trainer.experiment_name=' qwen2_7b_function_rm' \
62- trainer.n_gpus_per_node=8 \
62+ trainer.n_gpus_per_node=16 \
6363 trainer.nnodes=1 \
6464 trainer.save_freq=-1 \
6565 trainer.test_freq=5 \
Original file line number Diff line number Diff line change @@ -18,7 +18,7 @@ python3 -m verl.trainer.main_ppo \
1818 actor_rollout_ref.model.path=Qwen/Qwen2.5-VL-3B-Instruct \
1919 actor_rollout_ref.actor.optim.lr=1e-6 \
2020 actor_rollout_ref.model.use_remove_padding=True \
21- actor_rollout_ref.actor.ppo_mini_batch_size=16 \
21+ actor_rollout_ref.actor.ppo_mini_batch_size=32 \
2222 actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
2323 actor_rollout_ref.actor.use_kl_loss=True \
2424 actor_rollout_ref.actor.kl_loss_coef=0.01 \
@@ -44,7 +44,7 @@ python3 -m verl.trainer.main_ppo \
4444 trainer.logger=console \
4545 trainer.project_name=' verl_grpo_example_geo3k' \
4646 trainer.experiment_name=' qwen2_5_vl_3b_function_rm' \
47- trainer.n_gpus_per_node=8 \
47+ trainer.n_gpus_per_node=16 \
4848 trainer.nnodes=1 \
4949 trainer.save_freq=-1 \
5050 trainer.test_freq=-1 \
Load Diff This file was deleted.
Original file line number Diff line number Diff line change 3636 get_ulysses_sequence_parallel_world_size ,
3737 validate_ulysses_config ,
3838)
39+ from verl .utils .device import is_npu_available
3940
4041logger = logging .getLogger (__file__ )
4142logger .setLevel (os .getenv ("VERL_LOGGING_LEVEL" , "WARN" ))
4647
4748 _flash_supports_window_size = "window_size" in inspect .signature (flash_attn_func ).parameters
4849 _flash_supports_deterministic = "deterministic" in inspect .signature (flash_attn_func ).parameters
49- _flash_deterministic_enabled = os .getenv ("FLASH_ATTENTION_DETERMINISTIC" , "0" ) == "1"
5050 _flash_use_top_left_mask = not is_flash_attn_greater_or_equal_2_10 ()
5151
52+ if is_npu_available :
53+ from transformers .integrations .npu_flash_attention import npu_flash_attn_func as flash_attn_func
54+ from transformers .integrations .npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
55+ from transformers .modeling_flash_attention_utils import flash_attn_supports_top_left_mask
56+
57+ _flash_supports_window_size = "window_size" in inspect .signature (flash_attn_func ).parameters
58+ _flash_supports_deterministic = "deterministic" in inspect .signature (flash_attn_func ).parameters
59+ _flash_use_top_left_mask = flash_attn_supports_top_left_mask ()
60+
61+ _flash_deterministic_enabled = os .getenv ("FLASH_ATTENTION_DETERMINISTIC" , "0" ) == "1"
62+
5263
5364def get_rope_index (
5465 processor ,
You can’t perform that action at this time.
0 commit comments