Skip to content

Commit cfb0dff

Browse files
committed
pre dapo
1 parent dafd854 commit cfb0dff

File tree

8 files changed

+210
-158
lines changed

8 files changed

+210
-158
lines changed

examples/config.yaml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ data:
77
image_dir: null
88
max_prompt_length: 2048
99
max_response_length: 2048
10-
rollout_batch_size: 512
10+
rollout_batch_size: 512 # equivalent to verl's data.train_batch_size
11+
mini_rollout_batch_size: null # equivalent to verl's data.gen_batch_size
1112
val_batch_size: 1024
1213
format_prompt: ./examples/format_prompt/math_format.jinja
1314
override_chat_template: null
@@ -26,9 +27,9 @@ algorithm:
2627

2728
worker:
2829
actor:
29-
global_batch_size: 128
30-
micro_batch_size_per_device_for_update: 4
31-
micro_batch_size_per_device_for_experience: 16
30+
global_batch_size: 128 # equivalent to verl's actor.ppo_mini_batch_size
31+
micro_batch_size_per_device_for_update: 4 # equivalent to verl's actor.ppo_micro_batch_size_per_gpu
32+
micro_batch_size_per_device_for_experience: 16 # equivalent to verl's rollout.log_prob_micro_batch_size_per_gpu
3233
max_grad_norm: 1.0
3334
padding_free: true
3435
ulysses_size: 1

verl/trainer/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class DataConfig:
4242
max_prompt_length: int = 512
4343
max_response_length: int = 512
4444
rollout_batch_size: int = 512
45+
mini_rollout_batch_size: Optional[int] = None
4546
val_batch_size: int = -1
4647
format_prompt: Optional[str] = None
4748
override_chat_template: Optional[str] = None

verl/trainer/data_loader.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,14 @@ def create_dataloader(config: DataConfig, tokenizer: PreTrainedTokenizer, proces
4747
else:
4848
sampler = SequentialSampler(data_source=train_dataset)
4949

50+
if config.mini_rollout_batch_size is not None:
51+
train_batch_size = config.mini_rollout_batch_size
52+
else:
53+
train_batch_size = config.rollout_batch_size
54+
5055
train_dataloader = StatefulDataLoader(
5156
dataset=train_dataset,
52-
batch_size=config.rollout_batch_size,
57+
batch_size=train_batch_size,
5358
sampler=sampler,
5459
num_workers=8,
5560
collate_fn=collate_fn,
@@ -72,9 +77,15 @@ def create_dataloader(config: DataConfig, tokenizer: PreTrainedTokenizer, proces
7277
max_pixels=config.max_pixels,
7378
filter_overlong_prompts=config.filter_overlong_prompts,
7479
)
80+
81+
if config.val_batch_size == -1:
82+
val_batch_size = len(val_dataset)
83+
else:
84+
val_batch_size = config.val_batch_size
85+
7586
val_dataloader = StatefulDataLoader(
7687
dataset=val_dataset,
77-
batch_size=len(val_dataset) if config.val_batch_size == -1 else config.val_batch_size,
88+
batch_size=val_batch_size,
7889
shuffle=False,
7990
num_workers=8,
8091
collate_fn=collate_fn,

verl/trainer/ray_trainer.py

Lines changed: 152 additions & 123 deletions
Large diffs are not rendered by default.

verl/workers/actor/dp_actor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor:
192192
)
193193
log_probs_lst = []
194194
if self.rank == 0:
195-
micro_batches = tqdm(micro_batches, desc="Compute log probs", position=2)
195+
micro_batches = tqdm(micro_batches, desc="Compute log probs", position=1)
196196

197197
for micro_batch in micro_batches:
198198
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
@@ -217,15 +217,15 @@ def update_policy(self, data: DataProto) -> Dict[str, Any]:
217217
metrics = defaultdict(list)
218218
for _ in range(self.config.ppo_epochs):
219219
if self.rank == 0:
220-
mini_batches = tqdm(mini_batches, desc="Train mini-batches", position=2)
220+
mini_batches = tqdm(mini_batches, desc="Train mini-batches", position=1)
221221

222222
for mini_batch in mini_batches:
223223
gradient_accumulation = (
224224
self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update
225225
)
226226
micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update)
227227
if self.rank == 0:
228-
micro_batches = tqdm(micro_batches, desc="Update policy", position=3)
228+
micro_batches = tqdm(micro_batches, desc="Update policy", position=2)
229229

230230
for micro_batch in micro_batches:
231231
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}

verl/workers/critic/dp_critic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def compute_values(self, data: DataProto) -> torch.Tensor:
149149
)
150150
values_lst = []
151151
if self.rank == 0:
152-
micro_batches = tqdm(micro_batches, desc="Compute values", position=2)
152+
micro_batches = tqdm(micro_batches, desc="Compute values", position=1)
153153

154154
for micro_batch in micro_batches:
155155
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
@@ -176,15 +176,15 @@ def update_critic(self, data: DataProto) -> Dict[str, Any]:
176176
metrics = defaultdict(list)
177177
for _ in range(self.config.ppo_epochs):
178178
if self.rank == 0:
179-
mini_batches = tqdm(mini_batches, desc="Train mini-batches", position=2)
179+
mini_batches = tqdm(mini_batches, desc="Train mini-batches", position=1)
180180

181181
for mini_batch in mini_batches:
182182
gradient_accumulation = (
183183
self.config.global_batch_size_per_device // self.config.micro_batch_size_per_device_for_update
184184
)
185185
micro_batches = mini_batch.split(self.config.micro_batch_size_per_device_for_update)
186186
if self.rank == 0:
187-
micro_batches = tqdm(micro_batches, desc="Update critic", position=3)
187+
micro_batches = tqdm(micro_batches, desc="Update critic", position=2)
188188

