Skip to content

Commit a620c7f

Browse files
committed
Use bounded top-k selection instead of full result sorting
1 parent 7e33fcd commit a620c7f

File tree

4 files changed

+121
-32
lines changed

4 files changed

+121
-32
lines changed

Sources/VecturaKit/SearchEngine/BM25Index.swift

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,10 @@ public actor BM25Index {
205205
/// - topK: Maximum number of results to return
206206
/// - Returns: Array of tuples containing lightweight documents and their BM25 scores
207207
public func search(query: String, topK: Int = 10) -> [(document: BM25Document, score: Float)] {
208+
guard topK > 0 else {
209+
return []
210+
}
211+
208212
let queryTerms = tokenize(query)
209213
guard !queryTerms.isEmpty else {
210214
return []
@@ -223,8 +227,10 @@ public actor BM25Index {
223227
queryIDFs[term] = log(max(idfArgument, 1e-9))
224228
}
225229

226-
var scores: [(BM25Document, Float)] = []
227-
scores.reserveCapacity(documents.count)
230+
var topResults = TopKSelector<(BM25Document, Float)>(
231+
maxCount: topK,
232+
isHigherRanked: { $0.1 > $1.1 }
233+
)
228234

229235
for document in documents.values {
230236
let docLength = Float(documentLengths[document.id] ?? 0)
@@ -246,15 +252,12 @@ public actor BM25Index {
246252
score += Float(queryTermCount) * idf * (numerator / denominator)
247253
}
248254

249-
scores.append((document, score))
255+
if score > 0 {
256+
topResults.insert((document, score))
257+
}
250258
}
251259

252-
return Array(
253-
scores
254-
.sorted { $0.1 > $1.1 }
255-
.filter { $0.1 > 0 }
256-
.prefix(topK)
257-
)
260+
return topResults.sortedElements()
258261
}
259262

260263
/// Add a new document to the index incrementally

Sources/VecturaKit/SearchEngine/HybridSearchEngine.swift

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,17 @@ public struct HybridSearchEngine: VecturaSearchEngine {
220220
)
221221
}
222222

223-
// Sort and return top K
224-
return combinedResults.values
225-
.sorted(by: { $0.score > $1.score })
226-
.prefix(topK)
227-
.map { $0 }
223+
guard topK > 0 else {
224+
return []
225+
}
226+
227+
var topResults = TopKSelector<VecturaSearchResult>(
228+
maxCount: topK,
229+
isHigherRanked: { $0.score > $1.score }
230+
)
231+
for result in combinedResults.values {
232+
topResults.insert(result)
233+
}
234+
return topResults.sortedElements()
228235
}
229236
}

Sources/VecturaKit/SearchEngine/VectorSearchEngine.swift

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,10 @@ public struct VectorSearchEngine: VecturaSearchEngine {
114114
1
115115
)
116116

117-
// Build results
118-
var results = [VecturaSearchResult]()
119-
results.reserveCapacity(docsCount)
117+
var topResults = TopKSelector<VecturaSearchResult>(
118+
maxCount: options.numResults,
119+
isHigherRanked: { $0.score > $1.score }
120+
)
120121

