Skip to content

Commit 2c34824

Browse files
author
pytorchbot
committed
2026-03-21 nightly release (c0c0bf9)
1 parent d85396e commit 2c34824

File tree

14 files changed

+657
-112
lines changed

14 files changed

+657
-112
lines changed

.ci/docker/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
torchdata >= 0.8.0
2-
datasets >= 3.6.0
2+
datasets >= 3.6.0, < 4.8.0
33
tensorboard
44
wandb
55
fsspec
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
name: RL Numerics 2 GPU Integration Tests
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
tags:
7+
- ciflow/8gpu/*
8+
paths:
9+
- 'torchtitan/experiments/rl/**'
10+
- '.github/workflows/integration_test_2gpu_rl_numerics.yaml'
11+
pull_request:
12+
paths:
13+
- 'torchtitan/experiments/rl/**'
14+
- '.github/workflows/integration_test_2gpu_rl_numerics.yaml'
15+
16+
concurrency:
17+
group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
18+
cancel-in-progress: true
19+
20+
defaults:
21+
run:
22+
shell: bash -l -eo pipefail {0}
23+
24+
permissions:
25+
id-token: write
26+
contents: read
27+
28+
# Steps should be kept in sync with torchtitan/experiments/rl/README.md
29+
jobs:
30+
# Step 1: Dynamically compute the matrix based on conditions
31+
set-matrix:
32+
uses: ./.github/workflows/set-matrix.yaml
33+
with:
34+
runner-cuda: linux.aws.h100.8
35+
36+
# Step 2: Use the dynamic matrix in the build-test job
37+
build-test:
38+
needs: set-matrix
39+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
40+
strategy:
41+
fail-fast: false
42+
matrix: ${{ fromJSON(needs.set-matrix.outputs.matrix) }}
43+
with:
44+
runner: ${{ matrix.runner }}
45+
gpu-arch-type: ${{ matrix.gpu-arch-type }}
46+
gpu-arch-version: ${{ matrix.gpu-arch-version }}
47+
docker-image: ${{ matrix.docker-image }}
48+
repository: pytorch/torchtitan
49+
upload-artifact: outputs
50+
timeout: 90
51+
script: |
52+
set -eux
53+
54+
# The generic Linux job chooses to use base env, not the one setup by the image
55+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
56+
conda activate "${CONDA_ENV}"
57+
58+
conda install -y -c conda-forge libstdcxx-ng
59+
60+
pip install -e .
61+
62+
# Install CUDA 12.8 toolkit via conda (needed to build vLLM from source)
63+
# Temporarily disable -u because conda activation scripts have unbound variables
64+
set +u
65+
conda install -y -c nvidia cuda-toolkit=12.8
66+
set -u
67+
# Conda CUDA toolkit puts headers/libs under targets/x86_64-linux/
68+
export CUDA_HOME="${CONDA_PREFIX}/targets/x86_64-linux"
69+
# Pass as CMake args since CUDA_TOOLKIT_ROOT_DIR is a CMake variable, not an env var
70+
export CMAKE_ARGS="-DCUDA_TOOLKIT_ROOT_DIR=${CUDA_HOME} -DCMAKE_CUDA_COMPILER=${CONDA_PREFIX}/bin/nvcc"
71+
72+
# Log CUDA driver version for debugging.
73+
DRIVER_VERSION=$(nvidia-smi --query-gpu=driver_version --format=csv,noheader | head -n 1 || true)
74+
echo "CUDA driver version: ${DRIVER_VERSION}"
75+
76+
pip config --user set global.progress_bar off
77+
78+
# Install torch nightly first
79+
TORCH_SPEC="torch"
80+
if [ -n "${{ matrix.torch-version }}" ]; then
81+
TORCH_SPEC="torch==${{ matrix.torch-version }}"
82+
fi
83+
if [ "${{ matrix.gpu-arch-type }}" = "rocm" ]; then
84+
python -m pip install --force-reinstall --pre \
85+
"${TORCH_SPEC}" --index-url ${{ matrix.index-url }}
86+
else
87+
python -m pip install --force-reinstall --pre \
88+
torch --index-url ${{ matrix.index-url }}
89+
fi
90+
91+
# Install RL dependencies: xformers, monarch, flash-attn-3
92+
pip install xformers --extra-index-url ${{ matrix.index-url }}
93+
pip install torchmonarch==0.3.0
94+
pip install pygtrie portpicker
95+
pip install --no-deps "git+https://github.com/meta-pytorch/torchstore.git@main"
96+
pip install flash-attn-3 --extra-index=https://download.pytorch.org/whl/test/cu128
97+
98+
# Build and install vLLM from source using existing torch nightly
99+
git clone https://github.com/vllm-project/vllm.git /tmp/vllm
100+
cd /tmp/vllm
101+
python use_existing_torch.py
102+
pip install -r requirements/build.txt
103+
# Constrain torch to the installed nightly version so pip doesn't try to downgrade it
104+
TORCH_VERSION=$(python -c "import torch; print(torch.__version__)")
105+
echo "Installed torch version: ${TORCH_VERSION}"
106+
echo "torch==${TORCH_VERSION}" > /tmp/torch-constraint.txt
107+
PIP_CONSTRAINT=/tmp/torch-constraint.txt pip install --no-build-isolation -v -e .
108+
cd -
109+
110+
# Download Qwen3-0.6B checkpoint
111+
python scripts/download_hf_assets.py \
112+
--repo_id Qwen/Qwen3-0.6B \
113+
--local_dir torchtitan/experiments/rl/example_checkpoint \
114+
--all
115+
116+
# Run the attention numerics test (2 GPU, TP=2)
117+
torchrun --nproc-per-node=2 \
118+
torchtitan/experiments/rl/tests/test_attn_numerics.py

torchtitan/experiments/graph_trainer/common_utils.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -145,20 +145,25 @@ def convert_modules_to_fqns(modules, module_to_fqn_mapping):
145145
return module_fqns
146146

147147

148-
def maybe_disable_eager_ac(
148+
def apply_graph_ac(
149149
compile_config: CompileConfig,
150150
ac_config: "ActivationCheckpointConfig",
151151
) -> None:
152-
"""Disable eager AC when apply_sac graph pass is enabled.
152+
"""Add apply_sac to compile joint passes for graph-based selective AC.
153153
154-
When apply_sac is used as a joint graph pass, eager activation checkpointing
155-
must be disabled to avoid double-checkpointing. This must be called before
156-
the model parallelization step that applies eager AC.
154+
Must be called only when ac_config.mode != "none". Only "selective" mode
155+
is supported; other modes raise ValueError.
157156
"""
157+
if ac_config.mode != "selective":
158+
raise ValueError(
159+
f"graph_trainer only supports activation_checkpoint.mode 'selective' or "
160+
f"'none', got '{ac_config.mode}'. Use 'selective' for graph-based SAC."
161+
)
162+
158163
joint_pass_names = getattr(compile_config, "joint_passes", [])
159-
if "apply_sac" in joint_pass_names:
160-
if ac_config.mode != "none":
161-
logger.info(
162-
"apply_sac graph pass is enabled, overriding eager AC mode to none"
163-
)
164-
ac_config.mode = "none"
164+
if "apply_sac" not in joint_pass_names:
165+
compile_config.joint_passes = list(joint_pass_names) + ["apply_sac"]
166+
logger.info(
167+
"activation_checkpoint.mode is 'selective', added apply_sac to "
168+
"compile.joint_passes"
169+
)

torchtitan/experiments/graph_trainer/configs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from dataclasses import dataclass, field, fields
99
from typing import Literal
1010

11+
from torchtitan.config import ActivationCheckpointConfig
1112
from torchtitan.config.configs import CompileConfig
1213
from torchtitan.protocols.model_spec import ModelSpec
1314
from torchtitan.trainer import Trainer
@@ -59,4 +60,10 @@ def to_graph_trainer_config(
5960
d["model_spec"] = model_registry(base_config.model_spec.flavor)
6061
d.pop("compile")
6162

63+
# graph_trainer uses graph-based SAC instead of eager AC. Override any
64+
# non-"none" AC mode to "selective" so callers don't need per-config fixups.
65+
ac = d.get("activation_checkpoint")
66+
if ac is not None and ac.mode != "none":
67+
d["activation_checkpoint"] = ActivationCheckpointConfig(mode="selective")
68+
6269
return GraphTrainer.Config(**d)

torchtitan/experiments/graph_trainer/deepseek_v3/config_registry.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from torchtitan.config import ActivationCheckpointConfig
87
from torchtitan.experiments.graph_trainer.configs import (
98
GraphTrainerCompileConfig,
109
to_graph_trainer_config,
@@ -22,27 +21,23 @@
2221

2322
def graph_trainer_deepseek_v3_debugmodel() -> GraphTrainer.Config:
2423
config = to_graph_trainer_config(deepseek_v3_debugmodel(), model_registry)
25-
config.activation_checkpoint = ActivationCheckpointConfig(mode="none")
2624
config.compile = GraphTrainerCompileConfig(enable=True)
2725
return config
2826

2927

3028
def graph_trainer_deepseek_v3_debugmodel_flex_attn() -> (GraphTrainer.Config):
3129
config = to_graph_trainer_config(deepseek_v3_debugmodel_flex_attn(), model_registry)
32-
config.activation_checkpoint = ActivationCheckpointConfig(mode="none")
3330
config.compile = GraphTrainerCompileConfig(enable=True)
3431
return config
3532

3633

3734
def graph_trainer_deepseek_v3_16b() -> GraphTrainer.Config:
3835
config = to_graph_trainer_config(deepseek_v3_16b(), model_registry)
39-
config.activation_checkpoint = ActivationCheckpointConfig(mode="none")
4036
config.compile = GraphTrainerCompileConfig(enable=True)
4137
return config
4238

4339

4440
def graph_trainer_deepseek_v3_671b() -> GraphTrainer.Config:
4541
config = to_graph_trainer_config(deepseek_v3_671b(), model_registry)
46-
config.activation_checkpoint = ActivationCheckpointConfig(mode="none")
4742
config.compile = GraphTrainerCompileConfig(enable=True)
4843
return config

torchtitan/experiments/graph_trainer/deepseek_v3/parallelize.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
)
1818
from torchtitan.distributed import ParallelDims
1919

20-
from torchtitan.distributed.activation_checkpoint import apply_ac
2120
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
2221
from torchtitan.experiments.graph_trainer.common_utils import (
2322
annotate_ac_regions,
24-
maybe_disable_eager_ac,
23+
apply_graph_ac,
2524
)
2625
from torchtitan.experiments.graph_trainer.compile import apply_compile
2726
from torchtitan.experiments.graph_trainer.deepseek_v3.model import (
@@ -99,8 +98,6 @@ def parallelize_deepseekv3(
9998

10099
annotate_deepseekv3(model)
101100

102-
maybe_disable_eager_ac(compile_config, ac_config)
103-
104101
if parallel_dims.tp_enabled:
105102
float8_config = find_float8_linear_config(model_converters.converters)
106103
enable_float8_linear = float8_config is not None
@@ -135,7 +132,7 @@ def parallelize_deepseekv3(
135132
)
136133

137134
if ac_config.mode != "none":
138-
apply_ac(model, ac_config)
135+
apply_graph_ac(compile_config, ac_config)
139136

140137
mp_policy = MixedPrecisionPolicy(
141138
param_dtype=TORCH_DTYPE_MAP[training.mixed_precision_param],

torchtitan/experiments/graph_trainer/llama3/parallelize.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import torch
87
from torch.fx.traceback import annotate_fn
98

109
from torchtitan.components.quantization.float8 import find_float8_linear_config
@@ -16,11 +15,10 @@
1615
TrainingConfig,
1716
)
1817
from torchtitan.distributed import ParallelDims
19-
from torchtitan.distributed.activation_checkpoint import apply_ac
2018
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
2119
from torchtitan.experiments.graph_trainer.common_utils import (
2220
annotate_ac_regions,
23-
maybe_disable_eager_ac,
21+
apply_graph_ac,
2422
)
2523
from torchtitan.experiments.graph_trainer.compile import apply_compile
2624
from torchtitan.experiments.graph_trainer.llama3.model import GraphTrainerLlama3Model
@@ -32,25 +30,6 @@
3230
from torchtitan.protocols.model_converter import ModelConvertersContainer
3331
from torchtitan.tools.logging import logger
3432

35-
# for selective op activation checkpointing
36-
_op_sac_save_list = {
37-
torch.ops.aten.mm.default,
38-
torch.ops.aten.linear.default,
39-
torch.ops.aten._scaled_dot_product_efficient_attention.default,
40-
torch.ops.aten._scaled_dot_product_flash_attention.default,
41-
torch.ops.aten._scaled_dot_product_cudnn_attention.default,
42-
torch.ops.aten._scaled_dot_product_attention_math.default,
43-
torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default,
44-
torch.ops._c10d_functional.reduce_scatter_tensor.default,
45-
# for low precision training, it's useful to always save
46-
# the result of max, since the absolute maximum is
47-
# used to compute the scaling factor for quantization.
48-
torch.ops.aten.max.default,
49-
torch._higher_order_ops.flex_attention,
50-
torch.ops.torch_attn._varlen_attn.default,
51-
torch._higher_order_ops.inductor_compiled_code,
52-
}
53-
5433

5534
def annotate_llama(model: GraphTrainerLlama3Model) -> None:
5635
"""Attach annotations to FX graph nodes with ``torch.fx.traceback.annotate_fn``
@@ -103,8 +82,6 @@ def parallelize_llama(
10382

10483
annotate_llama(model)
10584

106-
maybe_disable_eager_ac(compile_config, ac_config)
107-
10885
if parallel_dims.tp_enabled:
10986
float8_config = find_float8_linear_config(model_converters.converters)
11087
enable_float8_linear = float8_config is not None
@@ -128,16 +105,7 @@ def parallelize_llama(
128105
maybe_enable_async_tp(parallelism, compile_config, tp_mesh)
129106

130107
if ac_config.mode != "none":
131-
model_compile_enabled = (
132-
compile_config.enable and "model" in compile_config.components
133-
)
134-
apply_ac(
135-
model,
136-
ac_config,
137-
model_compile_enabled=model_compile_enabled,
138-
op_sac_save_list=_op_sac_save_list,
139-
base_folder=dump_folder,
140-
)
108+
apply_graph_ac(compile_config, ac_config)
141109

142110
# apply data parallel
143111
if (

0 commit comments

Comments
 (0)