Skip to content

Commit 0077f3e

Browse files
authored
[ci] feat: Add CI for checking irregular device api usage (#2089)
### Checklist Before Starting - [x] Searched for similar PR(s). - [x] Checked PR Title format - In format of: [modules] type: Title - modules are in `fsdp, megatron, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data` - type is in `feat, fix, refactor, chore, test` - can involve multiple modules, seperated by `,` or space, like `[megatron, fsdp, doc] feat: xxx` ### What does this PR do? Add CI for checking irregular device api usage, suggest using api in `verl/utils/device.py` to get device name or object. Besides, this CI test case is friendly for non-linux system (e.g. windows), which is easier to debug and find out the problem. ### Test Not related. ### High-Level Design Not related. ### Specific Changes Add a new CI test case for checking irregular device api usage, suggest using api in `verl/utils/device.py`. ### API Not related. ### Usage Example ```shell python tests\special_sanity\check_device_api_usage.py --directory ./recipe` [CHECK] File D:\workspace\verl\recipe\char_count\create_dataset.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\char_count\reward_function.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\dapo\dapo_ray_trainer.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\dapo\main_dapo.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\prime\main_prime.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\prime\prime_core_algos.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\prime\prime_dp_rm.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\prime\prime_fsdp_workers.py is detected for device api usage check, check result: success. [SKIP] File D:\workspace\verl\recipe\prime\prime_ray_trainer.py is in device api usage check whitelist, checking is skipped. [CHECK] File D:\workspace\verl\recipe\prime\__init__.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\r1\data_process.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\r1\main_eval.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\r1\reward_score.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\r1\__init__.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\r1\tasks\gpqa.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\r1\tasks\livecodebench.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\r1\tasks\math.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\r1\tasks\__init__.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\retool\retool_multi_turn_sft_preprocess.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\spin\core_algos.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\spin\dp_actor.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\spin\fsdp_workers.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\spin\main_spin.py is detected for device api usage check, check result: success. [SKIP] File D:\workspace\verl\recipe\spin\spin_trainer.py is in device api usage check whitelist, checking is skipped. [CHECK] File D:\workspace\verl\recipe\sppo\dp_actor.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\sppo\main_sppo.py is detected for device api usage check, check result: success. [SKIP] File D:\workspace\verl\recipe\sppo\sppo_ray_trainer.py is in device api usage check whitelist, checking is skipped. [CHECK] File D:\workspace\verl\recipe\sppo\sppo_worker.py is detected for device api usage check, check result: success. [CHECK] File D:\workspace\verl\recipe\sppo\__init__.py is detected for device api usage check, check result: success. ``` ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [x] Add `[BREAKING]` to the PR title `description` if it breaks any API. - [x] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] New CI unit test(s) are added to cover the code path. - [x] Rely on existing unit tests on CI that covers the code path.
1 parent a44b83c commit 0077f3e

File tree

5 files changed

+105
-22
lines changed

5 files changed

+105
-22
lines changed

.github/workflows/sanity.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,7 @@ jobs:
8989
run: python3 tests/special_sanity/validate_structure.py
9090
- name: Assert documentation requirement for functions
9191
run: python3 tests/special_sanity/validate_imported_docs.py
92+
- name: Assert device api usage in verl/recipe
93+
run: python3 tests/special_sanity/check_device_api_usage.py --directory ./recipe
94+
- name: Assert device api usage in verl/verl
95+
run: python3 tests/special_sanity/check_device_api_usage.py --directory ./verl
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
This CI test is used for checking whether device api usage is irregular, suggest using api in `verl/utils/device.py`.
17+
Search targets include .py files in verl/recipe and verl/verl.
18+
Some files that must contain ".cuda", "cuda" or "nccl" keyword is pre-defined in whitelist below.
19+
"""
20+
21+
import os
22+
from argparse import ArgumentParser
23+
from pathlib import Path
24+
25+
# directory or file path must contain keyword ".cuda" or "cuda"
26+
CUDA_KEYWORD_CHECK_WHITELIST = [
27+
"verl/utils/device.py",
28+
"verl/third_party/vllm/vllm_v_0_5_4",
29+
"verl/third_party/vllm/vllm_v_0_6_3",
30+
"recipe/prime/prime_ray_trainer.py", # appear in default device_name
31+
"recipe/spin/spin_trainer.py", # appear in default device_name
32+
"recipe/sppo/sppo_ray_trainer.py", # appear in default device_name
33+
"verl/utils/debug/nvtx_profile.py", # appear in NsightSystemsProfiler
34+
"verl/utils/kernel/linear_cross_entropy.py", # appear in nvidia nvtx
35+
"verl/utils/rendezvous/ray_backend.py", # appear in cupy importance
36+
"verl/single_controller/ray/base.py", # appear in default device_name
37+
"verl/trainer/ppo/ray_trainer.py", # appear in default device_name
38+
"verl/utils/reward_score/sandbox_fusion/utils.py", # appear in sandbox language type
39+
"verl/workers/reward_model/megatron/reward_model.py", # appear in default device_name
40+
]
41+
42+
# directory or file path must contain keyword "nccl"
43+
NCCL_KEYWORD_CHECK_WHITELIST = [
44+
"verl/utils/device.py",
45+
"verl/third_party/vllm/vllm_v_0_5_4",
46+
"verl/third_party/vllm/vllm_v_0_6_3",
47+
"verl/third_party/sglang/parallel_state.py", # appear in default backend
48+
]
49+
50+
SEARCH_WHITELIST = CUDA_KEYWORD_CHECK_WHITELIST + NCCL_KEYWORD_CHECK_WHITELIST
51+
52+
SEARCH_KEYWORDS = [".cuda", '"cuda"', '"nccl"']
53+
54+
55+
if __name__ == "__main__":
56+
parser = ArgumentParser()
57+
parser.add_argument("--directory", "-d", required=True, type=str)
58+
args = parser.parse_args()
59+
directory_in_str = args.directory
60+
61+
pathlist = Path(directory_in_str).glob("**/*.py")
62+
for path in pathlist:
63+
path_in_str = str(path.absolute())
64+
65+
# judge whether current path is in pre-defined search whitelist or not.
66+
path_in_whitelist = False
67+
68+
for sw in SEARCH_WHITELIST:
69+
# for easy debugging in non-linux system
70+
sw = sw.replace("/", os.sep)
71+
if sw in path_in_str:
72+
print(f"[SKIP] File {path_in_str} is in device api usage check whitelist, checking is skipped.")
73+
path_in_whitelist = True
74+
break
75+
76+
if path_in_whitelist:
77+
continue
78+
79+
with open(path_in_str, encoding="utf-8") as f:
80+
file_content = f.read()
81+
82+
find_invalid_device_management = False
83+
84+
for sk in SEARCH_KEYWORDS:
85+
if sk in file_content:
86+
find_invalid_device_management = True
87+
break
88+
89+
print(f"[CHECK] File {path_in_str} is detected for device api usage check, check result: {'success' if not find_invalid_device_management else 'failed'}.")
90+
91+
assert not find_invalid_device_management, f'file {path_in_str} contains .cuda/"cuda"/"nccl" usage, please use api in verl/utils/device.py directly.'

verl/utils/checkpoint/checkpoint_manager.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from omegaconf import DictConfig
2626
from transformers import PreTrainedTokenizer, ProcessorMixin
2727

28-
from verl.utils.device import is_cuda_available, is_npu_available
28+
from verl.utils.device import get_device_name, get_torch_device
2929

3030

3131
class BaseCheckpointManager:
@@ -169,10 +169,8 @@ def get_rng_state():
169169
"random": random.getstate(),
170170
}
171171

172-
if is_cuda_available:
173-
rng_state["cuda"] = torch.cuda.get_rng_state()
174-
elif is_npu_available:
175-
rng_state["npu"] = torch.npu.get_rng_state()
172+
if get_device_name() != "cpu":
173+
rng_state[get_device_name()] = get_torch_device().get_rng_state()
176174

177175
return rng_state
178176

@@ -182,10 +180,8 @@ def load_rng_state(rng_state):
182180
np.random.set_state(rng_state["numpy"])
183181
random.setstate(rng_state["random"])
184182

185-
if is_cuda_available:
186-
torch.cuda.set_rng_state(rng_state["cuda"])
187-
elif is_npu_available:
188-
torch.npu.set_rng_state(rng_state["npu"])
183+
if get_device_name() != "cpu":
184+
get_torch_device().set_rng_state(rng_state[get_device_name()])
189185

190186

191187
def find_latest_ckpt_path(path, directory_format="global_step_{}"):

verl/utils/checkpoint/megatron_checkpoint_manager.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from transformers import GenerationConfig
2626

2727
from verl.models.weight_loader_registry import get_weight_saver
28-
from verl.utils.device import is_cuda_available, is_npu_available
28+
from verl.utils.device import get_device_name, get_torch_device
2929
from verl.utils.fs import is_non_local
3030
from verl.utils.logger import log_with_rank
3131
from verl.utils.megatron_utils import (
@@ -111,10 +111,8 @@ def get_rng_state(self, use_dist_ckpt: bool = False, data_parallel_random_init:
111111
"rng_tracker_states": tensor_parallel.get_cuda_rng_tracker().get_states(),
112112
}
113113

114-
if is_cuda_available:
115-
rng_state["cuda_rng_state"] = torch.cuda.get_rng_state()
116-
elif is_npu_available:
117-
rng_state["npu_rng_state"] = torch.npu.get_rng_state()
114+
if get_device_name() != "cpu":
115+
rng_state[f"{get_device_name()}_rng_state"] = get_torch_device().get_rng_state()
118116

119117
rng_state_list = None
120118
if torch.distributed.is_initialized() and mpu.get_data_parallel_world_size() > 1 and data_parallel_random_init:
@@ -203,10 +201,8 @@ def load_rng_states(self, ckpt_path, data_parallel_random_init=False, use_dist_c
203201
np.random.set_state(rng_state["np_rng_state"])
204202
torch.set_rng_state(rng_state["torch_rng_state"])
205203

206-
if is_cuda_available:
207-
torch.cuda.set_rng_state(rng_state["cuda_rng_state"])
208-
elif is_npu_available:
209-
torch.npu.set_rng_state(rng_state["npu_rng_state"])
204+
if get_device_name() != "cpu":
205+
get_torch_device().set_rng_state(rng_state[f"{get_device_name()}_rng_state"])
210206

211207
# Check for empty states array
212208
if not rng_state["rng_tracker_states"]:

verl/workers/rollout/vllm_rollout/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@ def get_version(pkg):
2929
if vllm_package_version is None:
3030
raise PackageNotFoundError("To use vllm rollout, please ensure the 'vllm' package is properly installed. See https://verl.readthedocs.io/en/latest/start/install.html for more details")
3131

32-
###
33-
# package_version = get_version(package_name)
34-
# [SUPPORT AMD:]
35-
# Do not call any torch.cuda* API here, or ray actor creation import class will fail.
3632
if "ROCM_PATH" in os.environ:
3733
import re
3834

0 commit comments

Comments
 (0)