Skip to content

Commit 4e18928

Browse files
committed
Flatten rl/ directory: remove vllm_compat/, move unified/ contents to rl/
- Move all files from rl/unified/ directly under rl/ (actors, models, scripts, etc.) - Remove rl/vllm_compat/ entirely (unused by unified code) - Rename types.py -> rl_types.py to avoid shadowing Python stdlib types module - Fix vllm.model_executor.layers.attention.Attention import for newer vLLM - Update experiment registry: rl.unified -> rl - Update all internal imports and README paths - Add rl_grpo_qwen3_0_6b_tp1 config for TP=1 testing
1 parent 3d4425e commit 4e18928

32 files changed

+140
-3297
lines changed

torchtitan/experiments/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@
1414
"autoparallel.deepseek_v3",
1515
"autoparallel.local_map_deepseek_v3",
1616
"ft.llama3",
17-
"rl.unified",
17+
"rl",
1818
]
1919
)
Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,70 @@
1-
# Deterministic RL Training with vLLM
1+
# Run vLLM inference with TorchTitan Qwen3 Model
22

3-
This package provides two approaches for integrating TorchTitan models with vLLM:
3+
This directory contains code to run a single canonical model definition (TorchTitan model definition) with vLLM inference engine (not batch-invariant yet, working in progress). This work is actively developing and only supports inference for now.
44

5-
1. vllm_compat/ - vLLM-Compatible approach
6-
- Separate model definition matching vLLM's weight format
7-
- Support batch-invariant and bit-wise identity between train and inference
8-
- Custom backward passes for attention gradient computation
5+
This work is inspired by https://github.com/vllm-project/vllm/pull/28685.
96

10-
2. unified/ - Unified approach
11-
- Uses canonical TorchTitan model definition for inference directly
12-
- Replaces attention with vLLM Compatible attention for inference
7+
## Overview
8+
The integration consists of two main components:
9+
10+
1. **Model Adapter** (`model/qwen3.py`): A custom model class that extends vLLM's `Qwen3ForCausalLM` to handle TorchTitan checkpoint naming conventions
11+
2. **Inference Script** (`inference_example.py`): A simple script to register the model and run inference
12+
13+
14+
## Quick Start
15+
### Prerequisites
16+
17+
0. Create and activate environment with uv:
18+
```bash
19+
uv venv --python 3.12 titan-rl
20+
source titan-rl/bin/activate
21+
```
22+
23+
1. Install Monarch:
24+
```bash
25+
uv pip install torchmonarch
26+
```
27+
28+
29+
2. Install PyTorch nightly for torchtitan, and pre-built vllm wheels (based on PyTorch nightly version).
30+
```bash
31+
# Install vllm with nightly torch
32+
uv pip install torch vllm xformers --pre \
33+
--extra-index-url https://download.pytorch.org/whl/nightly/cu128 \
34+
--index-strategy unsafe-best-match
35+
```
36+
37+
**NOTE:** The pre-built vLLM wheels are only compatible with CUDA 12.8, though they should work with most older CUDA versions. Alternatively, you can install the corresponding vLLM pre-built wheels directly from https://download.pytorch.org/whl/nightly/cu128, for example: `uv pip install vllm-1.0.0.dev20260219+cu130-<suffix>.whl`. Ensure the build version number (e.g., `dev20260219`) matches your PyTorch nightly installation.
38+
39+
40+
3. Install TorchTitan in editable mode:
41+
```bash
42+
uv pip install -e .
43+
```
44+
45+
4. Download `Qwen/Qwen3-0.6B` (or `Qwen/Qwen3-1.7B`) checkpoint from HuggingFace to `torchtitan/experiments/rl/example_checkpoint` folder.
46+
```bash
47+
python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --local_dir torchtitan/experiments/rl/example_checkpoint --all --hf_token=...
48+
49+
python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-1.7B --local_dir torchtitan/experiments/rl/example_checkpoint --all --hf_token=...
50+
```
51+
52+
5. Run inference with torchtitan model definition:
53+
```bash
54+
torchrun --nproc_per_node=2 torchtitan/experiments/rl/inference_example.py
55+
```
56+
57+
**NOTE:**: Set `--nproc_per_node` to the world size, which should match the `tensor_parallel_degree` in the `VLLMGenerator` config.
58+
59+
6. Run simple GRPO RL loop to learn sum digits task
60+
```bash
61+
python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b
62+
```
63+
64+
**NOTE:** If you downloaded your HF model to a different path than the one in step 4, specify it in your command with `--hf_assets_path=<path_to_model_checkpoint>`.
65+
66+
We use a unified model definition from torchtitan for the trainer and generator, ensuring bitwise-identical models to address a class of subtle correctness bugs in RL for LLMs.
67+
68+
69+
70+
**Current status:** Batch invariance is only supported for single-GPU configurations (TP=1) for both the trainer and generator. When tensor parallelism is enabled (TP > 1), batch-invariant mode is not yet supported.

