@@ -101,8 +101,8 @@ TEST_F(TensorTest, TestAdd) {
101
101
at::Tensor c = a.add (b, 1.0 );
102
102
103
103
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
104
- XLATensorPtr dev_a = GetValueOrThrow ( XLATensor::Create (a, device));
105
- XLATensorPtr dev_b = GetValueOrThrow ( XLATensor::Create (b, device));
104
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_a, XLATensor::Create (a, device));
105
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_b, XLATensor::Create (b, device));
106
106
XLATensorPtr dev_c = tensor_methods::add (dev_a, dev_b, 1.0 );
107
107
108
108
AllClose (c, dev_c);
@@ -121,8 +121,8 @@ TEST_F(TensorTest, TestIntegerAdd) {
121
121
at::isIntegralType (type) ? at::Scalar (int64_t (1 )) : at::Scalar (1.0 );
122
122
at::Tensor c = a.add (b, one);
123
123
124
- XLATensorPtr dev_a = GetValueOrThrow ( XLATensor::Create (a, device));
125
- XLATensorPtr dev_b = GetValueOrThrow ( XLATensor::Create (b, device));
124
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_a, XLATensor::Create (a, device));
125
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_b, XLATensor::Create (b, device));
126
126
XLATensorPtr dev_c = tensor_methods::add (dev_a, dev_b, one);
127
127
128
128
EXPECT_TRUE (EqualValuesNoElementTypeCheck (
@@ -135,7 +135,8 @@ TEST_F(TensorTest, TestSize) {
135
135
at::Tensor input = at::rand ({2 , 1 , 4 , 6 }, at::TensorOptions (at::kFloat ));
136
136
int rank = input.dim ();
137
137
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
138
- XLATensorPtr dev_input = GetValueOrThrow (XLATensor::Create (input, device));
138
+ XLA_ASSIGN_OR_THROW (XLATensorPtr dev_input,
139
+ XLATensor::Create (input, device));
139
140
for (int dim = -rank; dim < rank; ++dim) {
140
141
EXPECT_EQ (input.size (dim), dev_input->size (dim));
141
142
}
@@ -151,10 +152,10 @@ TEST_F(TensorTest, TestRrelu) {
151
152
at::Tensor noise = at::zeros_like (input);
152
153
at::Tensor output =
153
154
at::rrelu_with_noise (input, noise, lower, upper, training);
154
- XLATensorPtr dev_input =
155
- GetValueOrThrow ( XLATensor::Create (input, device));
156
- XLATensorPtr dev_noise =
157
- GetValueOrThrow ( XLATensor::Create (noise, device));
155
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_input,
156
+ XLATensor::Create (input, device));
157
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_noise,
158
+ XLATensor::Create (noise, device));
158
159
XLATensorPtr dev_outputs = tensor_methods::rrelu_with_noise (
159
160
dev_input, dev_noise, lower, upper, training);
160
161
AllClose (output, dev_outputs);
@@ -169,7 +170,8 @@ TEST_F(TensorTest, TestThreshold) {
169
170
float value = 20 ;
170
171
at::Tensor output = at::threshold (input, threshold, value);
171
172
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
172
- XLATensorPtr dev_input = GetValueOrThrow (XLATensor::Create (input, device));
173
+ XLA_ASSIGN_OR_THROW (XLATensorPtr dev_input,
174
+ XLATensor::Create (input, device));
173
175
XLATensorPtr dev_output =
174
176
tensor_methods::threshold (dev_input, threshold, value);
175
177
AllClose (output, dev_output);
@@ -187,10 +189,11 @@ TEST_F(TensorTest, TestAddMatMul) {
187
189
at::Tensor bias = at::rand ({labels}, at::TensorOptions (at::kFloat ));
188
190
at::Tensor output = at::addmm (bias, input, weight);
189
191
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
190
- XLATensorPtr dev_input = GetValueOrThrow (XLATensor::Create (input, device));
191
- XLATensorPtr dev_weight =
192
- GetValueOrThrow (XLATensor::Create (weight, device));
193
- XLATensorPtr dev_bias = GetValueOrThrow (XLATensor::Create (bias, device));
192
+ XLA_ASSIGN_OR_THROW (XLATensorPtr dev_input,
193
+ XLATensor::Create (input, device));
194
+ XLA_ASSIGN_OR_THROW (XLATensorPtr dev_weight,
195
+ XLATensor::Create (weight, device));
196
+ XLA_ASSIGN_OR_THROW (XLATensorPtr dev_bias, XLATensor::Create (bias, device));
194
197
XLATensorPtr dev_output =
195
198
tensor_methods::addmm (dev_input, dev_weight, dev_bias);
196
199
AllClose (output, dev_output);
@@ -201,7 +204,8 @@ TEST_F(TensorTest, TestTranspose) {
201
204
at::Tensor input = at::rand ({2 , 3 }, at::TensorOptions (at::kFloat ));
202
205
at::Tensor output = at::transpose (input, 0 , 1 );
203
206
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
204
- XLATensorPtr dev_input = GetValueOrThrow (XLATensor::Create (input, device));
207
+ XLA_ASSIGN_OR_THROW (XLATensorPtr dev_input,
208
+ XLATensor::Create (input, device));
205
209
XLATensorPtr dev_output = tensor_methods::transpose (dev_input, 0 , 1 );
206
210
AllClose (output, dev_output);
207
211
});
@@ -211,7 +215,8 @@ TEST_F(TensorTest, TestView) {
211
215
at::Tensor input = at::rand ({32 , 20 , 4 , 4 }, at::TensorOptions (at::kFloat ));
212
216
at::Tensor output = input.view ({-1 , 320 });
213
217
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
214
- XLATensorPtr dev_input = GetValueOrThrow (XLATensor::Create (input, device));
218
+ XLA_ASSIGN_OR_THROW (XLATensorPtr dev_input,
219
+ XLATensor::Create (input, device));
215
220
XLATensorPtr dev_output = tensor_methods::view (dev_input, {-1 , 320 });
216
221
AllClose (output, dev_output);
217
222
});
@@ -292,8 +297,8 @@ TEST_F(TensorTest, TestMaxPool2D) {
292
297
/* padding=*/ {padding, padding}, /* dilation=*/ {1 , 1 },
293
298
/* ceil_mode=*/ false );
294
299
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
295
- XLATensorPtr dev_input =
296
- GetValueOrThrow ( XLATensor::Create (input, device));
300
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_input,
301
+ XLATensor::Create (input, device));
297
302
auto dev_output = tensor_methods::max_pool_nd (
298
303
dev_input,
299
304
/* spatial_dim_count=*/ 2 ,
@@ -317,8 +322,8 @@ TEST_F(TensorTest, TestMaxPool2DNonSquare) {
317
322
/* padding=*/ {padding, padding + 1 }, /* dilation=*/ {1 , 1 },
318
323
/* ceil_mode=*/ false );
319
324
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
320
- XLATensorPtr dev_input =
321
- GetValueOrThrow ( XLATensor::Create (input, device));
325
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_input,
326
+ XLATensor::Create (input, device));
322
327
auto dev_output = tensor_methods::max_pool_nd (
323
328
dev_input,
324
329
/* spatial_dim_count=*/ 2 ,
@@ -346,8 +351,8 @@ TEST_F(TensorTest, TestAvgPool2D) {
346
351
/* ceil_mode=*/ false , count_include_pad,
347
352
/* divisor_override=*/ std::nullopt );
348
353
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
349
- XLATensorPtr dev_input =
350
- GetValueOrThrow ( XLATensor::Create (input, device));
354
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_input,
355
+ XLATensor::Create (input, device));
351
356
XLATensorPtr dev_output = tensor_methods::avg_pool_nd (
352
357
dev_input,
353
358
/* spatial_dim_count=*/ 2 ,
@@ -377,8 +382,8 @@ TEST_F(TensorTest, TestAvgPool2DNonSquare) {
377
382
/* count_include_pad=*/ count_include_pad,
378
383
/* divisor_override=*/ std::nullopt );
379
384
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
380
- XLATensorPtr dev_input =
381
- GetValueOrThrow ( XLATensor::Create (input, device));
385
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_input,
386
+ XLATensor::Create (input, device));
382
387
XLATensorPtr dev_output = tensor_methods::avg_pool_nd (
383
388
dev_input,
384
389
/* spatial_dim_count=*/ 2 ,
@@ -416,20 +421,20 @@ TEST_F(TensorTest, TestBatchNorm1D) {
416
421
/* running_mean=*/ running_mean, /* running_var=*/ running_var,
417
422
/* training=*/ training, /* momentum=*/ momentum, /* eps=*/ eps);
418
423
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
419
- XLATensorPtr xla_input =
420
- GetValueOrThrow ( XLATensor::Create (input, device));
421
- XLATensorPtr xla_weight =
422
- undef_weight_bias
423
- ? XLATensorPtr ()
424
- : GetValueOrThrow ( XLATensor::Create (weight, device));
425
- XLATensorPtr xla_bias =
426
- undef_weight_bias
427
- ? XLATensorPtr ()
428
- : GetValueOrThrow ( XLATensor::Create (bias, device));
429
- XLATensorPtr xla_running_mean =
430
- GetValueOrThrow ( XLATensor::Create (running_mean, device));
431
- XLATensorPtr xla_running_var =
432
- GetValueOrThrow ( XLATensor::Create (running_var, device));
424
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr xla_input,
425
+ XLATensor::Create (input, device));
426
+ XLATensorPtr xla_weight;
427
+ if (! undef_weight_bias) {
428
+ XLA_ASSIGN_OR_THROW (xla_weight, XLATensor::Create (weight, device));
429
+ }
430
+ XLATensorPtr xla_bias;
431
+ if (! undef_weight_bias) {
432
+ XLA_ASSIGN_OR_THROW (xla_bias, XLATensor::Create (bias, device));
433
+ }
434
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr xla_running_mean,
435
+ XLATensor::Create (running_mean, device));
436
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr xla_running_var,
437
+ XLATensor::Create (running_var, device));
433
438
auto xla_output = tensor_methods::native_batch_norm (
434
439
/* input=*/ xla_input, /* weight=*/ xla_weight, /* bias=*/ xla_bias,
435
440
/* running_mean=*/ xla_running_mean, /* running_var=*/ xla_running_var,
@@ -486,14 +491,14 @@ TEST_F(TensorTest, TestConv2D) {
486
491
/* output_padding=*/ {output_padding, output_padding},
487
492
/* groups=*/ groups, false , false , false );
488
493
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
489
- XLATensorPtr dev_input =
490
- GetValueOrThrow ( XLATensor::Create (input, device));
491
- XLATensorPtr dev_weight =
492
- GetValueOrThrow ( XLATensor::Create (weight, device));
494
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_input,
495
+ XLATensor::Create (input, device));
496
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_weight,
497
+ XLATensor::Create (weight, device));
493
498
XLATensorPtr dev_output;
494
499
if (with_bias) {
495
- XLATensorPtr dev_bias =
496
- GetValueOrThrow ( XLATensor::Create (bias, device));
500
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_bias,
501
+ XLATensor::Create (bias, device));
497
502
dev_output = tensor_methods::convolution_overrideable (
498
503
dev_input, dev_weight, dev_bias,
499
504
/* stride=*/ {stride, stride},
@@ -558,14 +563,14 @@ TEST_F(TensorTest, TestConv2DNonSquare) {
558
563
/* groups=*/ groups, false , false , false );
559
564
560
565
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
561
- XLATensorPtr dev_input =
562
- GetValueOrThrow ( XLATensor::Create (input, device));
563
- XLATensorPtr dev_weight =
564
- GetValueOrThrow ( XLATensor::Create (weight, device));
566
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_input,
567
+ XLATensor::Create (input, device));
568
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_weight,
569
+ XLATensor::Create (weight, device));
565
570
XLATensorPtr dev_output;
566
571
if (with_bias) {
567
- XLATensorPtr dev_bias =
568
- GetValueOrThrow ( XLATensor::Create (bias, device));
572
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_bias,
573
+ XLATensor::Create (bias, device));
569
574
dev_output = tensor_methods::convolution_overrideable (
570
575
dev_input, dev_weight, dev_bias,
571
576
/* stride=*/ {stride, stride + 1 },
@@ -634,14 +639,14 @@ TEST_F(TensorTest, TestConv3D) {
634
639
{output_padding, output_padding, output_padding},
635
640
/* groups=*/ groups, false , false , false );
636
641
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
637
- XLATensorPtr dev_input =
638
- GetValueOrThrow ( XLATensor::Create (input, device));
639
- XLATensorPtr dev_weight =
640
- GetValueOrThrow ( XLATensor::Create (weight, device));
642
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_input,
643
+ XLATensor::Create (input, device));
644
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_weight,
645
+ XLATensor::Create (weight, device));
641
646
XLATensorPtr dev_output;
642
647
if (with_bias) {
643
- XLATensorPtr dev_bias =
644
- GetValueOrThrow ( XLATensor::Create (bias, device));
648
+ XLA_ASSIGN_OR_THROW ( XLATensorPtr dev_bias,
649
+ XLATensor::Create (bias, device));
645
650
dev_output = tensor_methods::convolution_overrideable (
646
651
dev_input, dev_weight, dev_bias,
647
652
/* stride=*/ {stride, stride, stride},
@@ -709,15 +714,14 @@ TEST_F(TensorTest, TestConv3D) {
709
714
// {output_padding, output_padding + 1, output_padding},
710
715
// /*groups=*/groups, false, false, false);
711
716
// ForEachDevice([&](const torch::lazy::BackendDevice& device) {
712
- // XLATensorPtr dev_input =
713
- // GetValueOrThrow(XLATensor::Create(input, device));
714
- // XLATensorPtr dev_weight =
715
- // GetValueOrThrow(XLATensor::Create(weight, device);
716
- // XLATensorPtr dev_output;
717
- // if (with_bias) {
718
- // XLATensorPtr dev_bias =
719
- // GetValueOrThrow(XLATensor::Create(bias, device));
720
- // dev_output = tensor_methods::convolution_overrideable(
717
+ // XLA_ASSIGN_OR_THROW(XLATensorPtr dev_input,
718
+ // XLATensor::Create(input, device));
719
+ // XLA_ASSIGN_OR_THROW(XLATensorPtr dev_weight,
720
+ // XLATensor::Create(weight, device)); XLATensorPtr
721
+ // dev_output; if (with_bias) {
722
+ // XLA_ASSIGN_OR_THROW(XLATensorPtr dev_bias,
723
+ // XLATensor::Create(bias, device)); dev_output =
724
+ // tensor_methods::convolution_overrideable(
721
725
// dev_input, dev_weight, dev_bias,
722
726
// /*stride=*/{stride, stride + 1, stride + 1},
723
727
// /*padding=*/{padding, padding + 1, padding + 1},
0 commit comments