Skip to content

Commit ba4e1af

Browse files
committed
Add tests for clip(), floor() functions in all backends.
1 parent ab2b641 commit ba4e1af

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

tests/test_backends.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,39 @@ def test_backend_methods_2(backend):
435435
# assert tc.dtype(a) == "float32"
436436

437437

438+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
439+
def test_backend_floor(backend):
440+
"""Test floor method (element-wise, dtype/device preservation, integers unchanged)."""
441+
a = tc.backend.convert_to_tensor([-1.7, -0.0, 0.0, 0.2, 3.9])
442+
r = tc.backend.floor(a)
443+
expected = tc.backend.convert_to_tensor([-2.0, -0.0, 0.0, 0.0, 3.0])
444+
np.testing.assert_allclose(r, expected, atol=1e-6)
445+
assert tc.backend.dtype(r) == tc.backend.dtype(a)
446+
assert tc.backend.device(r) == tc.backend.device(a)
447+
ai = tc.backend.convert_to_tensor([0, 1, -2, 3])
448+
ri = tc.backend.floor(ai)
449+
np.testing.assert_allclose(ri, ai)
450+
451+
452+
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
453+
def test_backend_clip(backend):
454+
"""Test clip method (scalar/tensor bounds, broadcasting, dtype/device)."""
455+
a = tc.backend.convert_to_tensor([-2.0, -0.5, 0.0, 0.5, 10.0])
456+
a_min = tc.backend.convert_to_tensor(-1.0)
457+
a_max = tc.backend.convert_to_tensor(1.0)
458+
r = tc.backend.clip(a, a_min, a_max)
459+
expected = tc.backend.convert_to_tensor([-1.0, -0.5, 0.0, 0.5, 1.0])
460+
np.testing.assert_allclose(r, expected, atol=1e-6)
461+
assert tc.backend.dtype(r) == tc.backend.dtype(a)
462+
assert tc.backend.device(r) == tc.backend.device(a)
463+
a2 = tc.backend.convert_to_tensor([[-5.0, 0.0, 5.0], [1.0, 2.0, 3.0]])
464+
a2_min = tc.backend.convert_to_tensor([[-1.0, 0.0, 0.0], [0.0, 1.0, 2.0]])
465+
a2_max = tc.backend.convert_to_tensor([[0.0, 0.0, 4.0], [1.0, 2.0, 2.0]])
466+
r2 = tc.backend.clip(a2, a2_min, a2_max)
467+
expected2 = tc.backend.convert_to_tensor([[-1.0, 0.0, 4.0], [1.0, 2.0, 2.0]])
468+
np.testing.assert_allclose(r2, expected2, atol=1e-6)
469+
470+
438471
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
439472
def test_device_cpu_only(backend):
440473
a = tc.backend.ones([])

0 commit comments

Comments
 (0)