Skip to content

Commit f74842d

Browse files
ruisizhang123pytorchmergebot
authored andcommitted
[DTensor] enable SimpleFSDP's composability with Tensor Parallel (pytorch#152286)
This PR adds support for SimpleFSDP's composability with Tensor Parallel + torch.compile. `_StridedShard` is used in SimpleFSDP/FSDP2 to support correct distributed checkpointing when FSDP+TP is applied. Previously, `_StridedShard` is not guarded by torch.compile. This PR adds `_StridedShard` as an additional placement type to be guarded by torch.compile. Pull Request resolved: pytorch#152286 Approved by: https://github.com/bdhirsh
1 parent 7509b15 commit f74842d

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

test/distributed/tensor/test_dtensor_compile.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
PrepareModuleOutput,
3030
RowwiseParallel,
3131
)
32+
from torch.distributed.tensor.placement_types import _StridedShard
3233
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
3334
from torch.testing._internal.common_fsdp import get_devtype
3435
from torch.testing._internal.common_utils import (
@@ -196,8 +197,10 @@ def fn(x):
196197
return a
197198

198199
compiled_fn = torch.compile(backend="aot_eager", fullgraph=True)(fn)
199-
200-
for x in [Shard(0), Replicate(), Partial()]:
200+
split_factors = [2, 3, 4]
201+
for x in [Shard(0), Replicate(), Partial()] + [
202+
_StridedShard(0, split_factor=s) for s in split_factors
203+
]:
201204
opt_fn = fn(x)
202205
compiled_out = compiled_fn(x)
203206
self.assertEqual(opt_fn, compiled_out)

torch/_dynamo/guards.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1702,6 +1702,7 @@ def EQUALS_MATCH(self, guard: Guard):
17021702
if torch.distributed.is_available():
17031703
from torch.distributed.device_mesh import DeviceMesh
17041704
from torch.distributed.tensor.placement_types import (
1705+
_StridedShard,
17051706
Partial,
17061707
Replicate,
17071708
Shard,
@@ -1712,6 +1713,7 @@ def EQUALS_MATCH(self, guard: Guard):
17121713
Replicate,
17131714
Partial,
17141715
DeviceMesh,
1716+
_StridedShard,
17151717
)
17161718

17171719
from torch.export.dynamic_shapes import _IntWrapper

0 commit comments

Comments
 (0)