Skip to content

Commit ea614ba

Browse files
authored
Flattenrl directory: remove vllm_compat, consolidate unified (#2618)
`rl/` - Move all files from rl/unified/ directly under rl/ (actors, models, scripts, etc.) - Remove rl/vllm_compat/ entirely (unused by unified code) - Rename types.py -> rl_types.py to avoid shadowing Python stdlib types module - Fix vllm.model_executor.layers.attention.Attention import for newer vLLM - Update experiment registry: rl.unified -> rl - Update all internal imports and README paths - Add rl_grpo_qwen3_0_6b_tp1 config for TP=1 testing
1 parent 27b3985 commit ea614ba

31 files changed

+102
-3237
lines changed

torchtitan/experiments/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@
1414
"autoparallel.deepseek_v3",
1515
"autoparallel.local_map_deepseek_v3",
1616
"ft.llama3",
17-
"rl.unified",
17+
"rl",
1818
]
1919
)
Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,69 @@
1-
# Deterministic RL Training with vLLM
1+
# RL Training with TorchTitan and vLLM
22

3-
This package provides two approaches for integrating TorchTitan models with vLLM:
3+
This directory contains code for RL training using TorchTitan model definitions with vLLM inference engine for fast rollout generation.
44

5-
1. vllm_compat/ - vLLM-Compatible approach
6-
- Separate model definition matching vLLM's weight format
7-
- Support batch-invariant and bit-wise identity between train and inference
8-
- Custom backward passes for attention gradient computation
5+
## Overview
6+
The integration consists of the following components:
97

10-
2. unified/ - Unified approach
11-
- Uses canonical TorchTitan model definition for inference directly
12-
- Replaces attention with vLLM Compatible attention for inference
8+
1. **vLLM Model Wrapper** (`models/vllm_wrapper.py`): Adapts TorchTitan models for vLLM's inference engine
9+
2. **RL Training Loop** (`simple_grpo_sum_digits.py`): GRPO-based RL training with Monarch actors
10+
3. **Inference Script** (`inference_example.py`): Standalone inference using the vLLM engine
11+
12+
13+
## Quick Start
14+
### Prerequisites
15+
16+
0. Create and activate environment with uv:
17+
```bash
18+
uv venv --python 3.12 titan-rl
19+
source titan-rl/bin/activate
20+
```
21+
22+
1. Install Monarch:
23+
```bash
24+
uv pip install torchmonarch
25+
```
26+
27+
28+
2. Install PyTorch nightly for torchtitan, and pre-built vllm wheels (based on PyTorch nightly version).
29+
```bash
30+
# Install vllm with nightly torch
31+
uv pip install torch vllm xformers --pre \
32+
--extra-index-url https://download.pytorch.org/whl/nightly/cu128 \
33+
--index-strategy unsafe-best-match
34+
```
35+
36+
**NOTE:** The pre-built vLLM wheels are only compatible with CUDA 12.8, though they should work with most older CUDA versions. Alternatively, you can install the corresponding vLLM pre-built wheels directly from https://download.pytorch.org/whl/nightly/cu128, for example: `uv pip install vllm-1.0.0.dev20260219+cu130-<suffix>.whl`. Ensure the build version number (e.g., `dev20260219`) matches your PyTorch nightly installation.
37+
38+
39+
3. Install TorchTitan in editable mode:
40+
```bash
41+
uv pip install -e .
42+
```
43+
44+
4. Download `Qwen/Qwen3-0.6B` (or `Qwen/Qwen3-1.7B`) checkpoint from HuggingFace to `torchtitan/experiments/rl/example_checkpoint` folder.
45+
```bash
46+
python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --local_dir torchtitan/experiments/rl/example_checkpoint --all --hf_token=...
47+
48+
python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-1.7B --local_dir torchtitan/experiments/rl/example_checkpoint --all --hf_token=...
49+
```
50+
51+
5. Run inference with torchtitan model definition:
52+
```bash
53+
torchrun --nproc_per_node=2 torchtitan/experiments/rl/inference_example.py
54+
```
55+
56+
**NOTE:**: Set `--nproc_per_node` to the world size, which should match the `tensor_parallel_degree` in the `VLLMGenerator` config.
57+
58+
6. Run simple GRPO RL loop to learn sum digits task
59+
```bash
60+
python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b
61+
```
62+
63+
**NOTE:** If you downloaded your HF model to a different path than the one in step 4, specify it in your command with `--hf_assets_path=<path_to_model_checkpoint>`.
64+
65+
We use a unified model definition from torchtitan for the trainer and generator, ensuring bitwise-identical models to address a class of subtle correctness bugs in RL for LLMs.
66+
67+
68+
69+
**Current status:** Batch invariance is only supported for single-GPU configurations (TP=1) for both the trainer and generator. When tensor parallelism is enabled (TP > 1), batch-invariant mode is not yet supported.

