Skip to content

Commit c604ea9

Browse files
committed
follow gemini
1 parent 1f97e1d commit c604ea9

File tree

1 file changed

+55
-60
lines changed

1 file changed

+55
-60
lines changed

verl/workers/actor/megatron_actor.py

Lines changed: 55 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -181,72 +181,70 @@ def compute_log_prob(self, data: DataProto, calculate_entropy=False) -> torch.Te
181181
# We make recompute_old_log_prob by default here.
182182
# TODO (zhangchi.usc1992): actually, this function should only return log_prob and this logic should be
183183
# handled by user outside
184-
recompute_old_log_prob = self.config.get("recompute_old_log_prob", True)
185-
186184
entropys = torch.Tensor()
187-
if recompute_old_log_prob:
188-
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
189-
batch = data.select(batch_keys=select_keys).batch
190-
input_ids = batch["input_ids"]
191-
batch_size = input_ids.size(0)
192-
response = batch["responses"]
193-
response_length = response.size(1)
194-
with torch.no_grad():
195-
output = self.forward_backward_batch(
196-
data,
197-
forward_only=True,
198-
calculate_entropy=calculate_entropy,
199-
use_dynamic_bsz=use_dynamic_bsz,
200-
micro_batch_size=micro_batch_size,
201-
max_token_len=max_token_len,
202-
)
203-
if mpu.is_pipeline_last_stage(ignore_virtual=True):
204-
# only on last rank. It should be on every tp rank
205-
log_probs = [o["log_probs"] for o in output["output"]] # (bs, seq_size)
206-
log_probs = torch.cat(log_probs, dim=0).to(torch.float32)
207185

186+
select_keys = ["responses", "input_ids", "attention_mask", "position_ids"]
187+
batch = data.select(batch_keys=select_keys).batch
188+
input_ids = batch["input_ids"]
189+
batch_size = input_ids.size(0)
190+
response = batch["responses"]
191+
response_length = response.size(1)
192+
with torch.no_grad():
193+
output = self.forward_backward_batch(
194+
data,
195+
forward_only=True,
196+
calculate_entropy=calculate_entropy,
197+
use_dynamic_bsz=use_dynamic_bsz,
198+
micro_batch_size=micro_batch_size,
199+
max_token_len=max_token_len,
200+
)
201+
if mpu.is_pipeline_last_stage(ignore_virtual=True):
202+
# only on last rank. It should be on every tp rank
203+
log_probs = [o["log_probs"] for o in output["output"]] # (bs, seq_size)
204+
log_probs = torch.cat(log_probs, dim=0).to(torch.float32)
205+
206+
if calculate_entropy:
207+
entropys = torch.cat([o["entropy"] for o in output["output"]], dim=0)
208+
entropys = entropys.to(torch.float32)
209+
210+
if use_dynamic_bsz:
211+
indices = output["indices"]
212+
indices = list(itertools.chain.from_iterable(indices))
213+
assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}"
214+
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
215+
log_probs = log_probs[revert_indices]
208216
if calculate_entropy:
209-
entropys = torch.cat([o["entropy"] for o in output["output"]], dim=0)
210-
entropys = entropys.to(torch.float32)
211-
212-
if use_dynamic_bsz:
213-
indices = output["indices"]
214-
indices = list(itertools.chain.from_iterable(indices))
215-
assert len(indices) == log_probs.size(0), f"{len(indices)} vs. {log_probs.size()}"
216-
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
217-
log_probs = log_probs[revert_indices]
218-
if calculate_entropy:
219-
assert len(indices) == entropys.size(0), f"{len(indices)} vs. {entropys.size()}"
220-
entropys = entropys[revert_indices]
221-
else:
222-
# other pp ranks
223-
log_probs = torch.empty(
217+
assert len(indices) == entropys.size(0), f"{len(indices)} vs. {entropys.size()}"
218+
entropys = entropys[revert_indices]
219+
else:
220+
# other pp ranks
221+
log_probs = torch.empty(
222+
size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device
223+
)
224+
if calculate_entropy:
225+
entropys = torch.empty(
224226
size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device
225227
)
226-
if calculate_entropy:
227-
entropys = torch.empty(
228-
size=(batch_size, response_length), dtype=torch.float32, device=input_ids.device
229-
)
230228

231-
log_probs = log_probs.to(get_device_id())
232-
# broadcast across pp ranks
229+
log_probs = log_probs.to(get_device_id())
230+
# broadcast across pp ranks
231+
torch.distributed.broadcast(
232+
tensor=log_probs,
233+
src=mpu.get_pipeline_model_parallel_last_rank(),
234+
group=mpu.get_pipeline_model_parallel_group(),
235+
async_op=False,
236+
)
237+
log_probs = log_probs.to("cpu")
238+
239+
if calculate_entropy:
240+
entropys = entropys.to(get_device_id())
233241
torch.distributed.broadcast(
234-
tensor=log_probs,
242+
tensor=entropys,
235243
src=mpu.get_pipeline_model_parallel_last_rank(),
236244
group=mpu.get_pipeline_model_parallel_group(),
237245
async_op=False,
238246
)
239-
log_probs = log_probs.to("cpu")
240-
241-
if calculate_entropy:
242-
entropys = entropys.to(get_device_id())
243-
torch.distributed.broadcast(
244-
tensor=entropys,
245-
src=mpu.get_pipeline_model_parallel_last_rank(),
246-
group=mpu.get_pipeline_model_parallel_group(),
247-
async_op=False,
248-
)
249-
entropys = entropys.to("cpu")
247+
entropys = entropys.to("cpu")
250248

251249
# add empty cache after each compute
252250
get_torch_device().empty_cache()
@@ -309,13 +307,10 @@ def compute_ppo_loss(self, model_output, data):
309307
metrics = {}
310308

311309
response_mask = data["response_mask"].to(bool)
312-
loss_agg_mode = self.config.loss_agg_mode
313-
314310
# compute policy loss
315311
old_log_prob = data["old_log_probs"]
316312
advantages = data["advantages"]
317313

318-
entropy_coeff = self.config.entropy_coeff
319314
loss_agg_mode = self.config.loss_agg_mode
320315

321316
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
@@ -344,7 +339,7 @@ def compute_ppo_loss(self, model_output, data):
344339
if entropy is not None:
345340
entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
346341
entropy_coeff = self.config.entropy_coeff
347-
policy_loss = pg_loss - entropy_coeff * entropy_loss
342+
policy_loss -= entropy_coeff * entropy_loss
348343

349344
# add kl loss
350345
if self.config.use_kl_loss:
@@ -353,7 +348,7 @@ def compute_ppo_loss(self, model_output, data):
353348
kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type)
354349
kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode)
355350

356-
policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
351+
policy_loss += kl_loss * self.config.kl_loss_coef
357352
metrics["actor/kl_loss"] = kl_loss.detach().item()
358353
metrics["actor/kl_coef"] = self.config.kl_loss_coef
359354

0 commit comments

Comments
 (0)