Skip to content

Commit 72978cf

Browse files
committed
Update on "[rl] Add CI for numerics test against vllm native inference"
Test cases: 1. Integration tests: - single GPU, no compile + cudagraph - multiple GPU (with TP), no compile + cudagraph - multiple GPU, with compile + cudagraph - This test runs on A10G (default CI GPU type) 3. Numerics parity test: vLLM native model vs vLLM + TorchTitan wrapper. - test_weights_match: max_diff <= 1e-5 (exact weight loading) - test_attention_module: atol=1e-5 (TP=1) - test_end_to_end_logits: atol=1e-3 (TP=1) - We would need to run numerics test for only TP=1. This is because we are assuming both torchtitan and vllm will make sure their multi-GPU implementation is on par with single GPU. And we can add more numerics test under parallelism if needed. - This test runs on H100, and runs FA3 kernel for attention. [ghstack-poisoned]
2 parents 9d63e20 + e97d04c commit 72978cf

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

torchtitan/experiments/rl/actors/generator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@ class Config(Configurable.Config):
145145
num_samples_per_prompt: int = 8
146146
"""Number of completions to generate per prompt."""
147147

148+
max_model_len: int | None = None
149+
"""Maximum context length for vLLM's KV cache allocation. vLLM
150+
pre-allocates paged KV cache blocks up to this length; None lets
151+
vLLM use the model's max_position_embeddings (e.g. 40960 for
152+
Qwen3-0.6B)"""
153+
148154
seed: int | None = None
149155
"""Random seed for vLLM engine and sampling. None for non-deterministic."""
150156

@@ -211,6 +217,8 @@ def __init__(
211217
vllm_compilation_config = config.compile.get_vllm_compilation_config()
212218
if vllm_compilation_config is not None:
213219
engine_kwargs["compilation_config"] = vllm_compilation_config
220+
if config.max_model_len is not None:
221+
engine_kwargs["max_model_len"] = config.max_model_len
214222
if config.seed is not None:
215223
engine_kwargs["seed"] = config.seed
216224
engine_args = EngineArgs(**engine_kwargs)

torchtitan/experiments/rl/tests/integration_tests.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,12 @@ def build_rl_test_list() -> list[OverrideDefinitions]:
3636
"--config rl_grpo_qwen3_0_6b",
3737
"--trainer.parallelism.tensor_parallel_degree 2",
3838
"--generator.parallelism.tensor_parallel_degree 2",
39+
"--generator.max_model_len 2048",
3940
"--generator.compile.backend none",
4041
"--generator.compile.cudagraph_mode none",
4142
],
4243
],
43-
"RL GRPO TP=2 no compile (debug model)",
44+
"RL GRPO TP=2 no compile",
4445
"rl_grpo_tp2_no_compile",
4546
ngpu=4,
4647
),
@@ -51,9 +52,10 @@ def build_rl_test_list() -> list[OverrideDefinitions]:
5152
"--config rl_grpo_qwen3_0_6b",
5253
"--trainer.parallelism.tensor_parallel_degree 2",
5354
"--generator.parallelism.tensor_parallel_degree 2",
55+
"--generator.max_model_len 2048",
5456
],
5557
],
56-
"RL GRPO TP=2 compile (debug model)",
58+
"RL GRPO TP=2 compile",
5759
"rl_grpo_tp2_compile",
5860
ngpu=4,
5961
),

0 commit comments

Comments
 (0)