|
16 | 16 | This trainer supports model-agonistic model initialization with huggingface |
17 | 17 | """ |
18 | 18 |
|
| 19 | +import json |
19 | 20 | import os |
20 | 21 | import uuid |
21 | 22 | from collections import defaultdict |
@@ -176,6 +177,10 @@ def __init__( |
176 | 177 | self.reward_fn = reward_fn |
177 | 178 | self.val_reward_fn = val_reward_fn |
178 | 179 |
|
| 180 | + self.val_reward_score = 0.0 |
| 181 | + self.best_val_reward_score = -1.0 |
| 182 | + self.best_global_step = None |
| 183 | + |
179 | 184 | self.hybrid_engine = config.worker.hybrid_engine |
180 | 185 | self.role_worker_mapping = role_worker_mapping |
181 | 186 | self.resource_pool_manager = resource_pool_manager |
@@ -258,6 +263,7 @@ def _validate(self) -> Dict[str, Any]: |
258 | 263 | # Lists to collect samples for the table |
259 | 264 | sample_inputs, sample_outputs, sample_labels, sample_scores = [], [], [], [] |
260 | 265 | reward_metrics_lst = defaultdict(list) |
| 266 | + print("Start validation...") |
261 | 267 | for batch_dict in self.val_dataloader: |
262 | 268 | test_batch = DataProto.from_single_dict(batch_dict) |
263 | 269 | # Store original inputs |
@@ -295,9 +301,10 @@ def _validate(self) -> Dict[str, Any]: |
295 | 301 | reward_metrics_lst[key].extend(value) |
296 | 302 |
|
297 | 303 | self._maybe_log_val_generations(sample_inputs, sample_outputs, sample_labels, sample_scores) |
298 | | - reward_score = torch.cat(reward_tensor_lst, dim=0).sum(-1).mean().item() |
| 304 | + self.val_reward_score = torch.cat(reward_tensor_lst, dim=0).sum(-1).mean().item() |
299 | 305 | val_reward_metrics = {f"val/{key}_reward": value for key, value in reduce_metrics(reward_metrics_lst).items()} |
300 | | - return {"val/reward_score": reward_score, **val_reward_metrics} |
| 306 | + print("Finish validation.") |
| 307 | + return {"val/reward_score": self.val_reward_score, **val_reward_metrics} |
301 | 308 |
|
302 | 309 | def init_workers(self) -> None: |
303 | 310 | """Init resource pool and worker group""" |
@@ -359,24 +366,37 @@ def init_workers(self) -> None: |
359 | 366 |
|
360 | 367 | def _save_checkpoint(self) -> None: |
361 | 368 | # path: {save_checkpoint_path}/global_step_{global_step}/{actor,critic} |
| 369 | + if self.val_reward_score > self.best_val_reward_score: |
| 370 | + self.best_val_reward_score = self.val_reward_score |
| 371 | + self.best_global_step = self.global_step |
| 372 | + |
362 | 373 | remove_obsolete_ckpt( |
363 | | - self.config.trainer.save_checkpoint_path, self.global_step, self.config.trainer.save_limit |
| 374 | + self.config.trainer.save_checkpoint_path, |
| 375 | + self.global_step, |
| 376 | + self.best_global_step, |
| 377 | + self.config.trainer.save_limit, |
364 | 378 | ) |
365 | 379 | folder_path = os.path.join(self.config.trainer.save_checkpoint_path, f"global_step_{self.global_step}") |
366 | 380 | actor_path = os.path.join(folder_path, "actor") |
367 | 381 | self.actor_rollout_ref_wg.save_checkpoint(actor_path) |
368 | 382 |
|
369 | 383 | if self.use_critic: |
370 | 384 | critic_path = os.path.join(folder_path, "critic") |
371 | | - self.critic_wg.save_checkpoint(critic_path) |
| 385 | + self.critic_wg.save_checkpoint(critic_path, save_model_only=self.config.trainer.save_model_only) |
372 | 386 |
|
373 | 387 | dataloader_path = os.path.join(folder_path, "dataloader.pt") |
374 | 388 | dataloader_state_dict = self.train_dataloader.state_dict() |
375 | 389 | torch.save(dataloader_state_dict, dataloader_path) |
376 | 390 |
|
377 | | - last_global_step_path = os.path.join(self.config.trainer.save_checkpoint_path, CHECKPOINT_TRACKER) |
378 | | - with open(last_global_step_path, "w") as f: |
379 | | - f.write(str(self.global_step)) |
| 391 | + checkpointer_tracker_info = { |
| 392 | + "best_global_step": self.best_global_step, |
| 393 | + "best_val_reward_score": round(self.best_val_reward_score, 2), |
| 394 | + "last_global_step": self.global_step, |
| 395 | + "last_actor_path": os.path.abspath(actor_path), |
| 396 | + } |
| 397 | + checkpointer_tracker_path = os.path.join(self.config.trainer.save_checkpoint_path, CHECKPOINT_TRACKER) |
| 398 | + with open(checkpointer_tracker_path, "w") as f: |
| 399 | + json.dump(checkpointer_tracker_info, f, ensure_ascii=False, indent=2) |
380 | 400 |
|
381 | 401 | def _load_checkpoint(self) -> None: |
382 | 402 | if self.config.trainer.load_checkpoint_path is None: |
|
0 commit comments