Skip to content

Commit d84fdae

Browse files
committed
feat: Implement AbstractLattice and initial GeneralLattice
1 parent e75d84f commit d84fdae

File tree

1 file changed

+157
-0
lines changed

1 file changed

+157
-0
lines changed

tensorcircuit/templates/lattice.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
The lattice module for defining and manipulating lattice geometries.
4+
"""
5+
import abc
6+
import numpy as np
7+
import matplotlib.pyplot as plt
8+
from typing import List, Tuple, Dict, Any, Optional, Hashable, Iterator, Union
9+
10+
# --- Type Aliases for Readability ---
11+
SiteIndex = int
12+
SiteIdentifier = Hashable
13+
Coordinates = np.ndarray
14+
NeighborMap = Dict[SiteIndex, List[SiteIndex]]
15+
16+
17+
class AbstractLattice(abc.ABC):
18+
"""
19+
Abstract base class for describing lattice systems.
20+
21+
This class defines the common interface for all lattice structures,
22+
providing access to fundamental properties like site information
23+
24+
(count, coordinates, identifiers) and neighbor relationships.
25+
Subclasses are responsible for implementing the specific logic for
26+
generating the lattice points and calculating neighbor connections.
27+
"""
28+
29+
def __init__(self, dimensionality: int):
30+
"""
31+
Initializes the base lattice class.
32+
33+
Args:
34+
dimensionality (int): The spatial dimension of the lattice (e.g., 1, 2, 3).
35+
"""
36+
self._dimensionality = dimensionality
37+
38+
# --- Internal Data Structures (to be populated by subclasses) ---
39+
self._indices: List[SiteIndex] = []
40+
self._identifiers: List[SiteIdentifier] = []
41+
self._coordinates: List[Coordinates] = []
42+
self._ident_to_idx: Dict[SiteIdentifier, SiteIndex] = {}
43+
44+
# Neighbor information, structured as {k: {site_index: [neighbor_indices]}}
45+
# k=1 for nearest neighbors, k=2 for next-nearest, etc.
46+
self._neighbor_maps: Dict[int, NeighborMap] = {}
47+
48+
@property
49+
def num_sites(self) -> int:
50+
"""Returns the total number of sites (N) in the lattice."""
51+
return len(self._indices)
52+
53+
@property
54+
def dimensionality(self) -> int:
55+
"""Returns the spatial dimension of the lattice."""
56+
return self._dimensionality
57+
58+
def __len__(self) -> int:
59+
"""Returns the total number of sites, enabling `len(lattice)`."""
60+
return self.num_sites
61+
62+
# Other common methods like get_site_info, get_neighbors can be added here.
63+
64+
@abc.abstractmethod
65+
def _build_lattice(self, *args, **kwargs):
66+
"""
67+
Abstract method for subclasses to generate the lattice data.
68+
69+
This method should populate the following internal lists and dicts:
70+
- self._indices
71+
- self._identifiers
72+
- self._coordinates
73+
- self._ident_to_idx
74+
"""
75+
pass
76+
77+
@abc.abstractmethod
78+
def _build_neighbors(self, max_k: int = 1):
79+
"""
80+
Abstract method for subclasses to calculate neighbor relationships.
81+
82+
This method should populate the `self._neighbor_maps` dictionary.
83+
84+
Args:
85+
max_k (int): The maximum order of neighbors to compute (e.g., max_k=2 for NN and NNN).
86+
"""
87+
pass
88+
89+
@abc.abstractmethod
90+
def show(self, **kwargs):
91+
"""
92+
Abstract method for visualizing the lattice structure.
93+
94+
Subclasses should implement the specific plotting logic.
95+
"""
96+
pass
97+
98+
99+
class GeneralLattice(AbstractLattice):
100+
"""
101+
A general lattice built from an explicit list of sites and coordinates.
102+
"""
103+
def __init__(self,
104+
dimensionality: int,
105+
identifiers: List[SiteIdentifier],
106+
coordinates: List[Union[List[float], np.ndarray]]):
107+
108+
super().__init__(dimensionality)
109+
assert len(identifiers) == len(coordinates), "Identifiers and coordinates must have the same length."
110+
111+
# The lattice is built directly upon initialization.
112+
self._build_lattice(identifiers=identifiers, coordinates=coordinates)
113+
114+
# Neighbor relationships can be calculated later if needed.
115+
print(f"GeneralLattice with {self.num_sites} sites created.")
116+
117+
118+
def _build_lattice(self, identifiers: List[SiteIdentifier], coordinates: List[Union[List[float], np.ndarray]]):
119+
"""Implements the lattice building for GeneralLattice."""
120+
self._identifiers = list(identifiers)
121+
self._coordinates = [np.array(c) for c in coordinates]
122+
self._indices = list(range(len(identifiers)))
123+
self._ident_to_idx = {ident: idx for idx, ident in enumerate(identifiers)}
124+
125+
def _build_neighbors(self, max_k: int = 1):
126+
"""Calculates neighbors based on distance (to be implemented)."""
127+
print(f"Neighbor calculation for k={max_k} is not implemented yet.")
128+
# The logic for _find_neighbors_by_distance will be implemented here in the future.
129+
pass
130+
131+
def show(self, show_indices: bool = True, show_bonds: bool = False, **kwargs):
132+
"""
133+
A simple visualization of the lattice sites.
134+
135+
Args:
136+
show_indices (bool): If True, display the identifier of each site.
137+
show_bonds (bool): If True, display lines connecting neighbors (not implemented yet).
138+
"""
139+
if self.dimensionality != 2:
140+
print("Warning: show() is currently only implemented for 2D lattices.")
141+
return
142+
143+
coords = np.array(self._coordinates)
144+
145+
plt.figure(figsize=(6, 6))
146+
plt.scatter(coords[:, 0], coords[:, 1], s=100, zorder=2)
147+
148+
if show_indices:
149+
for i in range(self.num_sites):
150+
plt.text(coords[i, 0] + 0.1, coords[i, 1] + 0.1, str(self._identifiers[i]), fontsize=12)
151+
152+
plt.axis('equal')
153+
plt.grid(True)
154+
plt.title(f"{self.__class__.__name__} ({self.num_sites} sites)")
155+
plt.xlabel("x coordinate")
156+
plt.ylabel("y coordinate")
157+
plt.show()

0 commit comments

Comments
 (0)