Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions .github/workflows/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,6 @@ jobs:
- name: Test the latest vLLM on model with rope scaling
run: |
torchrun --standalone --nnodes=1 --nproc_per_node=4 $(which pytest) -s tests/workers/rollout/rollout_vllm/test_vllm_model_rope_scaling.py
- name: Run Qwen 0.5B generation test
run: |
cd tests/special_e2e/generation
export OUTPUT_PATH="${HOME}/data/gen/qwen_05_gen_test.parquet"
MODEL_ID=${HOME}/models/Qwen/Qwen2.5-0.5B-Instruct NGPUS_PER_NODE=4 GEN_TP=2 bash ./run_gen_qwen05.sh
rm -rf "${OUTPUT_PATH}"
- name: Run Qwen 0.5B generation test when world_size == 1
run: |
cd tests/special_e2e/generation
export OUTPUT_PATH="${HOME}/data/gen/qwen_05_gen_test.parquet"
MODEL_ID=${HOME}/models/Qwen/Qwen2.5-0.5B-Instruct NGPUS_PER_NODE=1 GEN_TP=1 bash ./run_gen_qwen05.sh
rm -rf "${OUTPUT_PATH}"
# Note(haibin.lin): for any new test, please update gpu_unit_tests.yaml to avoid repeated tests

cleanup:
Expand Down
167 changes: 167 additions & 0 deletions recipe/gspo/test_gspo_qwen30b_a3b_ep.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
#!/usr/bin/env bash
set -xeuo pipefail

export NCCL_DEBUG=WARN
# export VERL_LOGGING_LEVEL=DEBUG

project_name='DAPO'
exp_name='GSPO-Qwen3-30B-A3B-Base-MATH'

adv_estimator=grpo

use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0

clip_ratio_low=3e-4
clip_ratio_high=4e-4

max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0

loss_agg_mode="token-mean"
loss_mode=gspo

train_prompt_bsz=256
n_resp_per_prompt=16
train_prompt_mini_bsz=32

# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-2}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
# Paths
# RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
# MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-30B-A3B-Base"}
# CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
# TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
# TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}

MODEL_PATH=$HDFS_ROOT/model/Qwen3-30B-A3B-Base
CKPTS_DIR=$DATA_ROOT/checkpoint/${project_name}/${exp_name}
TRAIN_FILE=$DATA_ROOT/dataset/BytedTsinghua-SIA/DAPO-Math-17k/data/dapo-math-17k.parquet
aime24_test_path=$DATA_ROOT/dataset/aime-2024.parquet

TEST_FILE="['$aime24_test_path']"

# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7

# Performance Related Parameter
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 1))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
offload=True

# gen
rollout_name=vllm # vllm or sglang
gen_tp=1
gen_dp=4
gen_ep=4

# train
train_tp=4
train_pp=1
EP=4
ETP=1

python3 -m verl.trainer.main_ppo \
--config-path=config \
--config-name='ppo_megatron_trainer.yaml' \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.return_raw_chat=True \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.actor.policy_loss.loss_mode=${loss_mode} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.optim.clip_grad=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.megatron.param_offload=${offload} \
actor_rollout_ref.actor.megatron.optimizer_offload=${offload} \
actor_rollout_ref.actor.megatron.grad_offload=${offload} \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \
actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \
actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.rollout.name=${rollout_name} \
actor_rollout_ref.rollout.mode=async \
actor_rollout_ref.rollout.calculate_log_probs=True \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.data_parallel_size=${gen_dp} \
actor_rollout_ref.rollout.expert_parallel_size=${gen_ep} \
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \
actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \
actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \
actor_rollout_ref.ref.megatron.param_offload=${offload} \
actor_rollout_ref.actor.megatron.use_mbridge=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \
+actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
trainer.logger='["console","wandb"]' \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}-tp${gen_tp}-ep${gen_ep}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=False \
trainer.test_freq=10 \
trainer.save_freq=30 \
trainer.total_epochs=10 \
trainer.total_training_steps=300 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
trainer.log_val_generations=10
71 changes: 66 additions & 5 deletions tests/experimental/agent_loop/test_standalone_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import pytest
import ray
from omegaconf import DictConfig
from openai import AsyncOpenAI
from openai import AsyncOpenAI, OpenAI

from tests.experimental.agent_loop.agent_utils import init_agent_loop_manager
from verl.workers.rollout.replica import get_rollout_replica_class


Expand All @@ -31,17 +32,18 @@ def init_config() -> DictConfig:

config.trainer.n_gpus_per_node = 4
config.trainer.nnodes = 2
config.actor_rollout_ref.actor.use_dynamic_bsz = True
config.actor_rollout_ref.model.path = os.path.expanduser("~/models/Qwen/Qwen2.5-1.5B-Instruct")
config.actor_rollout_ref.rollout.name = os.environ["ROLLOUT_NAME"]
config.actor_rollout_ref.rollout.load_format = "auto"
config.actor_rollout_ref.rollout.enforce_eager = True
config.actor_rollout_ref.rollout.mode = "async"
config.actor_rollout_ref.rollout.skip_tokenizer_init = False

return config


@pytest.mark.asyncio
@pytest.mark.parametrize("tp_size", [2, 4])
async def test_standalone_(init_config, tp_size):
async def test_standalone_rollout(init_config, tp_size):
"""Test standalone rollout single node and multi nodes."""
ray.init(
runtime_env={
Expand All @@ -54,7 +56,6 @@ async def test_standalone_(init_config, tp_size):
}
)

