11import argparse
22import json
3+ import time
34from pathlib import Path
45
56import numpy as np
89from sqlite_rag import SQLiteRag
910
1011
12+ def calculate_dcg (relevance_scores ):
13+ """Calculate Discounted Cumulative Gain"""
14+ dcg = 0.0
15+ for i , rel in enumerate (relevance_scores ):
16+ dcg += rel / np .log2 (i + 2 ) # i+2 because log2(1) = 0
17+ return dcg
18+
19+
20+ def calculate_ndcg (predicted_relevance , ideal_relevance ):
21+ """Calculate Normalized Discounted Cumulative Gain"""
22+ if not predicted_relevance :
23+ return 0.0
24+
25+ dcg = calculate_dcg (predicted_relevance )
26+ idcg = calculate_dcg (ideal_relevance )
27+
28+ if idcg == 0 :
29+ return 0.0
30+
31+ return dcg / idcg
32+
33+
1134def test_ms_marco_processing (
1235 limit_rows = None , database_path = "ms_marco_test.sqlite" , rag_settings = None
1336):
1437 """Test processing MS MARCO dataset with SQLiteRag"""
1538
39+ start_time = time .time ()
40+
1641 if rag_settings is None :
1742 rag_settings = {"chunk_size" : 1000 , "chunk_overlap" : 0 }
1843
1944 # Load the MS MARCO test dataset
45+ print ("Loading MS MARCO dataset..." )
46+ load_start = time .time ()
2047 parquet_path = Path ("ms_marco_test.parquet" )
2148 if not parquet_path .exists ():
2249 raise FileNotFoundError (f"Dataset file { parquet_path } not found" )
2350
2451 df = pd .read_parquet (parquet_path )
52+ load_time = time .time () - load_start
2553
2654 # Limit rows if specified
2755 if limit_rows :
2856 df = df .head (limit_rows )
2957 print (
30- f"Loaded MS MARCO dataset with { len (df )} samples (limited from full dataset)"
58+ f"Loaded MS MARCO dataset with { len (df )} samples (limited from full dataset) in { load_time :.2f } s "
3159 )
3260 else :
33- print (f"Loaded MS MARCO dataset with { len (df )} samples" )
61+ print (f"Loaded MS MARCO dataset with { len (df )} samples in { load_time :.2f} s" )
62+
63+ if Path (database_path ).exists ():
64+ print (
65+ f"Warning: Database file { database_path } already exists and will be overwritten."
66+ )
67+ Path (database_path ).unlink ()
3468
3569 # Create SQLiteRag instance with provided settings
70+ print ("Initializing SQLiteRag..." )
71+ init_start = time .time ()
3672 rag = SQLiteRag .create (database_path , settings = rag_settings )
73+ init_time = time .time () - init_start
74+ print (f"SQLiteRag initialized in { init_time :.2f} s" )
3775
3876 # Process and add passages to the database
3977 total_passages_added = 0
4078 total_samples = len (df )
79+ processing_start = time .time ()
4180
42- print ("Adding passages to sqlite_rag..." )
81+ print (f "Adding passages to sqlite_rag... (processing { total_samples } queries) " )
4382
4483 for idx , (_ , row ) in enumerate (df .iterrows (), 1 ):
4584 query_id = row ["query_id" ]
@@ -74,12 +113,18 @@ def test_ms_marco_processing(
74113
75114 # Progress update every 100 samples
76115 if idx % 100 == 0 :
116+ elapsed = time .time () - processing_start
117+ rate = idx / elapsed if elapsed > 0 else 0
118+ eta = (total_samples - idx ) / rate if rate > 0 else 0
77119 print (
78- f"Processed { idx } /{ total_samples } samples ({ total_passages_added } passages)"
120+ f"Progress: { idx } /{ total_samples } queries ({ idx / total_samples * 100 :.1f} %) | "
121+ f"{ total_passages_added } passages | { rate :.1f} queries/s | "
122+ f"ETA: { eta / 60 :.1f} m"
79123 )
80124
125+ processing_time = time .time () - processing_start
81126 print (
82- f"Finished ! Added { total_passages_added } passages from { total_samples } queries"
127+ f"Processing completed ! Added { total_passages_added } passages from { total_samples } queries in { processing_time :.2f } s "
83128 )
84129
85130 # Verify data was added correctly
@@ -110,6 +155,18 @@ def test_ms_marco_processing(
110155 print ("\n Current settings:" )
111156 print (f" chunk_size: { settings_info ['chunk_size' ]} " )
112157 print (f" chunk_overlap: { settings_info ['chunk_overlap' ]} " )
158+ print (f" weight_fts: { settings_info .get ('weight_fts' , 1.0 )} " )
159+ print (f" weight_vec: { settings_info .get ('weight_vec' , 1.0 )} " )
160+
161+ total_time = time .time () - start_time
162+ print (f"\n { '=' * 60 } " )
163+ print ("TIMING SUMMARY:" )
164+ print (f" Dataset loading: { load_time :.2f} s" )
165+ print (f" SQLiteRag init: { init_time :.2f} s" )
166+ print (f" Processing: { processing_time :.2f} s" )
167+ print (f" Total time: { total_time :.2f} s" )
168+ print (f" Average rate: { total_passages_added / processing_time :.1f} passages/s" )
169+ print (f"{ '=' * 60 } " )
113170
114171 rag .close ()
115172 return total_passages_added , len (queries_dict )
@@ -148,7 +205,12 @@ def output(text):
148205
149206 # Metrics for different top-k values
150207 k_values = [1 , 3 , 5 , 10 ]
151- metrics = {k : {"hit_rate" : 0 , "reciprocal_ranks" : []} for k in k_values }
208+ metrics = {
209+ k : {"hit_rate" : 0 , "reciprocal_ranks" : [], "ndcg_scores" : []} for k in k_values
210+ }
211+
212+ # Track queries with no matches at HR@1 for detailed output
213+ failed_hr1_queries = []
152214
153215 total_queries = 0
154216 queries_with_relevant = 0
@@ -173,25 +235,45 @@ def output(text):
173235 search_results = rag .search (query_text , top_k = 10 )
174236
175237 # Check results for each k value
238+ hr1_found = False # Track if any relevant result found in top-1
239+
176240 for k in k_values :
177241 top_k_results = search_results [:k ]
178242
179243 # Find relevant results in top-k
180244 relevant_found = False
181245 first_relevant_rank = None
246+ predicted_relevance = [] # For NDCG calculation
182247
183248 for rank , result in enumerate (top_k_results , 1 ):
184249 metadata = result .document .metadata
185- if (
250+ is_relevant = (
186251 metadata
187252 and metadata .get ("query_id" ) == str (query_id )
188253 and metadata .get ("is_selected" )
189- ):
254+ )
190255
256+ # Add relevance score (1 for relevant, 0 for non-relevant)
257+ predicted_relevance .append (1.0 if is_relevant else 0.0 )
258+
259+ if is_relevant :
191260 if not relevant_found :
192261 relevant_found = True
193262 first_relevant_rank = rank
194263
264+ # Track HR@1 success
265+ if k == 1 :
266+ hr1_found = True
267+
268+ # Calculate NDCG@k
269+ # Ideal relevance: all relevant docs at the top
270+ num_relevant = len (selected_indices )
271+ ideal_relevance = [1.0 ] * min (num_relevant , k ) + [0.0 ] * max (
272+ 0 , k - num_relevant
273+ )
274+ ndcg_score = calculate_ndcg (predicted_relevance , ideal_relevance )
275+ metrics [k ]["ndcg_scores" ].append (ndcg_score )
276+
195277 # Update hit rate
196278 if relevant_found :
197279 metrics [k ]["hit_rate" ] += 1
@@ -202,6 +284,10 @@ def output(text):
202284 else :
203285 metrics [k ]["reciprocal_ranks" ].append (0.0 )
204286
287+ # Track queries that failed HR@1
288+ if not hr1_found :
289+ failed_hr1_queries .append ({"query_id" : query_id , "query" : query_text })
290+
205291 # Progress update
206292 if (idx + 1 ) % 50 == 0 :
207293 print (
@@ -246,11 +332,27 @@ def output(text):
246332 f"{ 'MRR' :<20} { mrr_values [0 ]:<10} { mrr_values [1 ]:<10} { mrr_values [2 ]:<10} { mrr_values [3 ]:<10} "
247333 )
248334
335+ # NDCG@k
336+ ndcg_values = []
337+ for k in k_values :
338+ if metrics [k ]["ndcg_scores" ]:
339+ ndcg = np .mean (metrics [k ]["ndcg_scores" ])
340+ ndcg_values .append (f"{ ndcg :.3f} " )
341+ else :
342+ ndcg_values .append ("0.000" )
343+
344+ output (
345+ f"{ 'NDCG' :<20} { ndcg_values [0 ]:<10} { ndcg_values [1 ]:<10} { ndcg_values [2 ]:<10} { ndcg_values [3 ]:<10} "
346+ )
347+
249348 output (f"\n { '=' * 60 } " )
250349 output ("INTERPRETATION:" )
251350 output ("- Hit Rate: % of queries where at least 1 relevant result appears in top-k" )
252351 output ("- MRR: Mean Reciprocal Rank, higher is better (max=1.0)" )
253- output ("- Good performance: HR@5 > 0.7, MRR@5 > 0.5" )
352+ output (
353+ "- NDCG: Normalized Discounted Cumulative Gain, considers relevance and position (max=1.0)"
354+ )
355+ output ("- Good performance: HR@5 > 0.7, MRR@5 > 0.5, NDCG@5 > 0.6" )
254356 output (f"{ '=' * 60 } " )
255357
256358 # Save to file if specified
@@ -261,6 +363,17 @@ def output(text):
261363 f .write (f"Limit rows: { limit_rows if limit_rows else 'All' } \n \n " )
262364 f .write ("\n " .join (output_lines ))
263365
366+ # Add list of queries that failed HR@1
367+ if failed_hr1_queries :
368+ f .write (f"\n \n { '=' * 60 } \n " )
369+ f .write (
370+ f"QUERIES WITH NO MATCHES AT HR@1 ({ len (failed_hr1_queries )} queries):\n "
371+ )
372+ f .write (f"{ '=' * 60 } \n \n " )
373+ for i , query_info in enumerate (failed_hr1_queries , 1 ):
374+ f .write (f"{ i } . Query ID: { query_info ['query_id' ]} \n " )
375+ f .write (f" Query: { query_info ['query' ]} \n \n " )
376+
264377 print (f"\n Results saved to: { output_file } " )
265378
266379
@@ -294,6 +407,8 @@ def create_example_config():
294407 "rag_settings" : {
295408 "chunk_size" : 1000 ,
296409 "chunk_overlap" : 0 ,
410+ "weight_fts" : 1.0 ,
411+ "weight_vec" : 1.0 ,
297412 "model_path_or_name" : "./models/Qwen/Qwen3-Embedding-0.6B-GGUF/Qwen3-Embedding-0.6B-Q8_0.gguf" ,
298413 "model_options" : "" ,
299414 "model_context_options" : "generate_embedding=1,normalize_embedding=1,pooling_type=mean,embedding_type=INT8" ,
0 commit comments