Skip to content

Commit b4bbb2c

Browse files
Small fixes to inference_example.py and README (#2467)
1 parent e51a77d commit b4bbb2c

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

torchtitan/experiments/rl/unified/README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ This work is inspired by https://github.com/vllm-project/vllm/pull/28685.
88
The integration consists of two main components:
99

1010
1. **Model Adapter** (`model/qwen3.py`): A custom model class that extends vLLM's `Qwen3ForCausalLM` to handle TorchTitan checkpoint naming conventions
11-
2. **Inference Script** (`infer.py`): A simple script to register the model and run inference
11+
2. **Inference Script** (`inference_example.py`): A simple script to register the model and run inference
1212

1313

1414
## Quick Start
@@ -49,10 +49,11 @@ python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --local_dir torch
4949

5050
5. Run inference with unified model definition:
5151
```bash
52-
torchrun --nproc_per_node=<world_size> \
53-
torchtitan/experiments/rl/unified/inference_example.py
52+
torchrun --nproc_per_node=2 torchtitan/experiments/rl/unified/inference_example.py
5453
```
5554

55+
**NOTE:**: Set `--nproc_per_node` to the world size, which should match the `tensor_parallel_degree` in the `VLLMGenerator` config.
56+
5657
6. Run simple GRPO RL loop
5758
```bash
5859
python torchtitan/experiments/rl/unified/simple_grpo.py --module rl.unified --config rl_grpo_qwen3_0_6b --hf_assets_path=<path_to_model_checkpoint>

torchtitan/experiments/rl/unified/inference_example.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ def generate():
3636
gen_config = config.generator
3737
model_path = config.trainer.hf_assets_path
3838

39+
# Patch model_spec to use the RL-specific parallelize function.
40+
# TODO: Switch to canonical Qwen3 parallel plan
41+
from torchtitan.experiments.rl.unified.models.parallelize import parallelize_qwen3
42+
43+
config.model_spec.parallelize_fn = parallelize_qwen3
44+
3945
# Register TorchTitan model with vLLM before engine creation
4046
from torchtitan.experiments.rl.unified.plugin import (
4147
register_model_to_vllm_model_registry,
@@ -52,7 +58,7 @@ def generate():
5258
)
5359

5460
# Create EngineArgs from config
55-
engine_args = EngineArgs(
61+
engine_kwargs = dict(
5662
# Model configuration
5763
model=model_path,
5864
trust_remote_code=True,
@@ -65,11 +71,12 @@ def generate():
6571
# Memory and performance
6672
gpu_memory_utilization=gen_config.gpu_memory_limit,
6773
enforce_eager=gen_config.enforce_eager,
68-
# Seed
69-
seed=gen_config.seed,
7074
# HuggingFace overrides
7175
hf_overrides={"architectures": [VLLM_MODEL_NAME]},
7276
)
77+
if gen_config.seed is not None:
78+
engine_kwargs["seed"] = gen_config.seed
79+
engine_args = EngineArgs(**engine_kwargs)
7380

7481
logger.debug("Initializing LLMEngine from EngineArgs...")
7582
engine = LLMEngine.from_engine_args(engine_args)

0 commit comments

Comments
 (0)