Skip to content

Commit 845d4d0

Browse files
authored
Merge pull request #4632 from vespa-engine/toregge/factor-out-bm25-scorer-from-bm25-feature-system-test
Factor out Bm25Scorer from Bm25Feature system test.
2 parents 8eebd9c + 04386d5 commit 845d4d0

File tree

2 files changed

+81
-67
lines changed

2 files changed

+81
-67
lines changed

lib/bm25_scorer.rb

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright Vespa.ai. All rights reserved.
2+
3+
require 'assertions'
4+
5+
# Utility class to calculate bm25 scores for a document.
6+
# Used by Bm25FeatureTest and SameElementOperator
7+
8+
class Bm25Scorer
9+
include Assertions
10+
11+
attr_reader :avg_element_length
12+
attr_reader :avg_field_length
13+
attr_reader :reverse_index
14+
attr_reader :idfs
15+
16+
def initialize(idfs, avg_element_length, avg_field_length, reverse_index)
17+
@idfs = idfs
18+
@avg_element_length = avg_element_length
19+
@avg_field_length = avg_field_length
20+
@reverse_index = reverse_index
21+
end
22+
23+
def self.idf(matching_doc_count, total_doc_count)
24+
# This is the same formula as used in vespa/searchlib/src/vespa/searchlib/features/bm25_feature.cpp
25+
Math.log(1 + ((total_doc_count - matching_doc_count + 0.5) / (matching_doc_count + 0.5)))
26+
end
27+
28+
def self.score(num_occs, field_length, inverse_doc_freq, avg_field_length)
29+
# This is the same formula as used in vespa/searchlib/src/vespa/searchlib/features/bm25_feature.cpp
30+
inverse_doc_freq * (num_occs * 2.2) / (num_occs + (1.2 * (0.25 + 0.75 * field_length / avg_field_length)))
31+
end
32+
33+
def matches(term, doc)
34+
rev_idx = reverse_index[term][doc]
35+
return rev_idx.transpose[0].sum > 0
36+
end
37+
38+
def bm25_score(term, doc)
39+
rev_idx = reverse_index[term][doc]
40+
num_occs = rev_idx.transpose[0].sum
41+
field_length = rev_idx.transpose[1].sum
42+
Bm25Scorer.score(num_occs, field_length, idfs[term], avg_field_length)
43+
end
44+
45+
def elementwise_bm25_score(term, doc)
46+
rev_idx = reverse_index[term][doc]
47+
num_occs = rev_idx.transpose[0]
48+
element_lengths = rev_idx.transpose[1]
49+
scores = []
50+
for element in 0...num_occs.size
51+
if num_occs[element] == 0
52+
scores.push(0)
53+
else
54+
scores.push(Bm25Scorer.score(num_occs[element], element_lengths[element], idfs[term], avg_element_length))
55+
end
56+
end
57+
scores
58+
end
59+
60+
def sum_scores(scores, other_scores)
61+
if scores.nil?
62+
return other_scores
63+
end
64+
assert_equal(scores.size, other_scores.size)
65+
summed_scores = []
66+
for i in 0...scores.size
67+
summed_scores.push(scores[i] + other_scores[i])
68+
end
69+
summed_scores
70+
end
71+
end

tests/search/bm25_feature/bm25_feature.rb

Lines changed: 10 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright Vespa.ai. All rights reserved.
22
require 'indexed_streaming_search_test'
3+
require 'bm25_scorer'
34

45
class Bm25FeatureTest < IndexedStreamingSearchTest
56
attr_reader :content_reverse_index
@@ -79,65 +80,9 @@ def make_query(terms)
7980
end
8081
end
8182

82-
class Scorer
83-
attr_reader :query_builder
84-
attr_reader :avg_element_length
85-
attr_reader :avg_field_length
86-
attr_reader :reverse_index
87-
attr_reader :idfs
88-
89-
def initialize(testcase, query_builder, avg_element_length, avg_field_length, reverse_index)
90-
@testcase = testcase
91-
@query_builder = query_builder
92-
@avg_element_length = avg_element_length
93-
@avg_field_length = avg_field_length
94-
@reverse_index = reverse_index
95-
@idfs = query_builder.idfs
96-
end
97-
98-
def matches(term, doc)
99-
rev_idx = reverse_index[term][doc]
100-
return rev_idx.transpose[0].sum > 0
101-
end
102-
103-
def bm25_score(term, doc)
104-
rev_idx = reverse_index[term][doc]
105-
num_occs = rev_idx.transpose[0].sum
106-
field_length = rev_idx.transpose[1].sum
107-
@testcase.score(num_occs, field_length, idfs[term], avg_field_length)
108-
end
109-
110-
def elementwise_bm25_score(term, doc)
111-
rev_idx = reverse_index[term][doc]
112-
num_occs = rev_idx.transpose[0]
113-
element_lengths = rev_idx.transpose[1]
114-
scores = []
115-
for element in 0...num_occs.size
116-
if num_occs[element] == 0
117-
scores.push(0)
118-
else
119-
scores.push(@testcase.score(num_occs[element], element_lengths[element], idfs[term], avg_element_length))
120-
end
121-
end
122-
scores
123-
end
124-
125-
def sum_scores(scores, other_scores)
126-
if scores.nil?
127-
return other_scores
128-
end
129-
@testcase.assert_equal(scores.size, other_scores.size)
130-
summed_scores = []
131-
for i in 0...scores.size
132-
summed_scores.push(scores[i] + other_scores[i])
133-
end
134-
summed_scores
135-
end
136-
end
137-
138-
class DegradedScorer < Scorer
139-
def initialize(testcase, query_builder, avg_element_length, avg_field_length, reverse_index)
140-
super(testcase, query_builder, avg_element_length, avg_field_length, reverse_index)
83+
class DegradedScorer < Bm25Scorer
84+
def initialize(idfs, avg_element_length, avg_field_length, reverse_index)
85+
super(idfs, avg_element_length, avg_field_length, reverse_index)
14186
end
14287

