Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@ def update_policy(self, data: DataProto):
# See PPO paper for details. https://arxiv.org/abs/1707.06347
mini_batches = data.split(self.config.ppo_mini_batch_size)

on_policy = len(mini_batches) == 1 and self.config.ppo_epochs == 1

metrics = {}
for _ in range(self.config.ppo_epochs):
for batch_idx, mini_batch in enumerate(mini_batches):
Expand All @@ -405,7 +407,6 @@ def update_policy(self, data: DataProto):
micro_batch_metrics = {}
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
response_mask = model_inputs["response_mask"]
old_log_prob = model_inputs["old_log_probs"]
advantages = model_inputs["advantages"]

entropy_coeff = self.config.entropy_coeff
Expand Down Expand Up @@ -438,6 +439,11 @@ def update_policy(self, data: DataProto):
config=self.config,
)

if on_policy:
old_log_prob = log_prob.detach()
else:
old_log_prob = model_inputs["old_log_probs"]

if entropy_coeff != 0:
entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)

Expand Down