@@ -101,8 +101,8 @@ TEST_F(TensorTest, TestAdd) {
101
101
at::Tensor c = a.add (b, 1.0 );
102
102
103
103
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
104
- XLATensorPtr dev_a = XLATensor::Create (a, device);
105
- XLATensorPtr dev_b = XLATensor::Create (b, device);
104
+ XLATensorPtr dev_a = GetValueOrThrow ( XLATensor::Create (a, device) );
105
+ XLATensorPtr dev_b = GetValueOrThrow ( XLATensor::Create (b, device) );
106
106
XLATensorPtr dev_c = tensor_methods::add (dev_a, dev_b, 1.0 );
107
107
108
108
AllClose (c, dev_c);
@@ -121,8 +121,8 @@ TEST_F(TensorTest, TestIntegerAdd) {
121
121
at::isIntegralType (type) ? at::Scalar (int64_t (1 )) : at::Scalar (1.0 );
122
122
at::Tensor c = a.add (b, one);
123
123
124
- XLATensorPtr dev_a = XLATensor::Create (a, device);
125
- XLATensorPtr dev_b = XLATensor::Create (b, device);
124
+ XLATensorPtr dev_a = GetValueOrThrow ( XLATensor::Create (a, device) );
125
+ XLATensorPtr dev_b = GetValueOrThrow ( XLATensor::Create (b, device) );
126
126
XLATensorPtr dev_c = tensor_methods::add (dev_a, dev_b, one);
127
127
128
128
EXPECT_TRUE (EqualValuesNoElementTypeCheck (
@@ -135,7 +135,7 @@ TEST_F(TensorTest, TestSize) {
135
135
at::Tensor input = at::rand ({2 , 1 , 4 , 6 }, at::TensorOptions (at::kFloat ));
136
136
int rank = input.dim ();
137
137
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
138
- XLATensorPtr dev_input = XLATensor::Create (input, device);
138
+ XLATensorPtr dev_input = GetValueOrThrow ( XLATensor::Create (input, device) );
139
139
for (int dim = -rank; dim < rank; ++dim) {
140
140
EXPECT_EQ (input.size (dim), dev_input->size (dim));
141
141
}
@@ -151,8 +151,10 @@ TEST_F(TensorTest, TestRrelu) {
151
151
at::Tensor noise = at::zeros_like (input);
152
152
at::Tensor output =
153
153
at::rrelu_with_noise (input, noise, lower, upper, training);
154
- XLATensorPtr dev_input = XLATensor::Create (input, device);
155
- XLATensorPtr dev_noise = XLATensor::Create (noise, device);
154
+ XLATensorPtr dev_input =
155
+ GetValueOrThrow (XLATensor::Create (input, device));
156
+ XLATensorPtr dev_noise =
157
+ GetValueOrThrow (XLATensor::Create (noise, device));
156
158
XLATensorPtr dev_outputs = tensor_methods::rrelu_with_noise (
157
159
dev_input, dev_noise, lower, upper, training);
158
160
AllClose (output, dev_outputs);
@@ -167,7 +169,7 @@ TEST_F(TensorTest, TestThreshold) {
167
169
float value = 20 ;
168
170
at::Tensor output = at::threshold (input, threshold, value);
169
171
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
170
- XLATensorPtr dev_input = XLATensor::Create (input, device);
172
+ XLATensorPtr dev_input = GetValueOrThrow ( XLATensor::Create (input, device) );
171
173
XLATensorPtr dev_output =
172
174
tensor_methods::threshold (dev_input, threshold, value);
173
175
AllClose (output, dev_output);
@@ -185,9 +187,10 @@ TEST_F(TensorTest, TestAddMatMul) {
185
187
at::Tensor bias = at::rand ({labels}, at::TensorOptions (at::kFloat ));
186
188
at::Tensor output = at::addmm (bias, input, weight);
187
189
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
188
- XLATensorPtr dev_input = XLATensor::Create (input, device);
189
- XLATensorPtr dev_weight = XLATensor::Create (weight, device);
190
- XLATensorPtr dev_bias = XLATensor::Create (bias, device);
190
+ XLATensorPtr dev_input = GetValueOrThrow (XLATensor::Create (input, device));
191
+ XLATensorPtr dev_weight =
192
+ GetValueOrThrow (XLATensor::Create (weight, device));
193
+ XLATensorPtr dev_bias = GetValueOrThrow (XLATensor::Create (bias, device));
191
194
XLATensorPtr dev_output =
192
195
tensor_methods::addmm (dev_input, dev_weight, dev_bias);
193
196
AllClose (output, dev_output);
@@ -198,7 +201,7 @@ TEST_F(TensorTest, TestTranspose) {
198
201
at::Tensor input = at::rand ({2 , 3 }, at::TensorOptions (at::kFloat ));
199
202
at::Tensor output = at::transpose (input, 0 , 1 );
200
203
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
201
- XLATensorPtr dev_input = XLATensor::Create (input, device);
204
+ XLATensorPtr dev_input = GetValueOrThrow ( XLATensor::Create (input, device) );
202
205
XLATensorPtr dev_output = tensor_methods::transpose (dev_input, 0 , 1 );
203
206
AllClose (output, dev_output);
204
207
});
@@ -208,7 +211,7 @@ TEST_F(TensorTest, TestView) {
208
211
at::Tensor input = at::rand ({32 , 20 , 4 , 4 }, at::TensorOptions (at::kFloat ));
209
212
at::Tensor output = input.view ({-1 , 320 });
210
213
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
211
- XLATensorPtr dev_input = XLATensor::Create (input, device);
214
+ XLATensorPtr dev_input = GetValueOrThrow ( XLATensor::Create (input, device) );
212
215
XLATensorPtr dev_output = tensor_methods::view (dev_input, {-1 , 320 });
213
216
AllClose (output, dev_output);
214
217
});
@@ -289,7 +292,8 @@ TEST_F(TensorTest, TestMaxPool2D) {
289
292
/* padding=*/ {padding, padding}, /* dilation=*/ {1 , 1 },
290
293
/* ceil_mode=*/ false );
291
294
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
292
- XLATensorPtr dev_input = XLATensor::Create (input, device);
295
+ XLATensorPtr dev_input =
296
+ GetValueOrThrow (XLATensor::Create (input, device));
293
297
auto dev_output = tensor_methods::max_pool_nd (
294
298
dev_input,
295
299
/* spatial_dim_count=*/ 2 ,
@@ -313,7 +317,8 @@ TEST_F(TensorTest, TestMaxPool2DNonSquare) {
313
317
/* padding=*/ {padding, padding + 1 }, /* dilation=*/ {1 , 1 },
314
318
/* ceil_mode=*/ false );
315
319
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
316
- XLATensorPtr dev_input = XLATensor::Create (input, device);
320
+ XLATensorPtr dev_input =
321
+ GetValueOrThrow (XLATensor::Create (input, device));
317
322
auto dev_output = tensor_methods::max_pool_nd (
318
323
dev_input,
319
324
/* spatial_dim_count=*/ 2 ,
@@ -341,7 +346,8 @@ TEST_F(TensorTest, TestAvgPool2D) {
341
346
/* ceil_mode=*/ false , count_include_pad,
342
347
/* divisor_override=*/ std::nullopt );
343
348
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
344
- XLATensorPtr dev_input = XLATensor::Create (input, device);
349
+ XLATensorPtr dev_input =
350
+ GetValueOrThrow (XLATensor::Create (input, device));
345
351
XLATensorPtr dev_output = tensor_methods::avg_pool_nd (
346
352
dev_input,
347
353
/* spatial_dim_count=*/ 2 ,
@@ -371,7 +377,8 @@ TEST_F(TensorTest, TestAvgPool2DNonSquare) {
371
377
/* count_include_pad=*/ count_include_pad,
372
378
/* divisor_override=*/ std::nullopt );
373
379
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
374
- XLATensorPtr dev_input = XLATensor::Create (input, device);
380
+ XLATensorPtr dev_input =
381
+ GetValueOrThrow (XLATensor::Create (input, device));
375
382
XLATensorPtr dev_output = tensor_methods::avg_pool_nd (
376
383
dev_input,
377
384
/* spatial_dim_count=*/ 2 ,
@@ -409,15 +416,20 @@ TEST_F(TensorTest, TestBatchNorm1D) {
409
416
/* running_mean=*/ running_mean, /* running_var=*/ running_var,
410
417
/* training=*/ training, /* momentum=*/ momentum, /* eps=*/ eps);
411
418
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
412
- XLATensorPtr xla_input = XLATensor::Create (input, device);
413
- XLATensorPtr xla_weight = undef_weight_bias
414
- ? XLATensorPtr ()
415
- : XLATensor::Create (weight, device);
416
- XLATensorPtr xla_bias = undef_weight_bias
417
- ? XLATensorPtr ()
418
- : XLATensor::Create (bias, device);
419
- XLATensorPtr xla_running_mean = XLATensor::Create (running_mean, device);
420
- XLATensorPtr xla_running_var = XLATensor::Create (running_var, device);
419
+ XLATensorPtr xla_input =
420
+ GetValueOrThrow (XLATensor::Create (input, device));
421
+ XLATensorPtr xla_weight =
422
+ undef_weight_bias
423
+ ? XLATensorPtr ()
424
+ : GetValueOrThrow (XLATensor::Create (weight, device));
425
+ XLATensorPtr xla_bias =
426
+ undef_weight_bias
427
+ ? XLATensorPtr ()
428
+ : GetValueOrThrow (XLATensor::Create (bias, device));
429
+ XLATensorPtr xla_running_mean =
430
+ GetValueOrThrow (XLATensor::Create (running_mean, device));
431
+ XLATensorPtr xla_running_var =
432
+ GetValueOrThrow (XLATensor::Create (running_var, device));
421
433
auto xla_output = tensor_methods::native_batch_norm (
422
434
/* input=*/ xla_input, /* weight=*/ xla_weight, /* bias=*/ xla_bias,
423
435
/* running_mean=*/ xla_running_mean, /* running_var=*/ xla_running_var,
@@ -474,11 +486,14 @@ TEST_F(TensorTest, TestConv2D) {
474
486
/* output_padding=*/ {output_padding, output_padding},
475
487
/* groups=*/ groups, false , false , false );
476
488
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
477
- XLATensorPtr dev_input = XLATensor::Create (input, device);
478
- XLATensorPtr dev_weight = XLATensor::Create (weight, device);
489
+ XLATensorPtr dev_input =
490
+ GetValueOrThrow (XLATensor::Create (input, device));
491
+ XLATensorPtr dev_weight =
492
+ GetValueOrThrow (XLATensor::Create (weight, device));
479
493
XLATensorPtr dev_output;
480
494
if (with_bias) {
481
- XLATensorPtr dev_bias = XLATensor::Create (bias, device);
495
+ XLATensorPtr dev_bias =
496
+ GetValueOrThrow (XLATensor::Create (bias, device));
482
497
dev_output = tensor_methods::convolution_overrideable (
483
498
dev_input, dev_weight, dev_bias,
484
499
/* stride=*/ {stride, stride},
@@ -543,11 +558,14 @@ TEST_F(TensorTest, TestConv2DNonSquare) {
543
558
/* groups=*/ groups, false , false , false );
544
559
545
560
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
546
- XLATensorPtr dev_input = XLATensor::Create (input, device);
547
- XLATensorPtr dev_weight = XLATensor::Create (weight, device);
561
+ XLATensorPtr dev_input =
562
+ GetValueOrThrow (XLATensor::Create (input, device));
563
+ XLATensorPtr dev_weight =
564
+ GetValueOrThrow (XLATensor::Create (weight, device));
548
565
XLATensorPtr dev_output;
549
566
if (with_bias) {
550
- XLATensorPtr dev_bias = XLATensor::Create (bias, device);
567
+ XLATensorPtr dev_bias =
568
+ GetValueOrThrow (XLATensor::Create (bias, device));
551
569
dev_output = tensor_methods::convolution_overrideable (
552
570
dev_input, dev_weight, dev_bias,
553
571
/* stride=*/ {stride, stride + 1 },
@@ -616,11 +634,14 @@ TEST_F(TensorTest, TestConv3D) {
616
634
{output_padding, output_padding, output_padding},
617
635
/* groups=*/ groups, false , false , false );
618
636
ForEachDevice ([&](const torch::lazy::BackendDevice& device) {
619
- XLATensorPtr dev_input = XLATensor::Create (input, device);
620
- XLATensorPtr dev_weight = XLATensor::Create (weight, device);
637
+ XLATensorPtr dev_input =
638
+ GetValueOrThrow (XLATensor::Create (input, device));
639
+ XLATensorPtr dev_weight =
640
+ GetValueOrThrow (XLATensor::Create (weight, device));
621
641
XLATensorPtr dev_output;
622
642
if (with_bias) {
623
- XLATensorPtr dev_bias = XLATensor::Create (bias, device);
643
+ XLATensorPtr dev_bias =
644
+ GetValueOrThrow (XLATensor::Create (bias, device));
624
645
dev_output = tensor_methods::convolution_overrideable (
625
646
dev_input, dev_weight, dev_bias,
626
647
/* stride=*/ {stride, stride, stride},
@@ -688,10 +709,14 @@ TEST_F(TensorTest, TestConv3D) {
688
709
// {output_padding, output_padding + 1, output_padding},
689
710
// /*groups=*/groups, false, false, false);
690
711
// ForEachDevice([&](const torch::lazy::BackendDevice& device) {
691
- // XLATensorPtr dev_input = XLATensor::Create(input, device);
692
- // XLATensorPtr dev_weight = XLATensor::Create(weight,
693
- // device); XLATensorPtr dev_output; if (with_bias) {
694
- // XLATensorPtr dev_bias = XLATensor::Create(bias, device);
712
+ // XLATensorPtr dev_input =
713
+ // GetValueOrThrow(XLATensor::Create(input, device));
714
+ // XLATensorPtr dev_weight =
715
+ // GetValueOrThrow(XLATensor::Create(weight, device);
716
+ // XLATensorPtr dev_output;
717
+ // if (with_bias) {
718
+ // XLATensorPtr dev_bias =
719
+ // GetValueOrThrow(XLATensor::Create(bias, device));
695
720
// dev_output = tensor_methods::convolution_overrideable(
696
721
// dev_input, dev_weight, dev_bias,
697
722
// /*stride=*/{stride, stride + 1, stride + 1},
0 commit comments