Skip to content

Commit 9bf2a9a

Browse files
drisspgpytorchmergebot
authored andcommitted
[ScaledMM] Fix NaNs in test for garbage input data (pytorch#144042)
Pull Request resolved: pytorch#144042 Approved by: https://github.com/janeyx99
1 parent b75f32b commit 9bf2a9a

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

test/test_matmul_cuda.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,9 @@ def test_scaled_mm_change_stride(self, base_dtype):
441441
x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype)
442442
y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype)
443443

444+
x.normal_()
445+
y.normal_()
446+
444447
x_scale = tensor_to_scale(x, input_dtype).float()
445448
y_scale = tensor_to_scale(y, input_dtype).float()
446449

0 commit comments

Comments
 (0)