diff --git a/torchtitan/experiments/deterministic_vllm_rl/README.md b/torchtitan/experiments/deterministic_vllm_rl/README.md index d2ef719c0d..c44bb45cfb 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/README.md +++ b/torchtitan/experiments/deterministic_vllm_rl/README.md @@ -1,262 +1,12 @@ -# Deterministic RL Training with vLLM +# Deterministic vLLM RL Training -This experiment combines vLLM's deterministic kernels with PyTorch autograd to enable reinforcement learning training where forward passes produce bitwise-identical results across runs. +This package provides two approaches for integrating TorchTitan models with vLLM: -## Overview +1. vllm_compat/ - vLLM-Compatible approach + - Separate model definition matching vLLM's weight format + - Support batch-invariant and bit-wise identity between train and inference + - Custom backward passes for attention gradient computation -RL training requires both fast inference for generating rollouts and gradient computation for policy updates. vLLM provides deterministic forward passes but does not support gradients. This experiment adds backward passes to vLLM's operations. - -The implementation: -1. Uses vLLM's batch-invariant kernels for forward passes -2. Implements custom backward passes for gradient computation -3. Provides weight conversion utilities between TorchTitan and vLLM formats - -### Features - -- Bitwise determinism: Same inputs produce identical outputs across runs -- Gradient support: Backward passes through vLLM operations -- Weight conversion: Utilities to convert between model formats - -Note: Currently supports single-device training only. - -## Architecture - -### Components - -1. `models/attention.py`: VLLMCompatibleFlashAttention - - Uses vLLM's Flash Attention for forward pass - - Implements custom backward pass for gradient computation - - Uses `num_splits=1` for deterministic behavior - -2. `models/qwen3/model_vllm_compat.py`: Qwen3VLLMCompatModel - - Qwen3 model with merged gate/up projections matching vLLM format - - Uses VLLMRMSNorm with gradient support - -3. `batch_invariant_backward.py`: Backward passes for vLLM operations - - Registers gradients for vLLM's batch-invariant operations - - Supports matmul, linear, and RMSNorm - - Patches Flash Attention for autograd - -4. `weights_vllm_compat.py`: Weight conversion utilities - - Converts between TorchTitan format (separate w1, w2, w3) and vLLM format (merged gate_up_proj) - - Provides bidirectional conversion functions - -5. `simple_rl.py`: RL training loop - - Generates rollouts using vLLM engine - - Computes advantages using GRPO-style ranking - - Updates policy using PPO - -## Installation - -### Prerequisites - -```bash -# Install vLLM with deterministic support -pip install vllm - -# Install TorchTitan (from the repository root) -pip install -e . - -# Install additional dependencies -pip install transformers safetensors huggingface_hub tensorboard -``` - -### Enable Batch Invariance - -Initialize vLLM's batch-invariant mode before training: - -```python -from vllm.model_executor.layers.batch_invariant import init_batch_invariance -init_batch_invariance() -``` - -## Usage - -### Quick Start - -```python -import torch -from vllm.model_executor.layers.batch_invariant import init_batch_invariance -from torchtitan.experiments.deterministic_vllm_rl import ( - enable_batch_invariant_backward_mode, - Qwen3VLLMCompatModel, -) - -# 1. Enable deterministic mode -init_batch_invariance() -enable_batch_invariant_backward_mode() - -# 2. Load model -from torchtitan.models.qwen3.model.args import Qwen3ModelArgs -model_args = Qwen3ModelArgs( - dim=2048, - n_layers=24, - n_heads=16, - n_kv_heads=2, - vocab_size=151936, -) -model = Qwen3VLLMCompatModel(model_args) - -# 3. Forward pass (deterministic) -input_ids = torch.randint(0, 151936, (2, 128), device='cuda') -logits = model(input_ids) - -# 4. Backward pass -loss = logits.sum() -loss.backward() -``` - -### Full RL Training - -Run the RL training loop: - -```bash -VLLM_BATCH_INVARIANT=1 VLLM_FLASH_ATTN_VERSION=3 python -m torchtitan.experiments.deterministic_vllm_rl.simple_rl -``` - -This will: -1. Download Qwen3-1.7B from HuggingFace -2. Initialize vLLM engine for rollouts -3. Generate samples for training prompts -4. Compute rewards and advantages -5. Update the policy using PPO -6. Log metrics to TensorBoard - -View training progress: -```bash -tensorboard --logdir=./outputs/rl_training -``` - -## How It Works - -### Deterministic Forward Pass - -vLLM's batch-invariant mode makes operations deterministic: - -```python -# These operations are deterministic when batch_invariance is enabled -y = torch.matmul(a, b) # Uses vLLM's deterministic matmul -output = flash_attn_varlen_func(q, k, v, num_splits=1) # Deterministic FA -``` - -### Backward Pass with Gradients - -Custom backward passes: -1. Re-compute attention weights deterministically -2. Use standard chain rule for gradients -3. Apply gradients through vLLM's deterministic operations - -```python -class FlashAttnWithBackward(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, ...): - # Use vLLM's forward implementation - return flash_attn_varlen_func(q, k, v, num_splits=1, ...) - - @staticmethod - def backward(ctx, grad_output): - # Compute gradients deterministically - # (re-compute attention weights and apply chain rule) - return grad_q, grad_k, grad_v, ... -``` - -### Bitwise Determinism Verification - -The training loop compares logprobs from vLLM and TorchTitan: - -```python -# During training, compare logprobs -vllm_logprobs = [from vLLM rollout] -titan_logprobs = [from TorchTitan forward pass] - -assert torch.equal(vllm_logprobs, titan_logprobs) -``` - -## Testing - -Run the test suite: - -```bash -cd torchtitan/experiments/deterministic_vllm_rl/tests - -# Test backward passes -python test_batch_invariant_backward.py - -# Test determinism -python test_exact_determinism.py -``` - -## Technical Details - -### Why Determinism Matters for RL - -RL training steps: -1. Generate rollouts by sampling from the policy -2. Compute rewards based on the samples -3. Update the policy using gradients - -If the forward pass during training differs from the forward pass during rollout, policy gradients may be incorrect. This matters for algorithms like PPO that compare old and new policy probabilities. - -This implementation uses the same kernels for both rollouts (vLLM) and training (TorchTitan) to ensure `logprobs_rollout == logprobs_training` bitwise. - -### Performance - -- Rollout speed: Uses vLLM's optimized kernels -- Training speed: Similar to standard TorchTitan -- Memory: Saves activations for custom backward passes - -### Limitations - -1. Custom backward requires uniform sequence lengths -2. Only causal attention is supported -3. Requires NVIDIA GPUs with Flash Attention support - -## Project Structure - -``` -deterministic_vllm_rl/ -├── README.md # Documentation -├── __init__.py # Package initialization -├── batch_invariant_backward.py # Backward passes for vLLM ops -├── weights_vllm_compat.py # Weight conversion utilities -├── simple_rl.py # RL training loop -├── models/ -│ ├── __init__.py -│ ├── attention.py # VLLMCompatibleFlashAttention -│ └── qwen3/ -│ ├── __init__.py -│ └── model_vllm_compat.py # vLLM-compatible Qwen3 model -├── weights/ -│ ├── __init__.py -│ ├── converter.py # Weight conversion script -│ └── README.md # Weight conversion documentation -└── tests/ - ├── __init__.py - ├── test_batch_invariant_backward.py # Test backward passes - └── test_exact_determinism.py # Test determinism -``` - -## TODO - -- `FlashAttnWithBackward` will need to become more composable and should not live exclusively within this directory. -- vLLM integration will need to become more generic with a provided Attention operator that is KV-cache compatible. -- vLLM parallelism will need to add generic parallelism initialization to support Monarch managed TP/DP. - -## Contributing - -This experiment is part of TorchTitan. To contribute: - -1. Test your changes with `pytest tests/` -2. Verify bitwise determinism is maintained -3. Update this README if adding new features - -## References - -- [vLLM Documentation](https://docs.vllm.ai/) -- [Flash Attention Paper](https://arxiv.org/abs/2205.14135) -- [PPO Algorithm](https://arxiv.org/abs/1707.06347) -- [GRPO: Group Relative Policy Optimization](https://arxiv.org/abs/2402.03300) - -## License - -This code is licensed under the BSD-style license found in the LICENSE file in the TorchTitan repository root directory. +2. unified/ - Unified approach + - Uses canonical TorchTitan model definition for inference directly + - Replaces attention with vLLM Compatible attention for inference diff --git a/torchtitan/experiments/deterministic_vllm_rl/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/__init__.py index 067555251f..66c1de78a5 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/__init__.py +++ b/torchtitan/experiments/deterministic_vllm_rl/__init__.py @@ -4,31 +4,30 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -""" -Deterministic RL training with vLLM experiment. -This experiment provides tools for bitwise-deterministic reinforcement learning -training using vLLM for fast rollouts and TorchTitan for training. - -Key components: -- VLLMCompatibleFlashAttention: Flash attention with custom backward pass -- Qwen3VLLMCompatModel: vLLM-compatible model with merged projections -- batch_invariant_backward: Gradient support for vLLM's deterministic operations -- simple_rl: End-to-end RL training loop -""" - -from .batch_invariant_backward import ( +from .unified import ( + create_parallel_dims_from_vllm_config, + register_torchtitan_model_from_train_spec, + TorchTitanVLLMModelWrapper, +) +from .vllm_compat import ( enable_batch_invariant_backward_mode, + Qwen3VLLMCompatModel, rms_norm_with_gradients, silu_and_mul_with_gradients, + VLLMCompatibleFlashAttention, ) -from .models import VLLMCompatibleFlashAttention -from .models.qwen3 import Qwen3VLLMCompatModel + __all__ = [ + # vllm_compat exports "VLLMCompatibleFlashAttention", "Qwen3VLLMCompatModel", "enable_batch_invariant_backward_mode", "rms_norm_with_gradients", "silu_and_mul_with_gradients", + # unified exports + "TorchTitanVLLMModelWrapper", + "create_parallel_dims_from_vllm_config", + "register_torchtitan_model_from_train_spec", ] diff --git a/torchtitan/experiments/deterministic_vllm_rl/unified/README.md b/torchtitan/experiments/deterministic_vllm_rl/unified/README.md new file mode 100644 index 0000000000..c30cf2241c --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/unified/README.md @@ -0,0 +1,67 @@ +# Run vLLM inference with TorchTitan Qwen3 Model + +This directory contains code to run a single canonical model definition (TorchTitan model definition) with vLLM inference engine (not batch-invariant yet, working in progress). +This work is inspired by https://github.com/vllm-project/vllm/pull/28685. + +## Overview +The integration consists of two main components: + +1. **Model Adapter** (`model/qwen3.py`): A custom model class that extends vLLM's `Qwen3ForCausalLM` to handle TorchTitan checkpoint naming conventions +2. **Inference Script** (`infer.py`): A simple script to register the model and run inference + + +## Quick Start +### Prerequisites + +1. Install PyTorch nightly for torchtitan: +``` +pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall +``` + + +2. Install vLLM from source [vllm-use-an-existing-pytorch-installation](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#use-an-existing-pytorch-installation): +```bash +# install PyTorch first, either from PyPI or from source +git clone https://github.com/vllm-project/vllm.git +cd vllm +python use_existing_torch.py +uv pip install -r requirements/build.txt +uv pip install --no-build-isolation -e . +``` + + +NOTE: If `flash_attn_varlen_func` hits error "torch.AcceleratorError: CUDA error: the provided PTX was compiled with an unsupported toolchain" during forward path, this is due to GPU driver version is not compatible with vLLM/PyTorch compiled version. Use the following command to recompile vLLM. + +``` +# Set CUDA version environment variable +export CUDA_HOME=/usr/local/cuda-12.4 +export PATH=/usr/local/cuda-12.4/bin:$PATH +export LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64:$LD_LIBRARY_PATH + +# Clean previous build +rm -rf build dist *.egg-info +uv pip uninstall -y vllm + +# Rebuild vLLM from source with CUDA 12.4 +pip install -e . + +``` + +3. Download Qwen3/Qwen3-0.6b checkpoint from HuggingFace and put into `example_checkpoint` folder. Make sure to change the "architecture" field in `config.json` to be `Qwen3TorchTitanForCausalLM` so vllm engine could use torchtitan model. + + +4. Run inference: +``` +python torchtitan/experiments/deterministic_vllm_rl/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B +``` + +Run with TP: (work in progress) +``` +python torchtitan/experiments/deterministic_vllm_rl/infer.py --model torchtitan/experiments/deterministic_vllm_rl/example_checkpoint/qwen3-0.6B --tensor-parallel-size 2 + +``` + +## TODO +1. Rewrite attention part to use vllm.Attention() with backward as the only attention path. +2. Integrate with simple_rl.py to run end-to-end RL with one canonical model definition. +3. Leverage batch-invariant kernels into model definition. diff --git a/torchtitan/experiments/deterministic_vllm_rl/unified/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/unified/__init__.py new file mode 100644 index 0000000000..5ecb952b67 --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/unified/__init__.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Unified approach for running TorchTitan models with vLLM inference. + +This module automatically registers TorchTitan models with vLLM when imported. +Uses the canonical TorchTitan model definition directly with vLLM inference engine. +""" + +from vllm.logger import init_logger + +from torchtitan.protocols.train_spec import get_train_spec, TrainSpec + +from .utils import create_parallel_dims_from_vllm_config +from .vllm_wrapper import TorchTitanVLLMModelWrapper + + +logger = init_logger(__name__) + + +def register_torchtitan_model_from_train_spec( + train_spec: TrainSpec, + model_name: str, +) -> None: + """ + Register a TorchTitan model with vLLM using a TrainSpec. + + Args: + train_spec: TorchTitan TrainSpec containing model components + model_name: Name to register in vLLM (e.g., "Qwen3TorchTitanForCausalLM") + + """ + from vllm.model_executor.models.registry import ModelRegistry + + # Extract model_args from TrainSpec + # TrainSpec has model_args as a Mapping, get the first value + if isinstance(train_spec.model_args, dict): + model_args_cls = type(next(iter(train_spec.model_args.values()))) + else: + model_args_cls = train_spec.model_args + + # Create dynamic model class directly from TrainSpec components + class TorchTitanVLLMModelFromSpec(TorchTitanVLLMModelWrapper): + """Dynamically created vLLM model from TrainSpec.""" + + def __init__(self, *, vllm_config, prefix=""): + super().__init__( + model_cls=train_spec.model_cls, + model_args_cls=model_args_cls, + state_dict_adapter=train_spec.state_dict_adapter, + parallelize_fn=train_spec.parallelize_fn, + vllm_config=vllm_config, + prefix=prefix, + ) + + # Set the class name + TorchTitanVLLMModelFromSpec.__name__ = model_name + TorchTitanVLLMModelFromSpec.__qualname__ = model_name + + # Register with vLLM + ModelRegistry.register_model(model_name, TorchTitanVLLMModelFromSpec) + + logger.info( + f"Successfully registered {model_name} with vLLM using TrainSpec " + f"(model_cls={train_spec.model_cls.__name__})" + ) + + +# Auto-register TorchTitan models with vLLM when this module is imported +register_torchtitan_model_from_train_spec( + train_spec=get_train_spec("qwen3"), + model_name="Qwen3TorchTitanForCausalLM", +) + + +__all__ = [ + "TorchTitanVLLMModelWrapper", + "create_parallel_dims_from_vllm_config", + "register_torchtitan_model_from_train_spec", +] diff --git a/torchtitan/experiments/deterministic_vllm_rl/unified/attention.py b/torchtitan/experiments/deterministic_vllm_rl/unified/attention.py new file mode 100644 index 0000000000..d96689383a --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/unified/attention.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from vllm.attention.layer import Attention + + +class VLLMAttention(torch.nn.Module): + """ + Wrapper around vLLM's Attention. Compatible with TorchTitan input shape. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + layer_name: str, + scale: float | None = None, + ) -> None: + super().__init__() + + self.hidden_size = hidden_size + self.layer_name = layer_name + + # Handle tensor parallelism: adjust num_heads and num_kv_heads for TP + # NOTE(jianiw): As we use local tensor for this region, we need to manually + + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + + if scale is None: + self.scale = head_dim**-0.5 + else: + self.scale = scale + + cache_config = ( + vllm_config.cache_config if hasattr(vllm_config, "cache_config") else None + ) + + self.vllm_attn = Attention( + num_heads=num_heads, + head_size=head_dim, + scale=self.scale, + num_kv_heads=num_kv_heads, + cache_config=cache_config, + quant_config=None, + prefix=f"model.layers.{layer_name}.attention.inner_attention", + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + scale: float | None = None, + ) -> torch.Tensor: + """ + Forward pass using vLLM's Attention layer for inference. + + Args: + q: Query tensor [batch, num_heads, seq_len, head_dim] + k: Key tensor [batch, num_kv_heads, seq_len, head_dim] + v: Value tensor [batch, num_kv_heads, seq_len, head_dim] + scale: Optional attention scale override (unused, vLLM uses internal scale) + + Returns: + output: [batch, num_heads, seq_len, head_dim] + """ + # Input is (batch, num_heads, seq_len, head_dim) + batch_size, num_heads, seq_len, head_dim = q.shape + + # Transpose to (batch, seq_len, num_heads, head_dim) for vLLM + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + output_varlen = self.vllm_attn(q, k, v) + + # Reshape back to batch format + output = output_varlen.view(batch_size, seq_len, num_heads, head_dim) + + # Transpose back to TorchTitan format: (batch, num_heads, seq_len, head_dim) + output = output.transpose(1, 2) + + return output diff --git a/torchtitan/experiments/deterministic_vllm_rl/unified/infer.py b/torchtitan/experiments/deterministic_vllm_rl/unified/infer.py new file mode 100755 index 0000000000..13217d8845 --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/unified/infer.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + +from vllm import LLM, SamplingParams +from vllm.logger import init_logger + +# Import unified module - this automatically registers TorchTitan models with vLLM +from torchtitan.experiments.deterministic_vllm_rl import unified # noqa: F401 + + +logger = init_logger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run TorchTitan model inference with vLLM Engine", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--model_ckpt_path", + type=str, + default="torchtitan/experiments/deterministic_vllm_rl/example_checkpoint", + help="Path to TorchTitan checkpoint directory", + ) + parser.add_argument( + "--prompt", + type=str, + default="Hello, my name is", + help="Prompt text for generation", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="Sampling temperature", + ) + parser.add_argument( + "--tensor-parallel-size", + type=int, + default=1, + help="Number of GPUs for tensor parallelism (default: 1 for single GPU)", + ) + return parser.parse_args() + + +def main(): + args = parse_args() + + logger.info("Initializing vLLM with TorchTitan model") + logger.info(f"Model: {args.model_ckpt_path}") + logger.info(f"Tensor Parallel Size: {args.tensor_parallel_size}") + + # Initialize vLLM with custom TorchTitan model + # The LLM initialization will internally: + # 1. Load TrainSpec for Qwen3 (from models/__init__.py register()) + # 2. Create TorchTitanVLLMModel instance + # 3. Create JobConfig and ParallelDims from vLLM config + # 4. Apply parallelization using parallelize_qwen3 + # 5. Load model weights and prepare for inference + logger.info("Creating vLLM LLM engine...") + + llm = LLM( + model=args.model_ckpt_path, # Model checkpoint path + hf_overrides={ + "checkpoint_dir": args.model_ckpt_path, + }, + dtype="bfloat16", + trust_remote_code=True, + enforce_eager=True, # Use eager mode + tensor_parallel_size=args.tensor_parallel_size, + ) + + logger.info("vLLM engine initialized successfully") + logger.info(f"Prompt: {args.prompt}") + + # Prepare prompt and sampling parameters + prompts = [args.prompt] + sampling_params = SamplingParams( + temperature=args.temperature, + top_p=0.95, + max_tokens=args.max_tokens, + ) + + # Generate text + logger.info("Generating text...") + outputs = llm.generate( + prompts=prompts, + sampling_params=sampling_params, + ) + + # Print results + logger.info("Generation complete") + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + + print(f"\nPrompt: {prompt}") + print(f"Generated text: {generated_text!r}\n") + + +if __name__ == "__main__": + main() diff --git a/torchtitan/experiments/deterministic_vllm_rl/unified/utils.py b/torchtitan/experiments/deterministic_vllm_rl/unified/utils.py new file mode 100644 index 0000000000..5da9435759 --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/unified/utils.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Parallelization utilities for vLLM + TorchTitan models. + +This module provides functions for setting up device mesh and applying +tensor parallelism to TorchTitan models in vLLM using TorchTitan's ParallelDims. +""" + +import torch.distributed as dist +from vllm.config import VllmConfig +from vllm.logger import init_logger + +from torchtitan.distributed.parallel_dims import ParallelDims + + +logger = init_logger(__name__) + + +def create_parallel_dims_from_vllm_config(vllm_config: VllmConfig) -> ParallelDims: + """ + Create ParallelDims from vLLM config. + + This function is needed because vLLM doesn't separate model creation and + parallelism application - it requires parallelization to be done inside + the model constructor. This creates a vLLM-compatible model from a + TorchTitan model definition. + + Maps vLLM parallelism settings to TorchTitan's ParallelDims dataclass. + + Args: + vllm_config: vLLM configuration object + + Returns: + ParallelDims object with parallelism settings validated + + Note: + vLLM doesn't use FSDP sharding (dp_shard=1) or expert parallelism (ep=1, etp=1) + in inference. These are set to default values. + """ + world_size = dist.get_world_size() + + # Map vLLM config to TorchTitan ParallelDims + parallel_dims = ParallelDims( + dp_replicate=vllm_config.parallel_config.data_parallel_size, + dp_shard=1, # vLLM doesn't use FSDP sharding + cp=vllm_config.parallel_config.decode_context_parallel_size, + tp=vllm_config.parallel_config.tensor_parallel_size, + pp=vllm_config.parallel_config.pipeline_parallel_size, + ep=1, # Expert parallelism not used in vLLM inference yet + etp=1, # Expert tensor parallelism not used in vLLM inference yet + world_size=world_size, + ) + + logger.info( + f"Created ParallelDims from vLLM config: " + f"DP={parallel_dims.dp_replicate}, TP={parallel_dims.tp}, " + f"CP={parallel_dims.cp}, PP={parallel_dims.pp}" + ) + + return parallel_dims diff --git a/torchtitan/experiments/deterministic_vllm_rl/unified/vllm_wrapper.py b/torchtitan/experiments/deterministic_vllm_rl/unified/vllm_wrapper.py new file mode 100644 index 0000000000..6ef8e7b352 --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/unified/vllm_wrapper.py @@ -0,0 +1,370 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Base wrapper for TorchTitan models to work with vLLM V1 engine. + +This module provides TorchTitanVLLMModel: Core model class that adapts +TorchTitan models for vLLM. +""" + +from functools import partial +from typing import Callable, TypeAlias + +import torch +import torch.nn as nn +from torch.distributed._tensor import DTensor, Replicate +from torch.distributed.checkpoint.state_dict import ( + set_model_state_dict, + StateDictOptions, +) + +from vllm.config import VllmConfig +from vllm.logger import init_logger + +from torchtitan.experiments.deterministic_vllm_rl.unified.attention import VLLMAttention +from torchtitan.models.qwen3.model.model import precompute_rope_cache +from torchtitan.protocols.model import BaseModelArgs, ModelProtocol +from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter + +from .utils import create_parallel_dims_from_vllm_config + + +logger = init_logger(__name__) + +ParallelizeFunction: TypeAlias = Callable[..., nn.Module] + + +class TorchTitanVLLMModelWrapper(nn.Module): + """ + Generic vLLM-compatible model wrapper for TorchTitan models. + + The wrapper handles: + - HF config to TorchTitan model args mapping + - Attention replacement with vLLM paged attention + - Tensor parallelism setup + - Weight loading from HF checkpoints + - vLLM forward/compute_logits interface + """ + + is_text_generation_model = True # Required for vLLM runner validation + supports_pp = False # Pipeline parallelism not supported yet + supports_multimodal = False + + def __init__( + self, + *, + model_cls: type[ModelProtocol], # passing types that is not instantiated + model_args_cls: type[BaseModelArgs], + state_dict_adapter: type[BaseStateDictAdapter], + parallelize_fn: ParallelizeFunction, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + + assert vllm_config is not None, "vllm_config is required" + + # Store components + self.model_cls = model_cls + self.model_args_cls = model_args_cls + self.state_dict_adapter = state_dict_adapter + self.parallelize_fn = parallelize_fn + + # Map HF config to TorchTitan ModelArgs + hf_config = vllm_config.model_config.hf_config + logger.info(f"Mapping HF config to {self.model_args_cls.__name__}") + model_args = self._map_hf_config_to_model_args(hf_config, self.model_args_cls) + + # Create TorchTitan model + logger.info(f"Creating {self.model_cls.__name__} with config: {model_args}") + self.model = self.model_cls(model_args) + self.config = model_args + + # Setup RoPE cache extension function if provided + self.rope_cache_extension_fn = partial( + precompute_rope_cache, + dim=self.config.head_dim, + base=self.config.rope_theta, + ) + # Replace attention with vLLM paged attention + self._replace_with_vllm_attention(model_args) + + # Create ParallelDims from vLLM config and apply parallelization + # NOTE: We need to apply parallelize within model.__init__ because w + parallel_dims = create_parallel_dims_from_vllm_config(vllm_config) + if parallel_dims.tp_enabled: + self.world_mesh = parallel_dims.world_mesh + tp_mesh = self.world_mesh["tp"] + parallelize_fn( + model=self.model, + tp_mesh=tp_mesh, + loss_parallel=False, + enable_float8_tensorwise_tp=False, + enable_async_tp=False, + ) + logger.info( + f"Successfully initialized model with with TP={parallel_dims.tp}" + ) + else: + logger.info("Single GPU mode - no parallelization needed") + + def _map_hf_config_to_model_args(self, hf_config, model_args_cls): + """ + Map HuggingFace config to TorchTitan ModelArgs. + + Default implementation that handles common model args fields. + Override in subclass if custom mapping is needed. + """ + # Maps TorchTitan parameter name to HF config attribute name + mapping = { + "vocab_size": "vocab_size", + "dim": "hidden_size", + "n_layers": "num_hidden_layers", + "n_heads": "num_attention_heads", + "n_kv_heads": "num_key_value_heads", + "head_dim": "head_dim", + "hidden_dim": "intermediate_size", + "norm_eps": "rms_norm_eps", + "max_seq_len": "max_position_embeddings", + "rope_theta": "rope_theta", + "qk_norm": "qk_norm", + } + + # Build kwargs for model args from mapping + kwargs = {} + for torchtitan_param, hf_attr in mapping.items(): + # Try to get value from HF config + if hasattr(hf_config, hf_attr): + kwargs[torchtitan_param] = getattr(hf_config, hf_attr) + + return model_args_cls(**kwargs) + + def _replace_with_vllm_attention(self, model_args): + """ + Replace TorchTitan attention with vLLM paged attention. + + Assumes model has .layers dict with .attention.inner_attention structure. + Override in subclass if different structure. + """ + if not hasattr(self.model, "layers"): + raise AttributeError( + f"Model {type(self.model).__name__} must have .layers attribute" + ) + + for layer_name, layer in self.model.layers.items(): + if not hasattr(layer, "attention"): + raise ValueError(f"Layer {layer_name} must have .attention attribute") + + # Create VLLMAttention layer + # VLLMAttention wraps vLLM's Attention and properly implements + # AttentionLayerBase with get_kv_cache_spec() for KV cache integration + vllm_attn = VLLMAttention( + hidden_size=model_args.dim, + num_heads=model_args.n_heads, + num_kv_heads=model_args.n_heads, # Use n_heads (already replicated) + head_dim=model_args.head_dim, + layer_name=layer_name, + scale=model_args.head_dim**-0.5, + ) + + # Replace inner attention + layer.attention.inner_attention = vllm_attn + + logger.info( + f"Successfully replaced TorchTitan attention with VLLMAttention " + f"({len(self.model.layers)} layers)" + ) + + def _extend_rope_cache_if_needed( + self, rope_cache: torch.Tensor, max_position: int + ) -> torch.Tensor: + """ + Extend RoPE cache if needed during vLLM profiling. + + Args: + rope_cache: Current RoPE cache tensor + max_position: Maximum position index needed + + Returns: + Extended RoPE cache if needed, otherwise original cache + """ + from torch.distributed._tensor import DTensor, Replicate + + required_len = max_position + 1 + + # No extension needed + if required_len <= rope_cache.shape[0]: + return rope_cache + + # If no extension function provided, return original cache + if self.rope_cache_extension_fn is None: + logger.warning( + f"RoPE cache extension needed (required_len={required_len}, " + f"current_len={rope_cache.shape[0]}) but no rope_cache_extension_fn provided. " + "Returning original cache." + ) + return rope_cache + + # Handle DTensor case + is_dtensor = isinstance(rope_cache, DTensor) + if is_dtensor: + device_mesh = rope_cache.device_mesh + local_rope_cache = rope_cache.to_local() + device = local_rope_cache.device + dtype = local_rope_cache.dtype + else: + device = rope_cache.device + dtype = rope_cache.dtype + + # Use provided extension function + try: + extended_cache = self.rope_cache_extension_fn(self.config, required_len) + extended_cache = extended_cache.to(device=device, dtype=dtype) + except Exception as e: + logger.warning( + f"Failed to extend RoPE cache using rope_cache_extension_fn: {e}. " + "Returning original cache." + ) + return rope_cache + + # Convert back to DTensor if needed + if is_dtensor: + rope_cache = DTensor.from_local( + extended_cache, + device_mesh=device_mesh, + placements=[Replicate()], + ) + else: + rope_cache = extended_cache + + return rope_cache + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + """Convert input token IDs to embeddings.""" + return self.model.tok_embeddings(input_ids) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + """Convert input token IDs to embeddings (deprecated vLLM interface).""" + return self.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + """ + Forward pass with vLLM interface. + + Args: + input_ids: Token IDs [total_tokens] (1D varlen format) + positions: Position indices [total_tokens] (1D varlen format) + inputs_embeds: Pre-computed embeddings (optional) + **kwargs: Additional vLLM kwargs + + Returns: + hidden_states: Final hidden states [total_tokens, hidden_size] + """ + if inputs_embeds is not None: + raise NotImplementedError("inputs_embeds not yet supported") + + if input_ids is None: + raise ValueError("Either input_ids or inputs_embeds must be provided") + + # Convert vLLM interface to TorchTitan interface + # vLLM: [total_tokens] → TorchTitan: [batch_size, seq_len] + tokens_2d = input_ids.unsqueeze(0) + + # Get embeddings + h = self.model.tok_embeddings(tokens_2d) + + # Get RoPE cache (handle model-specific attribute names) + # Use hasattr to avoid ambiguous boolean value error with tensors + if hasattr(self.model, "rope_cache"): + rope_attr = self.model.rope_cache + elif hasattr(self.model, "freqs_cis"): + rope_attr = self.model.freqs_cis + else: + rope_attr = None + + # Extend RoPE cache if needed (vLLM profiling may use 2x max_seq_len) + if positions is not None: + max_position = positions.max().item() + else: + max_position = 0 + + rope_cache = self._extend_rope_cache_if_needed(rope_attr, max_position) + positions = positions.unsqueeze(0) + + # Pass through transformer layers + for layer in self.model.layers.values(): + h = layer(h, rope_cache, attention_masks=None, positions=positions) + + # Convert to vLLM format: [total_tokens, hidden_size] + if h.dim() == 3: + batch_size, seq_len, hidden_size = h.shape + h = h.view(batch_size * seq_len, hidden_size) + + return h + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata=None, + ) -> torch.Tensor | None: + """Compute logits from hidden states.""" + h = self.model.norm(hidden_states) + logits = self.model.output(h) + + return logits + + def load_weights(self, weights_iter): + """ + Load weights from HF checkpoint using the provided state dict adapter. + + Args: + weights_iter: Iterator of (name, tensor) pairs from HF checkpoint + + Returns: + Set of loaded parameter names + """ + # Collect weights from iterator + hf_state_dict = {} + for name, tensor in weights_iter: + hf_state_dict[name] = tensor + + # Use adapter to convert HF → TorchTitan format + adapter = self.state_dict_adapter( + model_args=self.config, + hf_assets_path=None, + ) + + torchtitan_state_dict = adapter.from_hf(hf_state_dict) + model_state_dict = {k: v for k, v in self.model.state_dict().items()} + + # Convert to DTensor if target is DTensor + for name, tensor in torchtitan_state_dict.items(): + if name in model_state_dict and isinstance(model_state_dict[name], DTensor): + target_dtensor = model_state_dict[name] + device_mesh = target_dtensor.device_mesh + torchtitan_state_dict[name] = DTensor.from_local( + tensor.to(device_mesh.device_type), + device_mesh=device_mesh, + placements=[Replicate()], + ) + + # Load state dict + set_model_state_dict( + model=self.model, + model_state_dict=torchtitan_state_dict, + options=StateDictOptions(strict=False), + ) + + loaded_params = {f"model.{name}" for name in torchtitan_state_dict.keys()} + + return loaded_params diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/README.md b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/README.md new file mode 100644 index 0000000000..d2ef719c0d --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/README.md @@ -0,0 +1,262 @@ +# Deterministic RL Training with vLLM + +This experiment combines vLLM's deterministic kernels with PyTorch autograd to enable reinforcement learning training where forward passes produce bitwise-identical results across runs. + +## Overview + +RL training requires both fast inference for generating rollouts and gradient computation for policy updates. vLLM provides deterministic forward passes but does not support gradients. This experiment adds backward passes to vLLM's operations. + +The implementation: +1. Uses vLLM's batch-invariant kernels for forward passes +2. Implements custom backward passes for gradient computation +3. Provides weight conversion utilities between TorchTitan and vLLM formats + +### Features + +- Bitwise determinism: Same inputs produce identical outputs across runs +- Gradient support: Backward passes through vLLM operations +- Weight conversion: Utilities to convert between model formats + +Note: Currently supports single-device training only. + +## Architecture + +### Components + +1. `models/attention.py`: VLLMCompatibleFlashAttention + - Uses vLLM's Flash Attention for forward pass + - Implements custom backward pass for gradient computation + - Uses `num_splits=1` for deterministic behavior + +2. `models/qwen3/model_vllm_compat.py`: Qwen3VLLMCompatModel + - Qwen3 model with merged gate/up projections matching vLLM format + - Uses VLLMRMSNorm with gradient support + +3. `batch_invariant_backward.py`: Backward passes for vLLM operations + - Registers gradients for vLLM's batch-invariant operations + - Supports matmul, linear, and RMSNorm + - Patches Flash Attention for autograd + +4. `weights_vllm_compat.py`: Weight conversion utilities + - Converts between TorchTitan format (separate w1, w2, w3) and vLLM format (merged gate_up_proj) + - Provides bidirectional conversion functions + +5. `simple_rl.py`: RL training loop + - Generates rollouts using vLLM engine + - Computes advantages using GRPO-style ranking + - Updates policy using PPO + +## Installation + +### Prerequisites + +```bash +# Install vLLM with deterministic support +pip install vllm + +# Install TorchTitan (from the repository root) +pip install -e . + +# Install additional dependencies +pip install transformers safetensors huggingface_hub tensorboard +``` + +### Enable Batch Invariance + +Initialize vLLM's batch-invariant mode before training: + +```python +from vllm.model_executor.layers.batch_invariant import init_batch_invariance +init_batch_invariance() +``` + +## Usage + +### Quick Start + +```python +import torch +from vllm.model_executor.layers.batch_invariant import init_batch_invariance +from torchtitan.experiments.deterministic_vllm_rl import ( + enable_batch_invariant_backward_mode, + Qwen3VLLMCompatModel, +) + +# 1. Enable deterministic mode +init_batch_invariance() +enable_batch_invariant_backward_mode() + +# 2. Load model +from torchtitan.models.qwen3.model.args import Qwen3ModelArgs +model_args = Qwen3ModelArgs( + dim=2048, + n_layers=24, + n_heads=16, + n_kv_heads=2, + vocab_size=151936, +) +model = Qwen3VLLMCompatModel(model_args) + +# 3. Forward pass (deterministic) +input_ids = torch.randint(0, 151936, (2, 128), device='cuda') +logits = model(input_ids) + +# 4. Backward pass +loss = logits.sum() +loss.backward() +``` + +### Full RL Training + +Run the RL training loop: + +```bash +VLLM_BATCH_INVARIANT=1 VLLM_FLASH_ATTN_VERSION=3 python -m torchtitan.experiments.deterministic_vllm_rl.simple_rl +``` + +This will: +1. Download Qwen3-1.7B from HuggingFace +2. Initialize vLLM engine for rollouts +3. Generate samples for training prompts +4. Compute rewards and advantages +5. Update the policy using PPO +6. Log metrics to TensorBoard + +View training progress: +```bash +tensorboard --logdir=./outputs/rl_training +``` + +## How It Works + +### Deterministic Forward Pass + +vLLM's batch-invariant mode makes operations deterministic: + +```python +# These operations are deterministic when batch_invariance is enabled +y = torch.matmul(a, b) # Uses vLLM's deterministic matmul +output = flash_attn_varlen_func(q, k, v, num_splits=1) # Deterministic FA +``` + +### Backward Pass with Gradients + +Custom backward passes: +1. Re-compute attention weights deterministically +2. Use standard chain rule for gradients +3. Apply gradients through vLLM's deterministic operations + +```python +class FlashAttnWithBackward(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, ...): + # Use vLLM's forward implementation + return flash_attn_varlen_func(q, k, v, num_splits=1, ...) + + @staticmethod + def backward(ctx, grad_output): + # Compute gradients deterministically + # (re-compute attention weights and apply chain rule) + return grad_q, grad_k, grad_v, ... +``` + +### Bitwise Determinism Verification + +The training loop compares logprobs from vLLM and TorchTitan: + +```python +# During training, compare logprobs +vllm_logprobs = [from vLLM rollout] +titan_logprobs = [from TorchTitan forward pass] + +assert torch.equal(vllm_logprobs, titan_logprobs) +``` + +## Testing + +Run the test suite: + +```bash +cd torchtitan/experiments/deterministic_vllm_rl/tests + +# Test backward passes +python test_batch_invariant_backward.py + +# Test determinism +python test_exact_determinism.py +``` + +## Technical Details + +### Why Determinism Matters for RL + +RL training steps: +1. Generate rollouts by sampling from the policy +2. Compute rewards based on the samples +3. Update the policy using gradients + +If the forward pass during training differs from the forward pass during rollout, policy gradients may be incorrect. This matters for algorithms like PPO that compare old and new policy probabilities. + +This implementation uses the same kernels for both rollouts (vLLM) and training (TorchTitan) to ensure `logprobs_rollout == logprobs_training` bitwise. + +### Performance + +- Rollout speed: Uses vLLM's optimized kernels +- Training speed: Similar to standard TorchTitan +- Memory: Saves activations for custom backward passes + +### Limitations + +1. Custom backward requires uniform sequence lengths +2. Only causal attention is supported +3. Requires NVIDIA GPUs with Flash Attention support + +## Project Structure + +``` +deterministic_vllm_rl/ +├── README.md # Documentation +├── __init__.py # Package initialization +├── batch_invariant_backward.py # Backward passes for vLLM ops +├── weights_vllm_compat.py # Weight conversion utilities +├── simple_rl.py # RL training loop +├── models/ +│ ├── __init__.py +│ ├── attention.py # VLLMCompatibleFlashAttention +│ └── qwen3/ +│ ├── __init__.py +│ └── model_vllm_compat.py # vLLM-compatible Qwen3 model +├── weights/ +│ ├── __init__.py +│ ├── converter.py # Weight conversion script +│ └── README.md # Weight conversion documentation +└── tests/ + ├── __init__.py + ├── test_batch_invariant_backward.py # Test backward passes + └── test_exact_determinism.py # Test determinism +``` + +## TODO + +- `FlashAttnWithBackward` will need to become more composable and should not live exclusively within this directory. +- vLLM integration will need to become more generic with a provided Attention operator that is KV-cache compatible. +- vLLM parallelism will need to add generic parallelism initialization to support Monarch managed TP/DP. + +## Contributing + +This experiment is part of TorchTitan. To contribute: + +1. Test your changes with `pytest tests/` +2. Verify bitwise determinism is maintained +3. Update this README if adding new features + +## References + +- [vLLM Documentation](https://docs.vllm.ai/) +- [Flash Attention Paper](https://arxiv.org/abs/2205.14135) +- [PPO Algorithm](https://arxiv.org/abs/1707.06347) +- [GRPO: Group Relative Policy Optimization](https://arxiv.org/abs/2402.03300) + +## License + +This code is licensed under the BSD-style license found in the LICENSE file in the TorchTitan repository root directory. diff --git a/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/__init__.py new file mode 100644 index 0000000000..b86721fba5 --- /dev/null +++ b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +vLLM-Compatible approach for deterministic RL training. + +This module provides models that match vLLM's weight format (e.g., merged gate_up_proj) +with custom backward passes for gradient computation during training. +""" + +from .batch_invariant_backward import ( + enable_batch_invariant_backward_mode, + rms_norm_with_gradients, + silu_and_mul_with_gradients, +) +from .models.attention import VLLMCompatibleFlashAttention +from .models.qwen3 import Qwen3VLLMCompatModel + + +__all__ = [ + "VLLMCompatibleFlashAttention", + "Qwen3VLLMCompatModel", + "enable_batch_invariant_backward_mode", + "rms_norm_with_gradients", + "silu_and_mul_with_gradients", +] diff --git a/torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/batch_invariant_backward.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/batch_invariant_backward.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/__init__.py similarity index 74% rename from torchtitan/experiments/deterministic_vllm_rl/models/__init__.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/__init__.py index c8c11a170a..2e7a5fa6af 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py +++ b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/__init__.py @@ -6,8 +6,13 @@ """ Models for deterministic vLLM RL training. + +This module provides vLLM-compatible model components. """ from .attention import VLLMCompatibleFlashAttention -__all__ = ["VLLMCompatibleFlashAttention"] + +__all__ = [ + "VLLMCompatibleFlashAttention", +] diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/attention.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/attention.py similarity index 98% rename from torchtitan/experiments/deterministic_vllm_rl/models/attention.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/attention.py index 33dd5a140d..11e6d3af67 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/attention.py +++ b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/attention.py @@ -4,12 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -""" -vLLM-compatible Flash Attention implementation for deterministic RL training. -""" import torch -from vllm.vllm_flash_attn import flash_attn_varlen_func +from vllm.attention.utils.fa_utils import flash_attn_varlen_func class VLLMCompatibleFlashAttention(torch.nn.Module): diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/qwen3/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/qwen3/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/qwen3/model_vllm_compat.py similarity index 99% rename from torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/qwen3/model_vllm_compat.py index dd84665091..b4967fbbd9 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py +++ b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/models/qwen3/model_vllm_compat.py @@ -13,7 +13,7 @@ from torchtitan.components.tokenizer import BaseTokenizer # Import gradient-enabled operations from experiment utilities -from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( +from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( rms_norm_with_gradients, silu_and_mul_with_gradients, ) diff --git a/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/simple_rl.py similarity index 99% rename from torchtitan/experiments/deterministic_vllm_rl/simple_rl.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/simple_rl.py index ffc7d52eb0..3f938eba85 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py +++ b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/simple_rl.py @@ -30,11 +30,11 @@ from vllm import LLM, SamplingParams from vllm.model_executor.layers.batch_invariant import init_batch_invariance -from torchtitan.experiments.deterministic_vllm_rl.weights.converter import ( +from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.weights.converter import ( torchtitan_to_vllm, vllm_to_torchtitan, ) -from torchtitan.experiments.deterministic_vllm_rl.weights_vllm_compat import ( +from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.weights_vllm_compat import ( torchtitan_to_vllm_compat, ) @@ -340,7 +340,7 @@ def load_model(checkpoint_path: str, model_path: str, use_vllm_compat: bool = Tr if use_vllm_compat: # Create and load model (using vLLM-compat for bitwise determinism) - from torchtitan.experiments.deterministic_vllm_rl.models.qwen3 import ( + from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.models.qwen3 import ( Qwen3VLLMCompatModel, ) @@ -1058,7 +1058,7 @@ def main(): print("✓ Batch invariance detected - using vLLM-compatible model") # Add backward pass support to vLLM's batch_invariant mode print(" Adding gradient support to vLLM's batch_invariant mode...") - from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( + from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( enable_batch_invariant_backward_mode, ) diff --git a/torchtitan/experiments/deterministic_vllm_rl/tests/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/tests/__init__.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/test_batch_invariant_backward.py similarity index 97% rename from torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/test_batch_invariant_backward.py index 3ed9604d10..ddf8b01514 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py +++ b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/test_batch_invariant_backward.py @@ -8,9 +8,11 @@ Test batch_invariant_backward module to ensure it works correctly. """ +import sys + import torch -from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( +from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( disable_batch_invariant_backward_mode, enable_batch_invariant_backward_mode, linear_batch_invariant_backward, diff --git a/torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/test_exact_determinism.py similarity index 98% rename from torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/test_exact_determinism.py index 8d0ac3133e..bfb7954a2a 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py +++ b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/tests/test_exact_determinism.py @@ -13,7 +13,7 @@ import torch from vllm.model_executor.layers.batch_invariant import disable_batch_invariant_mode -from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( +from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( enable_batch_invariant_backward_mode, ) diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/README.md b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights/README.md similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights/README.md rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights/README.md diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/__init__.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights/__init__.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/converter.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights/converter.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights/converter.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights/converter.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights_vllm_compat.py b/torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights_vllm_compat.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights_vllm_compat.py rename to torchtitan/experiments/deterministic_vllm_rl/vllm_compat/weights_vllm_compat.py