torchtitan/experiments/rl/unified/__init__.py renamed to torchtitan/experiments/rl/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,16 @@
88
Unified approach for running TorchTitan models with vLLM inference.
99
1010
To register TorchTitan models with vLLM:
11-
from torchtitan.experiments.rl.unified.plugin import register
11+
from torchtitan.experiments.rl.plugin import register
1212
register(model_spec)
1313
"""
1414

15-
from torchtitan.experiments.rl.unified.models.vllm_wrapper import (
15+
from torchtitan.experiments.rl.models.vllm_wrapper import (
1616
TorchTitanVLLMModelWrapper,
1717
)
1818

1919
# Export plugin register function for manual use (no auto-registration)
20-
from torchtitan.experiments.rl.unified.plugin import (
20+
from torchtitan.experiments.rl.plugin import (
2121
register_model_to_vllm_model_registry,
2222
)
2323

torchtitan/experiments/rl/unified/actors/generator.py renamed to torchtitan/experiments/rl/actors/generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
from torch.distributed.tensor import distribute_tensor, DTensor
1515
from torchtitan.config import Configurable
1616
from torchtitan.config.configs import ParallelismConfig
17-
from torchtitan.experiments.rl.unified.plugin import (
17+
from torchtitan.experiments.rl.plugin import (
1818
register_model_to_vllm_model_registry,
1919
VLLM_MODEL_NAME,
2020
)
21-
from torchtitan.experiments.rl.unified.types import Episode
21+
from torchtitan.experiments.rl.rl_types import Episode
2222
from torchtitan.protocols.model_spec import ModelSpec
2323
from vllm import EngineArgs, LLMEngine, SamplingParams
2424
from vllm.config import AttentionConfig, CompilationConfig

torchtitan/experiments/rl/unified/actors/grader.py renamed to torchtitan/experiments/rl/actors/grader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111
from monarch.actor import Actor, endpoint
12-
from torchtitan.experiments.rl.unified.types import Episode
12+
from torchtitan.experiments.rl.rl_types import Episode
1313

1414
logger = logging.getLogger(__name__)
1515

torchtitan/experiments/rl/unified/actors/trainer.py renamed to torchtitan/experiments/rl/actors/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@
2323
from torchtitan.config import CommConfig, Configurable, TORCH_DTYPE_MAP
2424
from torchtitan.config.configs import ParallelismConfig, TrainingConfig
2525
from torchtitan.distributed import ParallelDims, utils as dist_utils
26-
from torchtitan.experiments.rl.unified.actors.utils import (
26+
from torchtitan.experiments.rl.actors.utils import (
2727
compute_policy_gradient_loss,
2828
compute_token_log_probs,
2929
verify_logprob_identity,
3030
)
31-
from torchtitan.experiments.rl.unified.models.attention import (
31+
from torchtitan.experiments.rl.models.attention import (
3232
replace_with_vllm_compatible_flash_attention,
3333
)
34-
from torchtitan.experiments.rl.unified.types import Episode
34+
from torchtitan.experiments.rl.rl_types import Episode
3535
from torchtitan.protocols.model_spec import ModelSpec
3636
from torchtitan.tools import utils
3737

File renamed without changes.

torchtitan/experiments/rl/unified/config_registry.py renamed to torchtitan/experiments/rl/config_registry.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@
88
Config entry points for the RL/unified experiment.
99
1010
Each function returns a complete ``RLTrainer.Config`` and is discoverable by
11-
``ConfigManager`` via ``--module rl.unified --config <function_name>``.
11+
``ConfigManager`` via ``--module rl --config <function_name>``.
1212
"""
1313

1414
from torchtitan.components.lr_scheduler import LRSchedulersContainer
1515
from torchtitan.components.optimizer import OptimizersContainer
1616
from torchtitan.config.configs import ParallelismConfig, TrainingConfig
17-
from torchtitan.experiments.rl.unified.actors.generator import (
17+
from torchtitan.experiments.rl.actors.generator import (
1818
GeneratorCompileConfig,
1919
SamplingConfig,
2020
VLLMGenerator,
2121
)
22-
from torchtitan.experiments.rl.unified.actors.trainer import PolicyTrainer
23-
from torchtitan.experiments.rl.unified.simple_grpo_sum_digits import RLTrainer
22+
from torchtitan.experiments.rl.actors.trainer import PolicyTrainer
23+
from torchtitan.experiments.rl.simple_grpo_sum_digits import RLTrainer
2424
from torchtitan.models.qwen3 import model_registry
2525

2626

@@ -102,6 +102,46 @@ def rl_grpo_qwen3_1_7b() -> RLTrainer.Config:
102102
)
103103

