Skip to content

Commit bf9ef7b

Browse files
committed
Update README and remove rl_grpo_qwen3_0_6b_tp1 config
1 parent 4e18928 commit bf9ef7b

File tree

8 files changed

+16
-65
lines changed

8 files changed

+16
-65
lines changed

torchtitan/experiments/rl/README.md

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
# Run vLLM inference with TorchTitan Qwen3 Model
1+
# RL Training with TorchTitan and vLLM
22

3-
This directory contains code to run a single canonical model definition (TorchTitan model definition) with vLLM inference engine (not batch-invariant yet, working in progress). This work is actively developing and only supports inference for now.
4-
5-
This work is inspired by https://github.com/vllm-project/vllm/pull/28685.
3+
This directory contains code for RL training using TorchTitan model definitions with vLLM inference engine for fast rollout generation.
64

75
## Overview
8-
The integration consists of two main components:
6+
The integration consists of the following components:
97

10-
1. **Model Adapter** (`model/qwen3.py`): A custom model class that extends vLLM's `Qwen3ForCausalLM` to handle TorchTitan checkpoint naming conventions
11-
2. **Inference Script** (`inference_example.py`): A simple script to register the model and run inference
8+
1. **vLLM Model Wrapper** (`models/vllm_wrapper.py`): Adapts TorchTitan models for vLLM's inference engine
9+
2. **RL Training Loop** (`simple_grpo_sum_digits.py`): GRPO-based RL training with Monarch actors
10+
3. **Inference Script** (`inference_example.py`): Standalone inference using the vLLM engine
1211

1312

1413
## Quick Start

torchtitan/experiments/rl/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,10 @@
1212
register(model_spec)
1313
"""
1414

15-
from torchtitan.experiments.rl.models.vllm_wrapper import (
16-
TorchTitanVLLMModelWrapper,
17-
)
15+
from torchtitan.experiments.rl.models.vllm_wrapper import TorchTitanVLLMModelWrapper
1816

1917
# Export plugin register function for manual use (no auto-registration)
20-
from torchtitan.experiments.rl.plugin import (
21-
register_model_to_vllm_model_registry,
22-
)
18+
from torchtitan.experiments.rl.plugin import register_model_to_vllm_model_registry
2319

2420

2521
__all__ = [

torchtitan/experiments/rl/config_registry.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -102,46 +102,6 @@ def rl_grpo_qwen3_1_7b() -> RLTrainer.Config:
102102
)
103103

104104

105-
def rl_grpo_qwen3_0_6b_tp1() -> RLTrainer.Config:
106-
"""GRPO training config for Qwen3-0.6B with TP=1 (2 GPUs: 1 gen + 1 train)."""
107-
return RLTrainer.Config(
108-
model_spec=model_registry("0.6B"),
109-
hf_assets_path="torchtitan/experiments/rl/example_checkpoint/Qwen3-0.6B",
110-
num_steps=10,
111-
batch_invariant_mode=True,
112-
trainer=PolicyTrainer.Config(
113-
optimizer=OptimizersContainer.Config(lr=2e-6),
114-
lr_scheduler=LRSchedulersContainer.Config(
115-
warmup_steps=2,
116-
decay_type="linear",
117-
),
118-
training=TrainingConfig(),
119-
parallelism=ParallelismConfig(
120-
tensor_parallel_degree=1,
121-
data_parallel_replicate_degree=1,
122-
),
123-
),
124-
generator=VLLMGenerator.Config(
125-
model_dtype="bfloat16",
126-
compile=GeneratorCompileConfig(
127-
backend="eager",
128-
cudagraph_mode="piecewise",
129-
),
130-
parallelism=ParallelismConfig(
131-
tensor_parallel_degree=1,
132-
data_parallel_replicate_degree=1,
133-
),
134-
num_samples_per_prompt=8,
135-
sampling=SamplingConfig(
136-
temperature=0.8,
137-
top_p=0.95,
138-
max_tokens=100,
139-
),
140-
attention_backend="FLASH_ATTN",
141-
),
142-
)
143-
144-
145105
def rl_grpo_qwen3_debug() -> RLTrainer.Config:
146106
"""Debug config for quick iteration -- small model, few steps (2 GPUs: 1 gen + 1 train)."""
147107
return RLTrainer.Config(

torchtitan/experiments/rl/inference_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
# See also https://docs.vllm.ai/en/v0.8.3/design/multiprocessing.html#python-multiprocessing
2222
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
2323

24-
from torchtitan.experiments.rl.config_registry import rl_grpo_qwen3_0_6b
25-
2624
from vllm import EngineArgs, LLMEngine, SamplingParams
2725
from vllm.logger import init_logger
2826

27+
from torchtitan.experiments.rl.config_registry import rl_grpo_qwen3_0_6b
28+
2929

3030
logger = init_logger(__name__)
3131

torchtitan/experiments/rl/models/vllm_wrapper.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@
2424

2525
from torchtitan.config import ParallelismConfig
2626
from torchtitan.distributed.parallel_dims import ParallelDims
27-
from torchtitan.experiments.rl.models.attention import (
28-
replace_with_vllm_attention,
29-
)
27+
from torchtitan.experiments.rl.models.attention import replace_with_vllm_attention
3028
from torchtitan.protocols.model_spec import ModelSpec
3129
from torchtitan.protocols.module import Module
3230
from vllm.compilation.decorators import support_torch_compile

torchtitan/experiments/rl/plugin.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,11 @@ def register_model_to_vllm_model_registry(
3030
Args:
3131
model_spec: TorchTitan ModelSpec containing model config and components
3232
"""
33-
from torchtitan.experiments.rl.models.vllm_wrapper import (
34-
TorchTitanVLLMModelWrapper,
35-
)
3633
from vllm.logger import init_logger
3734
from vllm.model_executor.models.registry import ModelRegistry
3835

36+
from torchtitan.experiments.rl.models.vllm_wrapper import TorchTitanVLLMModelWrapper
37+
3938
logger = init_logger(__name__)
4039

4140
# Create dynamic model class capturing ModelSpec in the closure

torchtitan/experiments/rl/simple_grpo_sum_digits.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,14 @@
3434
import torch
3535
from monarch.actor import this_host
3636
from monarch.spmd import setup_torch_elastic_env_async
37+
3738
from torchtitan.config import Configurable
3839
from torchtitan.config.manager import ConfigManager
3940
from torchtitan.experiments.rl.actors.generator import VLLMGenerator
4041
from torchtitan.experiments.rl.actors.grader import Grader
4142
from torchtitan.experiments.rl.actors.trainer import PolicyTrainer
42-
from torchtitan.experiments.rl.sum_digits import extract_answer, SumDigitsTask
4343
from torchtitan.experiments.rl.rl_types import Episode
44+
from torchtitan.experiments.rl.sum_digits import extract_answer, SumDigitsTask
4445
from torchtitan.protocols.model_spec import ModelSpec
4546

4647
logger = logging.getLogger(__name__)
@@ -140,9 +141,7 @@ def __init__(self, config: Config):
140141

141142
# Patch model_spec to use the RL-specific parallelize function.
142143
# TODO: Switch to canonical Qwen3 parallel plan
143-
from torchtitan.experiments.rl.models.parallelize import (
144-
parallelize_qwen3,
145-
)
144+
from torchtitan.experiments.rl.models.parallelize import parallelize_qwen3
146145

147146
config.model_spec.parallelize_fn = parallelize_qwen3
148147

0 commit comments

Comments
 (0)