Skip to content

Commit cc15111

Browse files
authored
Add dtensor placement test (#9458)
1 parent 01e579c commit cc15111

File tree

2 files changed

+104
-0
lines changed

2 files changed

+104
-0
lines changed

test/run_tests.sh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,23 @@ function run_test_without_functionalization {
6262
XLA_DISABLE_FUNCTIONALIZATION=1 run_test "$@"
6363
}
6464

65+
function run_test_multi_devices {
66+
if ! test_is_selected "$1"; then
67+
return
68+
fi
69+
echo "Running in PjRt runtime: $@"
70+
# TODO(darisoy): run these tests with multiple CPU devices, this fails due to TF issue.
71+
PJRT_DEVICE=CPU CPU_NUM_DEVICES=4 run_coverage "$@"
72+
}
73+
74+
function run_test_multi_devices_without_func {
75+
if ! test_is_selected "$1"; then
76+
return
77+
fi
78+
echo "Running with XLA_DISABLE_FUNCTIONALIZATION: $@"
79+
XLA_DISABLE_FUNCTIONALIZATION=1 run_test_multi_devices "$@"
80+
}
81+
6582
function run_use_bf16 {
6683
if ! test_is_selected "$1"; then
6784
return
@@ -235,6 +252,7 @@ function run_xla_op_tests3 {
235252
run_test "$_TEST_DIR/spmd/test_xla_spmd_python_api_interaction.py"
236253
run_test "$_TEST_DIR/spmd/test_dtensor_integration.py"
237254
run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py"
255+
run_test_multi_devices_without_func "$_TEST_DIR/spmd/test_dtensor_integration3.py"
238256
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
239257
run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
240258
run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py"
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import os
2+
import sys
3+
4+
import torch
5+
from torch import nn
6+
import torch.optim as optim
7+
from torch.distributed.tensor import (DeviceMesh, Replicate, Shard,
8+
distribute_tensor, distribute_module,
9+
init_device_mesh)
10+
import torch_xla
11+
import torch_xla.debug.metrics as met
12+
import torch_xla.runtime as xr
13+
import torch_xla.core.xla_model as xm
14+
from torch_xla.distributed.spmd import auto_policy
15+
16+
import unittest
17+
18+
import test_xla_sharding_base
19+
20+
21+
# This integration test passes when run independently.
22+
class DTensorIntegrationTest3(test_xla_sharding_base.XlaShardingTest):
23+
24+
@classmethod
25+
def setUpClass(cls):
26+
super().setUpClass()
27+
28+
# This test fails with functionalization, so disabled functionalization.
29+
def test_xla_placement(self):
30+
31+
class Model(torch.nn.Module):
32+
33+
def __init__(self):
34+
super().__init__()
35+
self.in_proj = torch.nn.Linear(32, 16, bias=False)
36+
self.out_proj = torch.nn.Linear(16, 8, bias=False)
37+
38+
def forward(self, hidden):
39+
hidden = self.in_proj(hidden)
40+
hidden = torch.relu(hidden)
41+
hidden = self.out_proj(hidden)
42+
return hidden
43+
44+
def forward_pure(hidden, in_proj_weight, out_proj_weight):
45+
hidden = torch.matmul(hidden, in_proj_weight.T)
46+
hidden = torch.relu(hidden)
47+
hidden = torch.matmul(hidden, out_proj_weight.T)
48+
return hidden
49+
50+
#xr.use_spmd()
51+
model = Model()
52+
model.to('xla')
53+
device_count = xr.global_runtime_device_count()
54+
device_mesh = init_device_mesh(
55+
device_type='xla', mesh_shape=(device_count,))
56+
57+
# Tensor parallel shardings
58+
inputs_sharding = [Replicate()]
59+
in_proj_weight_sharding = [Shard(0)]
60+
out_proj_weight_sharding = [Shard(1)]
61+
62+
torch.manual_seed(15213)
63+
inputs = torch.rand(2, 32)
64+
inputs = inputs.to('xla')
65+
outputs_unsharded = model(inputs)
66+
xm.mark_step()
67+
outputs_unsharded = outputs_unsharded.cpu()
68+
inputs = distribute_tensor(inputs, device_mesh, placements=inputs_sharding)
69+
in_proj_weight = distribute_tensor(
70+
model.in_proj.weight, device_mesh, placements=in_proj_weight_sharding)
71+
out_proj_weight = distribute_tensor(
72+
model.out_proj.weight, device_mesh, placements=out_proj_weight_sharding)
73+
outputs_sharded = forward_pure(inputs, in_proj_weight, out_proj_weight)
74+
xm.mark_step()
75+
outputs_sharded = outputs_sharded.cpu()
76+
#from torch_xla.distributed.spmd.debugging import visualize_sharding
77+
#generated_table = visualize_sharding(outputs.sharding_spec(), use_color=False)
78+
print(outputs_unsharded)
79+
print(outputs_sharded)
80+
torch.testing.assert_close(outputs_sharded.global_tensor.numpy(),
81+
outputs_unsharded.detach().numpy())
82+
83+
84+
if __name__ == '__main__':
85+
test = unittest.main()
86+
sys.exit(0 if test.result.wasSuccessful() else 1)

0 commit comments

Comments
 (0)