Skip to content

Commit 6076307

Browse files
emmanuelgjrclaude
andcommitted
Add Phase 3b — cross-encoder reranker, LLM labeler, eval harness
Reranker: cross-encoder/ms-marco-MiniLM-L-6-v2 on bi-encoder top-50. LLM labeler: Claude Sonnet few-shot from calibration set (ready to run when ANTHROPIC_API_KEY is set). Eval harness: P@k, R@k, MAP, R-Precision with 95% bootstrap CIs. Baseline results (bi-encoder): MAP=0.004, R@10=0.010 Reranker results: MAP=0.006 (+46%), P@3=0.122 (+200%) Both FAIL pre-registered thresholds — expected given only 8% of ground-truth controls exist in the registry (88/1097 overlap). The retrieval pipeline is semantically correct (ATLAS prompt injection ranks #1 for LLM01, MAESTRO agent goal integrity ranks #1 for ASI01) but registry coverage must increase for metrics to pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 10f2372 commit 6076307

6 files changed

Lines changed: 827 additions & 5 deletions

File tree

classifier/classify.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ def load_index():
3131

3232

3333
def classify(source_id: str, target_framework: str | None = None,
34-
top_k: int = DEFAULT_TOP_K, output_json: bool = False) -> list[dict]:
34+
top_k: int = DEFAULT_TOP_K, output_json: bool = False,
35+
use_reranker: bool = False) -> list[dict]:
3536
"""Retrieve top-k candidate controls for an OWASP entry."""
3637
# Load entry
3738
entries = load_entries()
@@ -51,8 +52,10 @@ def classify(source_id: str, target_framework: str | None = None,
5152
q_emb = model.encode([query], normalize_embeddings=True)
5253
q_emb = np.array(q_emb, dtype=np.float32)
5354

54-
# Search — retrieve more if filtering by framework
55-
search_k = top_k * 5 if target_framework else top_k
55+
# Search — retrieve more for reranking or framework filtering
56+
search_k = top_k * 5 if (target_framework or use_reranker) else top_k
57+
if use_reranker:
58+
search_k = max(search_k, 50)
5659
search_k = min(search_k, index.ntotal)
5760
scores, indices = index.search(q_emb, search_k)
5861

@@ -71,10 +74,20 @@ def classify(source_id: str, target_framework: str | None = None,
7174
"title": m["title"],
7275
"function": m["function"],
7376
"score": round(float(score), 4),
77+
"text": f"{m['framework']} -- {m['control_id']}: {m['title']}",
7478
})
75-
if len(results) >= top_k:
79+
if not use_reranker and len(results) >= top_k:
7680
break
7781