14388
def bm25_score(term, doc)
@@ -282,7 +227,7 @@ def assert_bm25_scores_helper(total_doc_count, avg_field_length, ranking, add_si
282227
document_frequencies = tweaked_content_document_frequencies
283228
end
284229
query_builder = QueryBuilder.new(self, total_doc_count, 'content', document_frequencies, ranking, add_significance: add_significance, add_docfreq: add_docfreq)
285-
scorer = Scorer.new(self, query_builder, avg_field_length, avg_field_length, content_reverse_index)
230+
scorer = Bm25Scorer.new(query_builder.idfs, avg_field_length, avg_field_length, content_reverse_index)
286231
idfs = query_builder.idfs
287232
assert_scores_for_query(query_builder, scorer, ['a'],
288233
[score(2, 3, idfs['a'], avg_field_length),
@@ -314,7 +259,7 @@ def assert_bm25_array_scores(total_doc_count, avg_field_length)
314259

315260
def assert_bm25_array_scores_helper(total_doc_count, avg_field_length, add_docfreq: false)
316261
query_builder = QueryBuilder.new(self, total_doc_count, 'contenta', contenta_document_frequencies, 'default', add_docfreq: add_docfreq)
317-
scorer = Scorer.new(self, query_builder, avg_field_length.to_f / 2, avg_field_length, contenta_reverse_index)
262+
scorer = Bm25Scorer.new(query_builder.idfs, avg_field_length.to_f / 2, avg_field_length, contenta_reverse_index)
318263
idfs = query_builder.idfs
319264
assert_scores_for_query(query_builder, scorer, ['a'],
320265
[score(2, 6, idfs['a'], avg_field_length),
@@ -335,7 +280,7 @@ def assert_bm25_array_scores_helper(total_doc_count, avg_field_length, add_docfr
335280

336281
def assert_degraded_bm25_scores(total_doc_count, avg_field_length)
337282
query_builder = QueryBuilder.new(self, total_doc_count, 'content', content_document_frequencies, 'default')
338-
scorer = DegradedScorer.new(self, query_builder, avg_field_length, avg_field_length, content_reverse_index)
283+
scorer = DegradedScorer.new(query_builder.idfs, avg_field_length, avg_field_length, content_reverse_index)
339284
assert_scores_for_query(query_builder, scorer, ['a'],
340285
[idf(3, total_doc_count),
341286
idf(3, total_doc_count),
@@ -353,7 +298,7 @@ def assert_degraded_bm25_scores(total_doc_count, avg_field_length)
353298

354299
def assert_degraded_bm25_array_scores(total_doc_count, avg_field_length)
355300
query_builder = QueryBuilder.new(self, total_doc_count, 'contenta', contenta_document_frequencies, 'default')
356-
scorer = DegradedScorer.new(self, query_builder, avg_field_length / 2, avg_field_length, contenta_reverse_index)
301+
scorer = DegradedScorer.new(query_builder.idfs, avg_field_length / 2, avg_field_length, contenta_reverse_index)
357302
assert_scores_for_query(query_builder, scorer, ['a'],
358303
[idf(3, total_doc_count),
359304
idf(3, total_doc_count),
@@ -370,13 +315,11 @@ def assert_degraded_bm25_array_scores(total_doc_count, avg_field_length)
370315
end
371316

372317
def idf(matching_doc_count, total_doc_count = 3)
373-
# This is the same formula as used in vespa/searchlib/src/vespa/searchlib/features/bm25_feature.cpp
374-
Math.log(1 + ((total_doc_count - matching_doc_count + 0.5) / (matching_doc_count + 0.5)))
318+
Bm25Scorer.idf(matching_doc_count, total_doc_count)
375319
end
376320

377321
def score(num_occs, field_length, inverse_doc_freq, avg_field_length = 4)
378-
# This is the same formula as used in vespa/searchlib/src/vespa/searchlib/features/bm25_feature.cpp
379-
inverse_doc_freq * (num_occs * 2.2) / (num_occs + (1.2 * (0.25 + 0.75 * field_length / avg_field_length)))
322+
Bm25Scorer.score(num_occs, field_length, inverse_doc_freq, avg_field_length)
380323
end
381324

382325
def assert_elementwise_bm25_feature(feature_name, exp_cells, features)

0 commit comments

Comments
 (0)