Skip to content

Commit d7578f1

Browse files
committed
Merge remote-tracking branch 'origin/feature/gate-layering' into feature/gate-layering
2 parents 9bb2b52 + 4368a32 commit d7578f1

File tree

2 files changed

+184
-0
lines changed

2 files changed

+184
-0
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from typing import List, Set, Tuple
2+
3+
from .lattice import AbstractLattice
4+
5+
6+
def get_nn_gate_layers(lattice: AbstractLattice) -> List[List[Tuple[int, int]]]:
7+
"""
8+
Partitions nearest-neighbor pairs into compatible layers for parallel
9+
gate application using a greedy edge-coloring algorithm.
10+
11+
In quantum circuits, a single qubit cannot participate in more than one
12+
two-qubit gate simultaneously. This function takes a lattice geometry,
13+
finds its nearest-neighbor graph, and partitions the edges of that graph
14+
(the neighbor pairs) into the minimum number of sets ("layers") where
15+
no two edges in a set share a vertex.
16+
17+
This is essential for efficiently scheduling gates in algorithms like
18+
Trotterized Hamiltonian evolution.
19+
20+
:Example:
21+
22+
>>> import numpy as np
23+
>>> from tensorcircuit.templates.lattice import SquareLattice
24+
>>> sq_lattice = SquareLattice(size=(2, 2), pbc=False)
25+
>>> gate_layers = get_nn_gate_layers(sq_lattice)
26+
>>> print(gate_layers)
27+
[[[0, 1], [2, 3]], [[0, 2], [1, 3]]]
28+
29+
:param lattice: An initialized `AbstractLattice` object from which to
30+
extract nearest-neighbor connectivity.
31+
:type lattice: AbstractLattice
32+
:return: A list of layers. Each layer is a list of tuples, where each
33+
tuple represents a nearest-neighbor pair (i, j) of site indices.
34+
All pairs within a layer are non-overlapping.
35+
:rtype: List[List[Tuple[int, int]]]
36+
"""
37+
uncolored_edges: Set[Tuple[int, int]] = set(
38+
lattice.get_neighbor_pairs(k=1, unique=True)
39+
)
40+
41+
layers: List[List[Tuple[int, int]]] = []
42+
43+
while uncolored_edges:
44+
current_layer: List[Tuple[int, int]] = []
45+
qubits_in_this_layer: Set[int] = set()
46+
edges_to_remove: Set[Tuple[int, int]] = set()
47+
48+
for edge in sorted(list(uncolored_edges)):
49+
i, j = edge
50+
if i not in qubits_in_this_layer and j not in qubits_in_this_layer:
51+
current_layer.append(edge)
52+
qubits_in_this_layer.add(i)
53+
qubits_in_this_layer.add(j)
54+
edges_to_remove.add(edge)
55+
56+
layers.append(sorted(current_layer))
57+
uncolored_edges -= edges_to_remove
58+
59+
return layers

