@@ -520,13 +520,13 @@ def update_actor(self, data: DataProto):
520520 metrics ["actor/lr" ] = lr
521521 self .lr_scheduler .step ()
522522
523- # Metrics should be in non_tensor_batch instead of meta_info, as DataProto not concat meta_info.
523+ # Metrics should be in non_tensor_batch instead of meta_info, as DataProto not concat meta_info
524524 output = DataProto (
525525 non_tensor_batch = {
526526 key : np .array ([value ] if np .isscalar (value ) else value ) for key , value in metrics .items ()
527527 }
528528 )
529- output = self . ulysses_sharding_manager . postprocess_data ( data = output )
529+ # Metrics do not need post processing since their batch size is 1
530530
531531 if self ._use_param_offload :
532532 offload_fsdp_model (self .fsdp_module )
@@ -677,13 +677,13 @@ def update_critic(self, data: DataProto):
677677 lr = self .lr_scheduler .get_last_lr ()[0 ]
678678 metrics ["critic/lr" ] = lr
679679
680- # Metrics should be in non_tensor_batch instead of meta_info, as DataProto not concat meta_info.
680+ # Metrics should be in non_tensor_batch instead of meta_info, as DataProto not concat meta_info
681681 output = DataProto (
682682 non_tensor_batch = {
683- metric : np .array ([value ] if np .isscalar (value ) else value ) for metric , value in metrics .items ()
683+ key : np .array ([value ] if np .isscalar (value ) else value ) for key , value in metrics .items ()
684684 }
685685 )
686- data = self . ulysses_sharding_manager . postprocess_data ( data = output )
686+ # Metrics do not need post processing since their batch size is 1
687687
688688 if self ._use_param_offload :
689689 offload_fsdp_model (self .fsdp_module )
0 commit comments