@@ -50,15 +50,18 @@ TEST_F(XLAShardingTest, GetShardShape) {
50
50
{0 , 1 },
51
51
{2 , 3 },
52
52
});
53
- auto sharding = xla::HloSharding::Tile (mesh).ToProto ();
53
+ auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
54
+ torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
54
55
auto sharding_spec =
55
56
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
56
57
57
58
auto shard_shape = ShardingUtil::GetShardShape (sharding_spec);
58
59
// For tiled sharding, each dimension should be halved
59
60
EXPECT_EQ (shard_shape, std::vector<int64_t >({4 , 4 }));
60
61
61
- sharding_spec->sharding = xla::HloSharding::Replicate ().ToProto ();
62
+ xla_sharding = xla::HloSharding::Replicate ().ToProto ();
63
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
64
+ sharding_spec->sharding = sharding;
62
65
shard_shape = ShardingUtil::GetShardShape (sharding_spec);
63
66
// For replicated sharding, each dimension should be preserved
64
67
EXPECT_EQ (shard_shape, std::vector<int64_t >({8 , 7 }));
@@ -74,7 +77,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
74
77
{0 , 1 },
75
78
{2 , 3 },
76
79
});
77
- auto sharding = xla::HloSharding::Tile (mesh).ToProto ();
80
+ auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
81
+ torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
78
82
auto sharding_spec =
79
83
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
80
84
auto shard_shape = ShardingUtil::GetShardShape (sharding_spec);
@@ -103,7 +107,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
103
107
EXPECT_EQ (slice.step (), 1 );
104
108
}
105
109
}
106
- sharding = xla::HloSharding::Replicate ().ToProto ();
110
+ xla_sharding = xla::HloSharding::Replicate ().ToProto ();
111
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
107
112
sharding_spec->sharding = sharding;
108
113
shard_shape = ShardingUtil::GetShardShape (sharding_spec);
109
114
replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices (
@@ -126,11 +131,12 @@ TEST_F(XLAShardingTest, ShardTensor) {
126
131
at::Tensor tensor = at::ones ({8 }, at::TensorOptions (at::kFloat ));
127
132
xla::Shape tensor_shape =
128
133
CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ());
129
- xla::OpSharding sharding =
134
+ xla::OpSharding xla_sharding =
130
135
xla::HloSharding::Tile1D (
131
136
CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ()),
132
137
devices.size ())
133
138
.ToProto ();
139
+ torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
134
140
auto sharding_spec =
135
141
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
136
142
auto shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -148,7 +154,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
148
154
{0 , 1 , 2 , 3 },
149
155
{4 , 5 , 6 , 7 },
150
156
});
151
- sharding = xla::HloSharding::Tile (mesh).ToProto ();
157
+ xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
158
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
152
159
sharding_spec =
153
160
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
154
161
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -160,15 +167,19 @@ TEST_F(XLAShardingTest, ShardTensor) {
160
167
// 3D tiled, the first dim is replicated and the last halved. The last shard
161
168
// size should be smaller in dim=1 because it's not evenly divisible.
162
169
xla::Array3D<int64_t > cube ({{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}});
163
- sharding_spec->sharding = xla::HloSharding::Tile (cube).ToProto ();
170
+ xla_sharding = xla::HloSharding::Tile (cube).ToProto ();
171
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
172
+ sharding_spec->sharding = sharding;
164
173
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
165
174
/* padded=*/ false );
166
175
EXPECT_EQ (shards.size (), 8 );
167
176
EXPECT_EQ (shards[0 ].sizes (), c10::ArrayRef<long >({8 , 2 , 2 }));
168
177
EXPECT_EQ (shards[7 ].sizes (), c10::ArrayRef<long >({8 , 1 , 2 }));
169
178
170
179
// Replicated, all shards should be identical.
171
- sharding_spec->sharding = xla::HloSharding::Replicate ().ToProto ();
180
+ xla_sharding = xla::HloSharding::Replicate ().ToProto ();
181
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
182
+ sharding_spec->sharding = sharding;
172
183
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
173
184
/* padded=*/ false );
174
185
EXPECT_EQ (shards.size (), 8 );
@@ -182,7 +193,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
182
193
tensor_shape =
183
194
CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ());
184
195
xla::Array4D<int64_t > tesseract ({{{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}}});
185
- sharding = xla::HloSharding::Tile (tesseract).ToProto ();
196
+ xla_sharding = xla::HloSharding::Tile (tesseract).ToProto ();
197
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
186
198
sharding_spec =
187
199
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
188
200
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -206,7 +218,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
206
218
CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ());
207
219
xla::Array<int64_t > hypercube (std::vector<int64_t >{1 , 1 , 2 , 2 , 2 });
208
220
hypercube.FillIota (0 );
209
- sharding = xla::HloSharding::Tile (hypercube).ToProto ();
221
+ xla_sharding = xla::HloSharding::Tile (hypercube).ToProto ();
222
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
210
223
sharding_spec =
211
224
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
212
225
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
@@ -234,7 +247,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
234
247
{4 , 5 , 0 , 1 },
235
248
{6 , 7 , 2 , 3 },
236
249
});
237
- auto sharding = xla::HloSharding::Tile (mesh).ToProto ();
250
+ auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
251
+ torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
238
252
auto sharding_spec =
239
253
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
240
254
// For devices at the start of the mesh, all shards should have the same
@@ -251,7 +265,9 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
251
265
{0 , 1 , 4 , 5 },
252
266
{2 , 3 , 6 , 7 },
253
267
});
254
- sharding_spec->sharding = xla::HloSharding::Tile (mesh).ToProto ();
268
+ xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
269
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
270
+ sharding_spec->sharding = sharding;
255
271
shards = ShardingUtil::ShardTensor (tensor, sharding_spec, devices,
256
272
/* padded=*/ false );
257
273
EXPECT_EQ (shards.size (), 4 );
@@ -278,7 +294,8 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
278
294
{{7 }},
279
295
});
280
296
281
- auto sharding = xla::HloSharding::Tile (mesh).ToProto ();
297
+ auto xla_sharding = xla::HloSharding::Tile (mesh).ToProto ();
298
+ torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
282
299
auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
283
300
sharding, global_shape, /* minibatch=*/ true );
284
301
auto shards = ShardingUtil::ShardTensor (minibatch_tensor, sharding_spec,
@@ -292,17 +309,20 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) {
292
309
auto tensor = at::ones ({8 , 7 }, at::TensorOptions (at::kFloat ));
293
310
xla::Shape tensor_shape =
294
311
CreateComputationShapeFromTensor (tensor, bridge::GetDefaultDevice ());
295
- XLATensor::ShardingSpec tiled_2d (xla::HloSharding::Tile ({
296
- {0 , 1 , 2 , 3 },
297
- {4 , 5 , 6 , 7 },
298
- })
299
- .ToProto (),
300
- tensor_shape);
301
- XLATensor::ShardingSpec tiled_3d (
302
- xla::HloSharding::Tile ({{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}}).ToProto (),
303
- tensor_shape);
304
- XLATensor::ShardingSpec replicated (xla::HloSharding::Replicate ().ToProto (),
305
- tensor_shape);
312
+ auto xla_sharding = xla::HloSharding::Tile ({
313
+ {0 , 1 , 2 , 3 },
314
+ {4 , 5 , 6 , 7 },
315
+ })
316
+ .ToProto ();
317
+ torch_xla::OpSharding sharding (xla_sharding, std::nullopt );
318
+ XLATensor::ShardingSpec tiled_2d (sharding, tensor_shape);
319
+ xla_sharding =
320
+ xla::HloSharding::Tile ({{{0 , 1 }, {2 , 3 }, {4 , 5 }, {6 , 7 }}}).ToProto ();
321
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
322
+ XLATensor::ShardingSpec tiled_3d (sharding, tensor_shape);
323
+ xla_sharding = xla::HloSharding::Replicate ().ToProto ();
324
+ sharding = torch_xla::OpSharding (xla_sharding, std::nullopt );
325
+ XLATensor::ShardingSpec replicated (sharding, tensor_shape);
306
326
EXPECT_TRUE (ShardingUtil::EqualShardingSpecs (tiled_2d, tiled_2d));
307
327
EXPECT_FALSE (ShardingUtil::EqualShardingSpecs (tiled_2d, tiled_3d));
308
328
EXPECT_TRUE (ShardingUtil::EqualShardingSpecs (replicated, replicated));
@@ -323,12 +343,17 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
323
343
std::vector<std::string> devices (3 );
324
344
std::fill_n (devices.begin (), devices.size (),
325
345
bridge::GetDefaultDevice ()->toString ());
346
+ auto replicate_xla_sharding = xla::HloSharding::Replicate ().ToProto ();
347
+ auto unknown_xla_sharding = xla::HloSharding::Unknown ().ToProto ();
348
+ torch_xla::OpSharding replicate_sharding (replicate_xla_sharding,
349
+ std::nullopt );
350
+ torch_xla::OpSharding unknown_sharding (unknown_xla_sharding, std::nullopt );
326
351
std::vector<XLATensor::ShardingSpecPtr> shardings = {
327
352
nullptr ,
328
- std::make_shared<XLATensor::ShardingSpec>(
329
- xla::HloSharding::Replicate (). ToProto (), tensor_shape),
330
- std::make_shared<XLATensor::ShardingSpec>(
331
- xla::HloSharding::Unknown (). ToProto (), tensor_shape)};
353
+ std::make_shared<XLATensor::ShardingSpec>(replicate_sharding,
354
+ tensor_shape),
355
+ std::make_shared<XLATensor::ShardingSpec>(unknown_sharding,
356
+ tensor_shape)};
332
357
std::vector<torch::lazy::BackendDataPtr> tensors_data =
333
358
CreateTensorsData (tensors, shardings, devices);
334
359
@@ -387,13 +412,21 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
387
412
auto y = xla::Add (x, xla::ConstantR0<float >(&b, 3 ));
388
413
xla::XlaComputation xla_computation =
389
414
GetValueOrThrow (b.Build (/* remove_dynamic_dimensions=*/ false ));
415
+
416
+ std::vector<torch::lazy::BackendDataPtr> parameters_data;
417
+ parameters_data.push_back (
418
+ torch_xla::runtime::GetComputationClientOrDie ()->CreateDataPlaceholder (
419
+ bridge::GetDefaultDevice ()->toString (), std::move (shape)));
420
+
390
421
std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
391
422
instances.push_back ({std::move (xla_computation),
392
423
bridge::GetDefaultDevice ()->toString (),
393
424
{bridge::GetDefaultDevice ()->toString ()},
394
425
&shape,
395
426
/* should_wrap_parameter=*/ false ,
396
- /* is_sharded=*/ true });
427
+ /* is_sharded=*/ true ,
428
+ /* allow_spmd_sharding_propagation_to_output=*/ true ,
429
+ /* parameters_data=*/ parameters_data});
397
430
398
431
std::vector<
399
432
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
@@ -417,11 +450,12 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
417
450
if (n_devices > 1 ) {
418
451
// Tiled sharding requires multiple devices.
419
452
EXPECT_TRUE (xla::protobuf_util::HaveSameSerialization (
420
- tiled, sharding_specs[0 ]->sharding ));
453
+ tiled, sharding_specs[0 ]->sharding . GetXlaOpSharding () ));
421
454
} else {
422
455
// Sincle device execution defaults to replication sharding.
423
456
EXPECT_TRUE (xla::protobuf_util::HaveSameSerialization (
424
- xla::HloSharding::Replicate ().ToProto (), sharding_specs[0 ]->sharding ));
457
+ xla::HloSharding::Replicate ().ToProto (),
458
+ sharding_specs[0 ]->sharding .GetXlaOpSharding ()));
425
459
}
426
460
427
461
// Check if the placeholder is on a SPMD device (sharded) with no real values.
0 commit comments