Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions tests/single_controller/test_nested_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import ray

from verl.single_controller.base.decorator import Dispatch, register
from verl.single_controller.base.worker import Worker
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup


class TestActor(Worker):
# TODO: pass *args and **kwargs is bug prone and not very convincing
def __init__(self, x) -> None:
super().__init__()
self.a = x

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def get(self):
return self.a + self.rank


class TestHighLevelActor(Worker):
def __init__(self, x=None) -> None:
super().__init__()
self.test_actor = TestActor(x=x)

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def get(self):
return self.test_actor.get()


def test_nested_worker():
ray.init(num_cpus=100)

# create 4 workers, each hold a GPU
resource_pool = RayResourcePool([4], use_gpu=True)
class_with_args = RayClassWithInitArgs(cls=ray.remote(TestActor), x=2)

worker_group = RayWorkerGroup(
resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic"
)

output = worker_group.get()

assert output == [2, 3, 4, 5]

class_with_args = RayClassWithInitArgs(cls=ray.remote(TestHighLevelActor), x=2)
high_level_worker_group = RayWorkerGroup(
resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic_2"
)

output_1 = high_level_worker_group.get()

assert output_1 == [2, 3, 4, 5]

ray.shutdown()
5 changes: 3 additions & 2 deletions verl/workers/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .engine import * # noqa
from .optimizer import * # noqa
from .rollout import * # noqa
from . import actor, critic, engine, optimizer, rollout
from .model import * # noqa
from . import actor, critic, engine, optimizer, rollout, model

__all__ = actor.__all__ + critic.__all__ + engine.__all__ + optimizer.__all__ + rollout.__all__
__all__ = actor.__all__ + critic.__all__ + engine.__all__ + optimizer.__all__ + rollout.__all__ + model.__all__
111 changes: 111 additions & 0 deletions verl/workers/config/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Any, Optional

from omegaconf import MISSING
from transformers import AutoConfig

from verl.base_config import BaseConfig
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.fs import copy_to_local
from verl.utils.model import get_generation_config, update_model_config

__all__ = ["HFModelConfig"]


@dataclass
class HFModelConfig(BaseConfig):
# note that we separate model_path, model_config_path and tokenizer_path in case they are different
_mutable_fields = {
"hf_config_path",
"tokenizer_path",
"hf_config",
"generation_config",
"tokenizer",
"processor",
"local_path",
}

path: str = MISSING
local_path: Optional[str] = None
hf_config_path: Optional[str] = None
tokenizer_path: Optional[str] = None

hf_config: Any = None
generation_config: Any = None
tokenizer: Any = None
processor: Any = None

# whether to use shared memory
use_shm: bool = False
trust_remote_code: bool = False

# custom chat template for the model
custom_chat_template: Optional[str] = None

external_lib: Optional[str] = None

override_config: dict = field(default_factory=dict)

enable_gradient_checkpointing: bool = True
enable_activation_offload: bool = False

use_remove_padding: bool = False

# lora related. We may setup a separate config later
lora_rank: int = 0
lora_alpha: int = 16
target_modules: Optional[str] = "all-linear"

exclude_modules: Optional[str] = None
use_liger: bool = False

use_fused_kernels: bool = False
fused_kernel_options: dict = field(default_factory=dict)

def __post_init__(self):
if self.hf_config_path is None:
self.hf_config_path = self.path
if self.tokenizer_path is None:
self.tokenizer_path = self.path

# constuct tokenizer
self.local_path = copy_to_local(self.path, use_shm=self.use_shm)
self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=self.trust_remote_code)
self.processor = hf_processor(self.local_path, trust_remote_code=self.trust_remote_code)

self.generation_config = get_generation_config(self.hf_config_path, trust_remote_code=self.trust_remote_code)

# constuct hf_config
attn_implementation = self.override_config.get("attn_implementation", "flash_attention_2")
self.hf_config = AutoConfig.from_pretrained(
self.hf_config_path, trust_remote_code=self.trust_remote_code, attn_implementation=attn_implementation
)

override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
override_config_kwargs.update(self.override_config)
update_model_config(self.hf_config, override_config_kwargs=override_config_kwargs)

# per model patch
if getattr(self.hf_config, "model_type", None) == "kimi_vl":
self.hf_config.text_config.topk_method = "greedy"

def get_processor(self):
return self.processor if self.processor is not None else self.tokenizer
4 changes: 3 additions & 1 deletion verl/workers/config/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from dataclasses import dataclass, field
from typing import Optional

from omegaconf import MISSING

from verl.base_config import BaseConfig
from verl.utils.profiler import ProfilerConfig

Expand Down Expand Up @@ -77,7 +79,7 @@ class TraceConfig(BaseConfig):
class RolloutConfig(BaseConfig):
_mutable_fields = {"max_model_len"}

name: Optional[str] = None
name: Optional[str] = MISSING
mode: str = "sync"

temperature: float = 1.0
Expand Down
Loading
Loading