Skip to content

Commit b578795

Browse files
committed
add more bf16 tests from #16860
Signed-off-by: Staszek Pasko <[email protected]>
1 parent 29daef4 commit b578795

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tests/v1/test_serial_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def test_encode_decode():
4747
torch.rand((1, 10), dtype=torch.float32),
4848
torch.rand((3, 5, 4000), dtype=torch.float64),
4949
torch.tensor(1984), # test scalar too
50+
# Make sure to test bf16 which numpy doesn't support.
51+
torch.rand((3, 5, 1000), dtype=torch.bfloat16),
52+
torch.tensor([float("-inf"), float("inf")] * 1024,
53+
dtype=torch.bfloat16),
5054
],
5155
numpy_array=np.arange(512),
5256
unrecognized=UnrecognizedType(33),
@@ -64,7 +68,7 @@ def test_encode_decode():
6468
# There should be the main buffer + 4 large tensor buffers
6569
# + 1 large numpy array. "large" is <= 512 bytes.
6670
# The two small tensors are encoded inline.
67-
assert len(encoded) == 6
71+
assert len(encoded) == 8
6872

6973
decoded: MyType = decoder.decode(encoded)
7074

@@ -76,7 +80,7 @@ def test_encode_decode():
7680

7781
encoded2 = encoder.encode_into(obj, preallocated)
7882

79-
assert len(encoded2) == 6
83+
assert len(encoded2) == 8
8084
assert encoded2[0] is preallocated
8185

8286
decoded2: MyType = decoder.decode(encoded2)

0 commit comments

Comments
 (0)