121122
for (i, similarity) in similarities.enumerated() {
122123
if let threshold = options.threshold, similarity < threshold {
@@ -125,7 +126,7 @@ public struct VectorSearchEngine: VecturaSearchEngine {
125126

126127
// docIds and documents are built in parallel, so indices correspond
127128
let doc = documents[i]
128-
results.append(
129+
topResults.insert(
129130
VecturaSearchResult(
130131
id: doc.id,
131132
text: doc.text,
@@ -135,11 +136,7 @@ public struct VectorSearchEngine: VecturaSearchEngine {
135136
)
136137
}
137138

138-
results.sort { $0.score > $1.score }
139-
if results.count > options.numResults {
140-
results.removeSubrange(options.numResults..<results.count)
141-
}
142-
return results
139+
return topResults.sortedElements()
143140
}
144141

145142
private func searchWithIndexedStorage(
@@ -223,17 +220,18 @@ public struct VectorSearchEngine: VecturaSearchEngine {
223220
1
224221
)
225222

226-
// Build results
227-
var results = [VecturaSearchResult]()
228-
results.reserveCapacity(candidatesCount)
223+
var topResults = TopKSelector<VecturaSearchResult>(
224+
maxCount: options.numResults,
225+
isHigherRanked: { $0.score > $1.score }
226+
)
229227

230228
for (i, similarity) in similarities.enumerated() {
231229
if let threshold = options.threshold, similarity < threshold {
232230
continue
233231
}
234232

235233
let doc = candidateDocs[i]
236-
results.append(
234+
topResults.insert(
237235
VecturaSearchResult(
238236
id: doc.id,
239237
text: doc.text,
@@ -243,11 +241,7 @@ public struct VectorSearchEngine: VecturaSearchEngine {
243241
)
244242
}
245243

246-
results.sort { $0.score > $1.score }
247-
if results.count > options.numResults {
248-
results.removeSubrange(options.numResults..<results.count)
249-
}
250-
return results
244+
return topResults.sortedElements()
251245
}
252246

253247
// MARK: - Helper Methods
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import Foundation
2+
3+
/// Keeps only the top K elements using a bounded min-heap.
4+
///
5+
/// The heap root stores the lowest-ranked element currently retained, so inserts
6+
/// are `O(log k)` and the full pass is `O(n log k)`.
7+
struct TopKSelector<Element> {
8+
private let maxCount: Int
9+
private let isHigherRanked: (Element, Element) -> Bool
10+
private var heap: [Element] = []
11+
12+
init(
13+
maxCount: Int,
14+
isHigherRanked: @escaping (Element, Element) -> Bool
15+
) {
16+
precondition(maxCount > 0, "maxCount must be greater than zero")
17+
self.maxCount = maxCount
18+
self.isHigherRanked = isHigherRanked
19+
self.heap.reserveCapacity(maxCount)
20+
}
21+
22+
mutating func insert(_ element: Element) {
23+
if heap.count < maxCount {
24+
heap.append(element)
25+
siftUp(from: heap.count - 1)
26+
return
27+
}
28+
29+
guard let currentLowest = heap.first, isHigherRanked(element, currentLowest) else {
30+
return
31+
}
32+
33+
heap[0] = element
34+
siftDown(from: 0)
35+
}
36+
37+
func sortedElements() -> [Element] {
38+
heap.sorted(by: isHigherRanked)
39+
}
40+
41+
private func isLowerRanked(_ lhs: Element, than rhs: Element) -> Bool {
42+
isHigherRanked(rhs, lhs)
43+
}
44+
45+
private mutating func siftUp(from index: Int) {
46+
var childIndex = index
47+
48+
while childIndex > 0 {
49+
let parentIndex = (childIndex - 1) / 2
50+
guard isLowerRanked(heap[childIndex], than: heap[parentIndex]) else {
51+
break
52+
}
53+
54+
heap.swapAt(childIndex, parentIndex)
55+
childIndex = parentIndex
56+
}
57+
}
58+
59+
private mutating func siftDown(from index: Int) {
60+
var parentIndex = index
61+
62+
while true {
63+
let leftChildIndex = 2 * parentIndex + 1
64+
let rightChildIndex = leftChildIndex + 1
65+
var candidateIndex = parentIndex
66+
67+
if leftChildIndex < heap.count,
68+
isLowerRanked(heap[leftChildIndex], than: heap[candidateIndex]) {
69+
candidateIndex = leftChildIndex
70+
}
71+
72+
if rightChildIndex < heap.count,
73+
isLowerRanked(heap[rightChildIndex], than: heap[candidateIndex]) {
74+
candidateIndex = rightChildIndex
75+
}
76+
77+
guard candidateIndex != parentIndex else {
78+
return
79+
}
80+
81+
heap.swapAt(parentIndex, candidateIndex)
82+
parentIndex = candidateIndex
83+
}
84+
}
85+
}

0 commit comments

Comments
 (0)