Skip to content

Commit 89ee53e

Browse files
postprocessing for dc and amplitude example
1 parent f060edb commit 89ee53e

File tree

5 files changed

+214
-54
lines changed

5 files changed

+214
-54
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""
2+
amplitude constraction on multiple GPU cards with neat interface `DistributedContractor`
3+
"""
4+
5+
import os
6+
7+
NUM_DEVICES = 4
8+
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={NUM_DEVICES}"
9+
10+
import time
11+
import jax
12+
from jax import numpy as jnp
13+
import tensorcircuit as tc
14+
from tensorcircuit.experimental import DistributedContractor
15+
16+
K = tc.set_backend("jax")
17+
tc.set_dtype("complex64")
18+
19+
20+
N_QUBITS = 16
21+
DEPTH = 8
22+
23+
24+
def circuit_ansatz(n, d, params):
25+
c = tc.Circuit(n)
26+
c.h(range(n))
27+
for i in range(d):
28+
for j in range(0, n - 1):
29+
c.rzz(j, j + 1, theta=params[j, i, 0])
30+
for j in range(n):
31+
c.rx(j, theta=params[j, i, 1])
32+
for j in range(n):
33+
c.ry(j, theta=params[j, i, 2])
34+
return c
35+
36+
37+
def get_nodes_fn(n, d):
38+
def nodes_fn(params):
39+
psi = circuit_ansatz(n, d, params["circuit"])
40+
return psi.amplitude_before(params["amplitude"])
41+
42+
return nodes_fn
43+
44+
45+
def get_binary_representation(i: int, N: int) -> jax.Array:
46+
"""
47+
Generates the binary representation of an integer as a JAX array.
48+
"""
49+
# Create an array of shift amounts, from N-1 down to 0
50+
# For N=8, this is [7, 6, 5, 4, 3, 2, 1, 0]
51+
shifts = jnp.arange(N - 1, -1, -1)
52+
# Right-shift the integer 'i' by each amount in 'shifts'.
53+
# This effectively isolates each bit at the rightmost position.
54+
# For i=5 (..0101) and shifts=[..., 3, 2, 1, 0]
55+
# shifted_i will be [..0, ..0, ..1, ..10, ..101] -> [0, 0, 1, 2, 5]
56+
shifted_i = i >> shifts
57+
# Use a bitwise AND with 1 to extract just the last bit from each shifted value.
58+
# [0&1, 0&1, 1&1, 2&1, 5&1] -> [0, 0, 1, 0, 1]
59+
# We explicitly cast to int32 as requested.
60+
bits = (shifted_i & 1).astype(jnp.int32)
61+
return bits
62+
63+
64+
if __name__ == "__main__":
65+
print(f"JAX is using {jax.local_device_count()} devices.")
66+
67+
nodes_fn = get_nodes_fn(N_QUBITS, DEPTH)
68+
69+
@K.jit
70+
def baseline(params):
71+
psi = circuit_ansatz(N_QUBITS, DEPTH, params["circuit"])
72+
return psi.amplitude(params["amplitude"])
73+
74+
key = jax.random.PRNGKey(42)
75+
params_circuit = (
76+
jax.random.normal(key, shape=[N_QUBITS, DEPTH, 3], dtype=tc.rdtypestr) * 0.1
77+
)
78+
params = {
79+
"circuit": params_circuit,
80+
"amplitude": get_binary_representation(0, N_QUBITS),
81+
}
82+
DC = DistributedContractor(
83+
nodes_fn=nodes_fn,
84+
params=params,
85+
cotengra_options={
86+
"slicing_reconf_opts": {"target_size": 2**16},
87+
"max_repeats": 64,
88+
"progbar": True,
89+
"minimize": "write",
90+
"parallel": 4,
91+
},
92+
)
93+
94+
n_steps = 100
95+
96+
print("\nStarting amplitude loop...")
97+
for i in range(n_steps):
98+
bs_vector = get_binary_representation(i, N_QUBITS)
99+
t0 = time.time()
100+
params = {"circuit": params_circuit, "amplitude": bs_vector}
101+
amp = DC.value(params)
102+
t1 = time.time()
103+
print(
104+
f"Bitstring: {K.numpy(bs_vector).tolist()} | "
105+
f"amp: {amp:.8f} | "
106+
f"baseline_amp: {baseline(params):.8f} | "
107+
f"Time: {t1 - t0:.4f} s"
108+
)

examples/new_distributed_interface.py renamed to examples/distributed_interface_vqe.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,5 +112,7 @@ def opt_update(params, opt_state, grads):
112112
print(f"Step {i+1:03d} | " f"Loss: {loss:.8f} | " f"Time: {t1 - t0:.4f} s")
113113

114114
print("\nOptimization finished.")
115-
final_energy = DC.value(params)
115+
final_energy = DC.value(
116+
params, op=lambda x: K.real(K.sum(x)), output_dtype=tc.rdtypestr
117+
)
116118
print(f"Final energy: {final_energy:.8f}")

tensorcircuit/basecircuit.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -441,27 +441,17 @@ def measure_jit(
441441

442442
measure = measure_jit
443443

444-
def amplitude(self, l: Union[str, Tensor]) -> Tensor:
444+
def amplitude_before(self, l: Union[str, Tensor]) -> List[Gate]:
445445
r"""
446-
Returns the amplitude of the circuit given the bitstring l.
446+
Returns the tensornetwor nodes for the amplitude of the circuit given the bitstring l.
447447
For state simulator, it computes :math:`\langle l\vert \psi\rangle`,
448448
for density matrix simulator, it computes :math:`Tr(\rho \vert l\rangle \langle 1\vert)`
449449
Note how these two are different up to a square operation.
450450
451-
:Example:
452-
453-
>>> c = tc.Circuit(2)
454-
>>> c.X(0)
455-
>>> c.amplitude("10")
456-
array(1.+0.j, dtype=complex64)
457-
>>> c.CNOT(0, 1)
458-
>>> c.amplitude("11")
459-
array(1.+0.j, dtype=complex64)
460-
461451
:param l: The bitstring of 0 and 1s.
462452
:type l: Union[str, Tensor]
463-
:return: The amplitude of the circuit.
464-
:rtype: tn.Node.tensor
453+
:return: The tensornetwork nodes for the amplitude of the circuit.
454+
:rtype: List[Gate]
465455
"""
466456
no, d_edges = self._copy()
467457
ms = []
@@ -502,6 +492,32 @@ def amplitude(self, l: Union[str, Tensor]) -> Tensor:
502492
no.extend(ms)
503493
if self.is_dm:
504494
no.extend(msconj)
495+
return no
496+
497+
def amplitude(self, l: Union[str, Tensor]) -> Tensor:
498+
r"""
499+
Returns the amplitude of the circuit given the bitstring l.
500+
For state simulator, it computes :math:`\langle l\vert \psi\rangle`,
501+
for density matrix simulator, it computes :math:`Tr(\rho \vert l\rangle \langle 1\vert)`
502+
Note how these two are different up to a square operation.
503+
504+
:Example:
505+
506+
>>> c = tc.Circuit(2)
507+
>>> c.X(0)
508+
>>> c.amplitude("10")
509+
array(1.+0.j, dtype=complex64)
510+
>>> c.CNOT(0, 1)
511+
>>> c.amplitude("11")
512+
array(1.+0.j, dtype=complex64)
513+
514+
:param l: The bitstring of 0 and 1s.
515+
:type l: Union[str, Tensor]
516+
:return: The amplitude of the circuit.
517+
:rtype: tn.Node.tensor
518+
"""
519+
no = self.amplitude_before(l)
520+
505521
return contractor(no).tensor
506522

507523
def probability(self) -> Tensor:

tensorcircuit/experimental.py

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,12 @@ def __init__(
737737

738738
logger.info("Initialization complete.")
739739

740-
def _get_single_slice_contraction_fn(self) -> Callable[[Any, Tensor, int], Tensor]:
740+
def _get_single_slice_contraction_fn(
741+
self, op: Optional[Callable[[Tensor], Tensor]] = None
742+
) -> Callable[[Any, Tensor, int], Tensor]:
743+
if op is None:
744+
op = backend.sum
745+
741746
def single_slice_contraction(
742747
tree: ctg.ContractionTree, params: Tensor, slice_idx: int
743748
) -> Tensor:
@@ -746,16 +751,25 @@ def single_slice_contraction(
746751
input_arrays = [node.tensor for node in standardized_nodes]
747752
sliced_arrays = tree.slice_arrays(input_arrays, slice_idx)
748753
result = tree.contract_core(sliced_arrays, backend=self._backend)
749-
return backend.sum(backend.real(result))
754+
return op(result)
750755

751756
return single_slice_contraction
752757

753758
def _get_device_sum_vg_fn(
754759
self,
760+
op: Optional[Callable[[Tensor], Tensor]] = None,
761+
output_dtype: Optional[str] = None,
755762
) -> Callable[[Any, Tensor, Tensor], Tuple[Tensor, Tensor]]:
756-
base_fn = self._get_single_slice_contraction_fn()
763+
post_processing = lambda x: backend.real(backend.sum(x))
764+
if op is None:
765+
op = post_processing
766+
base_fn = self._get_single_slice_contraction_fn(op=op)
767+
# to ensure the output is real so that can be differentiated
757768
single_slice_vg_fn = jaxlib.value_and_grad(base_fn, argnums=1)
758769

770+
if output_dtype is None:
771+
output_dtype = rdtypestr
772+
759773
def device_sum_fn(
760774
tree: ctg.ContractionTree, params: Tensor, slice_indices_for_device: Tensor
761775
) -> Tuple[Tensor, Tensor]:
@@ -785,7 +799,7 @@ def do_nothing() -> Tuple[Tensor, Tensor]:
785799
)
786800

787801
initial_carry = (
788-
backend.cast(backend.convert_to_tensor(0.0), dtype=rdtypestr),
802+
backend.cast(backend.convert_to_tensor(0.0), dtype=output_dtype),
789803
jaxlib.tree_util.tree_map(lambda x: jaxlib.numpy.zeros_like(x), params),
790804
)
791805
(final_value, final_grads), _ = jaxlib.lax.scan(
@@ -795,21 +809,14 @@ def do_nothing() -> Tuple[Tensor, Tensor]:
795809

796810
return device_sum_fn
797811

798-
def _compile_value_and_grad(self) -> None:
799-
if self._compiled_vg_fn is not None:
800-
return
801-
device_sum_fn = self._get_device_sum_vg_fn()
802-
# `tree` is arg 0, `params` is arg 1, `indices` is arg 2
803-
# `tree` is static and broadcast to all devices
804-
self._compiled_vg_fn = jaxlib.pmap(
805-
device_sum_fn,
806-
in_axes=(None, None, 0), # tree: broadcast, params: broadcast, indices: map
807-
static_broadcasted_argnums=(0,), # arg 0 (tree) is a static argument
808-
devices=self.devices,
809-
)
810-
811-
def _get_device_sum_v_fn(self) -> Callable[[Any, Tensor, Tensor], Tensor]:
812-
base_fn = self._get_single_slice_contraction_fn()
812+
def _get_device_sum_v_fn(
813+
self,
814+
op: Optional[Callable[[Tensor], Tensor]] = None,
815+
output_dtype: Optional[str] = None,
816+
) -> Callable[[Any, Tensor, Tensor], Tensor]:
817+
base_fn = self._get_single_slice_contraction_fn(op=op)
818+
if output_dtype is None:
819+
output_dtype = dtypestr
813820

814821
def device_sum_fn(
815822
tree: ctg.ContractionTree, params: Tensor, slice_indices_for_device: Tensor
@@ -828,7 +835,7 @@ def compute_and_add() -> Tensor:
828835
)
829836

830837
initial_carry = backend.cast(
831-
backend.convert_to_tensor(0.0), dtype=rdtypestr
838+
backend.convert_to_tensor(0.0), dtype=output_dtype
832839
)
833840
final_value, _ = jaxlib.lax.scan(
834841
scan_body, initial_carry, slice_indices_for_device
@@ -837,22 +844,28 @@ def compute_and_add() -> Tensor:
837844

838845
return device_sum_fn
839846

840-
def _compile_value(self) -> None:
841-
if self._compiled_v_fn is not None:
842-
return
843-
device_sum_fn = self._get_device_sum_v_fn()
844-
self._compiled_v_fn = jaxlib.pmap(
845-
device_sum_fn,
846-
in_axes=(None, None, 0),
847-
static_broadcasted_argnums=(0,),
848-
devices=self.devices,
849-
)
850-
851847
# --- Public API ---
852848
def value_and_grad(
853-
self, params: Tensor, aggregate: bool = True
849+
self,
850+
params: Tensor,
851+
aggregate: bool = True,
852+
op: Optional[Callable[[Tensor], Tensor]] = None,
853+
output_dtype: Optional[str] = None,
854854
) -> Tuple[Tensor, Tensor]:
855-
self._compile_value_and_grad()
855+
if self._compiled_vg_fn is None:
856+
device_sum_fn = self._get_device_sum_vg_fn(op=op, output_dtype=output_dtype)
857+
# `tree` is arg 0, `params` is arg 1, `indices` is arg 2
858+
# `tree` is static and broadcast to all devices
859+
self._compiled_vg_fn = jaxlib.pmap(
860+
device_sum_fn,
861+
in_axes=(
862+
None,
863+
None,
864+
0,
865+
), # tree: broadcast, params: broadcast, indices: map
866+
static_broadcasted_argnums=(0,), # arg 0 (tree) is a static argument
867+
devices=self.devices,
868+
)
856869
# Pass `self.tree` as the first argument
857870
device_values, device_grads = self._compiled_vg_fn( # type: ignore
858871
self.tree, params, self.batched_slice_indices
@@ -865,15 +878,36 @@ def value_and_grad(
865878
return total_value, total_grad
866879
return device_values, device_grads
867880

868-
def value(self, params: Tensor, aggregate: bool = True) -> Tensor:
869-
self._compile_value()
881+
def value(
882+
self,
883+
params: Tensor,
884+
aggregate: bool = True,
885+
op: Optional[Callable[[Tensor], Tensor]] = None,
886+
output_dtype: Optional[str] = None,
887+
) -> Tensor:
888+
if self._compiled_v_fn is None:
889+
device_sum_fn = self._get_device_sum_v_fn(op=op, output_dtype=output_dtype)
890+
self._compiled_v_fn = jaxlib.pmap(
891+
device_sum_fn,
892+
in_axes=(None, None, 0),
893+
static_broadcasted_argnums=(0,),
894+
devices=self.devices,
895+
)
870896
device_values = self._compiled_v_fn( # type: ignore
871897
self.tree, params, self.batched_slice_indices
872898
)
873899
if aggregate:
874900
return backend.sum(device_values)
875901
return device_values
876902

877-
def grad(self, params: Tensor, aggregate: bool = True) -> Tensor:
878-
_, grad = self.value_and_grad(params, aggregate=aggregate)
903+
def grad(
904+
self,
905+
params: Tensor,
906+
aggregate: bool = True,
907+
op: Optional[Callable[[Tensor], Tensor]] = None,
908+
output_dtype: Optional[str] = None,
909+
) -> Tensor:
910+
_, grad = self.value_and_grad(
911+
params, aggregate=aggregate, op=op, output_dtype=output_dtype
912+
)
879913
return grad

tests/test_stabilizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ def test_circuit_inputs():
178178

179179
def test_depolarize():
180180
r = []
181-
for _ in range(20):
181+
for _ in range(40):
182182
c = tc.StabilizerCircuit(2)
183183
c.h(0)
184184
c.depolarizing(0, 1, p=0.2)
185185
c.h(0)
186186
r.append(c.expectation_ps(z=[0]))
187-
assert 4 < np.sum(r) < 20
187+
assert 5 < np.sum(r) < 38
188188

189189

190190
def test_tableau_inputs():

0 commit comments

Comments
 (0)