Skip to content

Commit 5d378b5

Browse files
authored
[rollout] refactor: rename "clip" mode back to "mask" mode (#3750)
# Rollout Importance Sampling Framework related to #3694 ## Summary This PR introduces a comprehensive **Rollout Importance Sampling (IS)** framework to correct distribution mismatch between data-collecting (rollout) and training policies, a critical factor for ensuring stable and efficient model training in RL fine-tuning. This work is motivated by the analysis in our blog post, [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda). If you find this implementation useful in your research, please consider citing: ```bibtex @misc{liu-li-2025, title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch}, url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda}, author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen}, year = {2025}, month = {September}, } ``` --- ## Problem Statement When using different policies for rollout generation (e.g., vLLM with BFloat16) and training (e.g., FSDP with FP32), distribution mismatch occurs, leading to: - Biased gradient estimates - Training instability and collapse - Reduced sample efficiency - Poor convergence properties This framework addresses these issues through principled importance sampling correction. --- ## Key Features & Improvements ### 1. **Flexible Aggregation Levels** Three methods for calculating IS weights: - **`token`**: Per-token importance ratios - **`sequence`**: Product of per-token ratios - **`geometric`**: Geometric mean of ratios ### 2. **Advanced Bounding Modes** Two strategies to control weight variance: - **`truncate`** (TIS): Caps weights at upper threshold only, preserving gradients - **`mask`** (MIS): Zeros out weights outside bounds, more aggressive filtering ### 3. **Comprehensive Diagnostics** Detailed metrics to monitor distribution mismatch and training health: **Rollout IS Metrics** (automatically prefixed with `mismatch/`): - Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean` - Distribution statistics: `rollout_is_p25`, `rollout_is_p50`, `rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`, `rollout_is_min`, `rollout_is_std` - Diagnostics: `rollout_is_veto_fraction`, `rollout_is_catastrophic_token_fraction`, `rollout_is_masked_fraction` (mask mode) - Sequence-level statistics (for sequence/geometric modes): `rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`, `rollout_is_seq_min`, etc. **Mismatch Metrics** (computed efficiently within IS weight computation): - KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3 estimator for stability) - Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`, `mismatch_ppl_ratio` - Log perplexity statistics: `mismatch_log_ppl_diff`, `mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`, `mismatch_log_ppl_diff_min` ### 4. **Outlier Mitigation** - **Veto mechanism**: Automatically discards samples with catastrophic importance weights (per-token ratios below threshold) - Prevents gradient corruption from extreme outliers - Configurable threshold (default: 1e-4) ### 5. **Numerical Stability** - All core computations in **log-space** to prevent underflow/overflow - Carefully designed clamping and bounding to maintain numerical precision - Safe handling of edge cases (zero probabilities, extreme ratios) ### 6. **Memory Efficiency** - Optimized computation to minimize CUDA memory usage - Efficient metric aggregation without large intermediate tensors - Suitable for large-scale distributed training ### 7. **Metrics-Only Mode** - Compute and monitor mismatch metrics **without** applying IS weights - Useful for: - Understanding distribution mismatch before intervention - Deciding whether IS correction is needed - A/B testing IS impact - Controlled by `algorithm.rollout_is` flag (independent of weight computation) ### 8. **Universal PPO Support** - Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov, KL-Cov, geo_mean - Consistent interface across different policy loss functions - Automatic weight application when enabled --- ## API and Configuration Changes ### Migration from Legacy TIS #### ❌ **Before (REMOVED)** ```yaml # Old TIS configuration - NO LONGER SUPPORTED actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # Removed from actor config ``` The legacy implementation: - Only supported token-level truncation - No metrics tracking - Lacked numerical stability - Limited configurability #### ✅ **After (New Framework)** Configuration moved to `algorithm` section for better organization: ```yaml algorithm: # Main on/off switch: null = disabled, float = enabled rollout_is_threshold: 2.0 # Control weight application (independent of metrics computation) rollout_is: true # true = apply weights, false = metrics only # Optional: lower threshold (defaults to 1/upper if null) rollout_is_threshold_lower: null # Aggregation level: "token", "sequence", or "geometric" rollout_is_level: token # Bounding mode: "truncate" or "mask" rollout_is_mode: truncate # Veto threshold for catastrophic outliers (null = disabled) rollout_is_veto_threshold: 1e-4 # REQUIRED: Enable log probability calculation actor_rollout_ref: rollout: calculate_log_probs: true ``` ### Configuration Examples **1. Token-level truncation (recommended starting point)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate ``` **2. Sequence-level masking (more aggressive)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: sequence rollout_is_mode: mask ``` **3. Metrics-only mode (monitoring without correction)** ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: false # Compute metrics but don't apply weights rollout_is_level: token rollout_is_mode: truncate ``` **Example script:** `bash examples/rollout_importance_sampling/run_with_rollout_is.sh` --- ## Code Changes Overview ### New Files (4 files, 1,442 lines) 1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines) - Core implementation of IS weight computation - Three aggregation levels: token, sequence, geometric - Two bounding modes: truncate, mask - Veto mechanism for outlier detection - Comprehensive metrics computation (IS + mismatch) - All computations in log-space for numerical stability - Memory-efficient design 2. **`docs/advance/rollout_is_migration.md`** (642 lines) - Comprehensive migration guide from legacy TIS - Detailed explanation of all configuration options - Recommended threshold ranges for each aggregation level - Troubleshooting guide and best practices - Metrics interpretation guide 3. **`examples/rollout_importance_sampling/README.md`** (242 lines) - Quick start guide with working examples - Configuration templates for common scenarios - Threshold tuning guidelines - Metrics monitoring instructions 4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99 lines) - Complete working example script - Demonstrates token-level and sequence-level configurations - Ready to run with minimal modifications ### Modified Core Files (9 files) 1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed) - Removed legacy TIS logic (`tis_imp_ratio_cap`) - Added `rollout_is_weights` parameter to all policy loss functions - Unified IS weight application interface across all PPO variants: - `compute_policy_loss_vanilla` - `compute_policy_loss_gspo` - `compute_policy_loss_gpg` - `compute_policy_loss_clip_cov` - `compute_policy_loss_kl_cov` - `compute_policy_loss_geo_mean` - Special handling for `geo_mean` (sequence-level aggregation) 2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added) - New method: `compute_rollout_importance_weights_and_add_to_batch()` - Centralized IS computation (once per batch, on driver) - Conditional weight distribution to workers based on `algorithm.rollout_is` - Metrics collection and aggregation - Integration with existing training loop 3. **`verl/trainer/config/algorithm.py`** (+18 lines) - Added 6 new Rollout IS parameters: - `rollout_is_threshold` (main on/off switch) - `rollout_is` (weight application control) - `rollout_is_threshold_lower` - `rollout_is_level` - `rollout_is_mode` - `rollout_is_veto_threshold` - Comprehensive docstrings explaining each parameter 4. **`verl/workers/config/actor.py`** (-1 line) - Removed deprecated `tis_imp_ratio_cap` parameter 5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed) - Updated to use new `rollout_is_weights` parameter - Removed legacy TIS logic 7. **Configuration Files** (4 files updated) - `verl/trainer/config/ppo_trainer.yaml` - `verl/trainer/config/ppo_megatron_trainer.yaml` - `verl/trainer/config/_generated_ppo_trainer.yaml` - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml` - Added default Rollout IS configuration section with explanatory comments ### Testing (2 files, 530 lines) 1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines) - Unit tests for `mismatch_helper.py` - Coverage for all aggregation levels (token, sequence, geometric) - Coverage for all bounding modes (truncate, mask) - Veto mechanism tests - Edge case handling (zeros, extremes, empty sequences) - Numerical stability verification - Metrics correctness validation 2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines) - Integration tests with PPO training loop - End-to-end workflow validation - Batch processing tests - Configuration validation - Metrics collection verification - Compatibility with distributed training ### Updated Recipes (2 files) 1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines) - Updated imports to use new framework 2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed) - Migrated from legacy TIS to new Rollout IS configuration - Updated documentation and comments ### Documentation Updates (2 files) 1. **`docs/examples/config.rst`** (~22 lines changed) - Updated configuration examples - Added Rollout IS section 2. **`docs/index.rst`** (+1 line) - Added link to Rollout IS migration guide --- ## Implementation Highlights ### Centralized Architecture The new design follows a clean separation of concerns: ``` ray_trainer.py (driver) └─> compute_rollout_importance_weights_and_add_to_batch() └─> mismatch_helper.compute_rollout_importance_weights() ├─> Computes IS weights (token/sequence/geometric) ├─> Applies bounding (truncate/mask) ├─> Veto mechanism for outliers ├─> Computes IS metrics └─> Computes mismatch metrics (KL, PPL) └─> Conditionally adds weights to batch (if rollout_is=True) └─> Distributes batch to workers actor workers (dp_actor, megatron_actor) └─> Receive batch with rollout_is_weights (if enabled) └─> Pass weights to policy loss function core_algos.py └─> All policy loss functions accept rollout_is_weights └─> Apply weights if provided: pg_losses *= rollout_is_weights ``` ### Key Design Decisions 1. **Centralized Computation**: IS weights computed once on driver, not per worker - Reduces redundant computation - Ensures consistency across workers - Simplifies debugging and metrics collection 2. **Configuration in Algorithm**: Moved from actor config to algorithm config - Better conceptual organization (algorithm-level concern, not worker-level) - Easier to manage and validate - Consistent with other algorithm parameters 3. **Two-Level Control**: - `rollout_is_threshold`: Enables/disables entire system (null = off) - `rollout_is`: Controls weight application (true = apply, false = metrics only) - Allows flexible monitoring and gradual rollout 4. **Metrics Consolidation**: Mismatch metrics computed within IS weight computation - Eliminates duplicate computation - Reduces memory overhead - Maintains metric accuracy 5. **Universal PPO Support**: Single interface for all PPO variants - Minimal code changes required - Consistent behavior across algorithms - Easy to add new variants --- ## Migration Guide ### For Users of Legacy TIS **Step 1: Update your configuration file** ```yaml # OLD (remove this) actor_rollout_ref: actor: tis_imp_ratio_cap: 2.0 # NEW (add this) algorithm: rollout_is_threshold: 2.0 # Use same value as old tis_imp_ratio_cap rollout_is: true rollout_is_level: token rollout_is_mode: truncate # REQUIRED (add if not present) actor_rollout_ref: rollout: calculate_log_probs: true ``` **Step 2: Monitor metrics** The first time you run with the new configuration, check these metrics: - `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size - `mismatch/rollout_is_veto_fraction`: Should be < 5% - `mismatch/rollout_is_mean`: Should be close to 1.0 **Step 3: Tune if needed** If effective sample size is too low: - Increase `rollout_is_threshold` - Try `rollout_is_mode: mask` with appropriate lower bound - Consider `rollout_is_level: sequence` for more aggressive correction For detailed guidance, see `docs/advance/rollout_is_migration.md`. ### For New Users Start with recommended defaults: ```yaml algorithm: rollout_is_threshold: 2.0 rollout_is: true rollout_is_level: token rollout_is_mode: truncate actor_rollout_ref: rollout: calculate_log_probs: true ``` Run the example script to see it in action: ```bash bash examples/rollout_importance_sampling/run_with_rollout_is.sh ``` --- ## Testing ### Unit Tests - **289 lines** of comprehensive unit tests in `test_rollout_is.py` - Covers all aggregation levels, bounding modes, and edge cases - Validates numerical stability and correctness - Fast execution (~1-2 seconds) ### Integration Tests - **241 lines** of integration tests in `test_rollout_is_integration.py` - End-to-end workflow with PPO training loop - Distributed training compatibility - Metrics collection validation - Moderate execution time (~10-20 seconds) ### Running Tests ```bash # Run all Rollout IS tests pytest tests/trainer/ppo/test_rollout_is.py -v pytest tests/trainer/ppo/test_rollout_is_integration.py -v # Run specific test pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v ``` --- ## Metrics Reference ### Rollout IS Metrics (all prefixed with `mismatch/`) | Metric | Description | Ideal Range | |--------|-------------|-------------| | `rollout_is_eff_sample_size` | Effective number of samples after IS | > 80% of batch | | `rollout_is_mean` | Mean IS weight | ~1.0 | | `rollout_is_std` | Standard deviation of IS weights | Low variance | | `rollout_is_p25` | 25th percentile | ~0.8-1.0 | | `rollout_is_p50` | Median IS weight | ~1.0 | | `rollout_is_p75` | 75th percentile | ~1.0-1.2 | | `rollout_is_p95` | 95th percentile | < threshold | | `rollout_is_p99` | 99th percentile | < threshold | | `rollout_is_max` | Maximum weight | ≤ threshold | | `rollout_is_min` | Minimum weight | ≥ lower threshold (mask mode) | | `rollout_is_veto_fraction` | % sequences vetoed | < 5% | | `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | < 1% | | `rollout_is_masked_fraction` | % tokens masked (mask mode) | Variable | ### Mismatch Metrics (all prefixed with `mismatch/`) | Metric | Description | What It Means | |--------|-------------|---------------| | `mismatch_kl` | Forward KL divergence | Distribution difference (rollout vs training) | | `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small divergences | | `mismatch_training_ppl` | Training policy perplexity | Prediction difficulty of training policy | | `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction difficulty of rollout policy | | `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative prediction difficulty | | `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level PPL mismatch | | `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude of mismatch | | `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case mismatch | | `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case mismatch | | `mismatch_training_log_ppl` | Log of training PPL | Log-scale training perplexity | | `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout perplexity | --- ## Performance Impact ### Memory - Minimal overhead: ~1-2% increase in peak memory usage - Efficient log-space computation - No large intermediate tensors ### Computation - Negligible impact on training speed: < 1% overhead - Centralized computation on driver (no per-worker redundancy) - Optimized tensor operations ### Training Stability - Significant improvement in stability when distribution mismatch exists - Faster convergence in many scenarios - Reduced risk of training collapse --- ## Breaking Changes > [!IMPORTANT] > This PR contains **BREAKING CHANGES** to the configuration API. ### Removed - `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported ### Migration Required All users of the legacy TIS implementation must update their configuration files. See the migration guide above or `docs/advance/rollout_is_migration.md` for detailed instructions. ### Backward Compatibility - No backward compatibility with legacy TIS - Configuration files with `tis_imp_ratio_cap` will raise validation errors - Affected recipes have been updated in this PR --- ## Pre-Submission Checklist - [x] Search for similar PRs: [https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling) - [x] Format PR title as `[{modules}] {type}: {description}` (checked by CI) - **Suggested title:** `[BREAKING][rollout, trainer, algo] feat: implement comprehensive Rollout Importance Sampling framework` - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md) - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting) - [x] Add/update [documentation](https://github.com/volcengine/verl/tree/main/docs) (3 new docs, 2 updated) - [x] Add unit and integration tests (530 lines of tests) - [x] Once PR is ready for CI, send message in `ci-request` channel --- ## References - **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda) - **Migration guide:** `docs/advance/rollout_is_migration.md` - **Examples:** `examples/rollout_importance_sampling/` - **Tests:** `tests/trainer/ppo/test_rollout_is*.py`
1 parent 21271aa commit 5d378b5

File tree

10 files changed

+45
-45
lines changed

10 files changed

+45
-45
lines changed

docs/advance/rollout_is_migration.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ actor_rollout_ref:
5555
5656
The new implementation:
5757
- ✅ Three aggregation levels: token, sequence, geometric
58-
- ✅ Two bounding modes: truncate, clip
58+
- ✅ Two bounding modes: truncate, mask
5959
- ✅ Dual threshold support (upper/lower)
6060
- ✅ Veto mechanism for catastrophic outliers
6161
- ✅ 30+ comprehensive metrics
@@ -150,7 +150,7 @@ Aggregation level for IS weights:
150150
### `algorithm.rollout_is_mode` (str)
151151
Bounding mode:
152152
- `"truncate"`: Cap weights at upper threshold only
153-
- `"clip"`: Zero out weights outside [lower, upper]
153+
- `"mask"`: Zero out weights outside [lower, upper]
154154

155155
### `algorithm.rollout_is_veto_threshold` (float)
156156
Per-token veto threshold. If any token ratio < this, entire sequence is rejected.
@@ -199,7 +199,7 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear
199199
- **`rollout_is_min`**: Minimum IS weight observed
200200
- Shows the most underweighted token/sequence
201201

202-
- **`rollout_is_max`**: Maximum IS weight observed (before clipping)
202+
- **`rollout_is_max`**: Maximum IS weight observed (before truncation/masking)
203203
- Shows the most overweighted token/sequence
204204
- Compare with `rollout_is_threshold` to see truncation impact
205205

@@ -235,11 +235,11 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear
235235
#### **Threshold Exceedance Metrics**
236236

237237
- **`rollout_is_ratio_fraction_high`**: Fraction of weights exceeding upper threshold
238-
- Shows how often truncation/clipping occurs on high end
238+
- Shows how often truncation/masking occurs on high end
239239
- **Ideal value**: < 0.1 (most weights within bounds)
240240

241241
- **`rollout_is_ratio_fraction_low`**: Fraction of weights below lower threshold
242-
- Shows how often clipping occurs on low end (clip mode only)
242+
- Shows how often masking occurs on low end (mask mode only)
243243
- **Ideal value**: < 0.1
244244

245245
#### **Sequence-Level Metrics** (for sequence/geometric modes)
@@ -261,14 +261,14 @@ All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appear
261261

262262
- **`rollout_is_seq_fraction_low`**: Fraction of sequences below lower threshold
263263

264-
#### **Clipping Metrics** (clip mode only)
264+
#### **Masking Metrics** (mask mode only)
265265

266-
- **`rollout_is_clipped_fraction`**: Fraction of tokens clipped (set to zero)
266+
- **`rollout_is_masked_fraction`**: Fraction of tokens masked (set to zero)
267267
- **Ideal value**: < 0.1
268268
- **Warning**: > 0.3 means losing too much data
269269
270-
- **`rollout_is_seq_clipped_fraction`**: Fraction of sequences with at least one clipped token
271-
- Shows sequence-level impact of clipping
270+
- **`rollout_is_seq_masked_fraction`**: Fraction of sequences with at least one masked token
271+
- Shows sequence-level impact of masking
272272

273273
#### **Distribution Mismatch Metrics** (Training vs Rollout Policy)
274274

@@ -456,14 +456,14 @@ algorithm:
456456
rollout_is_mode: truncate
457457
```
458458

459-
### Example 3: Geometric Mean with Clip
459+
### Example 3: Geometric Mean with Mask
460460
```yaml
461461
algorithm:
462462
rollout_is_threshold: 1.0002
463463
rollout_is: true
464464
rollout_is_threshold_lower: 0.9998
465465
rollout_is_level: geometric
466-
rollout_is_mode: clip
466+
rollout_is_mode: mask
467467
```
468468

469469
### Example 4: Asymmetric Thresholds
@@ -473,7 +473,7 @@ algorithm:
473473
rollout_is: true
474474
rollout_is_threshold_lower: 0.8
475475
rollout_is_level: token
476-
rollout_is_mode: clip
476+
rollout_is_mode: mask
477477
```
478478

479479
## Troubleshooting

docs/examples/config.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ Actor/Rollout/Reference Policy
123123
rollout_is_threshold: null # Upper threshold for IS weights (null to disable)
124124
rollout_is_threshold_lower: null # Lower threshold (null = auto 1/upper)
125125
rollout_is_level: token # Aggregation: token/sequence/geometric
126-
rollout_is_mode: truncate # Bounding: truncate/clip
126+
rollout_is_mode: truncate # Bounding: truncate/mask
127127
rollout_is_veto_threshold: 1e-4 # Catastrophic outlier threshold
128128
use_torch_compile: True # False to disable torch compile
129129
kl_loss_coef: 0.001 # for grpo
@@ -527,7 +527,7 @@ Algorithm
527527
- ``rollout_is_threshold``: Upper threshold for IS weights. Set to ``null`` to disable IS completely.
528528
- ``rollout_is_threshold_lower``: Lower threshold for IS weights. If ``null``, defaults to reciprocal of upper (1/upper).
529529
- ``rollout_is_level``: Aggregation level: ``token`` (biased), ``sequence`` (unbiased), or ``geometric`` (experimental).
530-
- ``rollout_is_mode``: Bounding mode: ``truncate`` (cap upper only) or ``clip`` (zero outside bounds).
530+
- ``rollout_is_mode``: Bounding mode: ``truncate`` (cap upper only) or ``mask`` (zero outside bounds).
531531
- ``rollout_is_veto_threshold``: Per-token veto threshold for catastrophic outliers. Default is 1e-4.
532532
Note: Rollout IS requires setting ``actor_rollout_ref.rollout.calculate_log_probs=True``.
533533

examples/rollout_importance_sampling/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,15 @@ algorithm:
8686
rollout_is_mode: truncate
8787
```
8888
89-
### Example 3: Geometric Mean with Clip
89+
### Example 3: Geometric Mean with Mask
9090
9191
```yaml
9292
algorithm:
9393
rollout_is_threshold: 1.0002
9494
rollout_is: true
9595
rollout_is_threshold_lower: 0.9998
9696
rollout_is_level: geometric
97-
rollout_is_mode: clip
97+
rollout_is_mode: mask
9898
rollout_is_veto_threshold: 1e-4
9999
```
100100
@@ -118,7 +118,7 @@ algorithm:
118118
rollout_is: true
119119
rollout_is_threshold_lower: 0.8
120120
rollout_is_level: token
121-
rollout_is_mode: clip
121+
rollout_is_mode: mask
122122
```
123123
124124
## Monitoring Metrics
@@ -183,9 +183,9 @@ These metrics help diagnose the distribution mismatch between rollout and traini
183183
2. Verify rollout_log_probs are correctly passed
184184
3. Check for systematic bias in rollout vs training
185185

186-
### Issue: Too Much Data Discarded (Clip Mode)
186+
### Issue: Too Much Data Discarded (Mask Mode)
187187

188-
**Symptoms**: `rollout_is_clipped_fraction` > 0.5
188+
**Symptoms**: `rollout_is_masked_fraction` > 0.5
189189

190190
**Solutions**:
191191
1. Widen thresholds

examples/rollout_importance_sampling/run_with_rollout_is.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ rollout_is_threshold_lower=null
2121
# Aggregation level: token | sequence | geometric (experimental)
2222
rollout_is_level=token
2323

24-
# Bounding mode: truncate (cap upper) | clip (zero outside bounds)
24+
# Bounding mode: truncate (cap upper) | mask (zero outside bounds)
2525
rollout_is_mode=truncate
2626

2727
# Catastrophic outlier veto threshold

tests/trainer/ppo/test_rollout_is.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,14 @@ def test_basic_rollout_is():
9797
rollout_log_prob=rollout_log_prob,
9898
response_mask=eos_mask,
9999
rollout_is_level="geometric",
100-
rollout_is_mode="clip",
100+
rollout_is_mode="mask",
101101
rollout_is_threshold=1.5,
102102
rollout_is_threshold_lower=0.5,
103103
rollout_is_veto_threshold=1e-4,
104104
)
105105

