Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
"autoparallel.deepseek_v3",
"autoparallel.local_map_deepseek_v3",
"ft.llama3",
"rl.unified",
"rl",
]
)
75 changes: 66 additions & 9 deletions torchtitan/experiments/rl/README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,69 @@
# Deterministic RL Training with vLLM
# RL Training with TorchTitan and vLLM

This package provides two approaches for integrating TorchTitan models with vLLM:
This directory contains code for RL training using TorchTitan model definitions with vLLM inference engine for fast rollout generation.

1. vllm_compat/ - vLLM-Compatible approach
- Separate model definition matching vLLM's weight format
- Support batch-invariant and bit-wise identity between train and inference
- Custom backward passes for attention gradient computation
## Overview
The integration consists of the following components:

2. unified/ - Unified approach
- Uses canonical TorchTitan model definition for inference directly
- Replaces attention with vLLM Compatible attention for inference
1. **vLLM Model Wrapper** (`models/vllm_wrapper.py`): Adapts TorchTitan models for vLLM's inference engine
2. **RL Training Loop** (`simple_grpo_sum_digits.py`): GRPO-based RL training with Monarch actors
3. **Inference Script** (`inference_example.py`): Standalone inference using the vLLM engine


## Quick Start
### Prerequisites

0. Create and activate environment with uv:
```bash
uv venv --python 3.12 titan-rl
source titan-rl/bin/activate
```

1. Install Monarch:
```bash
uv pip install torchmonarch
```


2. Install PyTorch nightly for torchtitan, and pre-built vllm wheels (based on PyTorch nightly version).
```bash
# Install vllm with nightly torch
uv pip install torch vllm xformers --pre \
--extra-index-url https://download.pytorch.org/whl/nightly/cu128 \
--index-strategy unsafe-best-match
```

**NOTE:** The pre-built vLLM wheels are only compatible with CUDA 12.8, though they should work with most older CUDA versions. Alternatively, you can install the corresponding vLLM pre-built wheels directly from https://download.pytorch.org/whl/nightly/cu128, for example: `uv pip install vllm-1.0.0.dev20260219+cu130-<suffix>.whl`. Ensure the build version number (e.g., `dev20260219`) matches your PyTorch nightly installation.


3. Install TorchTitan in editable mode:
```bash
uv pip install -e .
```

4. Download `Qwen/Qwen3-0.6B` (or `Qwen/Qwen3-1.7B`) checkpoint from HuggingFace to `torchtitan/experiments/rl/example_checkpoint` folder.
```bash
python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --local_dir torchtitan/experiments/rl/example_checkpoint --all --hf_token=...

python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-1.7B --local_dir torchtitan/experiments/rl/example_checkpoint --all --hf_token=...
```

5. Run inference with torchtitan model definition:
```bash
torchrun --nproc_per_node=2 torchtitan/experiments/rl/inference_example.py
```

**NOTE:**: Set `--nproc_per_node` to the world size, which should match the `tensor_parallel_degree` in the `VLLMGenerator` config.

6. Run simple GRPO RL loop to learn sum digits task
```bash
python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b
```

**NOTE:** If you downloaded your HF model to a different path than the one in step 4, specify it in your command with `--hf_assets_path=<path_to_model_checkpoint>`.

We use a unified model definition from torchtitan for the trainer and generator, ensuring bitwise-identical models to address a class of subtle correctness bugs in RL for LLMs.



**Current status:** Batch invariance is only supported for single-GPU configurations (TP=1) for both the trainer and generator. When tensor parallelism is enabled (TP > 1), batch-invariant mode is not yet supported.
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,14 @@
Unified approach for running TorchTitan models with vLLM inference.

