Skip to content

Commit 39ba76b

Browse files
increase test cov
1 parent fa526b5 commit 39ba76b

File tree

5 files changed

+250
-0
lines changed

5 files changed

+250
-0
lines changed

tests/test_circuit.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,6 +1189,8 @@ def test_qiskit2tc():
11891189
qisc_from_tc = c.to_qiskit(enable_instruction=True)
11901190
qis_unitary2 = qi.Operator(qisc_from_tc)
11911191
qis_unitary2 = np.reshape(qis_unitary2, [2**n, 2**n])
1192+
# Note: the following assertion may fail intermittently due to non-deterministic behavior
1193+
# in Qiskit's UnitaryGate.control() method as of Qiskit 0.46.3
11921194
np.testing.assert_allclose(qis_unitary2, qis_unitary, atol=1e-5)
11931195

11941196

@@ -1812,3 +1814,25 @@ def test_cirq_gates_translation(backend):
18121814
u_cirq4 = cirq.unitary(c4)
18131815
u_tc4 = c4_tc.matrix()
18141816
assert_allclose_up_to_global_phase(u_cirq4, tc.backend.numpy(u_tc4), atol=1e-5)
1817+
1818+
1819+
@pytest.mark.parametrize("backend", [lf("jaxb"), lf("tfb")])
1820+
def test_circuit_extra_coverage(backend):
1821+
c = tc.Circuit(2)
1822+
c.h(0)
1823+
c.cx(0, 1)
1824+
1825+
# inverse
1826+
c1 = c.inverse()
1827+
assert len(c1.to_qir()) == 2
1828+
1829+
# replace_mps_inputs
1830+
c2 = tc.Circuit(2)
1831+
c2.replace_mps_inputs(c.quvector())
1832+
1833+
# depolarizing2
1834+
c.depolarizing2(0, px=0.1, py=0.1, pz=0.1)
1835+
1836+
# unitary_kraus2
1837+
kraus = [tc.gates.x(), tc.gates.y()]
1838+
c.unitary_kraus2(kraus, 0, prob=[0.5, 0.5])

tests/test_miscs.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,3 +302,162 @@ def exp(theta):
302302
return c.expectation_ps(z=[-3], reuse=False)
303303

304304
assert len(exp(0.3)) == 9
305+
306+
307+
@pytest.mark.parametrize("backend", [lf("jaxb"), lf("tfb")])
308+
def test_parameter_shift_grad(backend):
309+
def f(params):
310+
c = tc.Circuit(2)
311+
c.rx(0, theta=params[0])
312+
c.ry(1, theta=params[1])
313+
c.cnot(0, 1)
314+
return tc.backend.real(c.expectation_ps(z=[0, 1]))
315+
316+
params = tc.array_to_tensor(np.array([0.1, 0.2]))
317+
318+
# Standard AD gradient
319+
g_ad = tc.backend.grad(f)(params)
320+
321+
# Parameter shift gradient
322+
g_ps = experimental.parameter_shift_grad(f)(params)
323+
324+
np.testing.assert_allclose(g_ad, g_ps, atol=1e-5)
325+
326+
327+
@pytest.mark.parametrize("backend", [lf("jaxb")])
328+
def test_parameter_shift_grad_v2(backend):
329+
# v2 is mainly for jax and supports randomness
330+
def f(params):
331+
c = tc.Circuit(2)
332+
c.rx(0, theta=params[0])
333+
c.ry(1, theta=params[1])
334+
return tc.backend.real(c.expectation_ps(z=[0]))
335+
336+
params = tc.array_to_tensor(np.array([0.5, 0.5]))
337+
g_ps = experimental.parameter_shift_grad_v2(f)(params)
338+
g_ad = tc.backend.grad(f)(params)
339+
np.testing.assert_allclose(g_ps, g_ad, atol=1e-5)
340+
341+
342+
def test_broadcast_py_object_single_process(jaxb):
343+
# In a single process environment, broadcast should just return the object
344+
# though it uses jax.experimental.multihost_utils.broadcast_one_to_all
345+
obj = {"a": 1, "b": [1, 2, 3]}
346+
res = experimental.broadcast_py_object(obj)
347+
assert res == obj
348+
349+
350+
@pytest.mark.parametrize("backend", [lf("jaxb")])
351+
def test_jax_jitted_function_save_load_v2(backend, tmp_path):
352+
K = tc.backend
353+
354+
@K.jit
355+
def f(x):
356+
return x**2 + 1.0
357+
358+
x = K.ones([2])
359+
path = os.path.join(tmp_path, "f.bin")
360+
experimental.jax_jitted_function_save(path, f, x)
361+
362+
f_load = experimental.jax_jitted_function_load(path)
363+
np.testing.assert_allclose(f_load(x), f(x), atol=1e-5)
364+
365+
366+
@pytest.mark.parametrize("backend", [lf("jaxb"), lf("tfb")])
367+
def test_qng_options(backend):
368+
def f(params):
369+
c = tc.Circuit(1)
370+
c.rx(0, theta=params[0])
371+
return c.state()
372+
373+
params = tc.backend.ones([1])
374+
# test different options in qng to hit more lines
375+
qng_fn = experimental.qng(f, mode="fwd")
376+
qng_fn(params)
377+
378+
qng_fn2 = experimental.qng(f, mode="rev")
379+
qng_fn2(params)
380+
381+
qng_fn3 = experimental.qng(f, kernel="dynamics", postprocess=None)
382+
qng_fn3(params)
383+
384+
385+
@pytest.mark.parametrize("backend", [lf("jaxb"), lf("tfb")])
386+
def test_qng2_options(backend):
387+
def f(params):
388+
c = tc.Circuit(1)
389+
c.rx(0, theta=params[0])
390+
return c.state()
391+
392+
params = tc.backend.ones([1])
393+
qng_fn = experimental.qng2(f, mode="fwd")
394+
qng_fn(params)
395+
396+
qng_fn2 = experimental.qng2(f, mode="rev")
397+
qng_fn2(params)
398+
399+
qng_fn3 = experimental.qng2(f, kernel="dynamics", postprocess=None)
400+
qng_fn3(params)
401+
402+
403+
def test_vis_extra():
404+
c = tc.Circuit(2)
405+
c.h(0)
406+
c.cx(0, 1)
407+
tex = tc.vis.qir2tex(c.to_qir(), 2)
408+
assert "\\qw" in tex
409+
410+
assert tc.vis.gate_name_trans("ccnot") == (2, "not")
411+
assert tc.vis.gate_name_trans("h") == (0, "h")
412+
413+
414+
def test_cons_extra(jaxb):
415+
# set_function_backend
416+
@tc.cons.set_function_backend("jax")
417+
def f():
418+
return tc.backend.name
419+
420+
# set_function_dtype
421+
@tc.cons.set_function_dtype("complex128")
422+
def g():
423+
return tc.dtypestr
424+
425+
426+
def test_ascii_art():
427+
# hit some lines in asciiart.py
428+
429+
try:
430+
tc.set_ascii("wrong")
431+
except AttributeError:
432+
pass
433+
434+
# lucky() is only available after set_ascii
435+
assert not hasattr(tc, "lucky")
436+
437+
438+
def test_utils_extra():
439+
from tensorcircuit import utils
440+
441+
# return_partial
442+
f = lambda x: [x, x**2, x**3]
443+
f1 = utils.return_partial(f, return_argnums=1)
444+
assert f1(2) == 4
445+
f2 = utils.return_partial(f, return_argnums=[0, 2])
446+
assert f2(2) == (2, 8)
447+
448+
# append
449+
f3 = utils.append(lambda x: x**2, lambda x: x + 1)
450+
assert f3(2) == 5
451+
452+
# is_m1mac
453+
utils.is_m1mac()
454+
455+
# is_sequence, is_number
456+
assert utils.is_sequence([1])
457+
assert utils.is_number(1.0)
458+
459+
# benchmark
460+
def h(x):
461+
return x + 1
462+
463+
utils.benchmark(h, 1.0, tries=2)

