@@ -23,6 +23,8 @@ class TestArithmetic(MultiProcessTestCase):
23
23
This class tests all functions of the ArithmeticSharedTensor.
24
24
"""
25
25
26
+ # pyre-fixme[14]: `setUp` overrides method defined in `MultiProcessTestCase`
27
+ # inconsistently.
26
28
def setUp (self ) -> None :
27
29
super ().setUp ()
28
30
# We don't want the main process (rank -1) to initialize the communcator
@@ -170,23 +172,32 @@ def test_arithmetic(self) -> None:
170
172
self ._check (encrypted_out , reference , "square failed" )
171
173
172
174
# Test radd, rsub, and rmul
175
+ # pyre-fixme[61]: `tensor1` is undefined, or not always defined.
173
176
reference = 2 + tensor1
177
+ # pyre-fixme[61]: `tensor1` is undefined, or not always defined.
174
178
encrypted = ArithmeticSharedTensor (tensor1 )
179
+ # pyre-fixme[58]: `+` is not supported for operand types `int` and
180
+ # `ArithmeticSharedTensor`.
175
181
encrypted_out = 2 + encrypted
176
182
self ._check (encrypted_out , reference , "right add failed" )
177
183
184
+ # pyre-fixme[61]: `tensor1` is undefined, or not always defined.
178
185
reference = 2 - tensor1
179
186
encrypted_out = 2 - encrypted
180
187
self ._check (encrypted_out , reference , "right sub failed" )
181
188
189
+ # pyre-fixme[61]: `tensor1` is undefined, or not always defined.
182
190
reference = 2 * tensor1
191
+ # pyre-fixme[58]: `*` is not supported for operand types `int` and
192
+ # `ArithmeticSharedTensor`.
183
193
encrypted_out = 2 * encrypted
184
194
self ._check (encrypted_out , reference , "right mul failed" )
185
195
186
196
def test_sum (self ) -> None :
187
197
"""Tests sum reduction on encrypted tensor."""
188
198
tensor = get_random_test_tensor (size = (5 , 100 , 100 ), is_float = True )
189
199
encrypted = ArithmeticSharedTensor (tensor )
200
+ # pyre-fixme[16]: `ArithmeticSharedTensor` has no attribute `sum`.
190
201
self ._check (encrypted .sum (), tensor .sum (), "sum failed" )
191
202
192
203
for dim in [0 , 1 , 2 ]:
@@ -198,6 +209,7 @@ def test_prod(self) -> None:
198
209
"""Tests prod reduction on encrypted tensor."""
199
210
tensor = get_random_test_tensor (size = (3 , 3 ), max_value = 3 , is_float = False )
200
211
encrypted = ArithmeticSharedTensor (tensor )
212
+ # pyre-fixme[16]: `ArithmeticSharedTensor` has no attribute `prod`.
201
213
self ._check (encrypted .prod (), tensor .prod ().float (), "prod failed" )
202
214
203
215
# test with dim argument
@@ -231,6 +243,7 @@ def test_mean(self) -> None:
231
243
"""Tests computing means of encrypted tensors."""
232
244
tensor = get_random_test_tensor (size = (5 , 10 , 15 ), is_float = True )
233
245
encrypted = ArithmeticSharedTensor (tensor )
246
+ # pyre-fixme[16]: `ArithmeticSharedTensor` has no attribute `mean`.
234
247
self ._check (encrypted .mean (), tensor .mean (), "mean failed" )
235
248
236
249
for dim in [0 , 1 , 2 ]:
@@ -350,6 +363,7 @@ def test_dot_ger(self) -> None:
350
363
351
364
# dot
352
365
encrypted_tensor = ArithmeticSharedTensor (tensor1 )
366
+ # pyre-fixme[16]: `ArithmeticSharedTensor` has no attribute `dot`.
353
367
encrypted_out = encrypted_tensor .dot (tensor2 )
354
368
self ._check (
355
369
encrypted_out ,
@@ -363,6 +377,7 @@ def test_dot_ger(self) -> None:
363
377
364
378
# ger
365
379
encrypted_tensor = ArithmeticSharedTensor (tensor1 )
380
+ # pyre-fixme[16]: `ArithmeticSharedTensor` has no attribute `ger`.
366
381
encrypted_out = encrypted_tensor .ger (tensor2 )
367
382
self ._check (
368
383
encrypted_out ,
@@ -382,11 +397,13 @@ def test_squeeze(self) -> None:
382
397
reference = tensor .unsqueeze (dim )
383
398
384
399
encrypted = ArithmeticSharedTensor (tensor )
400
+ # pyre-fixme[16]: `ArithmeticSharedTensor` has no attribute `unsqueeze`.
385
401
encrypted_out = encrypted .unsqueeze (dim )
386
402
self ._check (encrypted_out , reference , "unsqueeze failed" )
387
403
388
404
# Test squeeze
389
405
encrypted = ArithmeticSharedTensor (tensor .unsqueeze (0 ))
406
+ # pyre-fixme[16]: `ArithmeticSharedTensor` has no attribute `squeeze`.
390
407
encrypted_out = encrypted .squeeze ()
391
408
self ._check (encrypted_out , reference .squeeze (), "squeeze failed" )
392
409
@@ -417,12 +434,15 @@ def test_transpose(self) -> None:
417
434
418
435
if len (size ) == 2 : # t() asserts dim == 2
419
436
reference = tensor .t ()
437
+ # pyre-fixme[16]: `ArithmeticSharedTensor` has no attribute `t`.
420
438
encrypted_out = encrypted_tensor .t ()
421
439
self ._check (encrypted_out , reference , "t() failed" )
422
440
423
441
for dim0 in range (len (size )):
424
442
for dim1 in range (len (size )):
425
443
reference = tensor .transpose (dim0 , dim1 )
444
+ # pyre-fixme[16]: `ArithmeticSharedTensor` has no attribute
445
+ # `transpose`.
426
446
encrypted_out = encrypted_tensor .transpose (dim0 , dim1 )
427
447
self ._check (encrypted_out , reference , "transpose failed" )
428
448
@@ -445,6 +465,7 @@ def test_permute(self) -> None:
445
465
# test reversing the dimensions
446
466
dim_arr = [x - 1 for x in range (tensor .dim (), 0 , - 1 )]
447
467
reference = tensor .permute (dim_arr )
468
+ # pyre-fixme[16]: `ArithmeticSharedTensor` has no attribute `permute`.
448
469
encrypted_out = encrypted_tensor .permute (dim_arr )
449
470
self ._check (encrypted_out , reference , "permute failed" )
450
471
@@ -628,6 +649,7 @@ def test_take(self) -> None:
628
649
for dimension in range (0 , 4 ):
629
650
reference = torch .from_numpy (tensor .numpy ().take (index , dimension ))
630
651
encrypted_tensor = ArithmeticSharedTensor (tensor )
652
+ # pyre-fixme[16]: `ArithmeticSharedTensor` has no attribute `take`.
631
653
encrypted_out = encrypted_tensor .take (index , dimension )
632
654
self ._check (encrypted_out , reference , "take function failed: dimension set" )
633
655
@@ -653,6 +675,8 @@ def test_get_set(self) -> None:
653
675
reference = tensor [:, 0 ]
654
676
655
677
encrypted_tensor = ArithmeticSharedTensor (tensor )
678
+ # pyre-fixme[16]: `ArithmeticSharedTensor` has no attribute
679
+ # `__getitem__`.
656
680
encrypted_out = encrypted_tensor [:, 0 ]
657
681
self ._check (encrypted_out , reference , "getitem failed" )
658
682
@@ -933,6 +957,7 @@ def test_gather(self) -> None:
933
957
index = index .abs ().clamp (0 , 4 )
934
958
encrypted = ArithmeticSharedTensor (tensor )
935
959
reference = tensor .gather (dim , index )
960
+ # pyre-fixme[16]: `ArithmeticSharedTensor` has no attribute `gather`.
936
961
encrypted_out = encrypted .gather (dim , index )
937
962
self ._check (encrypted_out , reference , f"gather failed with size { size } " )
938
963
@@ -948,6 +973,7 @@ def test_split(self) -> None:
948
973
for idx in range (6 ):
949
974
split = (idx , 5 - idx )
950
975
reference0 , reference1 = tensor .split (split , dim = dim )
976
+ # pyre-fixme[16]: `ArithmeticSharedTensor` has no attribute `split`.
951
977
encrypted_out0 , encrypted_out1 = encrypted .split (split , dim = dim )
952
978
953
979
self ._check (
0 commit comments