Skip to content

Commit 6eb21c3

Browse files
committed
fix(templates): Resolve neighbor-finding and matrix caching bugs
1 parent 9ef578d commit 6eb21c3

File tree

2 files changed

+409
-111
lines changed

2 files changed

+409
-111
lines changed

tensorcircuit/templates/lattice.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -479,18 +479,15 @@ def _identify_distance_shells(
479479
if all_distances_sq.size == 0:
480480
return []
481481

482-
# Sort unique distances and filter out zero-distance (self-loops)
483-
unique_sorted_dist = sorted(
484-
[d for d in np.unique(all_distances_sq) if d > tol**2]
485-
)
482+
sorted_dist = np.sort(all_distances_sq[all_distances_sq > 1e-12])
486483

487-
if not unique_sorted_dist:
484+
if sorted_dist.size == 0:
488485
return []
489486

490-
# Identify shells by checking if a new distance is significantly
491-
# larger than the last identified shell distance.
492-
dist_shells = [unique_sorted_dist[0]]
493-
for d_sq in unique_sorted_dist[1:]:
487+
# Identify shells using the user-provided tolerance.
488+
dist_shells = [sorted_dist[0]]
489+
490+
for d_sq in sorted_dist[1:]:
494491
if len(dist_shells) >= max_k:
495492
break
496493
# If the current distance is notably larger than the last shell's distance
@@ -747,6 +744,7 @@ def _build_neighbors(self, max_k: int = 2, **kwargs: Any) -> None:
747744
current_k_map[i].sort()
748745

749746
self._neighbor_maps[k] = current_k_map
747+
self._distance_matrix = dist_matrix
750748
return # IMPORTANT: Exit the function here, as we are done.
751749

752750
# Step 3: If we reach here, the system is fully periodic.
@@ -850,6 +848,41 @@ def _build_neighbors(self, max_k: int = 2, **kwargs: Any) -> None:
850848
if neighbors_of_i_for_k:
851849
self._neighbor_maps[k][i] = sorted(neighbors_of_i_for_k)
852850

851+
logger.info("Caching the full distance matrix via translational invariance...")
852+
dist_matrix_sq = np.zeros((self.num_sites, self.num_sites), dtype=float)
853+
size_arr = np.array(self.size)
854+
855+
for i in range(self.num_sites):
856+
ident_i = cast(Tuple[Any, ...], self._identifiers[i])
857+
uc_i = np.array(ident_i[:-1])
858+
basis_i = ident_i[-1]
859+
860+
uc_disp = -uc_i
861+
relative_ucs_list = [
862+
cast(Tuple[Any, ...], ident)[:-1] for ident in self._identifiers
863+
]
864+
relative_ucs_arr = np.array(relative_ucs_list)
865+
all_target_ucs = relative_ucs_arr + uc_disp
866+
867+
for dim in range(self.dimensionality):
868+
if self.pbc[dim]:
869+
all_target_ucs[:, dim] %= size_arr[dim]
870+
871+
# Convert these relative unit cell coordinates back to site indices
872+
all_target_basis = np.array(
873+
[cast(Tuple[Any, ...], ident)[-1] for ident in self._identifiers]
874+
)
875+
all_target_idents = [
876+
tuple(uc) + (basis,)
877+
for uc, basis in zip(all_target_ucs, all_target_basis)
878+
]
879+
target_indices = [self._ident_to_idx[ident] for ident in all_target_idents]
880+
881+
# Assign the entire row of distances from the lookup table
882+
dist_matrix_sq[i, :] = ref_dist_matrix_sq[basis_i, target_indices]
883+
884+
self._distance_matrix = np.sqrt(dist_matrix_sq)
885+
853886
def _compute_distance_matrix(self) -> Coordinates:
854887
"""Computes the distance matrix using the Minimum Image Convention."""
855888
return self._get_distance_matrix_with_mic()
@@ -1401,8 +1434,13 @@ def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None:
14011434
k = k_idx + 1
14021435
current_k_map: Dict[int, List[int]] = {}
14031436
for i in range(self.num_sites):
1404-
# The previously found indices for site i (neighbors from shells 1 to k-1)
1405-
prev_found = found_indices[i] if k_idx > 0 else {i}
1437+
1438+
if k_idx == 0:
1439+
co_located_indices = tree.query_ball_point(all_coords[i], r=1e-12)
1440+
prev_found = set(co_located_indices)
1441+
else:
1442+
prev_found = found_indices[i]
1443+
14061444
# The new neighbors are those in the current radius shell,
14071445
# excluding those already found in smaller shells.
14081446
new_neighbors = set(current_shell_indices[i]) - prev_found
@@ -1414,7 +1452,7 @@ def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None:
14141452
found_indices = [
14151453
set(l) for l in current_shell_indices
14161454
] # Update for next iteration
1417-
1455+
self._distance_matrix = np.sqrt(squareform(all_distances_sq))
14181456
logger.info("Neighbor building complete using KDTree.")
14191457

14201458
def _compute_distance_matrix(self) -> Coordinates:

0 commit comments

Comments
 (0)