Skip to content

Commit 5825706

Browse files
committed
Fix
1 parent f71bb8a commit 5825706

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

test/test_tv_tensors.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def test_to_wrapping(make_input):
139139
)
140140
@pytest.mark.parametrize("return_type", ["Tensor", "TVTensor"])
141141
def test_to_tv_tensor_reference(make_input, return_type):
142-
tensor = torch.rand((3, 16, 16), dtype=torch.float64)
142+
tensor = make_input().to(dtype=torch.float64).as_subclass(torch.Tensor)
143+
assert type(tensor) is torch.Tensor
143144
dp = make_input()
144145

145146
with tv_tensors.set_return_type(return_type):

0 commit comments

Comments
 (0)