189189
for micro_batch in micro_batches:
190190
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}

verl/workers/fsdp_workers.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def _build_rollout(self) -> None:
339339
module=self.fsdp_module,
340340
inference_engine=self.rollout.inference_engine,
341341
device_mesh=rollout_device_mesh,
342+
use_param_offload=self._use_param_offload,
342343
)
343344
print_gpu_memory_usage("After vllm init")
344345

@@ -518,9 +519,6 @@ def update_actor(self, data: DataProto):
518519
def generate_sequences(self, prompts: DataProto):
519520
assert self._has_rollout
520521

521-
if self._use_param_offload:
522-
load_fsdp_model(self.fsdp_module)
523-
524522
meta_info = {
525523
"eos_token_id": self.generation_config.eos_token_id
526524
if self.generation_config is not None
@@ -530,14 +528,8 @@ def generate_sequences(self, prompts: DataProto):
530528
else self.tokenizer.pad_token_id,
531529
}
532530
prompts.meta_info.update(meta_info)
531+
self.rollout_sharding_manager.skip_vllm_sync_once = prompts.meta_info.get("skip_vllm_sync_once", False)
533532
with self.rollout_sharding_manager:
534-
# after parameters sync with rollout, offload actor model to CPU
535-
if self._use_param_offload:
536-
offload_fsdp_model(self.fsdp_module)
537-
538-
if self._use_optimizer_offload:
539-
offload_fsdp_optimizer(optimizer=self.optimizer)
540-
541533
prompts = self.rollout_sharding_manager.preprocess_data(prompts)
542534
output = self.rollout.generate_sequences(prompts=prompts)
543535
output = self.rollout_sharding_manager.postprocess_data(output)

verl/workers/sharding_manager/fsdp_vllm.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from vllm.distributed import parallel_state as vllm_ps
2828

2929
from ...protocol import DataProto, all_gather_data_proto
30-
from ...utils.model_utils import print_gpu_memory_usage
30+
from ...utils.fsdp_utils import load_fsdp_model, offload_fsdp_model
31+
from ...utils.model_utils import is_rank0, print_gpu_memory_usage
3132
from .base import BaseShardingManager
3233

3334

@@ -37,10 +38,13 @@ def __init__(
3738
module: FSDP,
3839
inference_engine: LLM,
3940
device_mesh: DeviceMesh,
41+
use_param_offload: bool,
4042
):
4143
self.module = module
4244
self.inference_engine = inference_engine
4345
self.device_mesh = device_mesh
46+
self.use_param_offload = use_param_offload
47+
self.skip_vllm_sync_once = False
4448

4549
self.world_size = dist.get_world_size()
4650
self.tp_size = vllm_ps.get_tensor_model_parallel_world_size()
@@ -85,6 +89,24 @@ def _make_weight_iterator(
8589
for name, tensor in actor_weights.items():
8690
yield name, tensor.full_tensor() if self.world_size != 1 else tensor
8791

92+
def _sync_weight_to_vllm(self):
93+
if self.use_param_offload:
94+
load_fsdp_model(self.module)
95+
96+
actor_weights = get_model_state_dict(self.module)
97+
actor_weights = self._rename_weight_keys(actor_weights, self.module._fsdp_wrapped_module)
98+
print_gpu_memory_usage("After gather model weights in sharding manager")
99+
100+
model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
101+
model.load_weights(self._make_weight_iterator(actor_weights))
102+
103+
del actor_weights
104+
if self.use_param_offload:
105+
offload_fsdp_model(self.module)
106+
107+
torch.cuda.empty_cache()
108+
print_gpu_memory_usage("After sync model weights in sharding manager")
109+
88110
def __enter__(self):
89111
# NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and
90112
# after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.
@@ -94,27 +116,23 @@ def __enter__(self):
94116
# pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management
95117
# vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
96118
torch.cuda.empty_cache()
97-
print_gpu_memory_usage("Before state_dict() in sharding manager")
98-
actor_weights = get_model_state_dict(self.module)
99-
actor_weights = self._rename_weight_keys(actor_weights, self.module._fsdp_wrapped_module)
100-
print_gpu_memory_usage("After state_dict() in sharding manager")
101-
119+
print_gpu_memory_usage("Before vllm wake up in sharding manager")
102120
if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
103121
self.inference_engine.wake_up(tags=["weights"])
104122
else:
105123
self.inference_engine.wake_up()
106124

107-
model = self.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
108-
model.load_weights(self._make_weight_iterator(actor_weights))
109-
print_gpu_memory_usage("After sync model weights in sharding manager")
110-
111-
del actor_weights
112-
torch.cuda.empty_cache()
125+
if self.skip_vllm_sync_once:
126+
self.skip_vllm_sync_once = False # reset the flag
127+
if is_rank0():
128+
print("Skip vllm weight sync in sharding manager once.")
129+
else:
130+
self._sync_weight_to_vllm()
113131

114132
if "tags" in inspect.signature(self.inference_engine.wake_up).parameters:
115133
self.inference_engine.wake_up(tags=["kv_cache"])
116134

117-
print_gpu_memory_usage("After del state_dict and empty_cache in sharding manager")
135+
print_gpu_memory_usage("After vllm wake up in sharding manager")
118136
# important: need to manually set the random states of each tp to be identical.
119137
if self.device_mesh is not None:
120138
self.torch_random_states = torch.cuda.get_rng_state()

0 commit comments

Comments
 (0)