Skip to content

Commit 56d8493

Browse files
author
Daniele Briggi
committed
feat(evaluation): add nDCG metric
1 parent 455bc2c commit 56d8493

File tree

2 files changed

+127
-9
lines changed

2 files changed

+127
-9
lines changed
Lines changed: 124 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import json
3+
import time
34
from pathlib import Path
45

56
import numpy as np
@@ -8,38 +9,76 @@
89
from sqlite_rag import SQLiteRag
910

1011

12+
def calculate_dcg(relevance_scores):
13+
"""Calculate Discounted Cumulative Gain"""
14+
dcg = 0.0
15+
for i, rel in enumerate(relevance_scores):
16+
dcg += rel / np.log2(i + 2) # i+2 because log2(1) = 0
17+
return dcg
18+
19+
20+
def calculate_ndcg(predicted_relevance, ideal_relevance):
21+
"""Calculate Normalized Discounted Cumulative Gain"""
22+
if not predicted_relevance:
23+
return 0.0
24+
25+
dcg = calculate_dcg(predicted_relevance)
26+
idcg = calculate_dcg(ideal_relevance)
27+
28+
if idcg == 0:
29+
return 0.0
30+
31+
return dcg / idcg
32+
33+
1134
def test_ms_marco_processing(
1235
limit_rows=None, database_path="ms_marco_test.sqlite", rag_settings=None
1336
):
1437
"""Test processing MS MARCO dataset with SQLiteRag"""
1538

39+
start_time = time.time()
40+
1641
if rag_settings is None:
1742
rag_settings = {"chunk_size": 1000, "chunk_overlap": 0}
1843

1944
# Load the MS MARCO test dataset
45+
print("Loading MS MARCO dataset...")
46+
load_start = time.time()
2047
parquet_path = Path("ms_marco_test.parquet")
2148
if not parquet_path.exists():
2249
raise FileNotFoundError(f"Dataset file {parquet_path} not found")
2350

2451
df = pd.read_parquet(parquet_path)
52+
load_time = time.time() - load_start
2553

2654
# Limit rows if specified
2755
if limit_rows:
2856
df = df.head(limit_rows)
2957
print(
30-
f"Loaded MS MARCO dataset with {len(df)} samples (limited from full dataset)"
58+
f"Loaded MS MARCO dataset with {len(df)} samples (limited from full dataset) in {load_time:.2f}s"
3159
)
3260
else:
33-
print(f"Loaded MS MARCO dataset with {len(df)} samples")
61+
print(f"Loaded MS MARCO dataset with {len(df)} samples in {load_time:.2f}s")
62+
63+
if Path(database_path).exists():
64+
print(
65+
f"Warning: Database file {database_path} already exists and will be overwritten."
66+
)
67+
Path(database_path).unlink()
3468

3569
# Create SQLiteRag instance with provided settings
70+
print("Initializing SQLiteRag...")
71+
init_start = time.time()
3672
rag = SQLiteRag.create(database_path, settings=rag_settings)
73+
init_time = time.time() - init_start
74+
print(f"SQLiteRag initialized in {init_time:.2f}s")
3775

3876
# Process and add passages to the database
3977
total_passages_added = 0
4078
total_samples = len(df)
79+
processing_start = time.time()
4180

42-
print("Adding passages to sqlite_rag...")
81+
print(f"Adding passages to sqlite_rag... (processing {total_samples} queries)")
4382

