Skip to content

Commit 025c21b

Browse files
authored
Check torch.compile numerics in simpleFSDP tests (#1925)
Added checks for numerics of `torch.compile()` with `aot_eager` backends against eager in unit tests to guard regressions. ``` torchrun --nproc-per-node=8 -m pytest torchtitan/experiments/simple_fsdp/tests/test_numerics.py ```
1 parent 755ce8f commit 025c21b

File tree

2 files changed

+44
-13
lines changed

2 files changed

+44
-13
lines changed

torchtitan/experiments/simple_fsdp/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## SimpleFSDP
22

3-
[![integration tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml?query=branch%3Amain)
3+
[![integration and numerics tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml?query=branch%3Amain)
44
[![arXiv](https://img.shields.io/badge/arXiv-2411.00284-b31b1b.svg)](https://arxiv.org/abs/2411.00284)
55

66
💡 **Note**: SimpleFSDP's composability with Mixed Precision Training and Tensor Parallel requires updates from latest PyTorch, which can be installed (e.g., for CUDA 12.6) via

torchtitan/experiments/simple_fsdp/tests/test_numerics.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,21 +79,44 @@ def run_simple_fsdp(self, model, inputs, labels, epoch=20):
7979
losses.append(loss)
8080
return losses
8181

82+
def run_simple_fsdp_compiled_aot_eager(self, model, inputs, labels, epoch=20):
83+
model = data_parallel(
84+
model,
85+
device_mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)],
86+
mode=self.mode,
87+
)
88+
# TODO: Add "inductor" backend when it's numerical issues are fixed
89+
model = torch.compile(model, backend="aot_eager", fullgraph=True)
90+
optim = self.optimizer(model.parameters(), lr=1e-4)
91+
losses = []
92+
for _ in range(epoch):
93+
optim.zero_grad()
94+
out = model(inputs)
95+
loss = self.loss_fn(out, labels)
96+
loss.backward()
97+
optim.step()
98+
losses.append(loss)
99+
return losses
100+
82101
def test_replicate_convergence(self):
83102
# unit test for replicate mode
84103
self.mode = "replicate"
85104
self.init_test()
86105
model, inputs, labels = self.get_input()
87106

88107
fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels)
89-
simple_fsdp_replicate_losses = self.run_simple_fsdp(
108+
simple_fsdp_losses = self.run_simple_fsdp(copy.deepcopy(model), inputs, labels)
109+
simple_fsdp_compiled_aot_eager_losses = self.run_simple_fsdp_compiled_aot_eager(
90110
copy.deepcopy(model), inputs, labels
91111
)
92112

93-
for fsdp2_loss, simple_fsdp_replicate_loss in zip(
94-
fsdp2_losses, simple_fsdp_replicate_losses
113+
for (fsdp2_loss, simple_fsdp_loss, simple_fsdp_compiled_aot_eager_loss,) in zip(
114+
fsdp2_losses,
115+
simple_fsdp_losses,
116+
simple_fsdp_compiled_aot_eager_losses,
95117
):
96-
assert torch.equal(fsdp2_loss, simple_fsdp_replicate_loss)
118+
assert torch.equal(fsdp2_loss, simple_fsdp_loss)
119+
assert torch.equal(fsdp2_loss, simple_fsdp_compiled_aot_eager_loss)
97120

98121
def test_fullyshard_convergence(self):
99122
# unit test for fully_shard mode
@@ -102,14 +125,18 @@ def test_fullyshard_convergence(self):
102125
model, inputs, labels = self.get_input()
103126

104127
fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels)
105-
simple_fsdp_fullyshard_losses = self.run_simple_fsdp(
128+
simple_fsdp_losses = self.run_simple_fsdp(copy.deepcopy(model), inputs, labels)
129+
simple_fsdp_compiled_aot_eager_losses = self.run_simple_fsdp_compiled_aot_eager(
106130
copy.deepcopy(model), inputs, labels
107131
)
108132

109-
for fsdp2_loss, simple_fsdp_fullyshard_loss in zip(
110-
fsdp2_losses, simple_fsdp_fullyshard_losses
133+
for (fsdp2_loss, simple_fsdp_loss, simple_fsdp_compiled_aot_eager_loss,) in zip(
134+
fsdp2_losses,
135+
simple_fsdp_losses,
136+
simple_fsdp_compiled_aot_eager_losses,
111137
):
112-
assert torch.equal(fsdp2_loss, simple_fsdp_fullyshard_loss)
138+
assert torch.equal(fsdp2_loss, simple_fsdp_loss)
139+
assert torch.equal(fsdp2_loss, simple_fsdp_compiled_aot_eager_loss)
113140

114141
def test_hybridshard_convergence(self):
115142
# unit test for hybrid_shard mode
@@ -118,11 +145,15 @@ def test_hybridshard_convergence(self):
118145
model, inputs, labels = self.get_input()
119146

120147
fsdp2_losses = self.run_fsdp2(copy.deepcopy(model), inputs, labels)
121-
simple_fsdp_hybridshard_losses = self.run_simple_fsdp(
148+
simple_fsdp_losses = self.run_simple_fsdp(copy.deepcopy(model), inputs, labels)
149+
simple_fsdp_compiled_aot_eager_losses = self.run_simple_fsdp_compiled_aot_eager(
122150
copy.deepcopy(model), inputs, labels
123151
)
124152

125-
for fsdp2_loss, simple_fsdp_hybridshard_loss in zip(
126-
fsdp2_losses, simple_fsdp_hybridshard_losses
153+
for (fsdp2_loss, simple_fsdp_loss, simple_fsdp_compiled_aot_eager_loss,) in zip(
154+
fsdp2_losses,
155+
simple_fsdp_losses,
156+
simple_fsdp_compiled_aot_eager_losses,
127157
):
128-
assert torch.equal(fsdp2_loss, simple_fsdp_hybridshard_loss)
158+
assert torch.equal(fsdp2_loss, simple_fsdp_loss)
159+
assert torch.equal(fsdp2_loss, simple_fsdp_compiled_aot_eager_loss)

0 commit comments

Comments
 (0)