Skip to content

Commit c3b1022

Browse files
committed
Fix deletion bug (set size change while looping)
1 parent a4aa76e commit c3b1022

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

datalayer/hnsw.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -356,18 +356,22 @@ def _insert_node_to_layers(self, new_node, enter_point):
356356
enter_point = currently_found_nn
357357

358358
def _delete_neighbors_connections(self, node):
359-
"""Given a node, deletes the connections to their neighbors.
360-
361-
Arguments:
362-
node -- the node to delete
363-
"""
364-
365-
logger.debug(f"Deleting neighbors of \"{node.get_id()}\"")
366-
for layer in range(node.get_max_layer() + 1):
367-
for neighbor in node.get_neighbors_at_layer(layer):
368-
logger.debug(f"Deleting at L{layer} link \"{neighbor.get_id()}\"")
369-
neighbor.remove_neighbor(layer, node)
370-
node.remove_neighbor(layer, neighbor)
359+
"""Given a node, deletes the connections to their neighbors.
360+
361+
Arguments:
362+
node -- the node to delete
363+
"""
364+
365+
logger.debug(f"Deleting neighbors of \"{node.get_id()}\"")
366+
for layer in range(node.get_max_layer() + 1):
367+
neighbors_to_remove = set()
368+
for neighbor in node.get_neighbors_at_layer(layer):
369+
logger.debug(f"Deleting at L{layer} link \"{neighbor.get_id()}\"")
370+
neighbors_to_remove.add(neighbor)
371+
372+
for neighbor in neighbors_to_remove: # bidirectionally remove links
373+
node.remove_neighbor(layer, neighbor)
374+
neighbor.remove_neighbor(layer, node)
371375

372376
def _delete_node_dict(self, node):
373377
"""Deletes a node from the dict of the HNSW structure.

tests/unit.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,14 @@ def test_search_approximate(self):
6060
self.assertEqual(actual_founds, expected_founds)
6161
self.assertEqual(actual_distances, expected_distances)
6262

63-
@unittest.skip("?")
6463
def test_deletion(self):
6564
for hash in HASHES[:5]:
6665
self.apo_model.delete(HashNode(hash, TLSHHashAlgorithm))
6766

6867
expected_founds = [False, False, False, False, False, True, True, True, True, True]
6968
actual_founds = []
7069
for hash in HASHES:
71-
found, exact, result_dict = self.apo_model.knn_search(HashNode(hash, TLSHHashAlgorithm), 1)
70+
found, _, _ = self.apo_model.knn_search(HashNode(hash, TLSHHashAlgorithm), 1)
7271
actual_founds.append(found)
7372

7473
self.assertEqual(actual_founds, expected_founds)

0 commit comments

Comments
 (0)