@@ -181,72 +181,70 @@ def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Te
181181 # We make recompute_old_log_prob by default here.
182182 # TODO (zhangchi.usc1992): actually, this function should only return log_prob and this logic should be
183183 # handled by user outside
184- recompute_old_log_prob = self .config .get ("recompute_old_log_prob" , True )
185-
186184 entropys = torch .Tensor ()
187- if recompute_old_log_prob :
188- select_keys = ["responses" , "input_ids" , "attention_mask" , "position_ids" ]
189- batch = data .select (batch_keys = select_keys ).batch
190- input_ids = batch ["input_ids" ]
191- batch_size = input_ids .size (0 )
192- response = batch ["responses" ]
193- response_length = response .size (1 )
194- with torch .no_grad ():
195- output = self .forward_backward_batch (
196- data ,
197- forward_only = True ,
198- calculate_entropy = calculate_entropy ,
199- use_dynamic_bsz = use_dynamic_bsz ,
200- micro_batch_size = micro_batch_size ,
201- max_token_len = max_token_len ,
202- )
203- if mpu .is_pipeline_last_stage (ignore_virtual = True ):
204- # only on last rank. It should be on every tp rank
205- log_probs = [o ["log_probs" ] for o in output ["output" ]] # (bs, seq_size)
206- log_probs = torch .cat (log_probs , dim = 0 ).to (torch .float32 )
207185
186+ select_keys = ["responses" , "input_ids" , "attention_mask" , "position_ids" ]
187+ batch = data .select (batch_keys = select_keys ).batch
188+ input_ids = batch ["input_ids" ]
189+ batch_size = input_ids .size (0 )
190+ response = batch ["responses" ]
191+ response_length = response .size (1 )
192+ with torch .no_grad ():
193+ output = self .forward_backward_batch (
194+ data ,
195+ forward_only = True ,
196+ calculate_entropy = calculate_entropy ,
197+ use_dynamic_bsz = use_dynamic_bsz ,
198+ micro_batch_size = micro_batch_size ,
199+ max_token_len = max_token_len ,
200+ )
201+ if mpu .is_pipeline_last_stage (ignore_virtual = True ):
202+ # only on last rank. It should be on every tp rank
203+ log_probs = [o ["log_probs" ] for o in output ["output" ]] # (bs, seq_size)
204+ log_probs = torch .cat (log_probs , dim = 0 ).to (torch .float32 )
205+
206+ if calculate_entropy :
207+ entropys = torch .cat ([o ["entropy" ] for o in output ["output" ]], dim = 0 )
208+ entropys = entropys .to (torch .float32 )
209+
210+ if use_dynamic_bsz :
211+ indices = output ["indices" ]
212+ indices = list (itertools .chain .from_iterable (indices ))
213+ assert len (indices ) == log_probs .size (0 ), f"{ len (indices )} vs. { log_probs .size ()} "
214+ revert_indices = torch .tensor (get_reverse_idx (indices ), dtype = torch .long )
215+ log_probs = log_probs [revert_indices ]
208216 if calculate_entropy :
209- entropys = torch .cat ([o ["entropy" ] for o in output ["output" ]], dim = 0 )
210- entropys = entropys .to (torch .float32 )
211-
212- if use_dynamic_bsz :
213- indices = output ["indices" ]
214- indices = list (itertools .chain .from_iterable (indices ))
215- assert len (indices ) == log_probs .size (0 ), f"{ len (indices )} vs. { log_probs .size ()} "
216- revert_indices = torch .tensor (get_reverse_idx (indices ), dtype = torch .long )
217- log_probs = log_probs [revert_indices ]
218- if calculate_entropy :
219- assert len (indices ) == entropys .size (0 ), f"{ len (indices )} vs. { entropys .size ()} "
220- entropys = entropys [revert_indices ]
221- else :
222- # other pp ranks
223- log_probs = torch .empty (
217+ assert len (indices ) == entropys .size (0 ), f"{ len (indices )} vs. { entropys .size ()} "
218+ entropys = entropys [revert_indices ]
219+ else :
220+ # other pp ranks
221+ log_probs = torch .empty (
222+ size = (batch_size , response_length ), dtype = torch .float32 , device = input_ids .device
223+ )
224+ if calculate_entropy :
225+ entropys = torch .empty (
224226 size = (batch_size , response_length ), dtype = torch .float32 , device = input_ids .device
225227 )
226- if calculate_entropy :
227- entropys = torch .empty (
228- size = (batch_size , response_length ), dtype = torch .float32 , device = input_ids .device
229- )
230228
231- log_probs = log_probs .to (get_device_id ())
232- # broadcast across pp ranks
229+ log_probs = log_probs .to (get_device_id ())
230+ # broadcast across pp ranks
231+ torch .distributed .broadcast (
232+ tensor = log_probs ,
233+ src = mpu .get_pipeline_model_parallel_last_rank (),
234+ group = mpu .get_pipeline_model_parallel_group (),
235+ async_op = False ,
236+ )
237+ log_probs = log_probs .to ("cpu" )
238+
239+ if calculate_entropy :
240+ entropys = entropys .to (get_device_id ())
233241 torch .distributed .broadcast (
234- tensor = log_probs ,
242+ tensor = entropys ,
235243 src = mpu .get_pipeline_model_parallel_last_rank (),
236244 group = mpu .get_pipeline_model_parallel_group (),
237245 async_op = False ,
238246 )
239- log_probs = log_probs .to ("cpu" )
240-
241- if calculate_entropy :
242- entropys = entropys .to (get_device_id ())
243- torch .distributed .broadcast (
244- tensor = entropys ,
245- src = mpu .get_pipeline_model_parallel_last_rank (),
246- group = mpu .get_pipeline_model_parallel_group (),
247- async_op = False ,
248- )
249- entropys = entropys .to ("cpu" )
247+ entropys = entropys .to ("cpu" )
250248
251249 # add empty cache after each compute
252250 get_torch_device ().empty_cache ()
@@ -309,13 +307,10 @@ def compute_ppo_loss(self, model_output, data):
309307 metrics = {}
310308
311309 response_mask = data ["response_mask" ].to (bool )
312- loss_agg_mode = self .config .loss_agg_mode
313-
314310 # compute policy loss
315311 old_log_prob = data ["old_log_probs" ]
316312 advantages = data ["advantages" ]
317313
318- entropy_coeff = self .config .entropy_coeff
319314 loss_agg_mode = self .config .loss_agg_mode
320315
321316 loss_mode = self .config .policy_loss .get ("loss_mode" , "vanilla" )
@@ -344,7 +339,7 @@ def compute_ppo_loss(self, model_output, data):
344339 if entropy is not None :
345340 entropy_loss = agg_loss (loss_mat = entropy , loss_mask = response_mask , loss_agg_mode = loss_agg_mode )
346341 entropy_coeff = self .config .entropy_coeff
347- policy_loss = pg_loss - entropy_coeff * entropy_loss
342+ policy_loss -= entropy_coeff * entropy_loss
348343
349344 # add kl loss
350345 if self .config .use_kl_loss :
@@ -353,7 +348,7 @@ def compute_ppo_loss(self, model_output, data):
353348 kld = kl_penalty (logprob = log_prob , ref_logprob = ref_log_prob , kl_penalty = self .config .kl_loss_type )
354349 kl_loss = agg_loss (loss_mat = kld , loss_mask = response_mask , loss_agg_mode = self .config .loss_agg_mode )
355350
356- policy_loss = policy_loss + kl_loss * self .config .kl_loss_coef
351+ policy_loss += kl_loss * self .config .kl_loss_coef
357352 metrics ["actor/kl_loss" ] = kl_loss .detach ().item ()
358353 metrics ["actor/kl_coef" ] = self .config .kl_loss_coef
359354
0 commit comments