torchtitan/experiments/rl/unified/__init__.py renamed to torchtitan/experiments/rl/__init__.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,14 @@
88
Unified approach for running TorchTitan models with vLLM inference.
99
1010
To register TorchTitan models with vLLM:
11-
from torchtitan.experiments.rl.unified.plugin import register
11+
from torchtitan.experiments.rl.plugin import register
1212
register(model_spec)
1313
"""
1414

15-
from torchtitan.experiments.rl.unified.models.vllm_wrapper import (
16-
TorchTitanVLLMModelWrapper,
17-
)
15+
from torchtitan.experiments.rl.models.vllm_wrapper import TorchTitanVLLMModelWrapper
1816

1917
# Export plugin register function for manual use (no auto-registration)
20-
from torchtitan.experiments.rl.unified.plugin import (
21-
register_model_to_vllm_model_registry,
22-
)
18+
from torchtitan.experiments.rl.plugin import register_model_to_vllm_model_registry
2319

2420

2521
__all__ = [

torchtitan/experiments/rl/unified/actors/generator.py renamed to torchtitan/experiments/rl/actors/generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
from monarch.actor import Actor, endpoint
1515
from torchtitan.config import Configurable
1616
from torchtitan.config.configs import ParallelismConfig
17-
from torchtitan.experiments.rl.unified.plugin import (
17+
from torchtitan.experiments.rl.plugin import (
1818
register_model_to_vllm_model_registry,
1919
VLLM_MODEL_NAME,
2020
)
21-
from torchtitan.experiments.rl.unified.types import Episode
21+
from torchtitan.experiments.rl.types import Episode
2222
from torchtitan.protocols.model_spec import ModelSpec
2323
from vllm import EngineArgs, LLMEngine, SamplingParams
2424
from vllm.config import AttentionConfig, CompilationConfig

torchtitan/experiments/rl/unified/actors/grader.py renamed to torchtitan/experiments/rl/actors/grader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111
from monarch.actor import Actor, endpoint
12-
from torchtitan.experiments.rl.unified.types import Episode
12+
from torchtitan.experiments.rl.types import Episode
1313

1414
logger = logging.getLogger(__name__)
1515

torchtitan/experiments/rl/unified/actors/trainer.py renamed to torchtitan/experiments/rl/actors/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@
2323
from torchtitan.config import CommConfig, Configurable, TORCH_DTYPE_MAP
2424
from torchtitan.config.configs import ParallelismConfig, TrainingConfig
2525
from torchtitan.distributed import ParallelDims, utils as dist_utils
26-
from torchtitan.experiments.rl.unified.actors.utils import (
26+
from torchtitan.experiments.rl.actors.utils import (
2727
compute_policy_gradient_loss,
2828
compute_token_log_probs,
2929
verify_logprob_identity,
3030
)
31-
from torchtitan.experiments.rl.unified.models.attention import (
31+
from torchtitan.experiments.rl.models.attention import (
3232
replace_with_vllm_compatible_flash_attention,
3333
)
34-
from torchtitan.experiments.rl.unified.types import Episode
34+
from torchtitan.experiments.rl.types import Episode
3535
from torchtitan.protocols.model_spec import ModelSpec
3636
from torchtitan.tools import utils
3737

File renamed without changes.

torchtitan/experiments/rl/unified/config_registry.py renamed to torchtitan/experiments/rl/config_registry.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@
88
Config entry points for the RL/unified experiment.
99
1010
Each function returns a complete ``RLTrainer.Config`` and is discoverable by
11-
``ConfigManager`` via ``--module rl.unified --config <function_name>``.
11+
``ConfigManager`` via ``--module rl --config <function_name>``.
1212
"""
1313

