Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions tensorcircuit/backends/abstract_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,18 +865,46 @@ def mod(self: Any, x: Tensor, y: Tensor) -> Tensor:
"Backend '{}' has not implemented `mod`.".format(self.name)
)

def floor(self: Any, x: Tensor) -> Tensor:
def floor_divide(self: Any, x: Tensor, y: Tensor) -> Tensor:
r"""
Compute the element-wise floor division of two tensors.

This operation returns a new tensor containing the result of
dividing `x` by `y` and rounding each element down towards
negative infinity. The semantics are equivalent to the Python
`//` operator:

result[i] = floor(x[i] / y[i])

Broadcasting is supported according to the backend's rules.

:param x: Dividend tensor.
:type x: Tensor
:param y: Divisor tensor, must be broadcastable with `x`.
:type y: Tensor
:return: A tensor with the broadcasted shape of `x` and `y`,
where each element is the floored result of the division.
:rtype: Tensor

:raises NotImplementedError: If the backend does not provide an
implementation for `floor_divide`.
"""
raise NotImplementedError(
"Backend '{}' has not implemented `floor_divide`.".format(self.name)
)

def floor(self: Any, a: Tensor) -> Tensor:
"""
Compute the element-wise floor of the input tensor.

This operation returns a new tensor with the largest integers
less than or equal to each element of the input tensor,
i.e. it rounds each value down towards negative infinity.

:param x: Input tensor containing numeric values.
:type x: Tensor
:return: A tensor with the same shape as `x`, where each element
is the floored value of the corresponding element in `x`.
:param a: Input tensor containing numeric values.
:type a: Tensor
:return: A tensor with the same shape as `a`, where each element
is the floored value of the corresponding element in `a`.
:rtype: Tensor

:raises NotImplementedError: If the backend does not provide an
Expand Down
3 changes: 3 additions & 0 deletions tensorcircuit/backends/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ def mod(self, x: Tensor, y: Tensor) -> Tensor:
def floor(self, a: Tensor) -> Tensor:
return jnp.floor(a)

def floor_divide(self, x: Tensor, y: Tensor) -> Tensor:
return jnp.floor_divide(x, y)

def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
return jnp.clip(a, a_min, a_max)

Expand Down
3 changes: 3 additions & 0 deletions tensorcircuit/backends/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tenso
def mod(self, x: Tensor, y: Tensor) -> Tensor:
return np.mod(x, y)

def floor_divide(self, x: Tensor, y: Tensor) -> Tensor:
return np.floor_divide(x, y)

def floor(self, a: Tensor) -> Tensor:
return np.floor(a)

Expand Down
3 changes: 3 additions & 0 deletions tensorcircuit/backends/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,9 @@ def arange(self, start: int, stop: Optional[int] = None, step: int = 1) -> Tenso
def mod(self, x: Tensor, y: Tensor) -> Tensor:
return torchlib.fmod(x, y)

def floor_divide(self, x: Tensor, y: Tensor) -> Tensor:
return torchlib.floor_divide(x, y)

def floor(self, a: Tensor) -> Tensor:
return torchlib.floor(a)

Expand Down
3 changes: 3 additions & 0 deletions tensorcircuit/backends/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,9 @@ def floor(self, a: Tensor) -> Tensor:
return a
return tf.math.floor(a)

def floor_divide(self, x: Tensor, y: Tensor) -> Tensor:
return tf.math.floordiv(x, y)

def concat(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
return tf.concat(a, axis=axis)

Expand Down
59 changes: 59 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,3 +1335,62 @@ def test_backend_where(backend):
result = tc.backend.where(condition, x, y)
expected = tc.backend.convert_to_tensor([1, 5, 3])
np.testing.assert_allclose(result, expected)


@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb"), lf("torchb")])
def test_floor_divide_various_cases(backend):
r"""
Single test covering:
- basic positive integers
- negative dividends/divisors
- broadcasting
- floating inputs (matches floor semantics)
Ensures both operands are converted to the active backend's native tensor/array
to avoid type errors in Torch/JAX.
"""

def to_backend(z):
if hasattr(tc.backend, "asarray"):
return tc.backend.asarray(z)

name = getattr(tc.backend, "name", "").lower()
try:
if "torch" in name:
import torch

return torch.as_tensor(z)
if "jax" in name:
import jax.numpy as jnp

return jnp.asarray(z)
if "tf" in name or "tensorflow" in name:
import tensorflow as tf

return tf.convert_to_tensor(z)
except Exception:
pass
return np.asarray(z)

out = tc.backend.floor_divide(to_backend([7, 8, 9]), to_backend(2))
np.testing.assert_array_equal(np.array(out), np.array([3, 4, 4]))

out = tc.backend.floor_divide(to_backend([-3, -4]), to_backend(2))
np.testing.assert_array_equal(np.array(out), np.array([-2, -2]))

out = tc.backend.floor_divide(to_backend([3, 4]), to_backend(-2))
np.testing.assert_array_equal(np.array(out), np.array([-2, -2]))

out = tc.backend.floor_divide(to_backend([-3, -4]), to_backend(-2))
np.testing.assert_array_equal(np.array(out), np.array([1, 2]))

x = to_backend([[10, 20], [30, 40]])
y = to_backend([3, 5])
expected = np.array([[10, 20], [30, 40]]) // np.array([3, 5])
out = tc.backend.floor_divide(x, y)
np.testing.assert_array_equal(np.array(out), expected)

xf = to_backend([7.9, 8.1, -3.5])
yf = to_backend([2.0, 2.0, 2.0])
expectedf = np.floor_divide(np.array([7.9, 8.1, -3.5]), np.array([2.0, 2.0, 2.0]))
out = tc.backend.floor_divide(xf, yf)
np.testing.assert_array_equal(np.array(out), expectedf)