Skip to content

Commit 947d357

Browse files
committed
fix according to the review 2
1 parent ebdebc9 commit 947d357

File tree

2 files changed

+22
-39
lines changed

2 files changed

+22
-39
lines changed

examples/vqe2d_lattice.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
"""
2+
This example demonstrates how to use the VQE algorithm to find the ground state
3+
of a 2D Heisenberg model on a square lattice. It showcases the setup of the lattice,
4+
the Heisenberg Hamiltonian, a suitable ansatz, and the optimization process.
5+
"""
6+
17
import time
28
import optax
39
import tensorcircuit as tc
@@ -7,16 +13,18 @@
713
# Use JAX for high-performance, especially on GPU.
814
K = tc.set_backend("jax")
915
tc.set_dtype("complex64")
10-
# On Windows, cotengra's multiprocessing can cause issues.
11-
tc.set_contractor("cotengra-8192-8192", parallel=False)
16+
# On Windows, cotengra's multiprocessing can cause issues, use threads instead.
17+
tc.set_contractor("cotengra-8192-8192", parallel="threads")
1218

1319

1420
def run_vqe():
15-
n, m, nlayers = 4, 4, 6
21+
"""Set up and run the VQE optimization for a 2D Heisenberg model."""
22+
n, m, nlayers = 4, 4, 2
1623
lattice = SquareLattice(size=(n, m), pbc=True, precompute_neighbors=1)
1724
h = heisenberg_hamiltonian(lattice, j_coupling=[1.0, 1.0, 0.8]) # Jx, Jy, Jz
1825
nn_bonds = lattice.get_neighbor_pairs(k=1, unique=True)
1926
gate_layers = get_compatible_layers(nn_bonds)
27+
n_params = nlayers * len(nn_bonds) * 3
2028

2129
def singlet_init(circuit):
2230
# A good initial state for Heisenberg ground state search
@@ -36,22 +44,26 @@ def vqe_forward(param):
3644
"""
3745
c = tc.Circuit(n * m)
3846
c = singlet_init(c)
47+
param_idx = 0
3948

40-
for i in range(nlayers):
49+
for _ in range(nlayers):
4150
for layer in gate_layers:
4251
for j, k in layer:
43-
c.rzz(int(j), int(k), theta=param[i, 0])
52+
c.rzz(int(j), int(k), theta=param[param_idx])
53+
param_idx += 1
4454
for layer in gate_layers:
4555
for j, k in layer:
46-
c.rxx(int(j), int(k), theta=param[i, 1])
56+
c.rxx(int(j), int(k), theta=param[param_idx])
57+
param_idx += 1
4758
for layer in gate_layers:
4859
for j, k in layer:
49-
c.ryy(int(j), int(k), theta=param[i, 2])
60+
c.ryy(int(j), int(k), theta=param[param_idx])
61+
param_idx += 1
5062

5163
return tc.templates.measurements.operator_expectation(c, h)
5264

5365
vgf = K.jit(K.value_and_grad(vqe_forward))
54-
param = tc.backend.implicit_randn(stddev=0.02, shape=[nlayers, 3])
66+
param = tc.backend.implicit_randn(stddev=0.02, shape=[n_params])
5567
optimizer = optax.adam(learning_rate=3e-3)
5668
opt_state = optimizer.init(param)
5769

tests/test_lattice.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from unittest.mock import patch
22
import logging
3-
from typing import List, Set, Tuple
43

54
# import time
65

@@ -1669,35 +1668,7 @@ def test_distance_matrix_invariants_for_all_lattice_types(self, lattice):
16691668
# )
16701669

16711670

1672-
class MockLattice(AbstractLattice):
1673-
"""A mock lattice class for testing purposes to precisely control neighbors."""
1674-
1675-
def __init__(self, neighbor_pairs: List[Tuple[int, int]]):
1676-
super().__init__(dimensionality=0)
1677-
# Ensure bonds are stored in a canonical sorted format for consistency
1678-
self._neighbor_pairs = [tuple(sorted(p)) for p in neighbor_pairs]
1679-
1680-
def get_neighbor_pairs(
1681-
self, k: int = 1, unique: bool = True
1682-
) -> List[Tuple[int, int]]:
1683-
# The mock lattice only knows about k=1 neighbors
1684-
if k == 1:
1685-
return self._neighbor_pairs
1686-
return []
1687-
1688-
def _build_lattice(self, *args, **kwargs) -> None:
1689-
pass
1690-
1691-
def _build_neighbors(self, max_k: int = 1, **kwargs) -> None:
1692-
pass
1693-
1694-
def _compute_distance_matrix(self) -> np.ndarray:
1695-
return np.array([])
1696-
1697-
1698-
def _validate_layers(
1699-
bonds: List[Tuple[int, int]], layers: List[List[Tuple[int, int]]]
1700-
) -> None:
1671+
def _validate_layers(bonds, layers) -> None:
17011672
"""
17021673
A helper function to scientifically validate the output of get_compatible_layers.
17031674
"""
@@ -1711,7 +1682,7 @@ def _validate_layers(
17111682
"exactly match the input bonds."
17121683

17131684
for i, layer in enumerate(layers):
1714-
qubits_in_layer: Set[int] = set()
1685+
qubits_in_layer: set[int] = set()
17151686
for edge in layer:
17161687
q1, q2 = edge
17171688
assert (

0 commit comments

Comments
 (0)