@@ -332,14 +332,15 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
332
332
CreateTensorsData (tensors, shardings, devices);
333
333
334
334
int64_t n_devices =
335
- torch_xla::runtime::GetComputationClient ()->GetLocalDevices ().size ();
335
+ torch_xla::runtime::GetComputationClientOrDie ()->GetLocalDevices ().size ();
336
336
if (n_devices > 1 ) {
337
337
// null sharding is treated as replicated.
338
338
auto xla_data =
339
339
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
340
340
tensors_data[0 ]);
341
341
std::vector<torch_xla::runtime::ComputationClient::DataPtr> shards =
342
- torch_xla::runtime::GetComputationClient ()->GetDataShards (xla_data);
342
+ torch_xla::runtime::GetComputationClientOrDie ()->GetDataShards (
343
+ xla_data);
343
344
EXPECT_EQ (shards.size (), n_devices);
344
345
EXPECT_TRUE (xla::Shape::Equal ().IgnoreLayout ()(xla_data->shape (),
345
346
shards[0 ]->shape ()));
@@ -349,7 +350,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
349
350
auto sharded_xla_data =
350
351
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
351
352
tensors_data[1 ]);
352
- shards = torch_xla::runtime::GetComputationClient ()->GetDataShards (
353
+ shards = torch_xla::runtime::GetComputationClientOrDie ()->GetDataShards (
353
354
sharded_xla_data);
354
355
EXPECT_EQ (shards.size (), n_devices);
355
356
EXPECT_TRUE (xla::Shape::Equal ().IgnoreLayout ()(sharded_xla_data->shape (),
@@ -360,7 +361,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
360
361
sharded_xla_data =
361
362
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
362
363
tensors_data[2 ]);
363
- shards = torch_xla::runtime::GetComputationClient ()->GetDataShards (
364
+ shards = torch_xla::runtime::GetComputationClientOrDie ()->GetDataShards (
364
365
sharded_xla_data);
365
366
EXPECT_EQ (shards.size (), n_devices);
366
367
EXPECT_TRUE (xla::Shape::Equal ().IgnoreLayout ()(sharded_xla_data->shape (),
@@ -372,7 +373,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
372
373
TEST_F (XLAShardingTest, PrepareOutputShardingPropagation) {
373
374
xla::Shape shape = xla::ShapeUtil::MakeShape (xla::PrimitiveType::F32, {4 , 4 });
374
375
int64_t n_devices =
375
- torch_xla::runtime::GetComputationClient ()->GetLocalDevices ().size ();
376
+ torch_xla::runtime::GetComputationClientOrDie ()->GetLocalDevices ().size ();
376
377
xla::Array<int64_t > tile_assignment ({1 , n_devices});
377
378
tile_assignment.FillIota (0 );
378
379
xla::OpSharding tiled = xla::HloSharding::Tile (tile_assignment).ToProto ();
@@ -395,15 +396,15 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
395
396
396
397
std::vector<
397
398
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
398
- computations = torch_xla::runtime::GetComputationClient ()->Compile (
399
+ computations = torch_xla::runtime::GetComputationClientOrDie ()->Compile (
399
400
std::move (instances));
400
401
torch_xla::runtime::ComputationClient::ComputationPtr computation =
401
402
std::make_shared<torch_xla::runtime::ComputationClient::Computation>(
402
403
" add" , std::move (computations[0 ]->move_computation ()));
403
404
404
405
// Prepare output sharding propagation, expect a sharded output placeholder.
405
406
std::vector<XLATensorPtr> tensors{XLATensor::Create (
406
- torch_xla::runtime::GetComputationClient ()->CreateDataPlaceholder (
407
+ torch_xla::runtime::GetComputationClientOrDie ()->CreateDataPlaceholder (
407
408
bridge::GetDefaultDevice ()->toString (), std::move (shape)))};
408
409
std::vector<torch::lazy::BackendDataPtr> data_placeholders;
409
410
std::vector<XLATensor::ShardingSpecPtr> sharding_specs;
0 commit comments