diff --git a/.github/workflows/sanity.yml b/.github/workflows/sanity.yml index 107e4eb9b6f..0bfcd73aaf0 100644 --- a/.github/workflows/sanity.yml +++ b/.github/workflows/sanity.yml @@ -89,3 +89,7 @@ jobs: run: python3 tests/special_sanity/validate_structure.py - name: Assert documentation requirement for functions run: python3 tests/special_sanity/validate_imported_docs.py + - name: Assert device api usage in verl/recipe + run: python3 tests/special_sanity/check_device_api_usage.py --directory ./recipe + - name: Assert device api usage in verl/verl + run: python3 tests/special_sanity/check_device_api_usage.py --directory ./verl diff --git a/tests/special_sanity/check_device_api_usage.py b/tests/special_sanity/check_device_api_usage.py new file mode 100644 index 00000000000..2ded113e789 --- /dev/null +++ b/tests/special_sanity/check_device_api_usage.py @@ -0,0 +1,91 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This CI test is used for checking whether device api usage is irregular, suggest using api in `verl/utils/device.py`. +Search targets include .py files in verl/recipe and verl/verl. +Some files that must contain ".cuda", "cuda" or "nccl" keyword is pre-defined in whitelist below. +""" + +import os +from argparse import ArgumentParser +from pathlib import Path + +# directory or file path must contain keyword ".cuda" or "cuda" +CUDA_KEYWORD_CHECK_WHITELIST = [ + "verl/utils/device.py", + "verl/third_party/vllm/vllm_v_0_5_4", + "verl/third_party/vllm/vllm_v_0_6_3", + "recipe/prime/prime_ray_trainer.py", # appear in default device_name + "recipe/spin/spin_trainer.py", # appear in default device_name + "recipe/sppo/sppo_ray_trainer.py", # appear in default device_name + "verl/utils/debug/nvtx_profile.py", # appear in NsightSystemsProfiler + "verl/utils/kernel/linear_cross_entropy.py", # appear in nvidia nvtx + "verl/utils/rendezvous/ray_backend.py", # appear in cupy importance + "verl/single_controller/ray/base.py", # appear in default device_name + "verl/trainer/ppo/ray_trainer.py", # appear in default device_name + "verl/utils/reward_score/sandbox_fusion/utils.py", # appear in sandbox language type + "verl/workers/reward_model/megatron/reward_model.py", # appear in default device_name +] + +# directory or file path must contain keyword "nccl" +NCCL_KEYWORD_CHECK_WHITELIST = [ + "verl/utils/device.py", + "verl/third_party/vllm/vllm_v_0_5_4", + "verl/third_party/vllm/vllm_v_0_6_3", + "verl/third_party/sglang/parallel_state.py", # appear in default backend +] + +SEARCH_WHITELIST = CUDA_KEYWORD_CHECK_WHITELIST + NCCL_KEYWORD_CHECK_WHITELIST + +SEARCH_KEYWORDS = [".cuda", '"cuda"', '"nccl"'] + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("--directory", "-d", required=True, type=str) + args = parser.parse_args() + directory_in_str = args.directory + + pathlist = Path(directory_in_str).glob("**/*.py") + for path in pathlist: + path_in_str = str(path.absolute()) + + # judge whether current path is in pre-defined search whitelist or not. + path_in_whitelist = False + + for sw in SEARCH_WHITELIST: + # for easy debugging in non-linux system + sw = sw.replace("/", os.sep) + if sw in path_in_str: + print(f"[SKIP] File {path_in_str} is in device api usage check whitelist, checking is skipped.") + path_in_whitelist = True + break + + if path_in_whitelist: + continue + + with open(path_in_str, encoding="utf-8") as f: + file_content = f.read() + + find_invalid_device_management = False + + for sk in SEARCH_KEYWORDS: + if sk in file_content: + find_invalid_device_management = True + break + + print(f"[CHECK] File {path_in_str} is detected for device api usage check, check result: {'success' if not find_invalid_device_management else 'failed'}.") + + assert not find_invalid_device_management, f'file {path_in_str} contains .cuda/"cuda"/"nccl" usage, please use api in verl/utils/device.py directly.' diff --git a/verl/utils/checkpoint/checkpoint_manager.py b/verl/utils/checkpoint/checkpoint_manager.py index b7a7d97c459..2e45e667596 100644 --- a/verl/utils/checkpoint/checkpoint_manager.py +++ b/verl/utils/checkpoint/checkpoint_manager.py @@ -25,7 +25,7 @@ from omegaconf import DictConfig from transformers import PreTrainedTokenizer, ProcessorMixin -from verl.utils.device import is_cuda_available, is_npu_available +from verl.utils.device import get_device_name, get_torch_device class BaseCheckpointManager: @@ -169,10 +169,8 @@ def get_rng_state(): "random": random.getstate(), } - if is_cuda_available: - rng_state["cuda"] = torch.cuda.get_rng_state() - elif is_npu_available: - rng_state["npu"] = torch.npu.get_rng_state() + if get_device_name() != "cpu": + rng_state[get_device_name()] = get_torch_device().get_rng_state() return rng_state @@ -182,10 +180,8 @@ def load_rng_state(rng_state): np.random.set_state(rng_state["numpy"]) random.setstate(rng_state["random"]) - if is_cuda_available: - torch.cuda.set_rng_state(rng_state["cuda"]) - elif is_npu_available: - torch.npu.set_rng_state(rng_state["npu"]) + if get_device_name() != "cpu": + get_torch_device().set_rng_state(rng_state[get_device_name()]) def find_latest_ckpt_path(path, directory_format="global_step_{}"): diff --git a/verl/utils/checkpoint/megatron_checkpoint_manager.py b/verl/utils/checkpoint/megatron_checkpoint_manager.py index 489253374e2..faf563f4cd6 100644 --- a/verl/utils/checkpoint/megatron_checkpoint_manager.py +++ b/verl/utils/checkpoint/megatron_checkpoint_manager.py @@ -25,7 +25,7 @@ from transformers import GenerationConfig from verl.models.weight_loader_registry import get_weight_saver -from verl.utils.device import is_cuda_available, is_npu_available +from verl.utils.device import get_device_name, get_torch_device from verl.utils.fs import is_non_local from verl.utils.logger import log_with_rank from verl.utils.megatron_utils import ( @@ -111,10 +111,8 @@ def get_rng_state(self, use_dist_ckpt: bool = False, data_parallel_random_init: "rng_tracker_states": tensor_parallel.get_cuda_rng_tracker().get_states(), } - if is_cuda_available: - rng_state["cuda_rng_state"] = torch.cuda.get_rng_state() - elif is_npu_available: - rng_state["npu_rng_state"] = torch.npu.get_rng_state() + if get_device_name() != "cpu": + rng_state[f"{get_device_name()}_rng_state"] = get_torch_device().get_rng_state() rng_state_list = None if torch.distributed.is_initialized() and mpu.get_data_parallel_world_size() > 1 and data_parallel_random_init: @@ -203,10 +201,8 @@ def load_rng_states(self, ckpt_path, data_parallel_random_init=False, use_dist_c np.random.set_state(rng_state["np_rng_state"]) torch.set_rng_state(rng_state["torch_rng_state"]) - if is_cuda_available: - torch.cuda.set_rng_state(rng_state["cuda_rng_state"]) - elif is_npu_available: - torch.npu.set_rng_state(rng_state["npu_rng_state"]) + if get_device_name() != "cpu": + get_torch_device().set_rng_state(rng_state[f"{get_device_name()}_rng_state"]) # Check for empty states array if not rng_state["rng_tracker_states"]: diff --git a/verl/workers/rollout/vllm_rollout/__init__.py b/verl/workers/rollout/vllm_rollout/__init__.py index 450da77d5b9..1028780cb18 100644 --- a/verl/workers/rollout/vllm_rollout/__init__.py +++ b/verl/workers/rollout/vllm_rollout/__init__.py @@ -29,10 +29,6 @@ def get_version(pkg): if vllm_package_version is None: raise PackageNotFoundError("To use vllm rollout, please ensure the 'vllm' package is properly installed. See https://verl.readthedocs.io/en/latest/start/install.html for more details") -### -# package_version = get_version(package_name) -# [SUPPORT AMD:] -# Do not call any torch.cuda* API here, or ray actor creation import class will fail. if "ROCM_PATH" in os.environ: import re