@@ -47,6 +47,10 @@ def test_encode_decode():
47
47
torch .rand ((1 , 10 ), dtype = torch .float32 ),
48
48
torch .rand ((3 , 5 , 4000 ), dtype = torch .float64 ),
49
49
torch .tensor (1984 ), # test scalar too
50
+ # Make sure to test bf16 which numpy doesn't support.
51
+ torch .rand ((3 , 5 , 1000 ), dtype = torch .bfloat16 ),
52
+ torch .tensor ([float ("-inf" ), float ("inf" )] * 1024 ,
53
+ dtype = torch .bfloat16 ),
50
54
],
51
55
numpy_array = np .arange (512 ),
52
56
unrecognized = UnrecognizedType (33 ),
@@ -64,7 +68,7 @@ def test_encode_decode():
64
68
# There should be the main buffer + 4 large tensor buffers
65
69
# + 1 large numpy array. "large" is <= 512 bytes.
66
70
# The two small tensors are encoded inline.
67
- assert len (encoded ) == 6
71
+ assert len (encoded ) == 8
68
72
69
73
decoded : MyType = decoder .decode (encoded )
70
74
@@ -76,7 +80,7 @@ def test_encode_decode():
76
80
77
81
encoded2 = encoder .encode_into (obj , preallocated )
78
82
79
- assert len (encoded2 ) == 6
83
+ assert len (encoded2 ) == 8
80
84
assert encoded2 [0 ] is preallocated
81
85
82
86
decoded2 : MyType = decoder .decode (encoded2 )
0 commit comments