4483
for idx, (_, row) in enumerate(df.iterrows(), 1):
4584
query_id = row["query_id"]
@@ -74,12 +113,18 @@ def test_ms_marco_processing(
74113

75114
# Progress update every 100 samples
76115
if idx % 100 == 0:
116+
elapsed = time.time() - processing_start
117+
rate = idx / elapsed if elapsed > 0 else 0
118+
eta = (total_samples - idx) / rate if rate > 0 else 0
77119
print(
78-
f"Processed {idx}/{total_samples} samples ({total_passages_added} passages)"
120+
f"Progress: {idx}/{total_samples} queries ({idx/total_samples*100:.1f}%) | "
121+
f"{total_passages_added} passages | {rate:.1f} queries/s | "
122+
f"ETA: {eta/60:.1f}m"
79123
)
80124

125+
processing_time = time.time() - processing_start
81126
print(
82-
f"Finished! Added {total_passages_added} passages from {total_samples} queries"
127+
f"Processing completed! Added {total_passages_added} passages from {total_samples} queries in {processing_time:.2f}s"
83128
)
84129

85130
# Verify data was added correctly
@@ -110,6 +155,18 @@ def test_ms_marco_processing(
110155
print("\nCurrent settings:")
111156
print(f" chunk_size: {settings_info['chunk_size']}")
112157
print(f" chunk_overlap: {settings_info['chunk_overlap']}")
158+
print(f" weight_fts: {settings_info.get('weight_fts', 1.0)}")
159+
print(f" weight_vec: {settings_info.get('weight_vec', 1.0)}")
160+
161+
total_time = time.time() - start_time
162+
print(f"\n{'='*60}")
163+
print("TIMING SUMMARY:")
164+
print(f" Dataset loading: {load_time:.2f}s")
165+
print(f" SQLiteRag init: {init_time:.2f}s")
166+
print(f" Processing: {processing_time:.2f}s")
167+
print(f" Total time: {total_time:.2f}s")
168+
print(f" Average rate: {total_passages_added/processing_time:.1f} passages/s")
169+
print(f"{'='*60}")
113170

114171
rag.close()
115172
return total_passages_added, len(queries_dict)
@@ -148,7 +205,12 @@ def output(text):
148205

149206
# Metrics for different top-k values
150207
k_values = [1, 3, 5, 10]
151-
metrics = {k: {"hit_rate": 0, "reciprocal_ranks": []} for k in k_values}
208+
metrics = {
209+
k: {"hit_rate": 0, "reciprocal_ranks": [], "ndcg_scores": []} for k in k_values
210+
}
211+
212+
# Track queries with no matches at HR@1 for detailed output
213+
failed_hr1_queries = []
152214

153215
total_queries = 0
154216
queries_with_relevant = 0
@@ -173,25 +235,45 @@ def output(text):
173235
search_results = rag.search(query_text, top_k=10)
174236

175237
# Check results for each k value
238+
hr1_found = False # Track if any relevant result found in top-1
239+
176240
for k in k_values:
177241
top_k_results = search_results[:k]
178242

179243
# Find relevant results in top-k
180244
relevant_found = False
181245
first_relevant_rank = None
246+
predicted_relevance = [] # For NDCG calculation
182247

183248
for rank, result in enumerate(top_k_results, 1):
184249
metadata = result.document.metadata
185-
if (
250+
is_relevant = (
186251
metadata
187252
and metadata.get("query_id") == str(query_id)
188253
and metadata.get("is_selected")
189-
):
254+
)
190255

256+
# Add relevance score (1 for relevant, 0 for non-relevant)
257+
predicted_relevance.append(1.0 if is_relevant else 0.0)
258+
259+
if is_relevant:
191260
if not relevant_found:
192261
relevant_found = True
193262
first_relevant_rank = rank
194263

264+
# Track HR@1 success
265+
if k == 1:
266+
hr1_found = True
267+
268+
# Calculate NDCG@k
269+
# Ideal relevance: all relevant docs at the top
270+
num_relevant = len(selected_indices)
271+
ideal_relevance = [1.0] * min(num_relevant, k) + [0.0] * max(
272+
0, k - num_relevant
273+
)
274+
ndcg_score = calculate_ndcg(predicted_relevance, ideal_relevance)
275+
metrics[k]["ndcg_scores"].append(ndcg_score)
276+
195277
# Update hit rate
196278
if relevant_found:
197279
metrics[k]["hit_rate"] += 1
@@ -202,6 +284,10 @@ def output(text):
202284
else:
203285
metrics[k]["reciprocal_ranks"].append(0.0)
204286

287+
# Track queries that failed HR@1
288+
if not hr1_found:
289+
failed_hr1_queries.append({"query_id": query_id, "query": query_text})
290+
205291
# Progress update
206292
if (idx + 1) % 50 == 0:
207293
print(
@@ -246,11 +332,27 @@ def output(text):
246332
f"{'MRR':<20} {mrr_values[0]:<10} {mrr_values[1]:<10} {mrr_values[2]:<10} {mrr_values[3]:<10}"
247333
)
248334

335+
# NDCG@k
336+
ndcg_values = []
337+
for k in k_values:
338+
if metrics[k]["ndcg_scores"]:
339+
ndcg = np.mean(metrics[k]["ndcg_scores"])
340+
ndcg_values.append(f"{ndcg:.3f}")
341+
else:
342+
ndcg_values.append("0.000")
343+
344+
output(
345+
f"{'NDCG':<20} {ndcg_values[0]:<10} {ndcg_values[1]:<10} {ndcg_values[2]:<10} {ndcg_values[3]:<10}"
346+
)
347+
249348
output(f"\n{'='*60}")
250349
output("INTERPRETATION:")
251350
output("- Hit Rate: % of queries where at least 1 relevant result appears in top-k")
252351
output("- MRR: Mean Reciprocal Rank, higher is better (max=1.0)")
253-
output("- Good performance: HR@5 > 0.7, MRR@5 > 0.5")
352+
output(
353+
"- NDCG: Normalized Discounted Cumulative Gain, considers relevance and position (max=1.0)"
354+
)
355+
output("- Good performance: HR@5 > 0.7, MRR@5 > 0.5, NDCG@5 > 0.6")
254356
output(f"{'='*60}")
255357

256358
# Save to file if specified
@@ -261,6 +363,17 @@ def output(text):
261363
f.write(f"Limit rows: {limit_rows if limit_rows else 'All'}\n\n")
262364
f.write("\n".join(output_lines))
263365

366+
# Add list of queries that failed HR@1
367+
if failed_hr1_queries:
368+
f.write(f"\n\n{'='*60}\n")
369+
f.write(
370+
f"QUERIES WITH NO MATCHES AT HR@1 ({len(failed_hr1_queries)} queries):\n"
371+
)
372+
f.write(f"{'='*60}\n\n")
373+
for i, query_info in enumerate(failed_hr1_queries, 1):
374+
f.write(f"{i}. Query ID: {query_info['query_id']}\n")
375+
f.write(f" Query: {query_info['query']}\n\n")
376+
264377
print(f"\nResults saved to: {output_file}")
265378

266379

@@ -294,6 +407,8 @@ def create_example_config():
294407
"rag_settings": {
295408
"chunk_size": 1000,
296409
"chunk_overlap": 0,
410+
"weight_fts": 1.0,
411+
"weight_vec": 1.0,
297412
"model_path_or_name": "./models/Qwen/Qwen3-Embedding-0.6B-GGUF/Qwen3-Embedding-0.6B-Q8_0.gguf",
298413
"model_options": "",
299414
"model_context_options": "generate_embedding=1,normalize_embedding=1,pooling_type=mean,embedding_type=INT8",

src/sqlite_rag/cli.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#!/usr/bin/env python3
22
import json
3+
import os
34
import shlex
45
import sys
56
import time
@@ -401,6 +402,8 @@ def download_model(
401402

402403
try:
403404
# Download the specific GGUF file
405+
# Enable fast transfer
406+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
404407
downloaded_path = hf_hub_download(
405408
repo_id=model_id,
406409
filename=gguf_file,

0 commit comments

Comments
 (0)