Skip to content

Commit 1e06171

Browse files
Merge branch 'feature/gate-layering' of github.com:Stellogic/tensorcircuit-ng into ngpr24
2 parents 0343b6b + 947d357 commit 1e06171

File tree

3 files changed

+229
-0
lines changed

3 files changed

+229
-0
lines changed

examples/vqe2d_lattice.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
This example demonstrates how to use the VQE algorithm to find the ground state
3+
of a 2D Heisenberg model on a square lattice. It showcases the setup of the lattice,
4+
the Heisenberg Hamiltonian, a suitable ansatz, and the optimization process.
5+
"""
6+
7+
import time
8+
import optax
9+
import tensorcircuit as tc
10+
from tensorcircuit.templates.lattice import SquareLattice, get_compatible_layers
11+
from tensorcircuit.templates.hamiltonians import heisenberg_hamiltonian
12+
13+
# Use JAX for high-performance, especially on GPU.
14+
K = tc.set_backend("jax")
15+
tc.set_dtype("complex64")
16+
# On Windows, cotengra's multiprocessing can cause issues, use threads instead.
17+
tc.set_contractor("cotengra-8192-8192", parallel="threads")
18+
19+
20+
def run_vqe():
21+
"""Set up and run the VQE optimization for a 2D Heisenberg model."""
22+
n, m, nlayers = 4, 4, 2
23+
lattice = SquareLattice(size=(n, m), pbc=True, precompute_neighbors=1)
24+
h = heisenberg_hamiltonian(lattice, j_coupling=[1.0, 1.0, 0.8]) # Jx, Jy, Jz
25+
nn_bonds = lattice.get_neighbor_pairs(k=1, unique=True)
26+
gate_layers = get_compatible_layers(nn_bonds)
27+
n_params = nlayers * len(nn_bonds) * 3
28+
29+
def singlet_init(circuit):
30+
# A good initial state for Heisenberg ground state search
31+
nq = circuit._nqubits
32+
for i in range(0, nq - 1, 2):
33+
j = (i + 1) % nq
34+
circuit.X(i)
35+
circuit.H(i)
36+
circuit.cnot(i, j)
37+
circuit.X(j)
38+
return circuit
39+
40+
def vqe_forward(param):
41+
"""
42+
Defines the VQE ansatz and computes the energy expectation.
43+
The ansatz consists of nlayers of RZZ, RXX, and RYY entangling layers.
44+
"""
45+
c = tc.Circuit(n * m)
46+
c = singlet_init(c)
47+
param_idx = 0
48+
49+
for _ in range(nlayers):
50+
for layer in gate_layers:
51+
for j, k in layer:
52+
c.rzz(int(j), int(k), theta=param[param_idx])
53+
param_idx += 1
54+
for layer in gate_layers:
55+
for j, k in layer:
56+
c.rxx(int(j), int(k), theta=param[param_idx])
57+
param_idx += 1
58+
for layer in gate_layers:
59+
for j, k in layer:
60+
c.ryy(int(j), int(k), theta=param[param_idx])
61+
param_idx += 1
62+
63+
return tc.templates.measurements.operator_expectation(c, h)
64+
65+
vgf = K.jit(K.value_and_grad(vqe_forward))
66+
param = tc.backend.implicit_randn(stddev=0.02, shape=[n_params])
67+
optimizer = optax.adam(learning_rate=3e-3)
68+
opt_state = optimizer.init(param)
69+
70+
@K.jit
71+
def train_step(param, opt_state):
72+
"""A single training step, JIT-compiled for maximum speed."""
73+
loss_val, grads = vgf(param)
74+
updates, opt_state = optimizer.update(grads, opt_state, param)
75+
param = optax.apply_updates(param, updates)
76+
return param, opt_state, loss_val
77+
78+
print("Starting VQE optimization...")
79+
for i in range(1000):
80+
time0 = time.time()
81+
param, opt_state, loss = train_step(param, opt_state)
82+
time1 = time.time()
83+
if i % 10 == 0:
84+
print(
85+
f"Step {i:4d}: Loss = {loss:.6f} \t (Time per step: {time1 - time0:.4f}s)"
86+
)
87+
88+
print("Optimization finished.")
89+
print(f"Final Loss: {loss:.6f}")
90+
91+
92+
if __name__ == "__main__":
93+
run_vqe()

tensorcircuit/templates/lattice.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Union,
1616
TYPE_CHECKING,
1717
cast,
18+
Set,
1819
)
1920

2021
logger = logging.getLogger(__name__)
@@ -1446,3 +1447,54 @@ def remove_sites(self, identifiers: List[SiteIdentifier]) -> None:
14461447
logger.info(
14471448
f"{len(ids_to_remove)} sites removed. Lattice now has {self.num_sites} sites."
14481449
)
1450+
1451+
1452+
def get_compatible_layers(bonds: List[Tuple[int, int]]) -> List[List[Tuple[int, int]]]:
1453+
"""
1454+
Partitions a list of pairs (bonds) into compatible layers for parallel
1455+
gate application using a greedy edge-coloring algorithm.
1456+
1457+
This function takes a list of pairs, representing connections like
1458+
nearest-neighbor (NN) or next-nearest-neighbor (NNN) bonds, and
1459+
partitions them into the minimum number of sets ("layers") where no two
1460+
pairs in a set share an index. This is a general utility for scheduling
1461+
non-overlapping operations.
1462+
1463+
:Example:
1464+
1465+
>>> from tensorcircuit.templates.lattice import SquareLattice
1466+
>>> sq_lattice = SquareLattice(size=(2, 2), pbc=False)
1467+
>>> nn_bonds = sq_lattice.get_neighbor_pairs(k=1, unique=True)
1468+
1469+
>>> gate_layers = get_compatible_layers(nn_bonds)
1470+
>>> print(gate_layers)
1471+
[[[0, 1], [2, 3]], [[0, 2], [1, 3]]]
1472+
1473+
:param bonds: A list of tuples, where each tuple represents a bond (i, j)
1474+
of site indices to be scheduled.
1475+
:type bonds: List[Tuple[int, int]]
1476+
:return: A list of layers. Each layer is a list of tuples, where each
1477+
tuple represents a bond. All bonds within a layer are non-overlapping.
1478+
:rtype: List[List[Tuple[int, int]]]
1479+
"""
1480+
uncolored_edges: Set[Tuple[int, int]] = {(min(bond), max(bond)) for bond in bonds}
1481+
1482+
layers: List[List[Tuple[int, int]]] = []
1483+
1484+
while uncolored_edges:
1485+
current_layer: List[Tuple[int, int]] = []
1486+
qubits_in_this_layer: Set[int] = set()
1487+
1488+
edges_to_process = sorted(list(uncolored_edges))
1489+
1490+
for edge in edges_to_process:
1491+
i, j = edge
1492+
if i not in qubits_in_this_layer and j not in qubits_in_this_layer:
1493+
current_layer.append(edge)
1494+
qubits_in_this_layer.add(i)
1495+
qubits_in_this_layer.add(j)
1496+
1497+
uncolored_edges -= set(current_layer)
1498+
layers.append(sorted(current_layer))
1499+
1500+
return layers

tests/test_lattice.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
RectangularLattice,
2424
SquareLattice,
2525
TriangularLattice,
26+
AbstractLattice,
27+
get_compatible_layers,
2628
)
2729

2830

@@ -1664,3 +1666,85 @@ def test_distance_matrix_invariants_for_all_lattice_types(self, lattice):
16641666
# "The specialized PBC implementation is significantly slower "
16651667
# "than the general-purpose implementation."
16661668
# )
1669+
1670+
1671+
def _validate_layers(bonds, layers) -> None:
1672+
"""
1673+
A helper function to scientifically validate the output of get_compatible_layers.
1674+
"""
1675+
# MODIFICATION: This function now takes the original bonds list for comparison.
1676+
expected_edges = set(tuple(sorted(b)) for b in bonds)
1677+
actual_edges = set(tuple(sorted(edge)) for layer in layers for edge in layer)
1678+
1679+
assert (
1680+
expected_edges == actual_edges
1681+
), "Completeness check failed: The set of all edges in the layers must "
1682+
"exactly match the input bonds."
1683+
1684+
for i, layer in enumerate(layers):
1685+
qubits_in_layer: set[int] = set()
1686+
for edge in layer:
1687+
q1, q2 = edge
1688+
assert (
1689+
q1 not in qubits_in_layer
1690+
), f"Compatibility check failed: Qubit {q1} is reused in layer {i}."
1691+
qubits_in_layer.add(q1)
1692+
assert (
1693+
q2 not in qubits_in_layer
1694+
), f"Compatibility check failed: Qubit {q2} is reused in layer {i}."
1695+
qubits_in_layer.add(q2)
1696+
1697+
1698+
@pytest.mark.parametrize(
1699+
"lattice_instance",
1700+
[
1701+
SquareLattice(size=(3, 2), pbc=False),
1702+
SquareLattice(size=(3, 3), pbc=True),
1703+
HoneycombLattice(size=(2, 2), pbc=False),
1704+
],
1705+
ids=[
1706+
"SquareLattice_3x2_OBC",
1707+
"SquareLattice_3x3_PBC",
1708+
"HoneycombLattice_2x2_OBC",
1709+
],
1710+
)
1711+
def test_layering_on_various_lattices(lattice_instance: AbstractLattice):
1712+
"""Tests gate layering for various standard lattice types."""
1713+
bonds = lattice_instance.get_neighbor_pairs(k=1, unique=True)
1714+
layers = get_compatible_layers(bonds)
1715+
1716+
assert len(layers) > 0, "Layers should not be empty for non-trivial lattices."
1717+
_validate_layers(bonds, layers)
1718+
1719+
1720+
def test_layering_on_1d_chain_pbc():
1721+
"""Test layering on a 1D chain with periodic boundaries (a cycle graph)."""
1722+
lattice_even = ChainLattice(size=(6,), pbc=True)
1723+
bonds_even = lattice_even.get_neighbor_pairs(k=1, unique=True)
1724+
layers_even = get_compatible_layers(bonds_even)
1725+
_validate_layers(bonds_even, layers_even)
1726+
1727+
lattice_odd = ChainLattice(size=(5,), pbc=True)
1728+
bonds_odd = lattice_odd.get_neighbor_pairs(k=1, unique=True)
1729+
layers_odd = get_compatible_layers(bonds_odd)
1730+
assert len(layers_odd) == 3, "A 5-site cycle graph should be 3-colorable."
1731+
_validate_layers(bonds_odd, layers_odd)
1732+
1733+
1734+
def test_layering_on_custom_star_graph():
1735+
"""Test layering on a custom lattice forming a star graph."""
1736+
star_edges = [(0, 1), (0, 2), (0, 3)]
1737+
layers = get_compatible_layers(star_edges)
1738+
assert len(layers) == 3, "A star graph S_4 requires 3 layers."
1739+
_validate_layers(star_edges, layers)
1740+
1741+
1742+
def test_layering_on_edge_cases():
1743+
"""Test various edge cases: empty, single-site, and no-edge lattices."""
1744+
layers_empty = get_compatible_layers([])
1745+
assert layers_empty == [], "Layers should be empty for an empty set of bonds."
1746+
1747+
single_edge = [(0, 1)]
1748+
layers_single = get_compatible_layers(single_edge)
1749+
assert layers_single == [[(0, 1)]]
1750+
_validate_layers(single_edge, layers_single)

0 commit comments

Comments
 (0)