Skip to content

Commit 7fc8738

Browse files
author
Daniele Briggi
committed
feat(model): swith to gemma-embedding as default
chore(ai): defaults to gpu extension
1 parent 56d8493 commit 7fc8738

File tree

10 files changed

+2841
-14
lines changed

10 files changed

+2841
-14
lines changed

.devcontainer/devcontainer.json

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
{
22
"name": "Python 3.10",
33
"image": "mcr.microsoft.com/devcontainers/python:3.10",
4+
"runArgs": [
5+
"--runtime",
6+
"nvidia",
7+
"--gpus",
8+
"all",
9+
// optional but make sure CUDA workloads are available
10+
"--env",
11+
"NVIDIA_VISIBLE_DEVICES=all",
12+
// optional but make sure CUDA workloads are available
13+
"--env",
14+
"NVIDIA_DRIVER_CAPABILITIES=compute,utility"
15+
],
416
"customizations": {
517
"vscode": {
618
"extensions": [
@@ -15,4 +27,12 @@
1527
]
1628
}
1729
},
30+
"hostRequirements": {
31+
"gpu": "optional"
32+
},
33+
"remoteEnv": {
34+
// optional but make sure CUDA workloads are available
35+
"NVIDIA_VISIBLE_DEVICES": "all",
36+
"NVIDIA_DRIVER_CAPABILITIES": "compute,utility"
37+
}
1838
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"database_path": "./databases/ms_marco_gemma_300M_Q8_650rows.sqlite",
3+
"rag_settings": {
4+
"chunk_size": 1000,
5+
"chunk_overlap": 0,
6+
"model_path_or_name": "./../models/unsloth/embeddinggemma-300m-GGUF/embeddinggemma-300M-Q8_0.gguf",
7+
"model_options": "",
8+
"model_context_options": "generate_embedding=1,normalize_embedding=1,pooling_type=mean,embedding_type=INT8",
9+
"vector_type": "INT8",
10+
"embedding_dim": 768,
11+
"other_vector_options": "distance=cosine",
12+
"weight_fts": 0.0,
13+
"weight_vec": 1.0
14+
}
15+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
{
2+
"database_path": "./databases/ms_marco_qwen3_Q8_650rows.sqlite",
3+
"rag_settings": {
4+
"chunk_size": 1000,
5+
"chunk_overlap": 0,
6+
"model_path_or_name": "./../models/Qwen/Qwen3-Embedding-0.6B-GGUF/Qwen3-Embedding-0.6B-Q8_0.gguf",
7+
"model_options": "",
8+
"model_context_options": "generate_embedding=1,normalize_embedding=1,pooling_type=last,embedding_type=INT8",
9+
"vector_type": "INT8",
10+
"embedding_dim": 1024,
11+
"other_vector_options": "distance=cosine",
12+
"weight_fts": 0.0,
13+
"weight_vec": 1.0
14+
}
15+
}

model_evaluation/ms_marco.py

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,33 @@
55

66
import numpy as np
77
import pandas as pd
8+
import psutil
89

910
from sqlite_rag import SQLiteRag
1011

1112

13+
class MemoryMonitor:
14+
def __init__(self):
15+
self.process = psutil.Process()
16+
self.memory_usage = []
17+
18+
def record(self):
19+
"""Record current memory usage"""
20+
mem_info = self.process.memory_info()
21+
self.memory_usage.append(mem_info.rss / 1024 / 1024) # Convert to MB
22+
23+
def get_stats(self):
24+
"""Get memory statistics"""
25+
if not self.memory_usage:
26+
return {"avg": 0, "max": 0, "min": 0}
27+
28+
return {
29+
"avg": sum(self.memory_usage) / len(self.memory_usage),
30+
"max": max(self.memory_usage),
31+
"min": min(self.memory_usage),
32+
}
33+
34+
1235
def calculate_dcg(relevance_scores):
1336
"""Calculate Discounted Cumulative Gain"""
1437
dcg = 0.0
@@ -37,6 +60,7 @@ def test_ms_marco_processing(
3760
"""Test processing MS MARCO dataset with SQLiteRag"""
3861

3962
start_time = time.time()
63+
memory_monitor = MemoryMonitor()
4064

4165
if rag_settings is None:
4266
rag_settings = {"chunk_size": 1000, "chunk_overlap": 0}
@@ -70,6 +94,7 @@ def test_ms_marco_processing(
7094
print("Initializing SQLiteRag...")
7195
init_start = time.time()
7296
rag = SQLiteRag.create(database_path, settings=rag_settings)
97+
memory_monitor.record() # After RAG initialization
7398
init_time = time.time() - init_start
7499
print(f"SQLiteRag initialized in {init_time:.2f}s")
75100

@@ -107,12 +132,18 @@ def test_ms_marco_processing(
107132
uri = f"ms_marco_query_{query_id}_passage_{i}"
108133

109134
# Add passage to the database
110-
rag.add_text(text=passage_text, uri=uri, metadata=metadata)
135+
try:
136+
rag.add_text(text=passage_text, uri=uri, metadata=metadata)
111137

112-
total_passages_added += 1
138+
total_passages_added += 1
139+
except Exception as e:
140+
print(
141+
f"Error adding passage for query_id {query_id}, passage_index {i}: {e}"
142+
)
113143

114144
# Progress update every 100 samples
115145
if idx % 100 == 0:
146+
memory_monitor.record() # Record memory during processing
116147
elapsed = time.time() - processing_start
117148
rate = idx / elapsed if elapsed > 0 else 0
118149
eta = (total_samples - idx) / rate if rate > 0 else 0
@@ -158,6 +189,9 @@ def test_ms_marco_processing(
158189
print(f" weight_fts: {settings_info.get('weight_fts', 1.0)}")
159190
print(f" weight_vec: {settings_info.get('weight_vec', 1.0)}")
160191

192+
memory_monitor.record() # Final memory reading
193+
memory_stats = memory_monitor.get_stats()
194+
161195
total_time = time.time() - start_time
162196
print(f"\n{'='*60}")
163197
print("TIMING SUMMARY:")
@@ -167,6 +201,11 @@ def test_ms_marco_processing(
167201
print(f" Total time: {total_time:.2f}s")
168202
print(f" Average rate: {total_passages_added/processing_time:.1f} passages/s")
169203
print(f"{'='*60}")
204+
print("MEMORY USAGE SUMMARY:")
205+
print(f" Average memory: {memory_stats['avg']:.1f} MB")
206+
print(f" Maximum memory: {memory_stats['max']:.1f} MB")
207+
print(f" Minimum memory: {memory_stats['min']:.1f} MB")
208+
print(f"{'='*60}")
170209

171210
rag.close()
172211
return total_passages_added, len(queries_dict)
@@ -177,6 +216,9 @@ def evaluate_search_quality(
177216
):
178217
"""Evaluate search quality using proper metrics"""
179218

219+
# Setup memory monitoring
220+
memory_monitor = MemoryMonitor()
221+
180222
# Setup output capture
181223
output_lines = []
182224

@@ -202,6 +244,7 @@ def output(text):
202244

203245
# Create RAG instance
204246
rag = SQLiteRag.create(database_path)
247+
memory_monitor.record() # After RAG initialization
205248

206249
# Metrics for different top-k values
207250
k_values = [1, 3, 5, 10]
@@ -290,12 +333,17 @@ def output(text):
290333

291334
# Progress update
292335
if (idx + 1) % 50 == 0:
336+
memory_monitor.record() # Record memory during evaluation
293337
print(
294338
f"Processed {idx + 1}/{len(df)} queries..."
295339
) # Only to console, not to file
296340

297341
rag.close()
298342

343+
# Final memory reading and calculate stats
344+
memory_monitor.record()
345+
memory_stats = memory_monitor.get_stats()
346+
299347
# Calculate and display final metrics
300348
output(f"\n{'='*60}")
301349
output("SEARCH QUALITY EVALUATION RESULTS")
@@ -345,6 +393,11 @@ def output(text):
345393
f"{'NDCG':<20} {ndcg_values[0]:<10} {ndcg_values[1]:<10} {ndcg_values[2]:<10} {ndcg_values[3]:<10}"
346394
)
347395

396+
output(f"\n{'='*60}")
397+
output("MEMORY USAGE SUMMARY:")
398+
output(f" Average memory: {memory_stats['avg']:.1f} MB")
399+
output(f" Maximum memory: {memory_stats['max']:.1f} MB")
400+
output(f" Minimum memory: {memory_stats['min']:.1f} MB")
348401
output(f"\n{'='*60}")
349402
output("INTERPRETATION:")
350403
output("- Hit Rate: % of queries where at least 1 relevant result appears in top-k")
@@ -403,7 +456,7 @@ def load_config(config_path):
403456
def create_example_config():
404457
"""Create an example configuration file"""
405458
example_config = {
406-
"database_path": "ms_marco_test.sqlite",
459+
"database_path": "./configs/ms_marco_test.sqlite",
407460
"rag_settings": {
408461
"chunk_size": 1000,
409462
"chunk_overlap": 0,
@@ -418,7 +471,7 @@ def create_example_config():
418471
},
419472
}
420473

421-
config_file = Path("ms_marco_config.json")
474+
config_file = Path("./configs/ms_marco_config.json")
422475
with open(config_file, "w") as f:
423476
json.dump(example_config, f, indent=2)
424477

@@ -451,14 +504,14 @@ def main():
451504

452505
args = parser.parse_args()
453506

454-
# Load configuration
455-
try:
456-
config = load_config(args.config)
457-
except FileNotFoundError:
458-
print(f"Config file {args.config} not found. Creating example config...")
507+
if args.config is None:
508+
print("Missing config file. Creating example config...")
459509
create_example_config()
460510
print("Please edit ms_marco_config.json with your settings and try again.")
461511
return
512+
513+
try:
514+
config = load_config(args.config)
462515
except Exception as e:
463516
print(f"Error loading config: {e}")
464517
return

0 commit comments

Comments
 (0)