106106
print(f" Mean weight: {metrics_geo['mismatch/rollout_is_mean']:.4f}")
107-
print(f" Clipped fraction: {metrics_geo['mismatch/rollout_is_clipped_fraction']:.4f}")
107+
print(f" Masked fraction: {metrics_geo['mismatch/rollout_is_masked_fraction']:.4f}")
108108
print(" ✓ Geometric mean mode passed")
109109

110110
# Test veto mechanism

tests/trainer/ppo/test_rollout_is_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ def test_all_aggregation_levels(self, sample_data):
132132
assert "mismatch/rollout_is_mean" in metrics
133133

134134
def test_both_bounding_modes(self, sample_data):
135-
"""Test both truncate and clip modes."""
136-
modes = ["truncate", "clip"]
135+
"""Test both truncate and mask modes."""
136+
modes = ["truncate", "mask"]
137137

138138
for mode in modes:
139139
_, metrics = compute_rollout_importance_weights(

verl/trainer/config/algorithm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class AlgoConfig(BaseConfig):
7777
float value = enabled (compute weights and metrics). This is the main on/off switch.
7878
rollout_is_threshold_lower (Optional[float]): Lower threshold for IS weights. If None, defaults to 1/upper.
7979
rollout_is_level (str): Aggregation level: "token", "sequence", or "geometric".
80-
rollout_is_mode (str): Bounding mode: "truncate" (cap upper only) or "clip" (zero outside bounds).
80+
rollout_is_mode (str): Bounding mode: "truncate" (cap upper only) or "mask" (zero outside bounds).
8181
rollout_is_veto_threshold (float): Per-token veto threshold for catastrophic outliers.
8282
rollout_is (bool): Whether to apply IS weights to policy loss. True = apply weights,
8383
False = compute metrics only (useful for monitoring before enabling correction). Default: False.

verl/trainer/config/ppo_megatron_trainer.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ algorithm:
8484
# Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental)
8585
rollout_is_level: token
8686

87-
# Bounding mode: "truncate" (cap upper only), "clip" (zero outside bounds)
87+
# Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds)
8888
rollout_is_mode: truncate
8989

9090
# Per-token veto threshold for catastrophic outliers

verl/trainer/config/ppo_trainer.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ algorithm:
124124
# Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental)
125125
rollout_is_level: token
126126

127-
# Bounding mode: "truncate" (cap upper only), "clip" (zero outside bounds)
127+
# Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds)
128128
rollout_is_mode: truncate
129129

130130
# Per-token veto threshold for catastrophic outliers

verl/trainer/ppo/mismatch_helper.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
2121
Key Features:
2222
1. Three aggregation levels: token, sequence, geometric
23-
2. Two handling modes: truncate (TIS), clip (CIS)
23+
2. Two handling modes: truncate (TIS), mask (MIS)
2424
3. Per-token veto mechanism for catastrophic outliers
2525
4. Memory-efficient computation to prevent CUDA OOM
2626
5. Comprehensive metrics tracking
@@ -77,9 +77,9 @@ def compute_rollout_importance_weights(
7777
- "geometric": Geometric mean of ratios (experimental)
7878
rollout_is_mode: How to handle weights exceeding threshold:
7979
- "truncate": Cap weights at upper_threshold only (TIS)
80-
- "clip": Zero out weights outside [lower_threshold, upper_threshold] (CIS)
80+
- "mask": Zero out weights outside [lower_threshold, upper_threshold] (MIS)
8181
rollout_is_threshold: Upper threshold for IS weights
82-
rollout_is_threshold_lower: Lower threshold for IS weights (clip mode only; if None, defaults to 1/upper)
82+
rollout_is_threshold_lower: Lower threshold for IS weights (mask mode only; if None, defaults to 1/upper)
8383
rollout_is_veto_threshold: Per-token veto threshold. If any token ratio < this, zero entire sequence.
8484
If None, veto mechanism is disabled.
8585
@@ -179,32 +179,32 @@ def compute_rollout_importance_weights(
179179
SAFETY_BOUND=SAFETY_BOUND,
180180
)
181181

182-
# Step 3: Apply truncation or clipping based on mode
182+
# Step 3: Apply truncation or masking based on mode
183183
if rollout_is_mode == "truncate":
184184
# Truncated IS (TIS): only cap upper bound to prevent overweighting
185185
rollout_is_weights = rollout_is_weights.clamp(max=upper_threshold)
186186

187-
elif rollout_is_mode == "clip":
188-
# Clipped IS (CIS): zero out weights outside [lower_threshold, upper_threshold]
189-
clip_mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold)
190-
clip_mask = clip_mask.float()
187+
elif rollout_is_mode == "mask":
188+
# Masked IS (MIS): zero out weights outside [lower_threshold, upper_threshold]
189+
mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold)
190+
mask = mask.float()
191191

