@@ -479,18 +479,15 @@ def _identify_distance_shells(
479479 if all_distances_sq .size == 0 :
480480 return []
481481
482- # Sort unique distances and filter out zero-distance (self-loops)
483- unique_sorted_dist = sorted (
484- [d for d in np .unique (all_distances_sq ) if d > tol ** 2 ]
485- )
482+ sorted_dist = np .sort (all_distances_sq [all_distances_sq > 1e-12 ])
486483
487- if not unique_sorted_dist :
484+ if sorted_dist . size == 0 :
488485 return []
489486
490- # Identify shells by checking if a new distance is significantly
491- # larger than the last identified shell distance.
492- dist_shells = [ unique_sorted_dist [ 0 ]]
493- for d_sq in unique_sorted_dist [1 :]:
487+ # Identify shells using the user-provided tolerance.
488+ dist_shells = [ sorted_dist [ 0 ]]
489+
490+ for d_sq in sorted_dist [1 :]:
494491 if len (dist_shells ) >= max_k :
495492 break
496493 # If the current distance is notably larger than the last shell's distance
@@ -747,6 +744,7 @@ def _build_neighbors(self, max_k: int = 2, **kwargs: Any) -> None:
747744 current_k_map [i ].sort ()
748745
749746 self ._neighbor_maps [k ] = current_k_map
747+ self ._distance_matrix = dist_matrix
750748 return # IMPORTANT: Exit the function here, as we are done.
751749
752750 # Step 3: If we reach here, the system is fully periodic.
@@ -850,6 +848,41 @@ def _build_neighbors(self, max_k: int = 2, **kwargs: Any) -> None:
850848 if neighbors_of_i_for_k :
851849 self ._neighbor_maps [k ][i ] = sorted (neighbors_of_i_for_k )
852850
851+ logger .info ("Caching the full distance matrix via translational invariance..." )
852+ dist_matrix_sq = np .zeros ((self .num_sites , self .num_sites ), dtype = float )
853+ size_arr = np .array (self .size )
854+
855+ for i in range (self .num_sites ):
856+ ident_i = cast (Tuple [Any , ...], self ._identifiers [i ])
857+ uc_i = np .array (ident_i [:- 1 ])
858+ basis_i = ident_i [- 1 ]
859+
860+ uc_disp = - uc_i
861+ relative_ucs_list = [
862+ cast (Tuple [Any , ...], ident )[:- 1 ] for ident in self ._identifiers
863+ ]
864+ relative_ucs_arr = np .array (relative_ucs_list )
865+ all_target_ucs = relative_ucs_arr + uc_disp
866+
867+ for dim in range (self .dimensionality ):
868+ if self .pbc [dim ]:
869+ all_target_ucs [:, dim ] %= size_arr [dim ]
870+
871+ # Convert these relative unit cell coordinates back to site indices
872+ all_target_basis = np .array (
873+ [cast (Tuple [Any , ...], ident )[- 1 ] for ident in self ._identifiers ]
874+ )
875+ all_target_idents = [
876+ tuple (uc ) + (basis ,)
877+ for uc , basis in zip (all_target_ucs , all_target_basis )
878+ ]
879+ target_indices = [self ._ident_to_idx [ident ] for ident in all_target_idents ]
880+
881+ # Assign the entire row of distances from the lookup table
882+ dist_matrix_sq [i , :] = ref_dist_matrix_sq [basis_i , target_indices ]
883+
884+ self ._distance_matrix = np .sqrt (dist_matrix_sq )
885+
853886 def _compute_distance_matrix (self ) -> Coordinates :
854887 """Computes the distance matrix using the Minimum Image Convention."""
855888 return self ._get_distance_matrix_with_mic ()
@@ -1401,8 +1434,13 @@ def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None:
14011434 k = k_idx + 1
14021435 current_k_map : Dict [int , List [int ]] = {}
14031436 for i in range (self .num_sites ):
1404- # The previously found indices for site i (neighbors from shells 1 to k-1)
1405- prev_found = found_indices [i ] if k_idx > 0 else {i }
1437+
1438+ if k_idx == 0 :
1439+ co_located_indices = tree .query_ball_point (all_coords [i ], r = 1e-12 )
1440+ prev_found = set (co_located_indices )
1441+ else :
1442+ prev_found = found_indices [i ]
1443+
14061444 # The new neighbors are those in the current radius shell,
14071445 # excluding those already found in smaller shells.
14081446 new_neighbors = set (current_shell_indices [i ]) - prev_found
@@ -1414,7 +1452,7 @@ def _build_neighbors(self, max_k: int = 1, **kwargs: Any) -> None:
14141452 found_indices = [
14151453 set (l ) for l in current_shell_indices
14161454 ] # Update for next iteration
1417-
1455+ self . _distance_matrix = np . sqrt ( squareform ( all_distances_sq ))
14181456 logger .info ("Neighbor building complete using KDTree." )
14191457
14201458 def _compute_distance_matrix (self ) -> Coordinates :
0 commit comments