To register TorchTitan models with vLLM:
from torchtitan.experiments.rl.unified.plugin import register
from torchtitan.experiments.rl.plugin import register
register(model_spec)
"""

from torchtitan.experiments.rl.unified.models.vllm_wrapper import (
TorchTitanVLLMModelWrapper,
)
from torchtitan.experiments.rl.models.vllm_wrapper import TorchTitanVLLMModelWrapper

# Export plugin register function for manual use (no auto-registration)
from torchtitan.experiments.rl.unified.plugin import (
register_model_to_vllm_model_registry,
)
from torchtitan.experiments.rl.plugin import register_model_to_vllm_model_registry


__all__ = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from monarch.actor import Actor, endpoint
from torchtitan.config import Configurable
from torchtitan.config.configs import ParallelismConfig
from torchtitan.experiments.rl.unified.plugin import (
from torchtitan.experiments.rl.plugin import (
register_model_to_vllm_model_registry,
VLLM_MODEL_NAME,
)
from torchtitan.experiments.rl.unified.types import Episode
from torchtitan.experiments.rl.types import Episode
from torchtitan.protocols.model_spec import ModelSpec
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.config import AttentionConfig, CompilationConfig
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch
from monarch.actor import Actor, endpoint
from torchtitan.experiments.rl.unified.types import Episode
from torchtitan.experiments.rl.types import Episode

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
from torchtitan.config import CommConfig, Configurable, TORCH_DTYPE_MAP
from torchtitan.config.configs import ParallelismConfig, TrainingConfig
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.experiments.rl.unified.actors.utils import (
from torchtitan.experiments.rl.actors.utils import (
compute_policy_gradient_loss,
compute_token_log_probs,
verify_logprob_identity,
)
from torchtitan.experiments.rl.unified.models.attention import (
from torchtitan.experiments.rl.models.attention import (
replace_with_vllm_compatible_flash_attention,
)
from torchtitan.experiments.rl.unified.types import Episode
from torchtitan.experiments.rl.types import Episode
from torchtitan.protocols.model_spec import ModelSpec
from torchtitan.tools import utils

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
Config entry points for the RL/unified experiment.

Each function returns a complete ``RLTrainer.Config`` and is discoverable by
``ConfigManager`` via ``--module rl.unified --config <function_name>``.
``ConfigManager`` via ``--module rl --config <function_name>``.
"""

