Skip to content

Commit d1a360f

Browse files
committed
According to the maintainer`s relevant suggestions, the bugs have been fixed.
1 parent ba4e1af commit d1a360f

File tree

4 files changed

+8
-16
lines changed

4 files changed

+8
-16
lines changed

tensorcircuit/backends/jax_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,8 @@ def mod(self, x: Tensor, y: Tensor) -> Tensor:
352352
def floor(self, x: Tensor) -> Tensor:
353353
return jnp.floor(x)
354354

355-
def clip(self, x: Tensor, lower: Tensor, upper: Tensor) -> Tensor:
356-
return jnp.clip(x, lower, upper)
355+
def clip(self, x: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
356+
return jnp.clip(x, a_min, a_max)
357357

358358
def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
359359
return jnp.right_shift(x, y)

tensorcircuit/backends/numpy_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,8 @@ def mod(self, x: Tensor, y: Tensor) -> Tensor:
253253
def floor(self, x: Tensor) -> Tensor:
254254
return np.floor(x)
255255

256-
def clip(self, x: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
257-
return np.clip(x, a_min, a_max)
256+
def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
257+
return np.clip(a, a_min, a_max)
258258

259259
def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
260260
return np.right_shift(x, y)

tensorcircuit/backends/pytorch_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,8 @@ def mod(self, x: Tensor, y: Tensor) -> Tensor:
432432
def floor(self, x: Tensor) -> Tensor:
433433
return torchlib.floor(x)
434434

435-
def clip(self, x: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
436-
return torchlib.clamp(x, a_min, a_max)
435+
def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
436+
return torchlib.clamp(a, a_min, a_max)
437437

438438
def right_shift(self, x: Tensor, y: Tensor) -> Tensor:
439439
return torchlib.bitwise_right_shift(x, y)

tensorcircuit/backends/tensorflow_backend.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -577,17 +577,9 @@ def clip(self, a: Tensor, a_min: Tensor, a_max: Tensor) -> Tensor:
577577
return tf.clip_by_value(a, a_min, a_max)
578578

579579
def floor(self, x: Tensor) -> Tensor:
580-
dtype_str = x.dtype.name if hasattr(x.dtype, "name") else str(x.dtype)
581-
if x.dtype.is_floating:
582-
return tf.math.floor(x)
583-
elif x.dtype.is_integer:
580+
if x.dtype.is_integer:
584581
return x
585-
elif x.dtype.is_complex:
586-
raise TypeError(
587-
f"tf.math.floor does not support complex dtype ({dtype_str})"
588-
)
589-
else:
590-
raise TypeError(f"Unsupported dtype for floor: {dtype_str}")
582+
return tf.math.floor(x)
591583

592584
def concat(self, a: Sequence[Tensor], axis: int = 0) -> Tensor:
593585
return tf.concat(a, axis=axis)

0 commit comments

Comments
 (0)