Skip to content

Commit 60eff60

Browse files
committed
Add tests for reshaped() in abstractbackend.py.
1 parent f7c8169 commit 60eff60

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

tests/test_backends.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,75 @@ def test_dlpack(backend):
497497
np.testing.assert_allclose(a, a1, atol=1e-5)
498498

499499

500+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
501+
def test_backend_reshaped_basic(backend):
502+
a1 = tc.backend.convert_to_tensor(np.arange(27))
503+
r1 = tc.backend.reshaped(a1, 3)
504+
assert r1.shape == (3, 3, 3)
505+
np.testing.assert_allclose(tc.backend.numpy(r1), np.arange(27).reshape(3, 3, 3))
506+
d, n = 4, 3
507+
dim = d**n
508+
mat = np.arange(dim * dim, dtype=np.float32).reshape(dim, dim)
509+
a2 = tc.backend.convert_to_tensor(mat)
510+
r2 = tc.backend.reshaped(a2, d)
511+
assert r2.shape == (d,) * (2 * n)
512+
np.testing.assert_allclose(tc.backend.numpy(r2), mat.reshape((d,) * (2 * n)))
513+
514+
515+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
516+
def test_backend_reshaped_zero_size(backend):
517+
"""size == 0 returns a canonical empty vector shape (0,) regardless of input shape."""
518+
a0 = tc.backend.convert_to_tensor(np.array([], dtype=np.float32))
519+
r0 = tc.backend.reshaped(a0, 3)
520+
assert r0.shape == (0,)
521+
assert tc.backend.sizen(r0) == 0
522+
523+
a1 = tc.backend.convert_to_tensor(np.zeros((2, 0), dtype=np.float32))
524+
r1 = tc.backend.reshaped(a1, 5)
525+
assert r1.shape == (0,)
526+
assert tc.backend.sizen(r1) == 0
527+
528+
529+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
530+
def test_backend_reshaped_dtype_device_preserved(backend):
531+
"""Reshape should not change dtype or device."""
532+
a = tc.backend.ones([16], dtype="float32")
533+
dev = tc.backend.device(a)
534+
r = tc.backend.reshaped(a, 2)
535+
assert r.shape == (2, 2, 2, 2)
536+
assert tc.backend.dtype(r) == tc.backend.dtype(a)
537+
assert tc.backend.device(r) == dev
538+
539+
540+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
541+
def test_backend_reshaped_scalar_size_one(backend):
542+
"""size == 1 stays scalar: nleg = 0 so shape () is kept."""
543+
a = tc.backend.ones([]) # scalar tensor, total size = 1
544+
r = tc.backend.reshaped(a, 2)
545+
assert r.shape == ()
546+
np.testing.assert_allclose(tc.backend.numpy(r), tc.backend.numpy(a))
547+
548+
549+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
550+
def test_backend_reshaped_invalid_d_raises(backend):
551+
"""d must be a positive integer: non-int or <=0 should raise."""
552+
a = tc.backend.ones([4], dtype="float32")
553+
with pytest.raises(ValueError):
554+
tc.backend.reshaped(a, 0)
555+
with pytest.raises(ValueError):
556+
tc.backend.reshaped(a, -2)
557+
with pytest.raises(ValueError):
558+
tc.backend.reshaped(a, 2.5) # not an int
559+
560+
561+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
562+
def test_backend_reshaped_non_power_raises(backend):
563+
"""When size is not a power of d, raise ValueError."""
564+
a = tc.backend.convert_to_tensor(np.arange(10))
565+
with pytest.raises(ValueError):
566+
tc.backend.reshaped(a, 3)
567+
568+
500569
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
501570
def test_arg_cmp(backend):
502571
np.testing.assert_allclose(tc.backend.argmax(tc.backend.ones([3], "float64")), 0)

0 commit comments

Comments
 (0)