Skip to content

Commit 9537a09

Browse files
Merge branch 'master' into beta
2 parents c14d6a5 + 4882e85 commit 9537a09

File tree

4 files changed

+125
-2
lines changed

4 files changed

+125
-2
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88

99
- Torch support is updraded to 2.0, and now support native vmap and native functional grad, and thus `vvag`. Still jit support is conflict with these functional transformations and be turned off by default
1010

11+
- Add `torch_interfaces_kws` that support static keyword arguments when wrapping with the interface
12+
13+
### Fixed
14+
15+
- Add tests and fixed some missing methods for cupy backend, cupy backend is now ready to use (though still not guaranteed)
16+
1117
## 0.8.0
1218

1319
### Added

tensorcircuit/interfaces/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@
1212
)
1313
from .numpy import numpy_interface, np_interface
1414
from .scipy import scipy_interface, scipy_optimize_interface
15-
from .torch import torch_interface, pytorch_interface
15+
from .torch import torch_interface, pytorch_interface, torch_interface_kws
1616
from .tensorflow import tensorflow_interface, tf_interface

tensorcircuit/interfaces/torch.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
Interface wraps quantum function as a torch function
33
"""
44

5-
from typing import Any, Callable, Tuple
5+
from typing import Any, Callable, Dict, Tuple
6+
from functools import partial
67

78
from ..cons import backend
89
from ..utils import is_sequence
@@ -112,3 +113,51 @@ def backward(ctx: Any, *grad_y: Any) -> Any:
112113

113114

114115
pytorch_interface = torch_interface
116+
117+
118+
def torch_interface_kws(
119+
f: Callable[..., Any], jit: bool = True, enable_dlpack: bool = False
120+
) -> Callable[..., Any]:
121+
"""
122+
similar to py:meth:`tensorcircuit.interfaces.torch.torch_interface`,
123+
but now the interface support static arguments for function ``f``,
124+
which is not a tensor and can be used with keyword arguments
125+
126+
:Example:
127+
128+
.. code-block:: python
129+
130+
tc.set_backend("tensorflow")
131+
132+
def f(tensor, integer):
133+
r = 0.
134+
for i in range(integer):
135+
r += tensor
136+
return r
137+
138+
fnew = torch_interface_kws(f)
139+
140+
print(fnew(torch.ones([2]), integer=3))
141+
print(fnew(torch.ones([2]), integer=4))
142+
143+
:param f: _description_
144+
:type f: Callable[..., Any]
145+
:param jit: _description_, defaults to True
146+
:type jit: bool, optional
147+
:param enable_dlpack: _description_, defaults to False
148+
:type enable_dlpack: bool, optional
149+
:return: _description_
150+
:rtype: Callable[..., Any]
151+
"""
152+
cache_dict: Dict[Tuple[Any, ...], Callable[..., Any]] = {}
153+
154+
def wrapper(*args: Any, **kws: Any) -> Any:
155+
key = tuple([(k, v) for k, v in kws.items()])
156+
if key not in cache_dict:
157+
fnew = torch_interface(
158+
partial(f, **kws), jit=jit, enable_dlpack=enable_dlpack
159+
)
160+
cache_dict[key] = fnew
161+
return cache_dict[key](*args)
162+
163+
return wrapper

tests/test_interfaces.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,74 @@ def f3(x):
108108
np.testing.assert_allclose(pg, 2 * np.ones([2]).astype(np.complex64), atol=1e-5)
109109

110110

111+
@pytest.mark.skipif(is_torch is False, reason="torch not installed")
112+
@pytest.mark.parametrize("backend", [lf("tfb"), lf("jaxb")])
113+
def test_torch_interface_kws(backend):
114+
def f(param, n):
115+
c = tc.Circuit(n)
116+
c = tc.templates.blocks.example_block(c, param)
117+
loss = c.expectation(
118+
[
119+
tc.gates.x(),
120+
[
121+
1,
122+
],
123+
]
124+
)
125+
return tc.backend.real(loss)
126+
127+
f_jit_torch = tc.interfaces.torch_interface_kws(f, jit=True, enable_dlpack=True)
128+
129+
param = torch.ones([4, 4], requires_grad=True)
130+
l = f_jit_torch(param, n=4)
131+
l = l**2
132+
l.backward()
133+
134+
pg = param.grad
135+
np.testing.assert_allclose(pg.shape, [4, 4])
136+
np.testing.assert_allclose(pg[0, 1], -2.146e-3, atol=1e-5)
137+
138+
def f2(paramzz, paramx, n, nlayer):
139+
c = tc.Circuit(n)
140+
for i in range(n):
141+
c.H(i)
142+
for j in range(nlayer): # 2
143+
for i in range(n - 1):
144+
c.exp1(i, i + 1, unitary=tc.gates._zz_matrix, theta=paramzz[j, i])
145+
for i in range(n):
146+
c.rx(i, theta=paramx[j, i])
147+
loss1 = c.expectation(
148+
[
149+
tc.gates.x(),
150+
[
151+
1,
152+
],
153+
]
154+
)
155+
loss2 = c.expectation(
156+
[
157+
tc.gates.x(),
158+
[
159+
2,
160+
],
161+
]
162+
)
163+
return tc.backend.real(loss1), tc.backend.real(loss2)
164+
165+
f2_torch = tc.interfaces.torch_interface_kws(f2, jit=True, enable_dlpack=True)
166+
167+
paramzz = torch.ones([2, 4], requires_grad=True)
168+
paramx = torch.ones([2, 4], requires_grad=True)
169+
170+
l1, l2 = f2_torch(paramzz, paramx, n=4, nlayer=2)
171+
l = l1 - l2
172+
l.backward()
173+
174+
pg = paramzz.grad
175+
np.testing.assert_allclose(pg.shape, [2, 4])
176+
np.testing.assert_allclose(pg[0, 0], -0.41609, atol=1e-5)
177+
178+
111179
@pytest.mark.skipif(is_torch is False, reason="torch not installed")
112180
@pytest.mark.xfail(
113181
(int(tf.__version__.split(".")[1]) < 9)

0 commit comments

Comments
 (0)