Skip to content

Commit a4a4128

Browse files
authored
[worker] fix fsdp worker (#422)
1 parent b01da77 commit a4a4128

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

verl/workers/fsdp_workers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -520,13 +520,13 @@ def update_actor(self, data: DataProto):
520520
metrics["actor/lr"] = lr
521521
self.lr_scheduler.step()
522522

523-
# Metrics should be in non_tensor_batch instead of meta_info, as DataProto not concat meta_info.
523+
# Metrics should be in non_tensor_batch instead of meta_info, as DataProto not concat meta_info
524524
output = DataProto(
525525
non_tensor_batch={
526526
key: np.array([value] if np.isscalar(value) else value) for key, value in metrics.items()
527527
}
528528
)
529-
output = self.ulysses_sharding_manager.postprocess_data(data=output)
529+
# Metrics do not need post processing since their batch size is 1
530530

531531
if self._use_param_offload:
532532
offload_fsdp_model(self.fsdp_module)
@@ -677,13 +677,13 @@ def update_critic(self, data: DataProto):
677677
lr = self.lr_scheduler.get_last_lr()[0]
678678
metrics["critic/lr"] = lr
679679

680-
# Metrics should be in non_tensor_batch instead of meta_info, as DataProto not concat meta_info.
680+
# Metrics should be in non_tensor_batch instead of meta_info, as DataProto not concat meta_info
681681
output = DataProto(
682682
non_tensor_batch={
683-
metric: np.array([value] if np.isscalar(value) else value) for metric, value in metrics.items()
683+
key: np.array([value] if np.isscalar(value) else value) for key, value in metrics.items()
684684
}
685685
)
686-
data = self.ulysses_sharding_manager.postprocess_data(data=output)
686+
# Metrics do not need post processing since their batch size is 1
687687

688688
if self._use_param_offload:
689689
offload_fsdp_model(self.fsdp_module)

0 commit comments

Comments
 (0)