Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions .github/workflows/e2e_ppo_trainer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ jobs:
ray stop --force
python3 examples/data_preprocess/gsm8k.py
# HF sanity
- name: Running GSM8K E2E training tests on 1 L20 GPU with hf for sanity
run: |
ray stop --force
bash tests/special_e2e/ppo_trainer/run_single_gpu.sh
# HF sanity
- name: Running GSM8K E2E training tests on 1 L20 GPU with engine interface for sanity.
run: |
ray stop --force
bash tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh
# - name: Running GSM8K E2E training tests on 1 L20 GPU with hf for sanity
# run: |
# ray stop --force
# bash tests/special_e2e/ppo_trainer/run_single_gpu.sh
# # HF sanity
# - name: Running GSM8K E2E training tests on 1 L20 GPU with engine interface for sanity.
# run: |
# ray stop --force
# bash tests/special_e2e/ppo_trainer/run_single_gpu_with_engine.sh
# Function RM
- name: Running GSM8K E2E training tests on 8 L20 GPUs with rmpad using function rm with validation and saving (FSDP_SIZE=8)
run: |
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ verl is fast with:
amd_tutorial/amd_build_dockerfile_page.rst
amd_tutorial/amd_vllm_page.rst
ascend_tutorial/ascend_quick_start.rst
ascend_tutorial/ascend_profiling.rst
ascend_tutorial/ascend_profiling_zh.rst
ascend_tutorial/ascend_profiling_en.rst

.. toctree::
Expand Down
51 changes: 47 additions & 4 deletions recipe/one_step_off_policy/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register
from verl.utils import hf_processor, hf_tokenizer, omega_conf_to_dataclass
from verl.utils.debug import DistProfiler, DistProfilerExtension, log_gpu_memory_usage
from verl.utils.device import (
get_device_id,
get_device_name,
get_nccl_backend,
get_torch_device,
Expand All @@ -38,7 +38,8 @@
)
from verl.utils.import_utils import import_external_libs
from verl.utils.model import get_generation_config, update_model_config
from verl.utils.profiler import ProfilerConfig
from verl.utils.profiler import DistProfiler, DistProfilerExtension, ProfilerConfig, log_gpu_memory_usage, simple_timer
from verl.utils.profiler.performance import reduce_timing, topk_reduce_ratio_min_max
from verl.workers.fsdp_workers import ActorRolloutRefWorker as ARRWorker
from verl.workers.fsdp_workers import CriticWorker

Expand Down Expand Up @@ -231,8 +232,50 @@ def init_model(self):
self.rollout_sharding_manager = rollout_sharding_manager

@register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="rollout"), blocking=False)
def async_generate_sequences(self, *args, **kwargs):
return super().generate_sequences(*args, **kwargs)
def async_generate_sequences(self, prompts):
# Support all hardwares
prompts = prompts.to(get_device_id())

assert self._is_rollout

meta_info = {
"eos_token_id": self.generation_config.eos_token_id
if self.generation_config is not None
else self.tokenizer.eos_token_id,
"pad_token_id": self.generation_config.pad_token_id
if self.generation_config is not None
else self.tokenizer.pad_token_id,
}
prompts.meta_info.update(meta_info)
timing_generate = {}
with self.rollout_sharding_manager:
log_gpu_memory_usage("After entering rollout sharding manager", logger=logger)

with simple_timer("generate_sequences", timing_generate):
output = self.rollout.generate_sequences(prompts=prompts)

log_gpu_memory_usage("After rollout generation", logger=logger)

timing_generate.update(self.rollout_sharding_manager.timing)
# We calculate the average timing across all ranks
# to make sure meta_info["timing"] is the same
timing_generate_topk_ratio, timing_generate_min, timing_generate_max = topk_reduce_ratio_min_max(
timing_generate["generate_sequences"]
)
timing_generate = reduce_timing(timing_generate)
timing_generate.update(
{
"generation_timing/max": timing_generate_max,
"generation_timing/min": timing_generate_min,
"generation_timing/topk_ratio": timing_generate_topk_ratio,
}
)
output.meta_info["timing"] = timing_generate
output = output.to("cpu")

# clear kv cache
get_torch_device().empty_cache()
return output

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def set_actor_weights_info(self, weights_info):
Expand Down
68 changes: 68 additions & 0 deletions tests/single_controller/test_nested_worker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import ray

from verl.single_controller.base.decorator import Dispatch, register
from verl.single_controller.base.worker import Worker
from verl.single_controller.ray.base import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup


class TestActor(Worker):
# TODO: pass *args and **kwargs is bug prone and not very convincing
def __init__(self, x) -> None:
super().__init__()
self.a = x

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def get(self):
return self.a + self.rank