tests/test_mpscircuit.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,30 @@ def expec(params):
372372
tc.backend.numpy(exp_grad_dir_jit),
373373
atol=1e-6,
374374
)
375+
376+
377+
def test_mps_base_extra(jaxb):
378+
from tensorcircuit.mps_base import FiniteMPS
379+
380+
# create a simple MPS
381+
n = 4
382+
nodes = [
383+
tc.backend.ones([1, 2, 2]),
384+
tc.backend.ones([2, 2, 2]),
385+
tc.backend.ones([2, 2, 2]),
386+
tc.backend.ones([2, 2, 1]),
387+
]
388+
mps = FiniteMPS(nodes, center_position=0)
389+
mps1 = mps.copy()
390+
mps2 = mps.conj()
391+
392+
# measure_local_operator
393+
ops = [tc.backend.ones([2, 2])]
394+
mps.measure_local_operator(ops, [1])
395+
396+
# measure_two_body_correlator
397+
mps.measure_two_body_correlator(ops[0], ops[0], 0, [1, 2])
398+
399+
# apply_two_site_gate
400+
gate = tc.backend.ones([2, 2, 2, 2])
401+
mps.apply_two_site_gate(gate, 1, 2)

tests/test_shadows.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,21 @@ def test_ent(backend):
158158
#
159159
# assert np.isclose(expc, pl_expc)
160160
# assert np.isclose(ent, pl_ent)
161+
162+
163+
def test_shadow_extra(jaxb):
164+
165+
ps = [1, 2, 3] # X, Y, Z
166+
N, k = shadow_bound(ps, 0.1)
167+
assert N > 0
168+
assert k > 0
169+
170+
# test shadow_snapshots with measurement_only and sub
171+
c = tc.Circuit(3)
172+
c.h(range(3))
173+
psi = c.state()
174+
ns = 2
175+
pauli_strings = tc.backend.convert_to_tensor(np.random.randint(1, 4, size=(ns, 3)))
176+
177+
snapshots = shadow_snapshots(psi, pauli_strings, measurement_only=True, sub=[0, 1])
178+
assert snapshots.shape == (ns, 1, 2)

tests/test_simplify.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,25 @@ def test_rank_simplify():
4444

4545
nodes = simplify._full_rank_simplify([f, g, h])
4646
assert len(nodes) == 2
47+
48+
49+
def test_simplify_extra():
50+
from tensorcircuit import simplify
51+
import tensorcircuit as tc
52+
53+
a = tn.Node(np.ones([2, 2]), name="a")
54+
b = tn.Node(np.ones([2, 2]), name="b")
55+
a[1] ^ b[0]
56+
nodes = simplify._full_rank_simplify([a, b])
57+
assert len(nodes) == 1
58+
59+
# _full_light_cone_cancel
60+
c = tc.Circuit(2)
61+
c.h(0)
62+
c.cx(0, 1)
63+
c.h(0)
64+
# usually used in expectation where the psi and its conj can cancel
65+
# but we can just call it on any nodes list
66+
qir = c.to_qir()
67+
nodes = [g["gate"] for g in qir]
68+
simplify._full_light_cone_cancel(nodes)

0 commit comments

Comments
 (0)