Skip to content

Commit 3e23b47

Browse files
committed
Merge remote-tracking branch 'upstream/main' into fix-pos-id
2 parents 338bb8f + ea614ba commit 3e23b47

37 files changed

+112
-3247
lines changed

torchtitan/components/validate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def validate(
316316
loss, parallel_dims.get_optional_mesh("loss")
317317
)
318318
else:
319-
global_avg_loss = loss.item()
319+
global_avg_loss = float(loss.item())
320320

321321
self.metrics_processor.log_validation(loss=global_avg_loss, step=step)
322322

torchtitan/distributed/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ def _dist_reduce(
5252
x = funcol.all_reduce(x, reduceOp=reduceOp, group=extra_pg)
5353

5454
if mesh is None:
55-
return x.item()
55+
return float(x.item())
5656

5757
assert x.numel() == 1 # required by `.item()`
58-
return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item()
58+
return float(funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item())
5959

6060

6161
# TODO: rename this to maybe_dist_max

torchtitan/experiments/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@
1414
"autoparallel.deepseek_v3",
1515
"autoparallel.local_map_deepseek_v3",
1616
"ft.llama3",
17-
"rl.unified",
17+
"rl",
1818
]
1919
)
Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,69 @@
1-
# Deterministic RL Training with vLLM
1+
# RL Training with TorchTitan and vLLM
22

3-
This package provides two approaches for integrating TorchTitan models with vLLM:
3+
This directory contains code for RL training using TorchTitan model definitions with vLLM inference engine for fast rollout generation.
44

5-
1. vllm_compat/ - vLLM-Compatible approach
6-
- Separate model definition matching vLLM's weight format
7-
- Support batch-invariant and bit-wise identity between train and inference
8-
- Custom backward passes for attention gradient computation
5+
## Overview
6+
The integration consists of the following components:
97

10-
2. unified/ - Unified approach
11-
- Uses canonical TorchTitan model definition for inference directly
12-
- Replaces attention with vLLM Compatible attention for inference
8+
1. **vLLM Model Wrapper** (`models/vllm_wrapper.py`): Adapts TorchTitan models for vLLM's inference engine
9+
2. **RL Training Loop** (`simple_grpo_sum_digits.py`): GRPO-based RL training with Monarch actors
10+
3. **Inference Script** (`inference_example.py`): Standalone inference using the vLLM engine
11+
12+
13+
## Quick Start
14+
### Prerequisites
15+
16+
0. Create and activate environment with uv:
17+
```bash
18+
uv venv --python 3.12 titan-rl
19+
source titan-rl/bin/activate
20+
```
21+
22+
1. Install Monarch:
23+
```bash
24+
uv pip install torchmonarch
25+
```
26+
27+
28+
2. Install PyTorch nightly for torchtitan, and pre-built vllm wheels (based on PyTorch nightly version).
29+
```bash
30+
# Install vllm with nightly torch
31+
uv pip install torch vllm xformers --pre \
32+
--extra-index-url https://download.pytorch.org/whl/nightly/cu128 \
33+
--index-strategy unsafe-best-match
34+
```
35+
36+
**NOTE:** The pre-built vLLM wheels are only compatible with CUDA 12.8, though they should work with most older CUDA versions. Alternatively, you can install the corresponding vLLM pre-built wheels directly from https://download.pytorch.org/whl/nightly/cu128, for example: `uv pip install vllm-1.0.0.dev20260219+cu130-<suffix>.whl`. Ensure the build version number (e.g., `dev20260219`) matches your PyTorch nightly installation.
37+
38+
39+
3. Install TorchTitan in editable mode:
40+
```bash
41+
uv pip install -e .
42+
```
43+
44+
4. Download `Qwen/Qwen3-0.6B` (or `Qwen/Qwen3-1.7B`) checkpoint from HuggingFace to `torchtitan/experiments/rl/example_checkpoint` folder.
45+
```bash
46+
python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --local_dir torchtitan/experiments/rl/example_checkpoint --all --hf_token=...
47+
48+
python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-1.7B --local_dir torchtitan/experiments/rl/example_checkpoint --all --hf_token=...
49+
```
50+
51+
5. Run inference with torchtitan model definition:
52+
```bash
53+
torchrun --nproc_per_node=2 torchtitan/experiments/rl/inference_example.py
54+
```
55+
56+
**NOTE:**: Set `--nproc_per_node` to the world size, which should match the `tensor_parallel_degree` in the `VLLMGenerator` config.
57+
58+
6. Run simple GRPO RL loop to learn sum digits task
59+
```bash
60+
python torchtitan/experiments/rl/simple_grpo_sum_digits.py --module rl --config rl_grpo_qwen3_0_6b
61+
```
62+
63+
**NOTE:** If you downloaded your HF model to a different path than the one in step 4, specify it in your command with `--hf_assets_path=<path_to_model_checkpoint>`.
64+
65+
We use a unified model definition from torchtitan for the trainer and generator, ensuring bitwise-identical models to address a class of subtle correctness bugs in RL for LLMs.
66+
67+
68+
69+
**Current status:** Batch invariance is only supported for single-GPU configurations (TP=1) for both the trainer and generator. When tensor parallelism is enabled (TP > 1), batch-invariant mode is not yet supported.