init_config.actor_rollout_ref.rollout.skip_tokenizer_init = False
init_config.actor_rollout_ref.rollout.tensor_model_parallel_size = tp_size
num_replicas = (init_config.trainer.n_gpus_per_node * init_config.trainer.nnodes) // tp_size

Expand Down Expand Up @@ -87,3 +88,63 @@ async def test_standalone_(init_config, tp_size):
print(completion.choices[0].message.content)

ray.shutdown()


@pytest.mark.skip(reason="local test only")
def test_hybrid_rollout_with_ep(init_config):
"""Test hybrid rollout with expert parallelism, DP=2, TP=4, EP=8."""
ray.init(
runtime_env={
"env_vars": {
"TOKENIZERS_PARALLELISM": "true",
"NCCL_DEBUG": "WARN",
"VLLM_LOGGING_LEVEL": "INFO",
"VLLM_USE_V1": "1",
}
}
)

model_path = os.path.expanduser("~/models/Qwen/Qwen3-30B-A3B-Instruct-2507")
init_config.actor_rollout_ref.model.path = model_path

# parallelism config
init_config.actor_rollout_ref.rollout.tensor_model_parallel_size = 2
init_config.actor_rollout_ref.rollout.data_parallel_size = 4
init_config.actor_rollout_ref.rollout.expert_parallel_size = 8

# 1. init hybrid worker: FSDP+rollout
# - build FSDP model and optimizer
# - offload FSDP model and optimizer, build rollout
# - sleep rollout and load FSDP model and optimizer
agent_loop_manager = init_agent_loop_manager(init_config)

# 2. wake up rollout
# - wake_up weights
# - load_weights from FSDP
# - wake_up kv_cache
agent_loop_manager.wake_up()

# 3. test async openai call
server_address = agent_loop_manager.server_addresses[0]
client = OpenAI(
api_key="123-abc",
base_url=f"http://{server_address}/v1",
)

smapling_params = {
"temperature": 1.0,
"top_p": 1.0,
"max_tokens": 512,
}

response = client.chat.completions.create(
model=model_path,
messages=[{"role": "user", "content": "What can you do?"}],
**smapling_params,
)

completion = response.choices[0].message.content
print(f"response: {completion}")

print("Test passed!")
ray.shutdown()
5 changes: 4 additions & 1 deletion verl/experimental/agent_loop/agent_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,10 @@ def batch_fn(data_list: list[DataProto]) -> list[torch.Tensor]:
self.sleep()

def _initialize_llm_servers(self):
rollout_world_size = self.config.actor_rollout_ref.rollout.tensor_model_parallel_size
rollout_world_size = (
self.config.actor_rollout_ref.rollout.tensor_model_parallel_size
* self.config.actor_rollout_ref.rollout.data_parallel_size
)
world_size = (
self.worker_group.world_size
if self.worker_group
Expand Down
2 changes: 2 additions & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ actor_rollout_ref:
cudagraph_capture_sizes: null
free_cache_engine: true
tensor_model_parallel_size: 2
data_parallel_size: 1
expert_parallel_size: 1
max_num_batched_tokens: 8192
max_model_len: null
max_num_seqs: 1024
Expand Down
2 changes: 2 additions & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ actor_rollout_ref:
cudagraph_capture_sizes: null
free_cache_engine: true
tensor_model_parallel_size: 2
data_parallel_size: 1
expert_parallel_size: 1
max_num_batched_tokens: 8192
max_model_len: null
max_num_seqs: 1024
Expand Down
6 changes: 6 additions & 0 deletions verl/trainer/config/rollout/rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ free_cache_engine: True
# TP size for rollout. Not effective for hf
tensor_model_parallel_size: 2

# DP size for rollout
data_parallel_size: 1

# EP size for rollout
expert_parallel_size: 1

# max number of tokens in a batch
max_num_batched_tokens: 8192

Expand Down
8 changes: 4 additions & 4 deletions verl/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def set_numa_affinity():
# TODO (FightingZhen) libnuma.so is not available in e2e_ascend CI image, remove this code after image update.
return

libnuma = ctypes.CDLL("libnuma.so")
if libnuma.numa_available() < 0:
return

initialized = False
try:
libnuma = ctypes.CDLL("libnuma.so")
if libnuma.numa_available() < 0:
return

import pynvml

pynvml.nvmlInit()
Expand Down
7 changes: 7 additions & 0 deletions verl/workers/config/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,10 @@ class RolloutConfig(BaseConfig):
limit_images: Optional[int] = None

skip_tokenizer_init: bool = False

def __post_init__(self):
"""Validate the rollout config"""
if self.expert_parallel_size > 1:
assert self.expert_parallel_size == (self.tensor_model_parallel_size * self.data_parallel_size), (
"expert_parallel_size must be equal to tensor_model_parallel_size * data_parallel_size"
)
2 changes: 1 addition & 1 deletion verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def _build_rollout(self, trust_remote_code=False):
self.model_config = model_config

# 2. build rollout device mesh
infer_tp = self.config.rollout.tensor_model_parallel_size
infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size
dp = self.world_size // infer_tp
assert self.world_size % infer_tp == 0, (
f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
Expand Down
2 changes: 1 addition & 1 deletion verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def _build_rollout(self, trust_remote_code=False):
model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model, dataclass_type=HFModelConfig)

# 2. build rollout device mesh
infer_tp = self.config.rollout.tensor_model_parallel_size
infer_tp = self.config.rollout.tensor_model_parallel_size * self.config.rollout.data_parallel_size
dp = self.world_size // infer_tp
assert self.world_size % infer_tp == 0, (
f"rollout world_size: {self.world_size} is not divisible by infer_tp: {infer_tp}"
Expand Down
Loading