192-
# Track CIS-specific metrics
193-
metrics["rollout_is_clipped_fraction"] = verl_F.masked_mean(1 - clip_mask, response_mask)
192+
# Track MIS-specific metrics
193+
metrics["rollout_is_masked_fraction"] = verl_F.masked_mean(1 - mask, response_mask)
194194

195-
# Sequence-level clipping fraction
195+
# Sequence-level masking fraction
196196
if rollout_is_level in ["sequence", "geometric"]:
197-
# All tokens in a sequence have the same weight, so reuse clip_mask
198-
metrics["rollout_is_seq_clipped_fraction"] = (1 - clip_mask[:, 0]).mean()
197+
# All tokens in a sequence have the same weight, so reuse mask
198+
metrics["rollout_is_seq_masked_fraction"] = (1 - mask[:, 0]).mean()
199199
else:
200-
# Check if any token in each sequence is clipped
201-
seq_has_clipped = verl_F.masked_sum(1 - clip_mask, response_mask, axis=-1) > 0
202-
metrics["rollout_is_seq_clipped_fraction"] = seq_has_clipped.float().mean()
200+
# Check if any token in each sequence is masked
201+
seq_has_masked = verl_F.masked_sum(1 - mask, response_mask, axis=-1) > 0
202+
metrics["rollout_is_seq_masked_fraction"] = seq_has_masked.float().mean()
203203

204-
rollout_is_weights = rollout_is_weights * clip_mask
204+
rollout_is_weights = rollout_is_weights * mask
205205

206206
else:
207-
raise ValueError(f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate' or 'clip'.")
207+
raise ValueError(f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate' or 'mask'.")
208208

209209
# Apply veto mask AFTER all thresholding
210210
# This zeros out entire sequences that have any catastrophic token

0 commit comments

Comments
 (0)