from torchtitan.components.lr_scheduler import LRSchedulersContainer
from torchtitan.components.optimizer import OptimizersContainer
from torchtitan.config.configs import ParallelismConfig, TrainingConfig
from torchtitan.experiments.rl.unified.actors.generator import (
from torchtitan.experiments.rl.actors.generator import (
GeneratorCompileConfig,
SamplingConfig,
VLLMGenerator,
)
from torchtitan.experiments.rl.unified.actors.trainer import PolicyTrainer
from torchtitan.experiments.rl.unified.simple_grpo_sum_digits import RLTrainer
from torchtitan.experiments.rl.actors.trainer import PolicyTrainer
from torchtitan.experiments.rl.simple_grpo_sum_digits import RLTrainer
from torchtitan.models.qwen3 import model_registry


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
the vLLM engine and sampling parameters.

Run: torchrun --nproc_per_node=2 \
torchtitan/experiments/rl/unified/infer.py
torchtitan/experiments/rl/inference_example.py
"""
import os

Expand All @@ -21,11 +21,11 @@
# See also https://docs.vllm.ai/en/v0.8.3/design/multiprocessing.html#python-multiprocessing
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

from torchtitan.experiments.rl.unified.config_registry import rl_grpo_qwen3_0_6b

from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.logger import init_logger

from torchtitan.experiments.rl.config_registry import rl_grpo_qwen3_0_6b


logger = init_logger(__name__)

Expand All @@ -38,12 +38,12 @@ def generate():

# Patch model_spec to use the RL-specific parallelize function.
# TODO: Switch to canonical Qwen3 parallel plan
from torchtitan.experiments.rl.unified.models.parallelize import parallelize_qwen3
from torchtitan.experiments.rl.models.parallelize import parallelize_qwen3

config.model_spec.parallelize_fn = parallelize_qwen3

# Register TorchTitan model with vLLM before engine creation
from torchtitan.experiments.rl.unified.plugin import (
from torchtitan.experiments.rl.plugin import (
register_model_to_vllm_model_registry,
VLLM_MODEL_NAME,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

import torch
from torch.distributed.tensor import DTensor
from torchtitan.experiments.rl.vllm_compat.models.attention import (
from torchtitan.experiments.rl.models.vllm_compat_attention import (
VLLMCompatibleFlashAttention,
)
from torchtitan.protocols.module import Module

from vllm.model_executor.layers.attention import Attention

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# TODO: This file needs to be deleted after switching to PyTorch's Varlen Attention

import math
from collections.abc import Callable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@

from torchtitan.config import ParallelismConfig
from torchtitan.distributed.parallel_dims import ParallelDims
from torchtitan.experiments.rl.unified.models.attention import (
replace_with_vllm_attention,
)
from torchtitan.experiments.rl.models.attention import replace_with_vllm_attention
from torchtitan.protocols.model_spec import ModelSpec
from torchtitan.protocols.module import Module
from vllm.compilation.decorators import support_torch_compile
Expand Down Expand Up @@ -185,7 +183,7 @@ def __init__(

# Pre-extend RoPE cache to cover vLLM's max model length (profiling
# may use up to 2x max_seq_len, so use max_model_len which already
# accounts for this). This avoids data-dependent control flow in
# accounts for this). This avoids data-dependent control flow in
# forward() which is incompatible with torch.compile.
max_model_len = vllm_config.model_config.max_model_len
if self.model.freqs_cis.shape[0] < max_model_len:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
vLLM plugin for TorchTitan models.

Usage:
from torchtitan.experiments.rl.unified.plugin import register_model_to_vllm_model_registry
from torchtitan.experiments.rl.plugin import register_model_to_vllm_model_registry
register_model_to_vllm_model_registry(model_spec)
"""

Expand All @@ -30,12 +30,11 @@ def register_model_to_vllm_model_registry(
Args:
model_spec: TorchTitan ModelSpec containing model config and components
"""
from torchtitan.experiments.rl.unified.models.vllm_wrapper import (
TorchTitanVLLMModelWrapper,
)
from vllm.logger import init_logger
from vllm.model_executor.models.registry import ModelRegistry

from torchtitan.experiments.rl.models.vllm_wrapper import TorchTitanVLLMModelWrapper

logger = init_logger(__name__)

# Create dynamic model class capturing ModelSpec in the closure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
The architecture mirrors monarch's grpo_actor.py but adapted for vLLM rollouts + TorchTitan training.

Command to run:
python3 torchtitan/experiments/rl/unified/simple_grpo_sum_digits.py \
--module rl.unified --config rl_grpo_qwen3_0_6b \
python3 torchtitan/experiments/rl/simple_grpo_sum_digits.py \
--module rl --config rl_grpo_qwen3_0_6b \
--hf_assets_path=<path_to_model_checkpoint>
"""

Expand All @@ -35,13 +35,14 @@
import torchstore as ts
from monarch.actor import this_host
from monarch.spmd import setup_torch_elastic_env_async

from torchtitan.config import Configurable
from torchtitan.config.manager import ConfigManager
from torchtitan.experiments.rl.unified.actors.generator import VLLMGenerator
from torchtitan.experiments.rl.unified.actors.grader import Grader
from torchtitan.experiments.rl.unified.actors.trainer import PolicyTrainer
from torchtitan.experiments.rl.unified.sum_digits import extract_answer, SumDigitsTask
from torchtitan.experiments.rl.unified.types import Episode
from torchtitan.experiments.rl.actors.generator import VLLMGenerator
from torchtitan.experiments.rl.actors.grader import Grader
from torchtitan.experiments.rl.actors.trainer import PolicyTrainer
from torchtitan.experiments.rl.sum_digits import extract_answer, SumDigitsTask
from torchtitan.experiments.rl.types import Episode
from torchtitan.protocols.model_spec import ModelSpec

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -141,9 +142,7 @@ def __init__(self, config: Config):

# Patch model_spec to use the RL-specific parallelize function.
# TODO: Switch to canonical Qwen3 parallel plan
from torchtitan.experiments.rl.unified.models.parallelize import (
parallelize_qwen3,
)
from torchtitan.experiments.rl.models.parallelize import parallelize_qwen3

config.model_spec.parallelize_fn = parallelize_qwen3

Expand Down
File renamed without changes.
Loading
Loading