1414
from torchtitan.components.lr_scheduler import LRSchedulersContainer
1515
from torchtitan.components.optimizer import OptimizersContainer
1616
from torchtitan.config.configs import ParallelismConfig, TrainingConfig
17-
from torchtitan.experiments.rl.unified.actors.generator import (
17+
from torchtitan.experiments.rl.actors.generator import (
1818
GeneratorCompileConfig,
1919
SamplingConfig,
2020
VLLMGenerator,
2121
)
22-
from torchtitan.experiments.rl.unified.actors.trainer import PolicyTrainer
23-
from torchtitan.experiments.rl.unified.simple_grpo_sum_digits import RLTrainer
22+
from torchtitan.experiments.rl.actors.trainer import PolicyTrainer
23+
from torchtitan.experiments.rl.simple_grpo_sum_digits import RLTrainer
2424
from torchtitan.models.qwen3 import model_registry
2525

2626

torchtitan/experiments/rl/unified/inference_example.py renamed to torchtitan/experiments/rl/inference_example.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
the vLLM engine and sampling parameters.
1313
1414
Run: torchrun --nproc_per_node=2 \
15-
torchtitan/experiments/rl/unified/infer.py
15+
torchtitan/experiments/rl/inference_example.py
1616
"""
1717
import os
1818

@@ -21,11 +21,11 @@
2121
# See also https://docs.vllm.ai/en/v0.8.3/design/multiprocessing.html#python-multiprocessing
2222
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
2323

24-
from torchtitan.experiments.rl.unified.config_registry import rl_grpo_qwen3_0_6b
25-
2624
from vllm import EngineArgs, LLMEngine, SamplingParams
2725
from vllm.logger import init_logger
2826

27+
from torchtitan.experiments.rl.config_registry import rl_grpo_qwen3_0_6b
28+
2929

3030
logger = init_logger(__name__)
3131

@@ -38,12 +38,12 @@ def generate():
3838

3939
# Patch model_spec to use the RL-specific parallelize function.
4040
# TODO: Switch to canonical Qwen3 parallel plan
41-
from torchtitan.experiments.rl.unified.models.parallelize import parallelize_qwen3
41+
from torchtitan.experiments.rl.models.parallelize import parallelize_qwen3
4242

4343
config.model_spec.parallelize_fn = parallelize_qwen3
4444

4545
# Register TorchTitan model with vLLM before engine creation
46-
from torchtitan.experiments.rl.unified.plugin import (
46+
from torchtitan.experiments.rl.plugin import (
4747
register_model_to_vllm_model_registry,
4848
VLLM_MODEL_NAME,
4949
)

torchtitan/experiments/rl/unified/models/attention.py renamed to torchtitan/experiments/rl/models/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88

99
import torch
1010
from torch.distributed.tensor import DTensor
11-
from torchtitan.experiments.rl.vllm_compat.models.attention import (
11+
from torchtitan.experiments.rl.models.vllm_compat_attention import (
1212
VLLMCompatibleFlashAttention,
1313
)
1414
from torchtitan.protocols.module import Module
15+
1516
from vllm.model_executor.layers.attention import Attention
1617

1718
logger = logging.getLogger(__name__)

0 commit comments

Comments
 (0)