Skip to content

Commit bcc5ca7

Browse files
fix #54
1 parent c3cec24 commit bcc5ca7

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
- Add sparse matrix related methods for pytorch backend.
1010

11+
- Add exp and expm for torch backend.
12+
1113
## v1.4.0
1214

1315
### Added

tensorcircuit/backends/pytorch_backend.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,9 @@ def ones(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor:
229229
r = torchlib.ones(shape)
230230
return self.cast(r, dtype)
231231

232+
def exp(self, tensor: Tensor) -> Tensor:
233+
return torchlib.exp(tensor)
234+
232235
def zeros(self, shape: Tuple[int, ...], dtype: Optional[str] = None) -> Tensor:
233236
if dtype is None:
234237
dtype = dtypestr
@@ -248,7 +251,8 @@ def convert_to_tensor(self, tensor: Tensor, dtype: Optional[str] = None) -> Tens
248251
return result
249252

250253
def expm(self, a: Tensor) -> Tensor:
251-
raise NotImplementedError("pytorch backend doesn't support expm")
254+
return torchlib.linalg.matrix_exp(a)
255+
# raise NotImplementedError("pytorch backend doesn't support expm")
252256
# in 2020, torch has no expm, hmmm. but that's ok,
253257
# it doesn't support complex numbers which is more severe issue.
254258
# see https://github.com/pytorch/pytorch/issues/9983

tests/test_backends.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,14 @@ def sum_(carry, x):
305305
def test_backend_methods_2(backend):
306306
np.testing.assert_allclose(tc.backend.mean(tc.backend.ones([10])), 1.0, atol=1e-5)
307307
# acos acosh asin asinh atan atan2 atanh cosh (cos) tan tanh sinh (sin)
308+
np.testing.assert_allclose(
309+
tc.backend.exp(tc.backend.ones([2, 3])), np.exp(np.ones([2, 3])), atol=1e-5
310+
)
311+
np.testing.assert_allclose(
312+
tc.backend.expm(tc.backend.ones([3, 3])),
313+
scipy.linalg.expm(np.ones([3, 3])),
314+
atol=1e-5,
315+
)
308316
np.testing.assert_allclose(
309317
tc.backend.acos(tc.backend.ones([2], dtype="float32")),
310318
np.arccos(tc.backend.ones([2])),

0 commit comments

Comments
 (0)