Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/sanity.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,7 @@ jobs:
run: python3 tests/special_sanity/validate_structure.py
- name: Assert documentation requirement for functions
run: python3 tests/special_sanity/validate_imported_docs.py
- name: Assert device api usage in verl/recipe
run: python3 tests/special_sanity/check_device_api_usage.py --directory ./recipe
- name: Assert device api usage in verl/verl
run: python3 tests/special_sanity/check_device_api_usage.py --directory ./verl
91 changes: 91 additions & 0 deletions tests/special_sanity/check_device_api_usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This CI test is used for checking whether device api usage is irregular, suggest using api in `verl/utils/device.py`.
Search targets include .py files in verl/recipe and verl/verl.
Some files that must contain ".cuda", "cuda" or "nccl" keyword is pre-defined in whitelist below.
"""

import os
from argparse import ArgumentParser
from pathlib import Path

# directory or file path must contain keyword ".cuda" or "cuda"
CUDA_KEYWORD_CHECK_WHITELIST = [
"verl/utils/device.py",
"verl/third_party/vllm/vllm_v_0_5_4",
"verl/third_party/vllm/vllm_v_0_6_3",
"recipe/prime/prime_ray_trainer.py", # appear in default device_name
"recipe/spin/spin_trainer.py", # appear in default device_name
"recipe/sppo/sppo_ray_trainer.py", # appear in default device_name
"verl/utils/debug/nvtx_profile.py", # appear in NsightSystemsProfiler
"verl/utils/kernel/linear_cross_entropy.py", # appear in nvidia nvtx
"verl/utils/rendezvous/ray_backend.py", # appear in cupy importance
"verl/single_controller/ray/base.py", # appear in default device_name
"verl/trainer/ppo/ray_trainer.py", # appear in default device_name
"verl/utils/reward_score/sandbox_fusion/utils.py", # appear in sandbox language type
"verl/workers/reward_model/megatron/reward_model.py", # appear in default device_name
]

# directory or file path must contain keyword "nccl"
NCCL_KEYWORD_CHECK_WHITELIST = [
"verl/utils/device.py",
"verl/third_party/vllm/vllm_v_0_5_4",
"verl/third_party/vllm/vllm_v_0_6_3",
"verl/third_party/sglang/parallel_state.py", # appear in default backend
]

SEARCH_WHITELIST = CUDA_KEYWORD_CHECK_WHITELIST + NCCL_KEYWORD_CHECK_WHITELIST

SEARCH_KEYWORDS = [".cuda", '"cuda"', '"nccl"']


if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--directory", "-d", required=True, type=str)
args = parser.parse_args()
directory_in_str = args.directory

pathlist = Path(directory_in_str).glob("**/*.py")
for path in pathlist:
path_in_str = str(path.absolute())

# judge whether current path is in pre-defined search whitelist or not.
path_in_whitelist = False

for sw in SEARCH_WHITELIST:
# for easy debugging in non-linux system
sw = sw.replace("/", os.sep)
if sw in path_in_str:
print(f"[SKIP] File {path_in_str} is in device api usage check whitelist, checking is skipped.")
path_in_whitelist = True
break

if path_in_whitelist:
continue

with open(path_in_str, encoding="utf-8") as f:
file_content = f.read()

find_invalid_device_management = False

for sk in SEARCH_KEYWORDS:
if sk in file_content:
find_invalid_device_management = True
break

print(f"[CHECK] File {path_in_str} is detected for device api usage check, check result: {'success' if not find_invalid_device_management else 'failed'}.")

assert not find_invalid_device_management, f'file {path_in_str} contains .cuda/"cuda"/"nccl" usage, please use api in verl/utils/device.py directly.'
14 changes: 5 additions & 9 deletions verl/utils/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from omegaconf import DictConfig
from transformers import PreTrainedTokenizer, ProcessorMixin

from verl.utils.device import is_cuda_available, is_npu_available
from verl.utils.device import get_device_name, get_torch_device


class BaseCheckpointManager:
Expand Down Expand Up @@ -169,10 +169,8 @@ def get_rng_state():
"random": random.getstate(),
}

if is_cuda_available:
rng_state["cuda"] = torch.cuda.get_rng_state()
elif is_npu_available:
rng_state["npu"] = torch.npu.get_rng_state()
if get_device_name() != "cpu":
rng_state[get_device_name()] = get_torch_device().get_rng_state()

return rng_state

Expand All @@ -182,10 +180,8 @@ def load_rng_state(rng_state):
np.random.set_state(rng_state["numpy"])
random.setstate(rng_state["random"])

if is_cuda_available:
torch.cuda.set_rng_state(rng_state["cuda"])
elif is_npu_available:
torch.npu.set_rng_state(rng_state["npu"])
if get_device_name() != "cpu":
get_torch_device().set_rng_state(rng_state[get_device_name()])


def find_latest_ckpt_path(path, directory_format="global_step_{}"):
Expand Down
14 changes: 5 additions & 9 deletions verl/utils/checkpoint/megatron_checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from transformers import GenerationConfig

from verl.models.weight_loader_registry import get_weight_saver
from verl.utils.device import is_cuda_available, is_npu_available
from verl.utils.device import get_device_name, get_torch_device
from verl.utils.fs import is_non_local
from verl.utils.logger import log_with_rank
from verl.utils.megatron_utils import (
Expand Down Expand Up @@ -111,10 +111,8 @@ def get_rng_state(self, use_dist_ckpt: bool = False, data_parallel_random_init:
"rng_tracker_states": tensor_parallel.get_cuda_rng_tracker().get_states(),
}

if is_cuda_available:
rng_state["cuda_rng_state"] = torch.cuda.get_rng_state()
elif is_npu_available:
rng_state["npu_rng_state"] = torch.npu.get_rng_state()
if get_device_name() != "cpu":
rng_state[f"{get_device_name()}_rng_state"] = get_torch_device().get_rng_state()

rng_state_list = None
if torch.distributed.is_initialized() and mpu.get_data_parallel_world_size() > 1 and data_parallel_random_init:
Expand Down Expand Up @@ -203,10 +201,8 @@ def load_rng_states(self, ckpt_path, data_parallel_random_init=False, use_dist_c
np.random.set_state(rng_state["np_rng_state"])
torch.set_rng_state(rng_state["torch_rng_state"])

if is_cuda_available:
torch.cuda.set_rng_state(rng_state["cuda_rng_state"])
elif is_npu_available:
torch.npu.set_rng_state(rng_state["npu_rng_state"])
if get_device_name() != "cpu":
get_torch_device().set_rng_state(rng_state[f"{get_device_name()}_rng_state"])

# Check for empty states array
if not rng_state["rng_tracker_states"]:
Expand Down
4 changes: 0 additions & 4 deletions verl/workers/rollout/vllm_rollout/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ def get_version(pkg):
if vllm_package_version is None:
raise PackageNotFoundError("To use vllm rollout, please ensure the 'vllm' package is properly installed. See https://verl.readthedocs.io/en/latest/start/install.html for more details")

###
# package_version = get_version(package_name)
# [SUPPORT AMD:]
# Do not call any torch.cuda* API here, or ray actor creation import class will fail.
if "ROCM_PATH" in os.environ:
import re

Expand Down
Loading