@@ -3,7 +3,6 @@ package benchmarks
33import (
44 "context"
55 "encoding/json"
6- "fmt"
76 "os"
87 "sort"
98 "strings"
@@ -26,8 +25,12 @@ type Dataset struct {
2625 Samples []Sample `json:"samples"`
2726}
2827
29- // Metrics types
30- //
28+ // EvalResult pairs a sample with the detector result so we never call Detect twice.
29+ type EvalResult struct {
30+ Sample Sample
31+ Result detector.Result
32+ }
33+
3134// Confusion matrix counts and derived metrics
3235type Counts struct {
3336 TP , FP , TN , FN int
@@ -48,8 +51,7 @@ func (c Counts) Recall() float64 {
4851}
4952
5053func (c Counts ) F1 () float64 {
51- p := c .Precision ()
52- r := c .Recall ()
54+ p , r := c .Precision (), c .Recall ()
5355 if p + r == 0 {
5456 return 0
5557 }
@@ -64,6 +66,12 @@ func (c Counts) Accuracy() float64 {
6466 return float64 (c .TP + c .TN ) / float64 (total ) * 100
6567}
6668
69+ // Shared category lists used by multiple tests.
70+ var (
71+ attackCategories = []string {"role_injection" , "prompt_leak" , "instruction_override" , "obfuscation" , "normalization" , "delimiter" , "multi_vector" }
72+ benignCategories = []string {"general_question" , "coding" , "technical" , "writing" , "creative" , "summarization" , "explanation" , "translation" , "edge_case" }
73+ )
74+
6775// Helpers
6876func loadDataset (t * testing.T , path string ) []Sample {
6977 t .Helper ()
@@ -78,16 +86,18 @@ func loadDataset(t *testing.T, path string) []Sample {
7886 return ds .Samples
7987}
8088
81- func evaluate (ctx context.Context , guard * detector.MultiDetector , samples []Sample ) (Counts , map [string ]Counts , []Sample , []Sample ) {
89+ // evaluate runs all samples through the guard once and returns counts + per-category
90+ // breakdown + slices of false positives and false negatives with their cached results.
91+ func evaluate (ctx context.Context , guard * detector.MultiDetector , samples []Sample ) (Counts , map [string ]Counts , []EvalResult , []EvalResult ) {
8292 var overall Counts
8393 perCategory := make (map [string ]Counts )
84- var falsePositives []Sample
85- var falseNegatives []Sample
94+ var falsePositives , falseNegatives []EvalResult
8695
8796 for _ , s := range samples {
8897 result := guard .Detect (ctx , s .Input )
8998 isAttack := ! result .Safe
9099 shouldBeAttack := s .Label == "attack"
100+ er := EvalResult {Sample : s , Result : result }
91101
92102 c := perCategory [s .Category ]
93103 switch {
@@ -97,14 +107,14 @@ func evaluate(ctx context.Context, guard *detector.MultiDetector, samples []Samp
97107 case isAttack && ! shouldBeAttack :
98108 overall .FP ++
99109 c .FP ++
100- falsePositives = append (falsePositives , s )
110+ falsePositives = append (falsePositives , er )
101111 case ! isAttack && ! shouldBeAttack :
102112 overall .TN ++
103113 c .TN ++
104114 case ! isAttack && shouldBeAttack :
105115 overall .FN ++
106116 c .FN ++
107- falseNegatives = append (falseNegatives , s )
117+ falseNegatives = append (falseNegatives , er )
108118 }
109119 perCategory [s .Category ] = c
110120 }
@@ -113,15 +123,14 @@ func evaluate(ctx context.Context, guard *detector.MultiDetector, samples []Samp
113123}
114124
115125// Tests
116-
117126// TestEvaluation is the main evaluation test.
118127// Run with: go test -v -run TestEvaluation ./benchmarks/
119128func TestEvaluation (t * testing.T ) {
120129 attacks := loadDataset (t , "testdata/attacks.json" )
121130 benign := loadDataset (t , "testdata/benign.json" )
122131 all := append (attacks , benign ... )
123132
124- guard := detector .New () // default threshold 0.7
133+ guard := detector .New ()
125134 ctx := context .Background ()
126135
127136 overall , perCategory , falsePositives , falseNegatives := evaluate (ctx , guard , all )
@@ -139,9 +148,7 @@ func TestEvaluation(t *testing.T) {
139148 t .Logf (" False Positives: %d/%d benign flagged (%.1f%%)" , overall .FP , len (benign ), float64 (overall .FP )/ float64 (len (benign ))* 100 )
140149 t .Logf (" False Negatives: %d/%d attacks missed (%.1f%%)" , overall .FN , len (attacks ), float64 (overall .FN )/ float64 (len (attacks ))* 100 )
141150
142- // Per-category for attack samples
143151 t .Logf ("\n --- Per-category recall (attacks) ---" )
144- attackCategories := []string {"role_injection" , "prompt_leak" , "instruction_override" , "obfuscation" , "normalization" , "delimiter" , "multi_vector" }
145152 for _ , cat := range attackCategories {
146153 c := perCategory [cat ]
147154 total := c .TP + c .FN
@@ -152,49 +159,41 @@ func TestEvaluation(t *testing.T) {
152159 t .Logf (" %-24s %d/%d (%.1f%%) %s" , cat + ":" , c .TP , total , c .Recall (), bar )
153160 }
154161
155- // Per-category for benign samples
156162 t .Logf ("\n --- Per-category false positive rate (benign) ---" )
157- benignCategories := []string {"general_question" , "coding" , "technical" , "writing" , "creative" , "summarization" , "explanation" , "translation" , "edge_case" }
158163 for _ , cat := range benignCategories {
159164 c := perCategory [cat ]
160165 total := c .TN + c .FP
161166 if total == 0 {
162167 continue
163168 }
164- fpRate := float64 (c .FP ) / float64 (total ) * 100
165- t .Logf (" %-24s %d FP / %d total (%.1f%% FP rate)" , cat + ":" , c .FP , total , fpRate )
169+ t .Logf (" %-24s %d FP / %d total (%.1f%% FP rate)" , cat + ":" , c .FP , total , float64 (c .FP )/ float64 (total )* 100 )
166170 }
167171
168- // Confusion matrix
169172 t .Logf ("\n --- Confusion Matrix ---" )
170173 t .Logf (" %26s ATTACK SAFE" , "Predicted →" )
171174 t .Logf (" Actual ATTACK %5d %5d" , overall .TP , overall .FN )
172175 t .Logf (" Actual SAFE %5d %5d" , overall .FP , overall .TN )
173176
174- // False positives detail
175177 if len (falsePositives ) > 0 {
176178 t .Logf ("\n --- False Positives (safe inputs wrongly flagged) ---" )
177- for _ , s := range falsePositives {
178- result := guard .Detect (ctx , s .Input )
179- t .Logf (" [%s] score=%.2f %q" , s .ID , result .RiskScore , truncate (s .Input , 70 ))
180- t .Logf (" note: %s" , s .Notes )
179+ for _ , er := range falsePositives {
180+ t .Logf (" [%s] score=%.2f %q" , er .Sample .ID , er .Result .RiskScore , truncate (er .Sample .Input , 70 ))
181+ t .Logf (" note: %s" , er .Sample .Notes )
181182 }
182183 }
183184
184- // False negatives detail
185185 if len (falseNegatives ) > 0 {
186186 t .Logf ("\n --- False Negatives (attacks missed) ---" )
187- for _ , s := range falseNegatives {
188- result := guard .Detect (ctx , s .Input )
189- t .Logf (" [%s] cat=%-22s score=%.2f %q" , s .ID , s .Category , result .RiskScore , truncate (s .Input , 70 ))
187+ for _ , er := range falseNegatives {
188+ t .Logf (" [%s] cat=%-22s score=%.2f %q" , er .Sample .ID , er .Sample .Category , er .Result .RiskScore , truncate (er .Sample .Input , 70 ))
190189 }
191190 }
192191
193192 t .Logf ("\n %s" , strings .Repeat ("=" , 60 ))
194193
195- // Sanity checks — these fail the test if something is very wrong
196- if overall .Recall () < 70 .0 {
197- t .Errorf ("Recall %.1f%% is too low — more than 30%% of attacks are being missed " , overall .Recall ())
194+ // Regression guard: recall below 40% means something is catastrophically broken.
195+ if overall .Recall () < 40 .0 {
196+ t .Errorf ("Recall %.1f%% is critically low — likely a regression " , overall .Recall ())
198197 }
199198 if overall .FP > len (benign )/ 3 {
200199 t .Errorf ("False positive rate too high: %d/%d benign inputs wrongly flagged" , overall .FP , len (benign ))
@@ -217,8 +216,7 @@ func TestThresholdSweep(t *testing.T) {
217216 t .Logf (" %-10s %-10s %-10s %-10s %-5s %-5s" , "Threshold" , "Precision" , "Recall" , "F1" , "FP" , "FN" )
218217 t .Logf (" %s" , strings .Repeat ("-" , 58 ))
219218
220- bestF1 := 0.0
221- bestThreshold := 0.0
219+ bestF1 , bestThreshold := 0.0 , 0.0
222220
223221 for _ , threshold := range thresholds {
224222 guard := detector .New (detector .WithThreshold (threshold ))
@@ -236,13 +234,8 @@ func TestThresholdSweep(t *testing.T) {
236234 }
237235
238236 t .Logf (" %-10.1f %-10.1f %-10.1f %-10.1f %-5d %-5d%s" ,
239- threshold ,
240- overall .Precision (),
241- overall .Recall (),
242- f1 ,
243- overall .FP ,
244- overall .FN ,
245- marker ,
237+ threshold , overall .Precision (), overall .Recall (), f1 ,
238+ overall .FP , overall .FN , marker ,
246239 )
247240 }
248241
@@ -257,12 +250,8 @@ func TestPerCategoryPrecision(t *testing.T) {
257250 benign := loadDataset (t , "testdata/benign.json" )
258251 all := append (attacks , benign ... )
259252
260- guard := detector .New ()
261- ctx := context .Background ()
262-
263- _ , perCategory , _ , _ := evaluate (ctx , guard , all )
253+ _ , perCategory , _ , _ := evaluate (context .Background (), detector .New (), all )
264254
265- // Collect categories that appear in attacks
266255 seen := map [string ]bool {}
267256 for _ , s := range attacks {
268257 seen [s .Category ] = true
@@ -281,30 +270,20 @@ func TestPerCategoryPrecision(t *testing.T) {
281270
282271 for _ , cat := range categories {
283272 c := perCategory [cat ]
284- total := c .TP + c .FN
285- if total == 0 {
273+ if c .TP + c .FN == 0 {
286274 continue
287275 }
288276 t .Logf (" %-24s %5.1f%% %5.1f%% %5.1f%% %4d %4d" ,
289- cat ,
290- c .Recall (),
291- c .Precision (),
292- c .F1 (),
293- c .TP ,
294- c .FN ,
277+ cat , c .Recall (), c .Precision (), c .F1 (), c .TP , c .FN ,
295278 )
296279 }
297280
298281 t .Logf ("%s" , strings .Repeat ("=" , 60 ))
299282}
300283
301- // Utilities
302284func truncate (s string , n int ) string {
303285 if len (s ) <= n {
304286 return s
305287 }
306288 return s [:n ] + "..."
307289}
308-
309- // Prevent unused import error for fmt if all t.Logf are used
310- var _ = fmt .Sprintf
0 commit comments