Skip to content

Commit 369d5a7

Browse files
committed
Retrieval evaluation expand
1 parent 7afb9a3 commit 369d5a7

File tree

1 file changed

+91
-30
lines changed

1 file changed

+91
-30
lines changed
Lines changed: 91 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22
from niw_np_rag.app.rag import RAGPipeline
33
import json
4+
from langchain_huggingface import HuggingFaceEmbeddings
5+
import math
46

57
with open(r".\evaluation\datasets\niw_qna.json", "r", encoding="utf-8") as f:
68
dataset = json.load(f)
@@ -12,49 +14,108 @@
1214
semantic_chunking=True
1315
)
1416

15-
def distance(a, list_b):
16-
"""Compute minimum Levenshtein distance between string a and any string in list_b."""
17-
from Levenshtein import distance as lev_distance
18-
return min(lev_distance(a, b) for b in list_b)
1917

2018
def cosine_similarity(a, b):
2119
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
2220

2321
retriever = rag.get_retriever(k=5)
2422

23+
emb = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
2524

26-
def evaluate_recall_at_k(dataset, retriever, k=5):
27-
"""
28-
dataset: list of dicts with fields:
29-
- question
30-
- answer (not required for recall)
31-
- context (ground-truth context from source docs)
32-
"""
3325

34-
hits = []
26+
# -----------------------------
27+
# Utility Functions
28+
# -----------------------------
29+
def cosine_similarity(a, b):
30+
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
31+
32+
33+
def is_hit(gt_context, retrieved_doc, threshold=0.45):
34+
"""Semantic match check using embeddings."""
35+
gt_emb = emb.embed_query(gt_context)
36+
doc_emb = emb.embed_query(retrieved_doc)
37+
sim = cosine_similarity(gt_emb, doc_emb)
38+
return sim >= threshold
39+
40+
41+
# -----------------------------
42+
# Retrieval Metrics
43+
# -----------------------------
44+
def evaluate_retrieval(dataset, retriever, k=10):
45+
hits = [] # for Recall@K
46+
precision_scores = [] # for Precision@K
47+
mrr_scores = [] # for Mean Reciprocal Rank
48+
ndcg_scores = [] # for nDCG@K
3549

3650
for item in dataset:
3751
question = item["question"]
3852
ground_truth_context = item["context"]
39-
# print(question, ground_truth_context)
4053

41-
# Retrieve top-k documents
54+
# Retrieve documents (we keep top-K)
4255
retrieved_docs = retriever.invoke(question)
43-
44-
# Extract retrieved text
4556
retrieved_texts = [doc.page_content for doc in retrieved_docs[:k]]
4657

47-
# Check if ground-truth context appears in retrieved docs
48-
hit = any(
49-
distance (ground_truth_context, retrieved_texts) in retrieved_doc
50-
for retrieved_doc in retrieved_texts
51-
)
52-
53-
hits.append(1 if hit else 0)
54-
55-
recall_k = np.mean(hits)
56-
57-
print(f"Recall@{k}: {recall_k:.4f}")
58-
return recall_k
59-
60-
evaluate_recall_at_k(dataset, retriever, k=15)
58+
# Track hits at ranks
59+
hit_list = []
60+
61+
for idx, doc_text in enumerate(retrieved_texts):
62+
match = is_hit(ground_truth_context, doc_text)
63+
64+
hit_list.append(1 if match else 0)
65+
66+
# -----------------------------
67+
# Compute Metrics
68+
# -----------------------------
69+
70+
# Recall@K → was any retrieved doc correct?
71+
recall_k = 1 if any(hit_list) else 0
72+
hits.append(recall_k)
73+
74+
# Precision@K → proportion of correct retrieved docs
75+
precision = sum(hit_list) / k
76+
precision_scores.append(precision)
77+
78+
# MRR → 1 / rank of first relevant doc
79+
if 1 in hit_list:
80+
rank = hit_list.index(1) + 1
81+
mrr_scores.append(1 / rank)
82+
else:
83+
mrr_scores.append(0)
84+
85+
# nDCG@K → ranking quality
86+
dcg = sum([
87+
hit_list[i] / math.log2(i + 2) # DCG formula
88+
for i in range(len(hit_list))
89+
])
90+
91+
# Ideal DCG (all relevant docs ranked at top)
92+
ideal_hits = sorted(hit_list, reverse=True)
93+
idcg = sum([
94+
ideal_hits[i] / math.log2(i + 2)
95+
for i in range(len(ideal_hits))
96+
])
97+
98+
ndcg = dcg / idcg if idcg > 0 else 0
99+
ndcg_scores.append(ndcg)
100+
101+
# -----------------------------
102+
# Final Averages
103+
# -----------------------------
104+
results = {
105+
"Recall@K": float(np.mean(hits)),
106+
"Precision@K": float(np.mean(precision_scores)),
107+
"MRR": float(np.mean(mrr_scores)),
108+
"nDCG@K": float(np.mean(ndcg_scores)),
109+
}
110+
111+
return results
112+
113+
114+
# -----------------------------
115+
# Run Evaluation
116+
# -----------------------------
117+
metrics = evaluate_retrieval(dataset, retriever, k=15)
118+
119+
print("\n=== Retrieval Evaluation Results ===")
120+
for key, value in metrics.items():
121+
print(f"{key}: {value:.4f}")

0 commit comments

Comments
 (0)