82+
# Cross-encoder reranking
83+
if use_reranker and results:
84+
from .reranker import rerank
85+
results = rerank(query, results, top_k=top_k)
86+
87+
# Remove internal fields from output
88+
for r in results:
89+
r.pop("text", None)
90+
7891
# Check against existing mappings
7992
existing = {
8093
(m["framework"], m["control_id"])
@@ -118,9 +131,10 @@ def main():
118131
parser.add_argument("--target", default=None, help="Target framework name (optional)")
119132
parser.add_argument("--top-k", type=int, default=DEFAULT_TOP_K, help=f"Number of candidates (default {DEFAULT_TOP_K})")
120133
parser.add_argument("--json", action="store_true", help="Output as JSON")
134+
parser.add_argument("--rerank", action="store_true", help="Use cross-encoder reranker")
121135
args = parser.parse_args()
122136

123-
classify(args.source, args.target, args.top_k, args.json)
137+
classify(args.source, args.target, args.top_k, args.json, args.rerank)
124138

125139

126140
if __name__ == "__main__":

classifier/eval_harness.py

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
"""Evaluation harness for the classifier pipeline.
2+
3+
Computes P@k, R@k, MAP, R-Precision with 95% bootstrap CIs against the
4+
frozen test split. Compares bi-encoder baseline vs. cross-encoder reranker.
5+
6+
Usage:
7+
python -m classifier.eval_harness
8+
python -m classifier.eval_harness --rerank
9+
python -m classifier.eval_harness --rerank --verbose
10+
"""
11+
12+
import argparse
13+
import json
14+
import sys
15+
import time
16+
from collections import defaultdict
17+
18+
import numpy as np
19+
from sentence_transformers import SentenceTransformer
20+
21+
from .config import (
22+
BIENCODER_MODEL, EMBEDDING_DIM, INDEX_DIR, SPLITS_DIR,
23+
K_VALUES, BOOTSTRAP_SEED, BOOTSTRAP_N,
24+
)
25+
from .data_loader import load_entries, load_controls, build_query
26+
27+
28+
def load_test_split() -> dict[str, set[tuple[str, str]]]:
29+
"""Load test split as {entry_id: set of (framework, control_id)}."""
30+
test_path = SPLITS_DIR / "test.json"
31+
if not test_path.exists():
32+
print("ERROR: Test split not found. Run split_maker first.")
33+
sys.exit(1)
34+
test = json.loads(test_path.read_text("utf-8"))
35+
ground_truth = defaultdict(set)
36+
for item in test:
37+
ground_truth[item["entry_id"]].add(
38+
(item["framework"], item["control_id"])
39+
)
40+
return dict(ground_truth)
41+
42+
43+
def load_index():
44+
"""Load FAISS index and metadata."""
45+
import faiss
46+
index_path = INDEX_DIR / "controls.index"
47+
meta_path = INDEX_DIR / "controls_meta.json"
48+
if not index_path.exists():
49+
print("ERROR: Index not found. Run index_builder first.")
50+
sys.exit(1)
51+
index = faiss.read_index(str(index_path))
52+
meta = json.loads(meta_path.read_text("utf-8"))
53+
return index, meta
54+
55+
56+
def retrieve_candidates(entry: dict, index, meta, model,
57+
top_k: int = 50) -> list[dict]:
58+
"""Bi-encoder retrieval of top-k candidates."""
59+
query = build_query(entry["id"], entry["name"], entry["severity"])
60+
q_emb = model.encode([query], normalize_embeddings=True)
61+
q_emb = np.array(q_emb, dtype=np.float32)
62+
63+
search_k = min(top_k, index.ntotal)
64+
scores, indices = index.search(q_emb, search_k)
65+
66+
candidates = []
67+
for score, idx in zip(scores[0], indices[0]):
68+
if idx < 0:
69+
continue
70+
m = meta[idx]
71+
candidates.append({
72+
"framework": m["framework"],
73+
"control_id": m["control_id"],
74+
"title": m["title"],
75+
"function": m.get("function"),
76+
"score": float(score),
77+
"text": f"{m['framework']} -- {m['control_id']}: {m['title']}",
78+
})
79+
return candidates
80+
81+
82+
# ── Metrics ──────────────────────────────────────────────────────────────────
83+
84+
def precision_at_k(retrieved: list[tuple], relevant: set[tuple], k: int) -> float:
85+
"""Fraction of top-k that are relevant."""
86+
top = retrieved[:k]
87+
if not top:
88+
return 0.0
89+
return sum(1 for r in top if r in relevant) / k
90+
91+
92+
def recall_at_k(retrieved: list[tuple], relevant: set[tuple], k: int) -> float:
93+
"""Fraction of relevant items found in top-k."""
94+
if not relevant:
95+
return 0.0
96+
top = retrieved[:k]
97+
return sum(1 for r in top if r in relevant) / len(relevant)
98+
99+
100+
def average_precision(retrieved: list[tuple], relevant: set[tuple]) -> float:
101+
"""Average precision for a single query (for MAP computation)."""
102+
if not relevant:
103+
return 0.0
104+
hits = 0
105+
sum_prec = 0.0
106+
for i, r in enumerate(retrieved):
107+
if r in relevant:
108+
hits += 1
109+
sum_prec += hits / (i + 1)
110+
return sum_prec / len(relevant) if relevant else 0.0
111+
112+
113+
def r_precision(retrieved: list[tuple], relevant: set[tuple]) -> float:
114+
"""Precision at R, where R = number of relevant items."""
115+
r = len(relevant)
116+
if r == 0:
117+
return 0.0
118+
top_r = retrieved[:r]
119+
return sum(1 for item in top_r if item in relevant) / r
120+
121+
122+
def compute_metrics(all_retrieved: dict[str, list[tuple]],
123+
ground_truth: dict[str, set[tuple]],
124+
k_values: list[int]) -> dict:
125+
"""Compute all metrics across queries."""
126+
# Only evaluate entries that appear in both retrieved and ground truth
127+
entry_ids = sorted(set(all_retrieved.keys()) & set(ground_truth.keys()))
128+
n = len(entry_ids)
129+
130+
if n == 0:
131+
return {"n_queries": 0, "error": "No overlapping entries"}
132+
133+
results = {"n_queries": n}
134+
135+
for k in k_values:
136+
p_scores = [precision_at_k(all_retrieved[eid], ground_truth[eid], k) for eid in entry_ids]
137+
r_scores = [recall_at_k(all_retrieved[eid], ground_truth[eid], k) for eid in entry_ids]
138+
results[f"P@{k}"] = float(np.mean(p_scores))
139+
results[f"R@{k}"] = float(np.mean(r_scores))
140+
141+
ap_scores = [average_precision(all_retrieved[eid], ground_truth[eid]) for eid in entry_ids]
142+
rp_scores = [r_precision(all_retrieved[eid], ground_truth[eid]) for eid in entry_ids]
143+
results["MAP"] = float(np.mean(ap_scores))
144+
results["R-Precision"] = float(np.mean(rp_scores))
145+
146+
return results
147+
148+
149+
def bootstrap_ci(all_retrieved: dict[str, list[tuple]],
150+
ground_truth: dict[str, set[tuple]],
151+
k_values: list[int],
152+
n_bootstrap: int = BOOTSTRAP_N,
153+
seed: int = BOOTSTRAP_SEED) -> dict:
154+
"""Compute 95% bootstrap confidence intervals for all metrics."""
155+
rng = np.random.RandomState(seed)
156+
entry_ids = sorted(set(all_retrieved.keys()) & set(ground_truth.keys()))
157+
n = len(entry_ids)
158+
159+
if n == 0:
160+
return {}
161+
162+
# Collect per-query metric arrays
163+
metric_arrays = {}
164+
for k in k_values:
165+
metric_arrays[f"P@{k}"] = np.array([
166+
precision_at_k(all_retrieved[eid], ground_truth[eid], k) for eid in entry_ids
167+
])
168+
metric_arrays[f"R@{k}"] = np.array([
169+
recall_at_k(all_retrieved[eid], ground_truth[eid], k) for eid in entry_ids
170+
])
171+
metric_arrays["MAP"] = np.array([
172+
average_precision(all_retrieved[eid], ground_truth[eid]) for eid in entry_ids
173+
])
174+
metric_arrays["R-Precision"] = np.array([
175+
r_precision(all_retrieved[eid], ground_truth[eid]) for eid in entry_ids
176+
])
177+
178+
# Bootstrap
179+
cis = {}
180+
for metric_name, arr in metric_arrays.items():
181+
boot_means = np.array([
182+
np.mean(rng.choice(arr, size=n, replace=True))
183+
for _ in range(n_bootstrap)
184+
])
185+
lo, hi = np.percentile(boot_means, [2.5, 97.5])
186+
cis[metric_name] = {
187+
"mean": float(np.mean(arr)),
188+
"ci_lo": float(lo),
189+
"ci_hi": float(hi),
190+
}
191+
192+
return cis
193+
194+
195+
# ── Main evaluation ─────────────────────────────────────────────────────────
196+
197+
def run_eval(use_reranker: bool = False, verbose: bool = False,
198+
output_path: str | None = None) -> dict:
199+
"""Run full evaluation pipeline."""
200+
print("=" * 60)
201+
print(f" GenAI Security Crosswalk — Classifier Evaluation")
202+
print(f" Mode: {'Bi-encoder + Cross-encoder reranker' if use_reranker else 'Bi-encoder baseline'}")
203+
print("=" * 60)
204+
205+
# Load data
206+
print("\nLoading test split...")
207+
ground_truth = load_test_split()
208+
print(f" {len(ground_truth)} entries with ground truth")
209+
total_gt = sum(len(v) for v in ground_truth.values())
210+
print(f" {total_gt} total ground-truth mappings")
211+
212+
print("Loading index...")
213+
index, meta = load_index()
214+
print(f" {index.ntotal} controls indexed")
215+
216+
print(f"Loading bi-encoder: {BIENCODER_MODEL}")
217+
bi_model = SentenceTransformer(BIENCODER_MODEL)
218+
219+
reranker = None
220+
if use_reranker:
221+
from .reranker import rerank
222+
from .config import CROSSENCODER_MODEL
223+
print(f"Loading cross-encoder: {CROSSENCODER_MODEL}")
224+
# Warm up the reranker model
225+
from .reranker import _get_model
226+
_get_model()
227+
reranker = rerank
228+
229+
entries = load_entries()
230+
# Only eval entries in ground truth
231+
eval_entries = [e for e in entries if e["id"] in ground_truth]
232+
print(f"\nEvaluating {len(eval_entries)} entries...")
233+
234+
all_retrieved = {}
235+
t0 = time.time()
236+
237+
for i, entry in enumerate(eval_entries):
238+
# Bi-encoder retrieval (top-50 for reranking headroom)
239+
retrieve_k = 50 if use_reranker else max(K_VALUES)
240+
candidates = retrieve_candidates(entry, index, meta, bi_model, retrieve_k)
241+
242+
if use_reranker and reranker:
243+
query = build_query(entry["id"], entry["name"], entry["severity"])
244+
candidates = reranker(query, candidates, top_k=max(K_VALUES))
245+
246+
# Extract (framework, control_id) tuples in rank order
247+
retrieved_tuples = [
248+
(c["framework"], c["control_id"]) for c in candidates
249+
]
250+
all_retrieved[entry["id"]] = retrieved_tuples
251+
252+
if verbose and (i + 1) % 10 == 0:
253+
print(f" [{i+1}/{len(eval_entries)}]")
254+
255+
elapsed = time.time() - t0
256+
print(f" Done in {elapsed:.1f}s")
257+
258+
# Compute metrics
259+
print("\nComputing metrics...")
260+
metrics = compute_metrics(all_retrieved, ground_truth, K_VALUES)
261+
262+
print("Computing 95% bootstrap CIs (10,000 resamples)...")
263+
cis = bootstrap_ci(all_retrieved, ground_truth, K_VALUES)
264+
265+
# Report
266+
mode = "reranker" if use_reranker else "biencoder"
267+
report = {
268+
"mode": mode,
269+
"model": BIENCODER_MODEL,
270+
"n_queries": metrics["n_queries"],
271+
"n_controls": index.ntotal,
272+
"n_ground_truth_mappings": total_gt,
273+
"elapsed_seconds": round(elapsed, 1),
274+
"metrics": metrics,
275+
"confidence_intervals": cis,
276+
}
277+
278+
print("\n" + "=" * 60)
279+
print(f" RESULTS ({mode})")
280+
print("=" * 60)
281+
print(f"\n Queries: {metrics['n_queries']}")
282+
print(f" Ground truth mappings: {total_gt}")
283+
print()
284+
285+
# Table header
286+
header = f" {'Metric':<15} {'Value':>8} {'95% CI':>20}"
287+
print(header)
288+
print(f" {'-'*15} {'-'*8} {'-'*20}")
289+
290+
for k in K_VALUES:
291+
for prefix in ["P", "R"]:
292+
key = f"{prefix}@{k}"
293+
val = metrics.get(key, 0)
294+
ci = cis.get(key, {})
295+
ci_str = f"[{ci.get('ci_lo', 0):.4f}, {ci.get('ci_hi', 0):.4f}]" if ci else ""
296+
print(f" {key:<15} {val:>8.4f} {ci_str:>20}")
297+
298+
for key in ["MAP", "R-Precision"]:
299+
val = metrics.get(key, 0)
300+
ci = cis.get(key, {})
301+
ci_str = f"[{ci.get('ci_lo', 0):.4f}, {ci.get('ci_hi', 0):.4f}]" if ci else ""
302+
print(f" {key:<15} {val:>8.4f} {ci_str:>20}")
303+
304+
print()
305+
306+
# Check against pre-registered thresholds
307+
print(" Pre-registered thresholds:")
308+
r10 = metrics.get("R@10", 0)
309+
map_val = metrics.get("MAP", 0)
310+
r10_pass = r10 >= 0.50
311+
map_pass = map_val >= 0.25
312+
print(f" R@10 >= 0.50: {r10:.4f} {'PASS' if r10_pass else 'FAIL'}")
313+
print(f" MAP >= 0.25: {map_val:.4f} {'PASS' if map_pass else 'FAIL'}")
314+
print()
315+
316+
# Save report
317+
if output_path:
318+
from pathlib import Path
319+
Path(output_path).write_text(json.dumps(report, indent=2), encoding="utf-8")
320+
print(f" Report saved to {output_path}")
321+
else:
322+
default_path = SPLITS_DIR / f"eval_report_{mode}.json"
323+
default_path.write_text(json.dumps(report, indent=2), encoding="utf-8")
324+
print(f" Report saved to {default_path}")
325+
326+
return report
327+
328+
329+
def main():
330+
parser = argparse.ArgumentParser(description="Eval harness for classifier pipeline")
331+
parser.add_argument("--rerank", action="store_true", help="Use cross-encoder reranker")
332+
parser.add_argument("--verbose", action="store_true", help="Show progress")
333+
parser.add_argument("--output", default=None, help="Output report path")
334+
args = parser.parse_args()
335+
336+
run_eval(use_reranker=args.rerank, verbose=args.verbose, output_path=args.output)
337+
338+
339+
if __name__ == "__main__":
340+
main()

0 commit comments

Comments
 (0)