Skip to content

Commit d32fbd0

Browse files
committed
[worker] colocate actor and ref model (#342)
1 parent 6870bf9 commit d32fbd0

File tree

6 files changed

+86
-31
lines changed

6 files changed

+86
-31
lines changed

examples/config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,5 +89,6 @@ trainer:
8989
val_generations_to_log: 3
9090
save_freq: 5 # -1 to disable
9191
save_limit: 3 # -1 to disable
92+
save_model_only: false
9293
save_checkpoint_path: null
9394
load_checkpoint_path: null

verl/trainer/config.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,22 +93,40 @@ class AlgorithmConfig:
9393

9494
@dataclass
9595
class TrainerConfig:
96-
total_epochs: int = 10
96+
total_epochs: int = 15
97+
"""total epochs for training"""
9798
max_steps: Optional[int] = None
99+
"""max steps for training, if specified, total_epochs is ignored"""
98100
project_name: str = "easy_r1"
101+
"""project name for logger"""
99102
experiment_name: str = "demo"
103+
"""experiment name for logger"""
100104
logger: Tuple[str] = ("console", "wandb")
105+
"""logger type, support `console`, `mlflow`, `swanlab`, `tensorboard`, `wandb`"""
101106
nnodes: int = 1
107+
"""number of nodes for training"""
102108
n_gpus_per_node: int = 8
109+
"""number of gpus per node for training"""
103110
critic_warmup: int = 0
111+
"""critic warmup steps"""
104112
val_freq: int = -1
113+
"""validation frequency, -1 means no validation"""
105114
val_before_train: bool = True
115+
"""validate before training"""
106116
val_only: bool = False
117+
"""validate only, skip training"""
107118
val_generations_to_log: int = 0
119+
"""number of generations to log for validation"""
108120
save_freq: int = -1
121+
"""save frequency, -1 means no saving"""
109122
save_limit: int = -1
123+
"""max number of checkpoints to save, -1 means no limit"""
124+
save_model_only: bool = False
125+
"""save model only, no optimizer state dict"""
110126
save_checkpoint_path: Optional[str] = None
127+
"""save checkpoint path, if not specified, use `checkpoints/project_name/experiment_name`"""
111128
load_checkpoint_path: Optional[str] = None
129+
"""load checkpoint path"""
112130

113131
def post_init(self):
114132
if self.save_checkpoint_path is None:

verl/trainer/ray_trainer.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
This trainer supports model-agonistic model initialization with huggingface
1717
"""
1818

19+
import json
1920
import os
2021
import uuid
2122
from collections import defaultdict
@@ -176,6 +177,10 @@ def __init__(
176177
self.reward_fn = reward_fn
177178
self.val_reward_fn = val_reward_fn
178179

180+
self.val_reward_score = 0.0
181+
self.best_val_reward_score = -1.0
182+
self.best_global_step = None
183+
179184
self.hybrid_engine = config.worker.hybrid_engine
180185
self.role_worker_mapping = role_worker_mapping
181186
self.resource_pool_manager = resource_pool_manager
@@ -258,6 +263,7 @@ def _validate(self) -> Dict[str, Any]:
258263
# Lists to collect samples for the table
259264
sample_inputs, sample_outputs, sample_labels, sample_scores = [], [], [], []
260265
reward_metrics_lst = defaultdict(list)
266+
print("Start validation...")
261267
for batch_dict in self.val_dataloader:
262268
test_batch = DataProto.from_single_dict(batch_dict)
263269
# Store original inputs
@@ -295,9 +301,10 @@ def _validate(self) -> Dict[str, Any]:
295301
reward_metrics_lst[key].extend(value)
296302

297303
self._maybe_log_val_generations(sample_inputs, sample_outputs, sample_labels, sample_scores)
298-
reward_score = torch.cat(reward_tensor_lst, dim=0).sum(-1).mean().item()
304+
self.val_reward_score = torch.cat(reward_tensor_lst, dim=0).sum(-1).mean().item()
299305
val_reward_metrics = {f"val/{key}_reward": value for key, value in reduce_metrics(reward_metrics_lst).items()}
300-
return {"val/reward_score": reward_score, **val_reward_metrics}
306+
print("Finish validation.")
307+
return {"val/reward_score": self.val_reward_score, **val_reward_metrics}
301308

302309
def init_workers(self) -> None:
303310
"""Init resource pool and worker group"""
@@ -359,24 +366,37 @@ def init_workers(self) -> None:
359366

360367
def _save_checkpoint(self) -> None:
361368
# path: {save_checkpoint_path}/global_step_{global_step}/{actor,critic}
369+
if self.val_reward_score > self.best_val_reward_score:
370+
self.best_val_reward_score = self.val_reward_score
371+
self.best_global_step = self.global_step
372+
362373
remove_obsolete_ckpt(
363-
self.config.trainer.save_checkpoint_path, self.global_step, self.config.trainer.save_limit
374+
self.config.trainer.save_checkpoint_path,
375+
self.global_step,
376+
self.best_global_step,
377+
self.config.trainer.save_limit,
364378
)
365379
folder_path = os.path.join(self.config.trainer.save_checkpoint_path, f"global_step_{self.global_step}")
366380
actor_path = os.path.join(folder_path, "actor")
367381
self.actor_rollout_ref_wg.save_checkpoint(actor_path)
368382

369383
if self.use_critic:
370384
critic_path = os.path.join(folder_path, "critic")
371-
self.critic_wg.save_checkpoint(critic_path)
385+
self.critic_wg.save_checkpoint(critic_path, save_model_only=self.config.trainer.save_model_only)
372386

373387
dataloader_path = os.path.join(folder_path, "dataloader.pt")
374388
dataloader_state_dict = self.train_dataloader.state_dict()
375389
torch.save(dataloader_state_dict, dataloader_path)
376390

377-
last_global_step_path = os.path.join(self.config.trainer.save_checkpoint_path, CHECKPOINT_TRACKER)
378-
with open(last_global_step_path, "w") as f:
379-
f.write(str(self.global_step))
391+
checkpointer_tracker_info = {
392+
"best_global_step": self.best_global_step,
393+
"best_val_reward_score": round(self.best_val_reward_score, 4),
394+
"last_global_step": self.global_step,
395+
"last_actor_path": os.path.abspath(actor_path),
396+
}
397+
checkpointer_tracker_path = os.path.join(self.config.trainer.save_checkpoint_path, CHECKPOINT_TRACKER)
398+
with open(checkpointer_tracker_path, "w") as f:
399+
json.dump(checkpointer_tracker_info, f, ensure_ascii=False, indent=2)
380400

381401
def _load_checkpoint(self) -> None:
382402
if self.config.trainer.load_checkpoint_path is None:

verl/utils/checkpoint/checkpoint_manager.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from transformers import PreTrainedTokenizer, ProcessorMixin
2929

3030

31-
CHECKPOINT_TRACKER = "latest_global_step.txt"
31+
CHECKPOINT_TRACKER = "checkpoint_tracker.json"
3232

3333

3434
class BaseCheckpointManager(ABC):
@@ -135,7 +135,9 @@ def get_checkpoint_tracker_filename(root_path: str) -> str:
135135
return os.path.join(root_path, CHECKPOINT_TRACKER)
136136

137137

138-
def remove_obsolete_ckpt(path: str, global_step: int, save_limit: int = -1, directory_format: str = "global_step_{}"):
138+
def remove_obsolete_ckpt(
139+
path: str, global_step: int, best_global_step: int, save_limit: int = -1, directory_format: str = "global_step_{}"
140+
):
139141
"""
140142
Remove the obsolete checkpoints that exceed the save_limit.
141143
"""
@@ -146,15 +148,19 @@ def remove_obsolete_ckpt(path: str, global_step: int, save_limit: int = -1, dire
146148
return
147149

148150
pattern = re.escape(directory_format).replace(r"\{\}", r"(\d+)")
149-
ckpt_folders = []
151+
ckpt_global_steps = []
150152
for folder in os.listdir(path):
151153
if match := re.match(pattern, folder):
152154
step = int(match.group(1))
153155
if step < global_step:
154-
ckpt_folders.append((step, folder))
156+
ckpt_global_steps.append(step)
155157

156-
ckpt_folders.sort(reverse=True)
157-
for _, folder in ckpt_folders[save_limit - 1 :]:
158-
folder_path = os.path.join(path, folder)
158+
ckpt_global_steps.sort(reverse=True)
159+
if best_global_step in ckpt_global_steps:
160+
ckpt_global_steps.remove(best_global_step)
161+
save_limit = max(save_limit - 1, 0)
162+
163+
for step in ckpt_global_steps[save_limit - 1 :]:
164+
folder_path = os.path.join(path, directory_format.format(step))
159165
shutil.rmtree(folder_path, ignore_errors=True)
160166
print(f"Removed obsolete checkpoint: {folder_path}")

verl/utils/checkpoint/fsdp_checkpoint_manager.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717

1818
import torch
1919
import torch.distributed as dist
20-
from torch.distributed.checkpoint.state_dict import StateDictOptions, get_state_dict, set_state_dict
20+
from torch.distributed.checkpoint.state_dict import (
21+
StateDictOptions,
22+
get_model_state_dict,
23+
get_state_dict,
24+
set_state_dict,
25+
)
2126
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
2227
from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin
2328

@@ -77,27 +82,32 @@ def load_checkpoint(self, path: Optional[str] = None):
7782
if "rng" in extra_state_dict:
7883
self.load_rng_state(extra_state_dict["rng"])
7984

80-
def save_checkpoint(self, path: str):
85+
def save_checkpoint(self, path: str, save_model_only: bool = False):
8186
path = self.local_mkdir(path)
8287
dist.barrier()
8388

8489
# every rank will save its own model and optim shard
85-
state_dict_options = StateDictOptions(cpu_offload=True)
86-
model_state_dict, optim_state_dict = get_state_dict(self.model, self.optimizer, options=state_dict_options)
87-
extra_state_dict = {
88-
"lr_scheduler": self.lr_scheduler.state_dict(),
89-
"rng": self.get_rng_state(),
90-
}
9190
model_path = os.path.join(path, f"model_world_size_{self.world_size}_rank_{self.rank}.pt")
9291
optim_path = os.path.join(path, f"optim_world_size_{self.world_size}_rank_{self.rank}.pt")
9392
extra_path = os.path.join(path, f"extra_state_world_size_{self.world_size}_rank_{self.rank}.pt")
9493

95-
print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.")
96-
print(f"[rank-{self.rank}]: Saving optimizer to {os.path.abspath(optim_path)}.")
97-
print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.")
98-
torch.save(model_state_dict, model_path)
99-
torch.save(optim_state_dict, optim_path)
100-
torch.save(extra_state_dict, extra_path)
94+
state_dict_options = StateDictOptions(cpu_offload=True)
95+
if save_model_only:
96+
model_state_dict = get_model_state_dict(self.model, options=state_dict_options)
97+
print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.")
98+
torch.save(model_state_dict, model_path)
99+
else:
100+
model_state_dict, optim_state_dict = get_state_dict(self.model, self.optimizer, options=state_dict_options)
101+
extra_state_dict = {
102+
"lr_scheduler": self.lr_scheduler.state_dict(),
103+
"rng": self.get_rng_state(),
104+
}
105+
print(f"[rank-{self.rank}]: Saving model to {os.path.abspath(model_path)}.")
106+
print(f"[rank-{self.rank}]: Saving optimizer to {os.path.abspath(optim_path)}.")
107+
print(f"[rank-{self.rank}]: Saving extra_state to {os.path.abspath(extra_path)}.")
108+
torch.save(model_state_dict, model_path)
109+
torch.save(optim_state_dict, optim_path)
110+
torch.save(extra_state_dict, extra_path)
101111

102112
# wait for everyone to dump to local
103113
dist.barrier()

verl/workers/fsdp_workers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,12 +410,12 @@ def init_model(self):
410410
)
411411

412412
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
413-
def save_checkpoint(self, path: str):
413+
def save_checkpoint(self, path: str, save_model_only: bool = False):
414414
assert self._has_actor or self._has_critic
415415
if self._use_param_offload:
416416
load_fsdp_model(self.fsdp_module)
417417

418-
self.checkpoint_manager.save_checkpoint(path)
418+
self.checkpoint_manager.save_checkpoint(path, save_model_only)
419419
dist.barrier()
420420
if self._use_param_offload:
421421
offload_fsdp_model(self.fsdp_module)

0 commit comments

Comments
 (0)