Skip to content

Commit 8c666fa

Browse files
committed
feat(templates): Add greedy algorithm for gate layering
Introduces the `get_nn_gate_layers` utility function. This function takes a lattice object and partitions its nearest-neighbor pairs into the minimum number of compatible layers for parallel two-qubit gate application. This implementation uses a greedy edge-coloring algorithm to ensure that no two gates within the same layer act on the same qubit. The output is deterministic, with both layers and the edges within them being sorted. This functionality is essential for efficiently scheduling gates in algorithms like Trotterized Hamiltonian evolution and directly addresses Task 2 of the lattice API follow-up plan.
1 parent fe62094 commit 8c666fa

File tree

2 files changed

+184
-0
lines changed

2 files changed

+184
-0
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from typing import List, Set, Tuple
2+
3+
from .lattice import AbstractLattice
4+
5+
6+
def get_nn_gate_layers(lattice: AbstractLattice) -> List[List[Tuple[int, int]]]:
7+
"""
8+
Partitions nearest-neighbor pairs into compatible layers for parallel
9+
gate application using a greedy edge-coloring algorithm.
10+
11+
In quantum circuits, a single qubit cannot participate in more than one
12+
two-qubit gate simultaneously. This function takes a lattice geometry,
13+
finds its nearest-neighbor graph, and partitions the edges of that graph
14+
(the neighbor pairs) into the minimum number of sets ("layers") where
15+
no two edges in a set share a vertex.
16+
17+
This is essential for efficiently scheduling gates in algorithms like
18+
Trotterized Hamiltonian evolution.
19+
20+
:Example:
21+
22+
>>> import numpy as np
23+
>>> from tensorcircuit.templates.lattice import SquareLattice
24+
>>> sq_lattice = SquareLattice(size=(2, 2), pbc=False)
25+
>>> gate_layers = get_nn_gate_layers(sq_lattice)
26+
>>> print(gate_layers)
27+
[[[0, 1], [2, 3]], [[0, 2], [1, 3]]]
28+
29+
:param lattice: An initialized `AbstractLattice` object from which to
30+
extract nearest-neighbor connectivity.
31+
:type lattice: AbstractLattice
32+
:return: A list of layers. Each layer is a list of tuples, where each
33+
tuple represents a nearest-neighbor pair (i, j) of site indices.
34+
All pairs within a layer are non-overlapping.
35+
:rtype: List[List[Tuple[int, int]]]
36+
"""
37+
uncolored_edges: Set[Tuple[int, int]] = set(
38+
lattice.get_neighbor_pairs(k=1, unique=True)
39+
)
40+
41+
layers: List[List[Tuple[int, int]]] = []
42+
43+
while uncolored_edges:
44+
current_layer: List[Tuple[int, int]] = []
45+
qubits_in_this_layer: Set[int] = set()
46+
edges_to_remove: Set[Tuple[int, int]] = set()
47+
48+
for edge in sorted(list(uncolored_edges)):
49+
i, j = edge
50+
if i not in qubits_in_this_layer and j not in qubits_in_this_layer:
51+
current_layer.append(edge)
52+
qubits_in_this_layer.add(i)
53+
qubits_in_this_layer.add(j)
54+
edges_to_remove.add(edge)
55+
56+
layers.append(sorted(current_layer))
57+
uncolored_edges -= edges_to_remove
58+
59+
return layers

