Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions examples/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,16 @@ data:

algorithm:
adv_estimator: grpo
kl_coef: 0.0
use_kl_loss: true
kl_penalty: low_var_kl
kl_coef: 1.0e-2

worker:
actor:
global_batch_size: 128
micro_batch_size_per_device_for_update: 4
micro_batch_size_per_device_for_experience: 16
max_grad_norm: 1.0
use_kl_loss: true
kl_loss_coef: 1.0e-2
kl_loss_type: low_var_kl
padding_free: true
ulysses_sequence_parallel_size: 1
model:
Expand Down Expand Up @@ -72,7 +71,7 @@ trainer:
total_episodes: 15
logger: ["console", "wandb"]
project_name: easy_r1
experiment_name: qwen2_5_7b_math
experiment_name: qwen2_5_7b_math_grpo
n_gpus_per_node: 8
nnodes: 1
val_freq: 5
Expand Down
2 changes: 1 addition & 1 deletion examples/qwen2_5_vl_3b_geo3k_grpo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ python3 -m verl.trainer.main \
worker.actor.model.model_path=${MODEL_PATH} \
worker.rollout.tensor_parallel_size=1 \
worker.rollout.enable_chunked_prefill=false \
trainer.experiment_name=qwen2_5_vl_3b_geo \
trainer.experiment_name=qwen2_5_vl_3b_geo_grpo \
trainer.n_gpus_per_node=2
2 changes: 1 addition & 1 deletion examples/qwen2_5_vl_7b_geo3k_swanlab.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ python3 -m verl.trainer.main \
data.system_prompt="${SYSTEM_PROMPT}" \
worker.actor.model.model_path=${MODEL_PATH} \
worker.rollout.enable_chunked_prefill=false \
trainer.experiment_name=qwen2_5_vl_7b_geo \
trainer.experiment_name=qwen2_5_vl_7b_geo_grpo \
trainer.logger=['console','swanlab'] \
trainer.n_gpus_per_node=8
22 changes: 2 additions & 20 deletions verl/models/transformers/flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ def _custom_flash_attention_forward(
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}

if _flash_supports_deterministic:
if deterministic is None:
deterministic = _flash_deterministic_enabled
flash_kwargs["deterministic"] = deterministic
flash_kwargs["deterministic"] = deterministic if deterministic is not None else _flash_deterministic_enabled

