@@ -69,6 +69,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
6969 std::mutex deleted_elements_lock; // lock for deleted_elements
7070 std::unordered_set<tableint> deleted_elements; // contains internal ids of deleted elements
7171
72+ std::mutex repair_lock; // locks graph repair
73+
7274
7375 HierarchicalNSW (SpaceInterface<dist_t > *s) {
7476 }
@@ -190,9 +192,9 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
190192 }
191193
192194
193- int getRandomLevel (double reverse_size ) {
195+ int getRandomLevel (double ml ) {
194196 std::uniform_real_distribution<double > distribution (0.0 , 1.0 );
195- double r = -log (distribution (level_generator_)) * reverse_size ;
197+ double r = -log (distribution (level_generator_)) * ml ;
196198 return (int ) r;
197199 }
198200
@@ -240,14 +242,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
240242
241243 std::unique_lock <std::mutex> lock (link_list_locks_[curNodeNum]);
242244
243- int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
244- if (layer == 0 ) {
245- data = (int *)get_linklist0 (curNodeNum);
246- } else {
247- data = (int *)get_linklist (curNodeNum, layer);
248- // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_);
249- }
250- size_t size = getListCount ((linklistsizeint*)data);
245+ linklistsizeint *data = get_linklist_at_level (curNodeNum, layer);
246+ size_t size = getListCount (data);
251247 tableint *datal = (tableint *) (data + 1 );
252248#ifdef USE_SSE
253249 _mm_prefetch ((char *) (visited_array + *(data + 1 )), _MM_HINT_T0);
@@ -325,8 +321,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
325321 candidate_set.pop ();
326322
327323 tableint current_node_id = current_node_pair.second ;
328- int *data = ( int *) get_linklist0 (current_node_id);
329- size_t size = getListCount ((linklistsizeint*) data);
324+ linklistsizeint *data = get_linklist0 (current_node_id);
325+ size_t size = getListCount (data);
330326// bool cur_node_deleted = isMarkedDeleted(current_node_id);
331327 if (collect_metrics) {
332328 metric_hops++;
@@ -471,11 +467,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
471467 if (isUpdate) {
472468 lock.lock ();
473469 }
474- linklistsizeint *ll_cur;
475- if (level == 0 )
476- ll_cur = get_linklist0 (cur_c);
477- else
478- ll_cur = get_linklist (cur_c, level);
470+ linklistsizeint *ll_cur = get_linklist_at_level (cur_c, level);
479471
480472 if (*ll_cur && !isUpdate) {
481473 throw std::runtime_error (" The newly inserted element should have blank link list" );
@@ -495,12 +487,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
495487 for (size_t idx = 0 ; idx < selectedNeighbors.size (); idx++) {
496488 std::unique_lock <std::mutex> lock (link_list_locks_[selectedNeighbors[idx]]);
497489
498- linklistsizeint *ll_other;
499- if (level == 0 )
500- ll_other = get_linklist0 (selectedNeighbors[idx]);
501- else
502- ll_other = get_linklist (selectedNeighbors[idx], level);
503-
490+ linklistsizeint *ll_other = get_linklist_at_level (selectedNeighbors[idx], level);
504491 size_t sz_link_list_other = getListCount (ll_other);
505492
506493 if (sz_link_list_other > Mcurmax)
@@ -969,8 +956,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
969956
970957 {
971958 std::unique_lock <std::mutex> lock (link_list_locks_[neigh]);
972- linklistsizeint *ll_cur;
973- ll_cur = get_linklist_at_level (neigh, layer);
959+ linklistsizeint *ll_cur = get_linklist_at_level (neigh, layer);
974960 size_t candSize = candidates.size ();
975961 setListCount (ll_cur, candSize);
976962 tableint *data = (tableint *) (ll_cur + 1 );
@@ -999,7 +985,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
999985 bool changed = true ;
1000986 while (changed) {
1001987 changed = false ;
1002- unsigned int *data;
988+ linklistsizeint *data;
1003989 std::unique_lock <std::mutex> lock (link_list_locks_[currObj]);
1004990 data = get_linklist_at_level (currObj, level);
1005991 int size = getListCount (data);
@@ -1057,7 +1043,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
10571043
10581044 std::vector<tableint> getConnectionsWithLock (tableint internalId, int level) {
10591045 std::unique_lock <std::mutex> lock (link_list_locks_[internalId]);
1060- unsigned int *data = get_linklist_at_level (internalId, level);
1046+ linklistsizeint *data = get_linklist_at_level (internalId, level);
10611047 int size = getListCount (data);
10621048 std::vector<tableint> result (size);
10631049 tableint *ll = (tableint *) (data + 1 );
@@ -1095,6 +1081,10 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
10951081 }
10961082
10971083 cur_c = cur_element_count;
1084+ // use the element level as a flag to show that an element is not added yet
1085+ // the element count is increased but no lock is aquired
1086+ // so someone can start using the new element
1087+ element_levels_[cur_c] = -1 ;
10981088 cur_element_count++;
10991089 label_lookup_[label] = cur_c;
11001090 }
@@ -1134,7 +1124,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
11341124 bool changed = true ;
11351125 while (changed) {
11361126 changed = false ;
1137- unsigned int *data;
1127+ linklistsizeint *data;
11381128 std::unique_lock <std::mutex> lock (link_list_locks_[currObj]);
11391129 data = get_linklist (currObj, level);
11401130 int size = getListCount (data);
@@ -1196,9 +1186,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
11961186 bool changed = true ;
11971187 while (changed) {
11981188 changed = false ;
1199- unsigned int *data;
1200-
1201- data = (unsigned int *) get_linklist (currObj, level);
1189+ linklistsizeint *data = get_linklist (currObj, level);
12021190 int size = getListCount (data);
12031191 metric_hops++;
12041192 metric_distance_computations+=size;
@@ -1271,5 +1259,110 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
12711259 }
12721260 std::cout << " integrity ok, checked " << connections_checked << " connections\n " ;
12731261 }
1262+
1263+
1264+ void repair_zero_indegree () {
1265+ // only one repair is allowed to be in progress at a time
1266+ std::unique_lock <std::mutex> lock_repair (repair_lock);
1267+
1268+ int maxlevel_copy = maxlevel_;
1269+ size_t element_count_copy = cur_element_count;
1270+ std::vector<size_t > indegree (element_count_copy);
1271+
1272+ for (int level = maxlevel_copy; level >=0 ; level--) {
1273+ std::fill (indegree.begin (), indegree.end (), 0 );
1274+
1275+ size_t m_max = level ? maxM_ : maxM0_;
1276+ int num_elements = 0 ;
1277+ // calculate in-degree
1278+ for (tableint internal_id = 0 ; internal_id < element_count_copy; internal_id++) {
1279+ // lock until addition is finished
1280+ std::unique_lock <std::mutex> lock_el (link_list_locks_[internal_id]);
1281+ // skip elements that are not in the current level
1282+ // Note: if the element was not added to the graph before the lock
1283+ // then element_level = -1 and we skip it as well
1284+ int element_level = element_levels_[internal_id];
1285+ if (element_level < level) {
1286+ continue ;
1287+ }
1288+
1289+ linklistsizeint *ll = get_linklist_at_level (internal_id, level);
1290+ int size = getListCount (ll);
1291+ tableint *datal = (tableint *) (ll + 1 );
1292+ for (int i = 0 ; i < size; i++) {
1293+ tableint nei_id = datal[i];
1294+ // skip newly added elements
1295+ if (nei_id >= element_count_copy) {
1296+ continue ;
1297+ }
1298+ indegree[nei_id] += 1 ;
1299+ }
1300+ num_elements += 1 ;
1301+ }
1302+
1303+ // skip levels with 1 element
1304+ if (num_elements <= 1 ) {
1305+ continue ;
1306+ }
1307+
1308+ // fix elements with 0 in-degree
1309+ for (tableint internal_id = 0 ; internal_id < element_count_copy; internal_id++) {
1310+ int element_level = element_levels_[internal_id];
1311+ if (element_level < level || indegree[internal_id] > 0 ) {
1312+ continue ;
1313+ }
1314+
1315+ char * data_point = getDataByInternalId (internal_id);
1316+ tableint currObj = enterpoint_node_;
1317+
1318+ dist_t curdist = fstdistfunc_ (data_point, getDataByInternalId (currObj), dist_func_param_);
1319+ for (int level_above = maxlevel_copy; level_above > level; level_above--) {
1320+ bool changed = true ;
1321+ while (changed) {
1322+ changed = false ;
1323+ linklistsizeint *data;
1324+ std::unique_lock <std::mutex> lock (link_list_locks_[currObj]);
1325+ data = get_linklist_at_level (currObj, level_above);
1326+ int size = getListCount (data);
1327+
1328+ tableint *datal = (tableint *) (data + 1 );
1329+ for (int i = 0 ; i < size; i++) {
1330+ tableint cand = datal[i];
1331+ dist_t d = fstdistfunc_ (data_point, getDataByInternalId (cand), dist_func_param_);
1332+ if (d < curdist) {
1333+ curdist = d;
1334+ currObj = cand;
1335+ changed = true ;
1336+ }
1337+ }
1338+ }
1339+ }
1340+
1341+ std::priority_queue<std::pair<dist_t , tableint>, std::vector<std::pair<dist_t , tableint>>, CompareByFirst> candidates = searchBaseLayer (
1342+ currObj, data_point, level);
1343+
1344+ while (candidates.size () > 0 ) {
1345+ tableint cand_id = candidates.top ().second ;
1346+ // skip same element
1347+ if (cand_id == internal_id) {
1348+ candidates.pop ();
1349+ continue ;
1350+ }
1351+
1352+ // try to connect candidate to the element
1353+ // add an edge if there is space
1354+ std::unique_lock <std::mutex> lock (link_list_locks_[cand_id]);
1355+ linklistsizeint *ll_cand = get_linklist_at_level (cand_id, level);
1356+ tableint *data_cand = (tableint *) (ll_cand + 1 );
1357+ size_t size = getListCount (ll_cand);
1358+ if (size < m_max) {
1359+ data_cand[size] = internal_id;
1360+ setListCount (ll_cand, size + 1 );
1361+ }
1362+ candidates.pop ();
1363+ }
1364+ }
1365+ }
1366+ }
12741367};
12751368} // namespace hnswlib
0 commit comments