tests/test_circuit_utils.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from typing import List, Set, Tuple
2+
3+
import pytest
4+
import numpy as np
5+
6+
from tensorcircuit.templates.circuit_utils import get_nn_gate_layers
7+
from tensorcircuit.templates.lattice import (
8+
AbstractLattice,
9+
ChainLattice,
10+
HoneycombLattice,
11+
SquareLattice,
12+
)
13+
14+
15+
class MockLattice(AbstractLattice):
16+
"""A mock lattice class for testing purposes to precisely control neighbors."""
17+
18+
def __init__(self, neighbor_pairs: List[Tuple[int, int]]):
19+
super().__init__(dimensionality=0)
20+
self._neighbor_pairs = neighbor_pairs
21+
22+
def get_neighbor_pairs(
23+
self, k: int = 1, unique: bool = True
24+
) -> List[Tuple[int, int]]:
25+
return self._neighbor_pairs
26+
27+
def _build_lattice(self, *args, **kwargs) -> None:
28+
pass
29+
30+
def _build_neighbors(self, max_k: int = 1, **kwargs) -> None:
31+
pass
32+
33+
def _compute_distance_matrix(self) -> np.ndarray:
34+
return np.array([])
35+
36+
37+
def _validate_layers(
38+
lattice: AbstractLattice, layers: List[List[Tuple[int, int]]]
39+
) -> None:
40+
"""
41+
A helper function to scientifically validate the output of get_nn_gate_layers.
42+
"""
43+
expected_edges = set(lattice.get_neighbor_pairs(k=1, unique=True))
44+
actual_edges = set(tuple(sorted(edge)) for layer in layers for edge in layer)
45+
46+
assert (
47+
expected_edges == actual_edges
48+
), "Completeness check failed: The set of all edges in the layers must "
49+
"exactly match the lattice's unique nearest-neighbor pairs."
50+
51+
for i, layer in enumerate(layers):
52+
qubits_in_layer: Set[int] = set()
53+
for edge in layer:
54+
q1, q2 = edge
55+
assert (
56+
q1 not in qubits_in_layer
57+
), f"Compatibility check failed: Qubit {q1} is reused in layer {i}."
58+
qubits_in_layer.add(q1)
59+
assert (
60+
q2 not in qubits_in_layer
61+
), f"Compatibility check failed: Qubit {q2} is reused in layer {i}."
62+
qubits_in_layer.add(q2)
63+
64+
65+
@pytest.mark.parametrize(
66+
"lattice_instance",
67+
[
68+
SquareLattice(size=(3, 2), pbc=False),
69+
SquareLattice(size=(3, 3), pbc=True),
70+
HoneycombLattice(size=(2, 2), pbc=False),
71+
],
72+
ids=[
73+
"SquareLattice_3x2_OBC",
74+
"SquareLattice_3x3_PBC",
75+
"HoneycombLattice_2x2_OBC",
76+
],
77+
)
78+
def test_various_lattices_layering(lattice_instance: AbstractLattice):
79+
"""Tests gate layering for various standard lattice types."""
80+
layers = get_nn_gate_layers(lattice_instance)
81+
assert len(layers) > 0, "Layers should not be empty for non-trivial lattices."
82+
_validate_layers(lattice_instance, layers)
83+
84+
85+
def test_1d_chain_pbc():
86+
"""Test layering on a 1D chain with periodic boundaries (a cycle graph)."""
87+
lattice_even = ChainLattice(size=(6,), pbc=True)
88+
layers_even = get_nn_gate_layers(lattice_even)
89+
_validate_layers(lattice_even, layers_even)
90+
91+
lattice_odd = ChainLattice(size=(5,), pbc=True)
92+
layers_odd = get_nn_gate_layers(lattice_odd)
93+
assert len(layers_odd) == 3, "A 5-site cycle graph should be 3-colorable."
94+
_validate_layers(lattice_odd, layers_odd)
95+
96+
97+
def test_custom_star_graph():
98+
"""Test layering on a custom lattice forming a star graph."""
99+
star_edges = [(0, 1), (0, 2), (0, 3)]
100+
lattice = MockLattice(star_edges)
101+
layers = get_nn_gate_layers(lattice)
102+
assert len(layers) == 3, "A star graph S_4 requires 3 layers."
103+
_validate_layers(lattice, layers)
104+
105+
106+
def test_edge_cases():
107+
"""Test various edge cases: empty, single-site, and no-edge lattices."""
108+
empty_lattice = MockLattice([])
109+
layers = get_nn_gate_layers(empty_lattice)
110+
assert layers == [], "Layers should be empty for an empty lattice."
111+
112+
single_site_lattice = MockLattice([])
113+
layers = get_nn_gate_layers(single_site_lattice)
114+
assert layers == [], "Layers should be empty for a single-site lattice."
115+
116+
disconnected_lattice = MockLattice([])
117+
layers = get_nn_gate_layers(disconnected_lattice)
118+
assert layers == [], "Layers should be empty for a lattice with no neighbors."
119+
120+
single_edge_lattice = MockLattice([(0, 1)])
121+
layers = get_nn_gate_layers(single_edge_lattice)
122+
# The tuple inside the list might be (0, 1) or (1, 0) after sorting.
123+
# We check for the sorted version to be deterministic.
124+
assert layers == [[(0, 1)]]
125+
_validate_layers(single_edge_lattice, layers)

0 commit comments

Comments
 (0)