Skip to content

Commit ef7f8f5

Browse files
committed
fix(rollout_corr): prevent silent failure when rollout_rs enabled without threshold
Fixes critical bug where setting rollout_rs="token" but rollout_rs_threshold=None would silently skip rejection sampling instead of raising an error. Changes: - Raise ValueError if rollout_rs is set but rollout_rs_threshold is explicitly None - Update docs to clarify rollout_rs_threshold is required when rollout_rs is enabled - Default value (2.0) still works when parameter is omitted entirely
1 parent 36e0511 commit ef7f8f5

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

docs/advance/rollout_corr.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ algorithm:
6161
rollout_is: token # IS weights: "token", "sequence", or null
6262
rollout_is_threshold: 2.0 # Upper threshold for IS weights
6363
rollout_rs: null # Rejection sampling: "token", "sequence", "geometric", or null
64-
rollout_rs_threshold: null # RS upper threshold (uses rollout_is_threshold if null)
64+
rollout_rs_threshold: null # RS upper threshold (required if rollout_rs is enabled)
6565
rollout_rs_threshold_lower: null # RS lower threshold (auto-reciprocal if null)
6666
rollout_token_veto_threshold: null # Per-token veto threshold (null = disabled)
6767

verl/trainer/ppo/mismatch_helper.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -550,8 +550,8 @@ def compute_rollout_correction_and_rejection_mask(
550550
default 2.0.
551551
rollout_rs: Rejection sampling aggregation level (see compute_rollout_rejection_mask for options).
552552
Set to None to disable rejection sampling.
553-
rollout_rs_threshold: Upper threshold for rejection sampling (used if rollout_rs is set),
554-
default 2.0.
553+
rollout_rs_threshold: Upper threshold for rejection sampling. Required if rollout_rs is enabled.
554+
Default 2.0.
555555
rollout_rs_threshold_lower: Lower threshold for rejection sampling (used if rollout_rs is set).
556556
Defaults to 1/rollout_rs_threshold if None.
557557
rollout_token_veto_threshold: Minimum allowed token-level IS weight. Sequences containing
@@ -599,7 +599,12 @@ def compute_rollout_correction_and_rejection_mask(
599599

600600
# Step 3: Compute rejection mask (if enabled)
601601
modified_response_mask: torch.Tensor = response_mask.clone()
602-
if rollout_rs is not None and rollout_rs_threshold is not None:
602+
if rollout_rs is not None:
603+
if rollout_rs_threshold is None:
604+
raise ValueError(
605+
"rollout_rs_threshold must be explicitly provided when rollout_rs is enabled. "
606+
"Set rollout_rs_threshold to the desired threshold value."
607+
)
603608
modified_response_mask, rs_metrics = compute_rollout_rejection_mask(
604609
log_ratio=log_ratio,
605610
response_mask=response_mask,

0 commit comments

Comments
 (0)