Skip to content

Commit 49106eb

Browse files
Merge branch 'master' into beta
2 parents 3a028f7 + 1de1f17 commit 49106eb

File tree

9 files changed

+421
-7
lines changed

9 files changed

+421
-7
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212

1313
- Add `gpu_memory_share` function and enable it by default
1414

15+
- Add `scan` methods for backends
16+
17+
- Add example demontrating how jax compiling time can be accelerated by `jax.lax.scan`
18+
1519
### Fixed
1620

1721
- Add tests and fixed some missing methods for cupy backend, cupy backend is now ready to use (though still not guaranteed)

docs/source/quickstart.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -696,10 +696,10 @@ There is also a more flexible torch interface that support static non-tensor inp
696696
.. code-block:: python
697697
698698
def f(a, i):
699-
s = 0.
700-
for _ in range(i):
701-
s += a
702-
return s
699+
s = 0.
700+
for _ in range(i):
701+
s += a
702+
return s
703703
704704
f_torch = tc.interfaces.torch_interface_kws(f)
705705
f_torch(torch.ones([2]), i=3)

examples/hea_scan_jit_acc.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""
2+
reducing jit compiling time by general scan magic
3+
"""
4+
5+
import numpy as np
6+
import tensorcircuit as tc
7+
8+
n = 10
9+
nlayers = 16
10+
param_np = np.random.normal(size=[nlayers, n, 2])
11+
12+
for backend in ["tensorflow", "jax"]:
13+
with tc.runtime_backend(backend) as K:
14+
print("running %s" % K.name)
15+
16+
def energy_reference(param, n, nlayers):
17+
c = tc.Circuit(n)
18+
for i in range(n):
19+
c.h(i)
20+
for i in range(nlayers):
21+
for j in range(n - 1):
22+
c.rzz(j, j + 1, theta=param[i, j, 0])
23+
for j in range(n):
24+
c.rx(j, theta=param[i, j, 1])
25+
return K.real(c.expectation_ps(z=[0, 1]) + c.expectation_ps(x=[2]))
26+
27+
vg_reference = K.jit(
28+
K.value_and_grad(energy_reference, argnums=0), static_argnums=(1, 2)
29+
)
30+
31+
# a jit efficient way to utilize scan
32+
33+
def energy(param, n, nlayers, each):
34+
def loop_f(s_, param_):
35+
c_ = tc.Circuit(n, inputs=s_)
36+
for i in range(each):
37+
for j in range(n - 1):
38+
c_.rzz(j, j + 1, theta=param_[i, j, 0])
39+
for j in range(n):
40+
c_.rx(j, theta=param_[i, j, 1])
41+
s_ = c_.state()
42+
return s_
43+
44+
c = tc.Circuit(n)
45+
for i in range(n):
46+
c.h(i)
47+
s = c.state()
48+
s1 = K.scan(loop_f, K.reshape(param, [nlayers // each, each, n, 2]), s)
49+
c1 = tc.Circuit(n, inputs=s1)
50+
return K.real(c1.expectation_ps(z=[0, 1]) + c1.expectation_ps(x=[2]))
51+
52+
vg = K.jit(
53+
K.value_and_grad(energy, argnums=0),
54+
static_argnums=(1, 2, 3),
55+
jit_compile=True,
56+
)
57+
# set to False can improve compile time for tf
58+
59+
param = K.convert_to_tensor(param_np)
60+
61+
for each in [1, 2, 4]:
62+
print(" scan impl with each=%s" % str(each))
63+
r1 = tc.utils.benchmark(vg, param, n, nlayers, each)
64+
print(r1[0][0])
65+
66+
print(" plain impl")
67+
r0 = tc.utils.benchmark(vg_reference, param, n, nlayers) # too slow
68+
np.testing.assert_allclose(r0[0][0], r1[0][0], atol=1e-5)
69+
np.testing.assert_allclose(r0[0][1], r1[0][1], atol=1e-5)
70+
# correctness check
71+
72+
73+
# jit_compile=True icrease runtime while degrades jit time for tensorflow
74+
# and in general jax improves better with scan methodology,
75+
# both compile time and running time can outperform tf

examples/jax_scan_jit_acc.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""
2+
reducing jax jit compiling time by some magic:
3+
for backend agnostic but similar approach,
4+
see `hea_scan_jit_acc.py`
5+
"""
6+
7+
import numpy as np
8+
import jax
9+
import tensorcircuit as tc
10+
11+
K = tc.set_backend("jax")
12+
tc.set_dtype("complex128")
13+
14+
15+
def energy_reference(param, n, nlayers):
16+
c = tc.Circuit(n)
17+
for i in range(n):
18+
c.h(i)
19+
for i in range(nlayers):
20+
for j in range(n - 1):
21+
c.rzz(j, j + 1, theta=param[i, j, 0])
22+
for j in range(n):
23+
c.rx(j, theta=param[i, j, 1])
24+
return K.real(c.expectation_ps(z=[0, 1]))
25+
26+
27+
vg_reference = K.jit(
28+
K.value_and_grad(energy_reference, argnums=0), static_argnums=(1, 2)
29+
)
30+
31+
with tc.runtime_backend("tensorflow") as tfk:
32+
33+
def energy_reference_tf(param, n, nlayers):
34+
c = tc.Circuit(n)
35+
for i in range(n):
36+
c.h(i)
37+
for i in range(nlayers):
38+
for j in range(n - 1):
39+
c.rzz(j, j + 1, theta=param[i, j, 0])
40+
for j in range(n):
41+
c.rx(j, theta=param[i, j, 1])
42+
return tfk.real(c.expectation_ps(z=[0, 1]))
43+
44+
vg_reference_tf = tfk.jit(
45+
tfk.value_and_grad(energy_reference_tf, argnums=0), static_argnums=(1, 2)
46+
)
47+
48+
# a jit efficient way to utilize jax scan
49+
50+
51+
def energy(param, n, nlayers, each):
52+
def loop_f(s_, param_):
53+
c_ = tc.Circuit(n, inputs=s_)
54+
for i in range(each):
55+
for j in range(n - 1):
56+
c_.rzz(j, j + 1, theta=param_[i, j, 0])
57+
for j in range(n):
58+
c_.rx(j, theta=param_[i, j, 1])
59+
s_ = c_.state()
60+
return s_, s_
61+
62+
c = tc.Circuit(n)
63+
for i in range(n):
64+
c.h(i)
65+
s = c.state()
66+
s1, _ = jax.lax.scan(loop_f, s, K.reshape(param, [nlayers // each, each, n, 2]))
67+
c1 = tc.Circuit(n, inputs=s1)
68+
return K.real(c1.expectation_ps(z=[0, 1]))
69+
70+
71+
vg = K.jit(K.value_and_grad(energy, argnums=0), static_argnums=(1, 2, 3))
72+
73+
if __name__ == "__main__":
74+
n = 10
75+
nlayers = 32
76+
param = K.implicit_randn([nlayers, n, 2])
77+
78+
r1 = tc.utils.benchmark(vg, param, n, nlayers, 1)
79+
print(r1[0][0])
80+
r1 = tc.utils.benchmark(vg, param, n, nlayers, 2)
81+
print(r1[0][0])
82+
r1 = tc.utils.benchmark(vg, param, n, nlayers, 4)
83+
print(r1[0][0])
84+
85+
with tc.runtime_backend("tensorflow"):
86+
print("tf plain impl")
87+
param_tf = tc.array_to_tensor(param, dtype="float32")
88+
r0 = tc.utils.benchmark(vg_reference_tf, param_tf, n, nlayers)
89+
90+
np.testing.assert_allclose(r0[0][0], r1[0][0], atol=1e-5)
91+
np.testing.assert_allclose(r0[0][1], r1[0][1], atol=1e-5)
92+
# correctness check
93+
94+
print("jax plain impl (may be super slow for deeper system)")
95+
r0 = tc.utils.benchmark(vg_reference, param, n, nlayers) # too slow
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
"""
2+
slicing the output wavefunction to save the memory in VQA context
3+
"""
4+
5+
from itertools import product
6+
import numpy as np
7+
import tensorcircuit as tc
8+
9+
K = tc.set_backend("jax")
10+
11+
12+
def circuit(param, n, nlayers):
13+
c = tc.Circuit(n)
14+
for i in range(n):
15+
c.h(i)
16+
c = tc.templates.blocks.example_block(c, param, nlayers)
17+
return c
18+
19+
20+
def sliced_state(c, cut, mask):
21+
# mask = Tensor([0, 1, 0])
22+
# cut = [0, 1, 2]
23+
n = c._nqubits
24+
ncut = len(cut)
25+
end0 = tc.array_to_tensor(np.array([1.0, 0.0]))
26+
end1 = tc.array_to_tensor(np.array([0.0, 1.0]))
27+
ends = [tc.Gate(mask[i] * end1 + (1 - mask[i]) * end0) for i in range(ncut)]
28+
nodes, front = c._copy()
29+
for j, i in enumerate(cut):
30+
front[i] ^ ends[j][0]
31+
oeo = []
32+
for i in range(n):
33+
if i not in cut:
34+
oeo.append(front[i])
35+
ss = tc.contractor(nodes + ends, output_edge_order=oeo)
36+
return ss
37+
38+
39+
def sliced_op(ps, cut, mask1, mask2):
40+
# ps: Tensor([0, 0, 1, 1])
41+
n = K.shape_tuple(ps)[-1]
42+
ncut = len(cut)
43+
end0 = tc.array_to_tensor(np.array([1.0, 0.0]))
44+
end1 = tc.array_to_tensor(np.array([0.0, 1.0]))
45+
endsr = [tc.Gate(mask1[i] * end1 + (1 - mask1[i]) * end0) for i in range(ncut)]
46+
endsl = [tc.Gate(mask2[i] * end1 + (1 - mask2[i]) * end0) for i in range(ncut)]
47+
48+
structuresc = K.cast(ps, dtype="int32")
49+
structuresc = K.onehot(structuresc, num=4)
50+
structuresc = K.cast(structuresc, dtype=tc.dtypestr)
51+
obs = []
52+
for i in range(n):
53+
obs.append(
54+
tc.Gate(
55+
sum(
56+
[
57+
structuresc[i, k] * g.tensor
58+
for k, g in enumerate(tc.gates.pauli_gates)
59+
]
60+
)
61+
)
62+
)
63+
for j, i in enumerate(cut):
64+
obs[i][0] ^ endsl[j][0]
65+
obs[i][1] ^ endsr[j][0]
66+
oeo = []
67+
for i in range(n):
68+
if i not in cut:
69+
oeo.append(obs[i][0])
70+
for i in range(n):
71+
if i not in cut:
72+
oeo.append(obs[i][1])
73+
return obs + endsl + endsr, oeo
74+
75+
76+
def sliced_core(param, n, nlayers, ps, cut, mask1, mask2):
77+
# param, ps, mask1, mask2 are all tensor
78+
c = circuit(param, n, nlayers)
79+
ss = sliced_state(c, cut, mask1)
80+
ssc = sliced_state(c, cut, mask2)
81+
ssc, _ = tc.Circuit.copy([ssc], conj=True)
82+
op_nodes, op_edges = sliced_op(ps, cut, mask1, mask2)
83+
nodes = [ss] + ssc + op_nodes
84+
ssc = ssc[0]
85+
n = c._nqubits
86+
nleft = n - len(cut)
87+
for i in range(nleft):
88+
op_edges[i + nleft] ^ ss[i]
89+
op_edges[i] ^ ssc[i]
90+
scalar = tc.contractor(nodes)
91+
return K.real(scalar.tensor)
92+
93+
94+
sliced_core_vvg = K.jit(
95+
K.vectorized_value_and_grad(sliced_core, argnums=0, vectorized_argnums=(5, 6)),
96+
static_argnums=(1, 2, 4),
97+
) # vmap version if memory is enough
98+
99+
sliced_core_vg = K.jit(
100+
K.value_and_grad(sliced_core, argnums=0),
101+
static_argnums=(1, 2, 4),
102+
) # nonvmap version is memory is tight and distrubution workload may be enabled
103+
104+
105+
def sliced_expectation_and_grad(param, n, nlayers, ps, cut, is_vmap=True):
106+
pst = tc.array_to_tensor(ps)
107+
res = 0.0
108+
mask1s = []
109+
mask2s = []
110+
for mask1 in product(*[(0, 1) for _ in cut]):
111+
mask1t = tc.array_to_tensor(np.array(mask1))
112+
mask1s.append(mask1t)
113+
mask2 = list(mask1)
114+
for j, i in enumerate(cut):
115+
if ps[i] in [1, 2]:
116+
mask2[j] = 1 - mask1[j]
117+
mask2t = tc.array_to_tensor(np.array(mask2))
118+
mask2s.append(mask2t)
119+
if is_vmap:
120+
mask1s = K.stack(mask1s)
121+
mask2s = K.stack(mask2s)
122+
res = sliced_core_vvg(param, n, nlayers, pst, cut, mask1s, mask2s)
123+
res = list(res)
124+
res[0] = K.sum(res[0])
125+
res = tuple(res)
126+
else:
127+
# memory bounded
128+
# can modified to adpative pmap
129+
vs = 0.0
130+
gs = 0.0
131+
for i in range(len(mask1s)):
132+
mask1t = mask1s[i]
133+
mask2t = mask2s[i]
134+
v, g = sliced_core_vg(param, n, nlayers, pst, cut, mask1t, mask2t)
135+
vs += v
136+
gs += g
137+
res = (vs, gs)
138+
return res
139+
140+
141+
def sliced_expectation_ref(c, ps, cut):
142+
"""
143+
reference implementation
144+
"""
145+
# ps: [0, 2, 1]
146+
res = 0.0
147+
for mask1 in product(*[(0, 1) for _ in cut]):
148+
mask1t = tc.array_to_tensor(np.array(mask1))
149+
ss = sliced_state(c, cut, mask1t)
150+
mask2 = list(mask1)
151+
for j, i in enumerate(cut):
152+
if ps[i] in [1, 2]:
153+
mask2[j] = 1 - mask1[j]
154+
mask2t = tc.array_to_tensor(np.array(mask2))
155+
ssc = sliced_state(c, cut, mask2t)
156+
ssc, _ = tc.Circuit.copy([ssc], conj=True)
157+
ps = tc.array_to_tensor(ps)
158+
op_nodes, op_edges = sliced_op(ps, cut, mask1t, mask2t)
159+
nodes = [ss] + ssc + op_nodes
160+
ssc = ssc[0]
161+
n = c._nqubits
162+
nleft = n - len(cut)
163+
for i in range(nleft):
164+
op_edges[i + nleft] ^ ss[i]
165+
op_edges[i] ^ ssc[i]
166+
scalar = tc.contractor(nodes)
167+
res += scalar.tensor
168+
return res
169+
170+
171+
if __name__ == "__main__":
172+
n = 10
173+
nlayers = 5
174+
param = K.ones([n, 2 * nlayers], dtype="float32")
175+
cut = (0, 2, 5, 9)
176+
ops = [2, 0, 3, 1, 0, 0, 1, 2, 0, 1]
177+
ops_dict = tc.quantum.ps2xyz(ops)
178+
179+
def trivial_core(param, n, nlayers):
180+
c = circuit(param, n, nlayers)
181+
return K.real(c.expectation_ps(**ops_dict))
182+
183+
trivial_vg = K.jit(K.value_and_grad(trivial_core, argnums=0), static_argnums=(1, 2))
184+
185+
print("reference impl")
186+
r0 = tc.utils.benchmark(trivial_vg, param, n, nlayers)
187+
print("vmapped slice")
188+
r1 = tc.utils.benchmark(
189+
sliced_expectation_and_grad, param, n, nlayers, ops, cut, True
190+
)
191+
print("naive for slice")
192+
r2 = tc.utils.benchmark(
193+
sliced_expectation_and_grad, param, n, nlayers, ops, cut, False
194+
)
195+
196+
np.testing.assert_allclose(r0[0][0], r1[0][0], atol=1e-5)
197+
np.testing.assert_allclose(r2[0][0], r1[0][0], atol=1e-5)
198+
np.testing.assert_allclose(r0[0][1], r1[0][1], atol=1e-5)
199+
np.testing.assert_allclose(r2[0][1], r1[0][1], atol=1e-5)

0 commit comments

Comments
 (0)