torchtitan/experiments/rl/unified/__init__.py renamed to torchtitan/experiments/rl/__init__.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,14 @@
88
Unified approach for running TorchTitan models with vLLM inference.
99
1010
To register TorchTitan models with vLLM:
11-
from torchtitan.experiments.rl.unified.plugin import register
11+
from torchtitan.experiments.rl.plugin import register
1212
register(model_spec)
1313
"""
1414

15-
from torchtitan.experiments.rl.unified.models.vllm_wrapper import (
16-
TorchTitanVLLMModelWrapper,
17-
)
15+
from torchtitan.experiments.rl.models.vllm_wrapper import TorchTitanVLLMModelWrapper
1816

1917
# Export plugin register function for manual use (no auto-registration)
20-
from torchtitan.experiments.rl.unified.plugin import (
21-
register_model_to_vllm_model_registry,
22-
)
18+
from torchtitan.experiments.rl.plugin import register_model_to_vllm_model_registry
2319

2420

2521
__all__ = [

torchtitan/experiments/rl/unified/actors/generator.py renamed to torchtitan/experiments/rl/actors/generator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
from monarch.actor import Actor, endpoint
1515
from torchtitan.config import Configurable
1616
from torchtitan.config.configs import ParallelismConfig
17-
from torchtitan.experiments.rl.unified.plugin import (
17+
from torchtitan.experiments.rl.plugin import (
1818
register_model_to_vllm_model_registry,
1919
VLLM_MODEL_NAME,
2020
)
21-
from torchtitan.experiments.rl.unified.types import Episode
21+
from torchtitan.experiments.rl.types import Episode
2222
from torchtitan.protocols.model_spec import ModelSpec
2323
from vllm import EngineArgs, LLMEngine, SamplingParams
2424
from vllm.config import AttentionConfig, CompilationConfig

torchtitan/experiments/rl/unified/actors/grader.py renamed to torchtitan/experiments/rl/actors/grader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import torch
1111
from monarch.actor import Actor, endpoint
12-
from torchtitan.experiments.rl.unified.types import Episode
12+
from torchtitan.experiments.rl.types import Episode
1313

1414
logger = logging.getLogger(__name__)
1515

torchtitan/experiments/rl/unified/actors/trainer.py renamed to torchtitan/experiments/rl/actors/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@
2323
from torchtitan.config import CommConfig, Configurable, TORCH_DTYPE_MAP
2424
from torchtitan.config.configs import ParallelismConfig, TrainingConfig
2525
from torchtitan.distributed import ParallelDims, utils as dist_utils
26-
from torchtitan.experiments.rl.unified.actors.utils import (
26+
from torchtitan.experiments.rl.actors.utils import (
2727
compute_policy_gradient_loss,
2828
compute_token_log_probs,
2929
verify_logprob_identity,
3030
)
31-
from torchtitan.experiments.rl.unified.models.attention import (
31+
from torchtitan.experiments.rl.models.attention import (
3232
replace_with_vllm_compatible_flash_attention,
3333
)
34-
from torchtitan.experiments.rl.unified.types import Episode
34+
from torchtitan.experiments.rl.types import Episode
3535
from torchtitan.protocols.model_spec import ModelSpec
3636
from torchtitan.tools import utils
3737

File renamed without changes.

torchtitan/experiments/rl/unified/config_registry.py renamed to torchtitan/experiments/rl/config_registry.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,19 @@
88
Config entry points for the RL/unified experiment.
99
1010
Each function returns a complete ``RLTrainer.Config`` and is discoverable by
11-
``ConfigManager`` via ``--module rl.unified --config <function_name>``.
11+
``ConfigManager`` via ``--module rl --config <function_name>``.
1212
"""
1313

1414
from torchtitan.components.lr_scheduler import LRSchedulersContainer
1515
from torchtitan.components.optimizer import OptimizersContainer
1616
from torchtitan.config.configs import ParallelismConfig, TrainingConfig
17-
from torchtitan.experiments.rl.unified.actors.generator import (
17+
from torchtitan.experiments.rl.actors.generator import (
1818
GeneratorCompileConfig,
1919
SamplingConfig,
2020
VLLMGenerator,
2121
)
22-
from torchtitan.experiments.rl.unified.actors.trainer import PolicyTrainer
23-
from torchtitan.experiments.rl.unified.simple_grpo_sum_digits import RLTrainer
22+
from torchtitan.experiments.rl.actors.trainer import PolicyTrainer
23+
from torchtitan.experiments.rl.simple_grpo_sum_digits import RLTrainer
2424
from torchtitan.models.qwen3 import model_registry
2525

2626

0 commit comments

Comments
 (0)