tests/test_circuit_utils.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from typing import List, Set, Tuple
2+
3+
import pytest
4+
import numpy as np
5+
6+
from tensorcircuit.templates.circuit_utils import get_nn_gate_layers
7+
from tensorcircuit.templates.lattice import (
8+
AbstractLattice,
9+
ChainLattice,
10+
HoneycombLattice,
11+
SquareLattice,
12+
)
13+
14+
15+
class MockLattice(AbstractLattice):
16+
"""A mock lattice class for testing purposes to precisely control neighbors."""
17+
18+
def __init__(self, neighbor_pairs: List[Tuple[int, int]]):
19+
super().__init__(dimensionality=0)
20+
self._neighbor_pairs = neighbor_pairs
21+
22+
def get_neighbor_pairs(
23+
self, k: int = 1, unique: bool = True
24+
) -> List[Tuple[int, int]]:
25+
return self._neighbor_pairs
26+
27+
def _build_lattice(self, *args, **kwargs) -> None:
28+
pass
29+
30+
def _build_neighbors(self, max_k: int = 1, **kwargs) -> None:
31+
pass
32+
33+
def _compute_distance_matrix(self) -> np.ndarray:
34+
return np.array([])
35+
36+
37+
def _validate_layers(
38+
lattice: AbstractLattice, layers: List[List[Tuple[int, int]]]
39+
) -> None:
40+
"""
41+
A helper function to scientifically validate the output of get_nn_gate_layers.
42+
"""
43+
expected_edges = set(lattice.get_neighbor_pairs(k=1, unique=True))
44+
actual_edges = set(tuple(sorted(edge)) for layer in layers for edge in layer)
45+
46+
assert (
47+
expected_edges == actual_edges
48+
), "Completeness check failed: The set of all edges in the layers must "
49+
"exactly match the lattice's unique nearest-neighbor pairs."
50+
51+
for i, layer in enumerate(layers):
52+
qubits_in_layer: Set[int] = set()
53+
for edge in layer:
54+
q1, q2 = edge
55+
assert (
56+
q1 not in qubits_in_layer
57+
), f"Compatibility check failed: Qubit {q1} is reused in layer {i}."
58+
qubits_in_layer.add(q1)
59+
assert (
60+
q2 not in qubits_in_layer
61+
), f"Compatibility check failed: Qubit {q2} is reused in layer {i}."
62+
qubits_in_layer.add(q2)
63+
64+
65+
@pytest.mark.parametrize(
66+
"lattice_instance",
67+
[
68+
SquareLattice(size=(3, 2), pbc=False),
69+
SquareLattice(size=(3, 3), pbc=True),
70+
HoneycombLattice(size=(2, 2), pbc=False),
71+
],
72+
ids=[
73+
"SquareLattice_3x2_OBC",
74+
"SquareLattice_3x3_PBC",
75+
"HoneycombLattice_2x2_OBC",
76+
],
77+
)
78+
def test_various_lattices_layering(lattice_instance: AbstractLattice):
79+
"""Tests gate layering for various standard lattice types."""
80+
layers = get_nn_gate_layers(lattice_instance)
81+
assert len(layers) > 0, "Layers should not be empty for non-trivial lattices."
82+
_validate_layers(lattice_instance, layers)
83+
84+
85+
def test_1d_chain_pbc():
86+
"""Test layering on a 1D chain with periodic boundaries (a cycle graph)."""
87+
lattice_even = ChainLattice(size=(6,), pbc=True)
88+
layers_even = get_nn_gate_layers(lattice_even)
89+
_validate_layers(lattice_even, layers_even)
90+
91+
lattice_odd = ChainLattice(size=(5,), pbc=True)
92+
layers_odd = get_nn_gate_layers(lattice_odd)
93+
assert len(layers_odd) == 3, "A 5-site cycle graph should be 3-colorable."
94+
_validate_layers(lattice_odd, layers_odd)
95+
96+
97+
def test_custom_star_graph():
98+
"""Test layering on a custom lattice forming a star graph."""
99+
star_edges = [(0, 1), (0, 2), (0, 3)]
100+
lattice = MockLattice(star_edges)
101+
layers = get_nn_gate_layers(lattice)
102+
assert len(layers) == 3, "A star graph S_4 requires 3 layers."
103+
_validate_layers(lattice, layers)
104+
105+
106+
def test_edge_cases():
107+
"""Test various edge cases: empty, single-site, and no-edge lattices."""
108+
empty_lattice = MockLattice([])
109+
layers = get_nn_gate_layers(empty_lattice)
110+
assert layers == [], "Layers should be empty for an empty lattice."
111+
112+
single_site_lattice = MockLattice([])
113+
layers = get_nn_gate_layers(single_site_lattice)
114+
assert layers == [], "Layers should be empty for a single-site lattice."
115+
116+
disconnected_lattice = MockLattice([])
117+
layers = get_nn_gate_layers(disconnected_lattice)
118+
assert layers == [], "Layers should be empty for a lattice with no neighbors."
119+
120+
single_edge_lattice = MockLattice([(0, 1)])
121+
layers = get_nn_gate_layers(single_edge_lattice)
122+
# The tuple inside the list might be (0, 1) or (1, 0) after sorting.
123+
# We check for the sorted version to be deterministic.
124+
assert layers == [[(0, 1)]]
125+
_validate_layers(single_edge_lattice, layers)

0 commit comments

Comments
 (0)