class TestHighLevelActor(Worker):
def __init__(self, x=None) -> None:
super().__init__()
self.test_actor = TestActor(x=x)

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def get(self):
return self.test_actor.get()


def test_nested_worker():
ray.init(num_cpus=100)

# create 4 workers, each hold a GPU
resource_pool = RayResourcePool([4], use_gpu=True)
class_with_args = RayClassWithInitArgs(cls=ray.remote(TestActor), x=2)

worker_group = RayWorkerGroup(
resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic"
)

output = worker_group.get()

assert output == [2, 3, 4, 5]

class_with_args = RayClassWithInitArgs(cls=ray.remote(TestHighLevelActor), x=2)
high_level_worker_group = RayWorkerGroup(
resource_pool=resource_pool, ray_cls_with_init=class_with_args, name_prefix="worker_group_basic_2"
)

output_1 = high_level_worker_group.get()

assert output_1 == [2, 3, 4, 5]

ray.shutdown()
5 changes: 3 additions & 2 deletions verl/workers/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .engine import * # noqa
from .optimizer import * # noqa
from .rollout import * # noqa
from . import actor, critic, engine, optimizer, rollout
from .model import * # noqa
from . import actor, critic, engine, optimizer, rollout, model

__all__ = actor.__all__ + critic.__all__ + engine.__all__ + optimizer.__all__ + rollout.__all__
__all__ = actor.__all__ + critic.__all__ + engine.__all__ + optimizer.__all__ + rollout.__all__ + model.__all__
111 changes: 111 additions & 0 deletions verl/workers/config/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Any, Optional

from omegaconf import MISSING
from transformers import AutoConfig

from verl.base_config import BaseConfig
from verl.utils import hf_processor, hf_tokenizer
from verl.utils.fs import copy_to_local
from verl.utils.model import get_generation_config, update_model_config

__all__ = ["HFModelConfig"]


@dataclass
class HFModelConfig(BaseConfig):
# note that we separate model_path, model_config_path and tokenizer_path in case they are different
_mutable_fields = {
"hf_config_path",
"tokenizer_path",
"hf_config",
"generation_config",
"tokenizer",
"processor",
"local_path",
}

path: str = MISSING
local_path: Optional[str] = None
hf_config_path: Optional[str] = None
tokenizer_path: Optional[str] = None

hf_config: Any = None
generation_config: Any = None
tokenizer: Any = None
processor: Any = None

# whether to use shared memory
use_shm: bool = False
trust_remote_code: bool = False

# custom chat template for the model
custom_chat_template: Optional[str] = None

external_lib: Optional[str] = None

override_config: dict = field(default_factory=dict)

enable_gradient_checkpointing: bool = True
enable_activation_offload: bool = False

use_remove_padding: bool = False

# lora related. We may setup a separate config later
lora_rank: int = 0
lora_alpha: int = 16
target_modules: Optional[str] = "all-linear"

exclude_modules: Optional[str] = None
use_liger: bool = False

use_fused_kernels: bool = False
fused_kernel_options: dict = field(default_factory=dict)

def __post_init__(self):
if self.hf_config_path is None:
self.hf_config_path = self.path
if self.tokenizer_path is None:
self.tokenizer_path = self.path

# constuct tokenizer
self.local_path = copy_to_local(self.path, use_shm=self.use_shm)
self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=self.trust_remote_code)
self.processor = hf_processor(self.local_path, trust_remote_code=self.trust_remote_code)

self.generation_config = get_generation_config(self.hf_config_path, trust_remote_code=self.trust_remote_code)

# constuct hf_config
attn_implementation = self.override_config.get("attn_implementation", "flash_attention_2")
self.hf_config = AutoConfig.from_pretrained(
self.hf_config_path, trust_remote_code=self.trust_remote_code, attn_implementation=attn_implementation
)

override_config_kwargs = {
"bos_token_id": self.tokenizer.bos_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
override_config_kwargs.update(self.override_config)
update_model_config(self.hf_config, override_config_kwargs=override_config_kwargs)

# per model patch
if getattr(self.hf_config, "model_type", None) == "kimi_vl":
self.hf_config.text_config.topk_method = "greedy"

def get_processor(self):
return self.processor if self.processor is not None else self.tokenizer
4 changes: 3 additions & 1 deletion verl/workers/config/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from dataclasses import dataclass, field
from typing import Optional

from omegaconf import MISSING

from verl.base_config import BaseConfig
from verl.utils.profiler import ProfilerConfig

Expand Down Expand Up @@ -77,7 +79,7 @@ class TraceConfig(BaseConfig):
class RolloutConfig(BaseConfig):
_mutable_fields = {"max_model_len"}

name: Optional[str] = None
name: Optional[str] = MISSING
mode: str = "sync"

temperature: float = 1.0
Expand Down
Loading
Loading