if kwargs.get("softcap") is not None:
flash_kwargs["softcap"] = kwargs.pop("softcap")
Expand All @@ -114,7 +112,7 @@ def _custom_flash_attention_forward(
batch_size = query_states.size(0)
query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
query_states, key_states, value_states, position_ids
) # remove channel dimension
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output = flash_attn_varlen_func(
Expand Down Expand Up @@ -172,21 +170,6 @@ def flash_attention_forward(
key = key.transpose(1, 2)
value = value.transpose(1, 2)

# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (usually our RMSNorm modules handle it correctly)
target_dtype = None
if query.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(module.config, "_pre_quantization_dtype"):
target_dtype = module.config._pre_quantization_dtype
else:
target_dtype = next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype

# FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
kwargs.pop("is_causal", None)

Expand All @@ -202,7 +185,6 @@ def flash_attention_forward(
sliding_window=sliding_window,
softcap=softcap,
use_top_left_mask=_flash_use_top_left_mask,
target_dtype=target_dtype,
**kwargs,
)

Expand Down
3 changes: 2 additions & 1 deletion verl/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def __len__(self):
def __getitem__(self, item):
tensor_data = self.batch[item]
non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
return_type = DataProto if isinstance(item, slice) else DataProtoItem
return return_type(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)

def __getstate__(self):
buffer = io.BytesIO()
Expand Down
71 changes: 39 additions & 32 deletions verl/single_controller/base/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from enum import Enum, auto
from functools import wraps
from types import FunctionType
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING, Dict, List, Literal, Union

import ray

Expand Down Expand Up @@ -45,16 +45,16 @@ class Execute(Enum):
RANK_ZERO = 1


def _split_args_kwargs_data_proto(chunks, *args, **kwargs):
def _split_args_kwargs_data_proto(chunks: int, *args, **kwargs):
splitted_args = []
for arg in args:
assert isinstance(arg, (DataProto, DataProtoFuture))
splitted_args.append(arg.chunk(chunks=chunks))

splitted_kwargs = {}
for key, val in kwargs.items():
assert isinstance(val, (DataProto, DataProtoFuture))
splitted_kwargs[key] = val.chunk(chunks=chunks)
for key, value in kwargs.items():
assert isinstance(value, (DataProto, DataProtoFuture))
splitted_kwargs[key] = value.chunk(chunks=chunks)

return splitted_args, splitted_kwargs

Expand All @@ -73,34 +73,34 @@ def collect_all_to_all(worker_group: "WorkerGroup", output):
return output


def _concat_data_proto_or_future(output: List):
def _concat_data_proto_or_future(outputs: List[DataProto]) -> DataProto:
# make sure all the elements in output has the same type
for o in output:
assert type(o) is type(output[0])
for output in outputs:
assert type(output) is type(outputs[0])

o = output[0]
output = outputs[0]

if isinstance(o, DataProto):
return DataProto.concat(output)
elif isinstance(o, ray.ObjectRef):
return DataProtoFuture.concat(output)
if isinstance(output, DataProto):
return DataProto.concat(outputs)
elif isinstance(output, ray.ObjectRef):
return DataProtoFuture.concat(outputs)
else:
raise NotImplementedError


def dispatch_dp_compute(worker_group: "WorkerGroup", *args, **kwargs):
for value in args:
assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size
for arg in args:
assert isinstance(arg, (tuple, list)) and len(arg) == worker_group.world_size

for value in kwargs.values():
assert isinstance(value, (tuple, list)) and len(value) == worker_group.world_size

return args, kwargs


def collect_dp_compute(worker_group: "WorkerGroup", output):
assert len(output) == worker_group.world_size
return output
def collect_dp_compute(worker_group: "WorkerGroup", outputs: List[DataProto]) -> List[DataProto]:
assert len(outputs) == worker_group.world_size
return outputs


def dispatch_dp_compute_data_proto(worker_group: "WorkerGroup", *args, **kwargs):
Expand All @@ -115,15 +115,15 @@ def dispatch_dp_compute_data_proto_with_func(worker_group: "WorkerGroup", *args,
return splitted_args_with_func, splitted_kwargs


def collect_dp_compute_data_proto(worker_group: "WorkerGroup", output):
for o in output:
assert isinstance(o, (DataProto, ray.ObjectRef)), f"expecting {o} to be DataProto, but got {type(o)}."
def collect_dp_compute_data_proto(worker_group: "WorkerGroup", outputs: List[DataProto]) -> DataProto:
for output in outputs:
assert isinstance(output, (DataProto, ray.ObjectRef)), f"Expect a DataProto, but got {type(output)}"

output = collect_dp_compute(worker_group, output)
return _concat_data_proto_or_future(output)
outputs = collect_dp_compute(worker_group, outputs)
return _concat_data_proto_or_future(outputs)


def get_predefined_dispatch_fn(dispatch_mode):
def get_predefined_dispatch_fn(dispatch_mode: Dispatch):
predefined_dispatch_mode_fn = {
Dispatch.ONE_TO_ALL: {
"dispatch_fn": dispatch_one_to_all,
Expand All @@ -133,7 +133,10 @@ def get_predefined_dispatch_fn(dispatch_mode):
"dispatch_fn": dispatch_all_to_all,
"collect_fn": collect_all_to_all,
},
Dispatch.DP_COMPUTE: {"dispatch_fn": dispatch_dp_compute, "collect_fn": collect_dp_compute},
Dispatch.DP_COMPUTE: {
"dispatch_fn": dispatch_dp_compute,
"collect_fn": collect_dp_compute,
},
Dispatch.DP_COMPUTE_PROTO: {
"dispatch_fn": dispatch_dp_compute_data_proto,
"collect_fn": collect_dp_compute_data_proto,
Expand All @@ -142,12 +145,15 @@ def get_predefined_dispatch_fn(dispatch_mode):
"dispatch_fn": dispatch_dp_compute_data_proto_with_func,
"collect_fn": collect_dp_compute_data_proto,
},
Dispatch.DP_COMPUTE_METRIC: {"dispatch_fn": dispatch_dp_compute_data_proto, "collect_fn": collect_dp_compute},
Dispatch.DP_COMPUTE_METRIC: {
"dispatch_fn": dispatch_dp_compute_data_proto,
"collect_fn": collect_dp_compute,
},
}
return predefined_dispatch_mode_fn[dispatch_mode]


def get_predefined_execute_fn(execute_mode):
def get_predefined_execute_fn(execute_mode: Execute):
"""
Note that here we only asks execute_all and execute_rank_zero to be implemented
Leave the choice of how these two functions handle argument 'blocking' to users
Expand All @@ -159,7 +165,7 @@ def get_predefined_execute_fn(execute_mode):
return predefined_execute_mode_fn[execute_mode]


def _check_dispatch_mode(dispatch_mode):
def _check_dispatch_mode(dispatch_mode: Union[Dispatch, Dict[Literal["dispatch_fn", "collect_fn"], FunctionType]]):
assert isinstance(dispatch_mode, (Dispatch, dict)), (
f"dispatch_mode must be a Dispatch or a Dict. Got {dispatch_mode}"
)
Expand All @@ -169,7 +175,7 @@ def _check_dispatch_mode(dispatch_mode):
assert key in dispatch_mode, f"key {key} should be in dispatch_mode if it is a dictionary"


def _check_execute_mode(execute_mode):
def _check_execute_mode(execute_mode: Execute):
assert isinstance(execute_mode, Execute), f"execute_mode must be a Execute. Got {execute_mode}"


Expand All @@ -180,9 +186,10 @@ def _materialize_futures(*args, **kwargs):
arg = arg.get()
# add more type to materialize
new_args.append(arg)
for k, v in kwargs.items():
if isinstance(v, DataProtoFuture):
kwargs[k] = v.get()

for key, value in kwargs.items():
if isinstance(value, DataProtoFuture):
kwargs[key] = value.get()

new_args = tuple(new_args)
return new_args, kwargs
Expand Down
40 changes: 29 additions & 11 deletions verl/single_controller/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
import os
import socket
from dataclasses import dataclass
from typing import Tuple

import ray
import torch

from .decorator import Dispatch, Execute, register
from .register_center.ray import create_worker_group_register_center
Expand All @@ -40,7 +42,7 @@ class DistGlobalInfo:


class WorkerHelper:
def _get_node_ip(self):
def _get_node_ip(self) -> str:
host_ipv4 = os.getenv("MY_HOST_IP", None)
host_ipv6 = os.getenv("MY_HOST_IPV6", None)
host_ip_by_env = host_ipv4 or host_ipv6
Expand All @@ -49,12 +51,12 @@ def _get_node_ip(self):
host_ip = host_ip_by_env or host_ip_by_sdk
return host_ip

def _get_free_port(self):
def _get_free_port(self) -> int:
with socket.socket() as sock:
sock.bind(("", 0))
return sock.getsockname()[1]

def get_availale_master_addr_port(self):
def get_availale_master_addr_port(self) -> Tuple[str, str]:
return self._get_node_ip(), str(self._get_free_port())

def _get_pid(self):
Expand All @@ -81,16 +83,26 @@ def to_dict(self):

# we assume that in each WorkerGroup, there is a Master Worker
class Worker(WorkerHelper):
"""A (distributed) worker."""

_world_size: int
_rank: int
_local_world_size: int
_local_rank: int
_master_addr: str
_master_port: str
_cuda_visible_devices: str

def __new__(cls, *args, **kwargs):
instance = super().__new__(cls)

# note that here we use int to distinguish
disable_worker_init = int(os.environ.get("DISABLE_WORKER_INIT", 0))
disable_worker_init = int(os.getenv("DISABLE_WORKER_INIT", 0))
if disable_worker_init:
return instance

rank = os.environ.get("RANK", None)
worker_group_prefix = os.environ.get("WG_PREFIX", None)
rank = os.getenv("RANK", None)
worker_group_prefix = os.getenv("WG_PREFIX", None)

# when decorator @ray.remote applies, __new__ will be called while we don't want to apply _configure_before_init
if None not in [rank, worker_group_prefix] and "ActorClass(" not in cls.__name__:
Expand All @@ -112,13 +124,19 @@ def _configure_before_init(self, register_center_name: str, rank: int):

def __init__(self, cuda_visible_devices=None) -> None:
# construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])
world_size = int(os.getenv("WORLD_SIZE"))
rank = int(os.getenv("RANK"))
self._rank = rank
self._world_size = world_size

master_addr = os.environ["MASTER_ADDR"]
master_port = os.environ["MASTER_PORT"]
if "AMD" in torch.cuda.get_device_name():
os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv("ROCR_VISIBLE_DEVICES")
os.environ["LOCAL_RANK"] = os.getenv("RAY_LOCAL_RANK")
cuda_visible_devices = os.getenv("LOCAL_RANK", "0")
torch.cuda.set_device(int(cuda_visible_devices))

master_addr = os.getenv("MASTER_ADDR")
master_port = os.getenv("MASTER_PORT")

local_world_size = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
Expand Down Expand Up @@ -158,7 +176,7 @@ def get_master_addr_port(self):
return self._master_addr, self._master_port

def get_cuda_visible_devices(self):
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "not set")
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "not set")
return cuda_visible_devices

def print_rank0(self, *args, **kwargs):
Expand Down
10 changes: 5 additions & 5 deletions verl/single_controller/base/worker_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@
import signal
import threading
import time
from typing import Any, Callable, Dict, List
from typing import Any, Callable, Dict, List, Optional

from .decorator import MAGIC_ATTR, Dispatch, get_predefined_dispatch_fn, get_predefined_execute_fn


class ResourcePool:
"""The resource pool with meta info such as world_size."""
"""The resource pool with meta info such as world size."""

def __init__(self, process_on_nodes=None, max_collocate_count: int = 10, n_gpus_per_node=8) -> None:
def __init__(
self, process_on_nodes: Optional[Any] = None, max_collocate_count: int = 10, n_gpus_per_node: int = 8
) -> None:
if process_on_nodes is None:
process_on_nodes = []

Expand Down Expand Up @@ -76,8 +78,6 @@ def __call__(self) -> Any:


def check_workers_alive(workers: List, is_alive: Callable, gap_time: float = 1) -> None:
import time

while True:
for worker in workers:
if not is_alive(worker):
Expand Down
Loading