Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions verl/workers/actor/megatron_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,13 @@ def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
]
if self.config.use_kl_loss:
select_keys.append("ref_log_prob")
if self.config.tis_imp_ratio_cap > 0:
assert "rollout_log_probs" in data.batch.keys(), (
"Truncated Importance Sampling (TIS) requires to configure "
"`actor_rollout_ref.rollout.calculate_log_probs=True` "
"and is not currently supported in Server mode (agent loop)."
)
Comment on lines +292 to +296
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The code data.batch.keys() could raise an AttributeError if data.batch is None, as the DataProto class definition allows batch to be None. To prevent a potential crash, you should first check if data.batch is not None before accessing its keys. Using the in operator directly on the TensorDict is also more idiomatic.

Suggested change
assert "rollout_log_probs" in data.batch.keys(), (
"Truncated Importance Sampling (TIS) requires to configure "
"`actor_rollout_ref.rollout.calculate_log_probs=True` "
"and is not currently supported in Server mode (agent loop)."
)
assert data.batch is not None and "rollout_log_probs" in data.batch, (
"Truncated Importance Sampling (TIS) requires to configure "
"`actor_rollout_ref.rollout.calculate_log_probs=True` "
"and is not currently supported in Server mode (agent loop)."
)

select_keys.append("rollout_log_probs")
self.has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
if self.has_multi_modal_inputs:
data = data.select(select_keys, ["multi_modal_inputs"])
Expand All @@ -309,6 +316,7 @@ def compute_ppo_loss(self, model_output, data):
response_mask = data["response_mask"].to(bool)
# compute policy loss
old_log_prob = data["old_log_probs"]
rollout_log_probs = data["rollout_log_probs"] if self.config.tis_imp_ratio_cap > 0 else None
advantages = data["advantages"]

loss_agg_mode = self.config.loss_agg_mode
Expand All @@ -323,6 +331,7 @@ def compute_ppo_loss(self, model_output, data):
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
rollout_log_probs=rollout_log_probs,
)

metrics.update(
Expand Down
Loading