104104

105+
def rl_grpo_qwen3_0_6b_tp1() -> RLTrainer.Config:
106+
"""GRPO training config for Qwen3-0.6B with TP=1 (2 GPUs: 1 gen + 1 train)."""
107+
return RLTrainer.Config(
108+
model_spec=model_registry("0.6B"),
109+
hf_assets_path="torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B",
110+
num_steps=10,
111+
batch_invariant_mode=True,
112+
trainer=PolicyTrainer.Config(
113+
optimizer=OptimizersContainer.Config(lr=2e-6),
114+
lr_scheduler=LRSchedulersContainer.Config(
115+
warmup_steps=2,
116+
decay_type="linear",
117+
),
118+
training=TrainingConfig(),
119+
parallelism=ParallelismConfig(
120+
tensor_parallel_degree=1,
121+
data_parallel_replicate_degree=1,
122+
),
123+
),
124+
generator=VLLMGenerator.Config(
125+
model_dtype="bfloat16",
126+
compile=GeneratorCompileConfig(
127+
backend="eager",
128+
cudagraph_mode="piecewise",
129+
),
130+
parallelism=ParallelismConfig(
131+
tensor_parallel_degree=1,
132+
data_parallel_replicate_degree=1,
133+
),
134+
num_samples_per_prompt=8,
135+
sampling=SamplingConfig(
136+
temperature=0.8,
137+
top_p=0.95,
138+
max_tokens=100,
139+
),
140+
attention_backend="FLASH_ATTN",
141+
),
142+
)
143+
144+
105145
def rl_grpo_qwen3_debug() -> RLTrainer.Config:
106146
"""Debug config for quick iteration -- small model, few steps (2 GPUs: 1 gen + 1 train)."""
107147
return RLTrainer.Config(

torchtitan/experiments/rl/unified/inference_example.py renamed to torchtitan/experiments/rl/inference_example.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
the vLLM engine and sampling parameters.
1313
1414
Run: torchrun --nproc_per_node=2 \
15-
torchtitan/experiments/rl/unified/infer.py
15+
torchtitan/experiments/rl/inference_example.py
1616
"""
1717
import os
1818

@@ -21,7 +21,7 @@
2121
# See also https://docs.vllm.ai/en/v0.8.3/design/multiprocessing.html#python-multiprocessing
2222
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
2323

24-
from torchtitan.experiments.rl.unified.config_registry import rl_grpo_qwen3_0_6b
24+
from torchtitan.experiments.rl.config_registry import rl_grpo_qwen3_0_6b
2525

2626
from vllm import EngineArgs, LLMEngine, SamplingParams
2727
from vllm.logger import init_logger
@@ -38,12 +38,12 @@ def generate():
3838

3939
# Patch model_spec to use the RL-specific parallelize function.
4040
# TODO: Switch to canonical Qwen3 parallel plan
41-
from torchtitan.experiments.rl.unified.models.parallelize import parallelize_qwen3
41+
from torchtitan.experiments.rl.models.parallelize import parallelize_qwen3
4242

4343
config.model_spec.parallelize_fn = parallelize_qwen3
4444

4545
# Register TorchTitan model with vLLM before engine creation
46-
from torchtitan.experiments.rl.unified.plugin import (
46+
from torchtitan.experiments.rl.plugin import (
4747
register_model_to_vllm_model_registry,
4848
VLLM_MODEL_NAME,
4949
)

torchtitan/experiments/rl/unified/models/attention.py renamed to torchtitan/experiments/rl/models/attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88

99
import torch
1010
from torch.distributed.tensor import DTensor
11-
from torchtitan.experiments.rl.vllm_compat.models.attention import (
11+
from torchtitan.experiments.rl.models.vllm_compat_attention import (
1212
VLLMCompatibleFlashAttention,
1313
)
1414
from torchtitan.protocols.module import Module
15-
from vllm.model_executor.layers.attention import Attention
15+
from vllm.attention.layer import Attention
1616

1717
logger = logging.getLogger(__name__)
1818

0 commit comments

Comments
 (0)