Skip to content

Commit c300cb5

Browse files
Tavish9PopSoda2002
authored andcommitted
[ray] feat: add support for ray init kwargs (volcengine#3049)
### What does this PR do? This PR adds support for passing parameters to `ray.init`. Users can now dynamically configure settings such as `address`, `port`, `_temp_dir`, and more based on their specific needs. ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test ```bash # when /tmp/ray/ is used by others # when ray is initialized at 6379 by others # when the dashboard is not accessible at localhost # ... bash examples/grpo_trainer/run_qwen2_5_vl-7b.sh \ +ray_kwargs.ray_init._temp_dir=/tmp/ray/my_dir \ +ray_kwargs.ray_init.address=127.0.0.1:6378 \ +ray_kwargs.ray_init.dashboard_host=0.0.0.0 ``` ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
1 parent 810a230 commit c300cb5

20 files changed

+114
-71
lines changed

examples/split_placement/config/ppo_trainer_split.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,5 +183,6 @@ trainer:
183183
default_hdfs_dir: null
184184
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
185185

186-
ray_init:
187-
num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.
186+
ray_kwargs:
187+
ray_init:
188+
num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.

examples/split_placement/main_ppo_split.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import hydra
1919
import ray
2020
import torch
21+
from omegaconf import OmegaConf
2122
from split_monkey_patch import fit
2223

2324
from verl import DataProto
@@ -94,10 +95,13 @@ def __call__(self, data: DataProto, return_dict: bool = False):
9495
def main(config):
9596
if not ray.is_initialized():
9697
# this is for local ray cluster
97-
ray.init(
98-
runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}},
99-
num_cpus=config.ray_init.num_cpus,
100-
)
98+
default_runtime_env = {"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}
99+
ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
100+
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
101+
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
102+
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
103+
print(f"ray init kwargs: {ray_init_kwargs}")
104+
ray.init(**OmegaConf.to_container(ray_init_kwargs))
101105

102106
ray.get(main_task.remote(config))
103107

recipe/dapo/main_dapo.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,15 @@ def main(config):
3636
def run_ppo(config) -> None:
3737
if not ray.is_initialized():
3838
# this is for local ray cluster
39-
ray.init(
40-
runtime_env={
41-
"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}
42-
},
43-
num_cpus=config.ray_init.num_cpus,
44-
)
39+
default_runtime_env = {
40+
"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}
41+
}
42+
ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
43+
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
44+
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
45+
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
46+
print(f"ray init kwargs: {ray_init_kwargs}")
47+
ray.init(**OmegaConf.to_container(ray_init_kwargs))
4548

4649
if (
4750
is_cuda_available

recipe/entropy/main_entropy.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import hydra
1919
import ray
20+
from omegaconf import OmegaConf
2021

2122
from .entropy_ray_trainer import RayEntropyTrainer
2223
from .reward import load_reward_manager
@@ -30,17 +31,20 @@ def main(config):
3031
def run_ppo(config) -> None:
3132
if not ray.is_initialized():
3233
# this is for local ray cluster
33-
ray.init(
34-
runtime_env={
35-
"env_vars": {
36-
"TOKENIZERS_PARALLELISM": "true",
37-
"NCCL_DEBUG": "WARN",
38-
"VLLM_LOGGING_LEVEL": "WARN",
39-
"WANDB_API_KEY": "YOUR_WANDB_API_KEY",
40-
}
41-
},
42-
num_cpus=config.ray_init.num_cpus,
43-
)
34+
default_runtime_env = {
35+
"env_vars": {
36+
"TOKENIZERS_PARALLELISM": "true",
37+
"NCCL_DEBUG": "WARN",
38+
"VLLM_LOGGING_LEVEL": "WARN",
39+
"WANDB_API_KEY": "YOUR_WANDB_API_KEY",
40+
}
41+
}
42+
ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
43+
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
44+
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
45+
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
46+
print(f"ray init kwargs: {ray_init_kwargs}")
47+
ray.init(**OmegaConf.to_container(ray_init_kwargs))
4448

4549
runner = TaskRunner.remote()
4650
ray.get(runner.run.remote(config))

recipe/one_step_off_policy/main_ppo.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,13 @@ def run_ppo(config) -> None:
4343
# Set environment variables in the runtime environment to control tokenizer parallelism,
4444
# NCCL debug level, VLLM logging level, and allow runtime LoRA updating
4545
# `num_cpus` specifies the number of CPU cores Ray can use, obtained from the configuration
46-
ray.init(
47-
runtime_env=get_ppo_ray_runtime_env(),
48-
num_cpus=config.ray_init.num_cpus,
49-
)
46+
default_runtime_env = get_ppo_ray_runtime_env()
47+
ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
48+
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
49+
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
50+
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
51+
print(f"ray init kwargs: {ray_init_kwargs}")
52+
ray.init(**OmegaConf.to_container(ray_init_kwargs))
5053

5154
# Create a remote instance of the TaskRunner class, and
5255
# Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
@@ -63,7 +66,7 @@ def run_ppo(config) -> None:
6366

6467
# [Optional] get the path of the timeline trace file from the configuration, default to None
6568
# This file is used for performance analysis
66-
timeline_json_file = config.ray_init.get("timeline_json_file", None)
69+
timeline_json_file = config.ray_kwargs.get("timeline_json_file", None)
6770
if timeline_json_file:
6871
ray.timeline(filename=timeline_json_file)
6972

recipe/prime/main_prime.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
import hydra
3333
import ray
34+
from omegaconf import OmegaConf
3435

3536
from .prime_ray_trainer import RayPRIMETrainer
3637

@@ -42,11 +43,14 @@ def main(config):
4243

4344
def run_prime(config, compute_score=None):
4445
if not ray.is_initialized():
46+
default_runtime_env = {"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}}
47+
ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
48+
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
49+
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
50+
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
51+
print(f"ray init kwargs: {ray_init_kwargs}")
4552
# this is for local ray cluster
46-
ray.init(
47-
runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN"}},
48-
num_cpus=config.ray_init.num_cpus,
49-
)
53+
ray.init(**OmegaConf.to_container(ray_init_kwargs))
5054

5155
ray.get(main_task.remote(config, compute_score))
5256

recipe/r1/config/evaluation.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ custom_reward_function:
99
path: null
1010
name: compute_score
1111

12-
ray_init:
13-
num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.
12+
ray_kwargs:
13+
ray_init:
14+
num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.

recipe/r1/main_eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import numpy as np
2424
import pandas as pd
2525
import ray
26+
from omegaconf import OmegaConf
2627
from tqdm import tqdm
2728

2829
from verl.trainer.ppo.reward import get_custom_reward_fn
@@ -49,7 +50,7 @@ def main(config):
4950

5051
# Initialize Ray
5152
if not ray.is_initialized():
52-
ray.init(num_cpus=config.ray_init.num_cpus)
53+
ray.init(**OmegaConf.to_container(config.ray_kwargs.get("ray_init", {})))
5354

5455
# evaluate test_score based on data source
5556
data_source_reward = defaultdict(list)

recipe/sppo/main_sppo.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import hydra
2323
import ray
24+
from omegaconf import OmegaConf
2425

2526
from verl.trainer.ppo.reward import load_reward_manager
2627

@@ -38,12 +39,15 @@ def run_ppo(config) -> None:
3839
os.environ["ENSURE_CUDA_VISIBLE_DEVICES"] = os.environ.get("CUDA_VISIBLE_DEVICES", "")
3940
if not ray.is_initialized():
4041
# this is for local ray cluster
41-
ray.init(
42-
runtime_env={
43-
"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}
44-
},
45-
num_cpus=config.ray_init.num_cpus,
46-
)
42+
default_runtime_env = {
43+
"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_LOGGING_LEVEL": "WARN"}
44+
}
45+
ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
46+
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
47+
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
48+
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
49+
print(f"ray init kwargs: {ray_init_kwargs}")
50+
ray.init(**OmegaConf.to_container(ray_init_kwargs))
4751

4852
runner = TaskRunner.remote()
4953
ray.get(runner.run.remote(config))

tests/trainer/config/legacy_ppo_megatron_trainer.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ trainer:
457457
with_stack: False
458458
analysis: True
459459

460-
ray_init:
461-
num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.
460+
ray_kwargs:
461+
ray_init:
462+
num_cpus: null # `None` means using all CPUs, which might cause hang if limited in systems like SLURM. Please set to a number allowed then.
462463
timeline_json_file: null

0 commit comments

Comments
 (0)