Skip to content

Commit b1b7388

Browse files
committed
feat: add evaluation suite
1 parent 1e9cdf8 commit b1b7388

File tree

6 files changed

+230
-80
lines changed

6 files changed

+230
-80
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,10 @@ Input → MultiDetector
6363

6464
**Risk calculation:**
6565

66-
- Start with highest detector score
67-
- Add +0.1 for each additional pattern detected (capped at 1.0)
68-
- Example: 0.9 (role injection) + 0.1 (obfuscation) = 1.0
66+
- Each detector that fires contributes `score × weight` to the total
67+
- Detector weights reflect reliability: semantic detectors (role injection, prompt leak, instruction override) have weight 1.0; statistical detectors (entropy, perplexity, token anomaly) are discounted to 0.45–0.55 so they cannot trigger alone at borderline scores
68+
- Multiple detectors firing naturally combine: `final = min(Σ score_i × weight_i, 1.0)`
69+
- Example: role injection (0.9 × 1.0) + obfuscation (0.8 × 0.9) = 0.9 + 0.72 = 1.0 (capped)
6970

7071
**Performance:**
7172

benchmarks/eval_llm_test.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package benchmarks
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"net/http"
7+
"strings"
8+
"testing"
9+
"time"
10+
11+
"github.com/mdombrov-33/go-promptguard/detector"
12+
)
13+
14+
// TestEvaluationWithLLM runs the same dataset as TestEvaluation but with an
15+
// Ollama LLM judge in LLMFallback mode (LLM only runs when pattern detectors
16+
// score below threshold). This shows how much recall improves with LLM coverage.
17+
//
18+
// Run with: go test -v -run TestEvaluationWithLLM -timeout 10m ./benchmarks/
19+
//
20+
// Skipped automatically if Ollama is not reachable at localhost:11434.
21+
func TestEvaluationWithLLM(t *testing.T) {
22+
if !ollamaReachable() {
23+
t.Skip("Ollama not reachable at localhost:11434 - skipping LLM eval")
24+
}
25+
26+
attacks := loadDataset(t, "testdata/attacks.json")
27+
benign := loadDataset(t, "testdata/benign.json")
28+
all := append(attacks, benign...)
29+
ctx := context.Background()
30+
31+
judge := detector.NewOllamaJudge("llama3.1:8b")
32+
33+
baseOverall, _, _, _ := evaluate(ctx, detector.New(), all)
34+
llmOverall, llmPerCategory, llmFP, llmFN := evaluate(ctx, detector.New(detector.WithLLM(judge, detector.LLMFallback)), all)
35+
36+
t.Logf("\n%s", strings.Repeat("=", 60))
37+
t.Logf("LLM EVAL (Ollama llama3.1:8b, mode=LLMFallback)")
38+
t.Logf("%s", strings.Repeat("=", 60))
39+
40+
t.Logf("\n--- Comparison: pattern-only vs pattern+LLM ---")
41+
t.Logf(" %-20s %-14s %-14s %-10s", "Metric", "Pattern only", "Pattern+LLM", "Delta")
42+
t.Logf(" %s", strings.Repeat("-", 62))
43+
t.Logf(" %-20s %-14s %-14s %+.1f%%", "Recall",
44+
pct(baseOverall.Recall()), pct(llmOverall.Recall()), llmOverall.Recall()-baseOverall.Recall())
45+
t.Logf(" %-20s %-14s %-14s %+.1f%%", "Precision",
46+
pct(baseOverall.Precision()), pct(llmOverall.Precision()), llmOverall.Precision()-baseOverall.Precision())
47+
t.Logf(" %-20s %-14s %-14s %+.1f%%", "F1",
48+
pct(baseOverall.F1()), pct(llmOverall.F1()), llmOverall.F1()-baseOverall.F1())
49+
t.Logf(" %-20s %-14s %-14s %+d", "False Positives",
50+
fmt.Sprintf("%d", baseOverall.FP), fmt.Sprintf("%d", llmOverall.FP), llmOverall.FP-baseOverall.FP)
51+
t.Logf(" %-20s %-14s %-14s %+d", "False Negatives",
52+
fmt.Sprintf("%d", baseOverall.FN), fmt.Sprintf("%d", llmOverall.FN), llmOverall.FN-baseOverall.FN)
53+
54+
t.Logf("\n--- Per-category recall with LLM ---")
55+
for _, cat := range attackCategories {
56+
c := llmPerCategory[cat]
57+
total := c.TP + c.FN
58+
if total == 0 {
59+
continue
60+
}
61+
bar := strings.Repeat("█", c.TP) + strings.Repeat("░", c.FN)
62+
t.Logf(" %-24s %d/%d (%.1f%%) %s", cat+":", c.TP, total, c.Recall(), bar)
63+
}
64+
65+
if len(llmFP) > 0 {
66+
t.Logf("\n--- False Positives (safe inputs wrongly flagged) ---")
67+
for _, er := range llmFP {
68+
t.Logf(" [%s] score=%.2f %q", er.Sample.ID, er.Result.RiskScore, truncate(er.Sample.Input, 70))
69+
}
70+
}
71+
72+
if len(llmFN) > 0 {
73+
t.Logf("\n--- Attacks still missed after LLM ---")
74+
for _, er := range llmFN {
75+
t.Logf(" [%s] cat=%-22s score=%.2f %q", er.Sample.ID, er.Sample.Category, er.Result.RiskScore, truncate(er.Sample.Input, 70))
76+
}
77+
}
78+
79+
t.Logf("\n%s", strings.Repeat("=", 60))
80+
81+
if llmOverall.Recall() <= baseOverall.Recall() {
82+
t.Errorf("LLM fallback did not improve recall: pattern=%.1f%% llm=%.1f%%", baseOverall.Recall(), llmOverall.Recall())
83+
}
84+
}
85+
86+
func ollamaReachable() bool {
87+
client := &http.Client{Timeout: 3 * time.Second}
88+
resp, err := client.Get("http://localhost:11434/api/tags")
89+
if err != nil {
90+
return false
91+
}
92+
resp.Body.Close()
93+
return resp.StatusCode == http.StatusOK
94+
}
95+
96+
func pct(f float64) string { return fmt.Sprintf("%.1f%%", f) }

benchmarks/eval_test.go

Lines changed: 36 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package benchmarks
33
import (
44
"context"
55
"encoding/json"
6-
"fmt"
76
"os"
87
"sort"
98
"strings"
@@ -26,8 +25,12 @@ type Dataset struct {
2625
Samples []Sample `json:"samples"`
2726
}
2827

29-
// Metrics types
30-
//
28+
// EvalResult pairs a sample with the detector result so we never call Detect twice.
29+
type EvalResult struct {
30+
Sample Sample
31+
Result detector.Result
32+
}
33+
3134
// Confusion matrix counts and derived metrics
3235
type Counts struct {
3336
TP, FP, TN, FN int
@@ -48,8 +51,7 @@ func (c Counts) Recall() float64 {
4851
}
4952

5053
func (c Counts) F1() float64 {
51-
p := c.Precision()
52-
r := c.Recall()
54+
p, r := c.Precision(), c.Recall()
5355
if p+r == 0 {
5456
return 0
5557
}
@@ -64,6 +66,12 @@ func (c Counts) Accuracy() float64 {
6466
return float64(c.TP+c.TN) / float64(total) * 100
6567
}
6668

69+
// Shared category lists used by multiple tests.
70+
var (
71+
attackCategories = []string{"role_injection", "prompt_leak", "instruction_override", "obfuscation", "normalization", "delimiter", "multi_vector"}
72+
benignCategories = []string{"general_question", "coding", "technical", "writing", "creative", "summarization", "explanation", "translation", "edge_case"}
73+
)
74+
6775
// Helpers
6876
func loadDataset(t *testing.T, path string) []Sample {
6977
t.Helper()
@@ -78,16 +86,18 @@ func loadDataset(t *testing.T, path string) []Sample {
7886
return ds.Samples
7987
}
8088

81-
func evaluate(ctx context.Context, guard *detector.MultiDetector, samples []Sample) (Counts, map[string]Counts, []Sample, []Sample) {
89+
// evaluate runs all samples through the guard once and returns counts + per-category
90+
// breakdown + slices of false positives and false negatives with their cached results.
91+
func evaluate(ctx context.Context, guard *detector.MultiDetector, samples []Sample) (Counts, map[string]Counts, []EvalResult, []EvalResult) {
8292
var overall Counts
8393
perCategory := make(map[string]Counts)
84-
var falsePositives []Sample
85-
var falseNegatives []Sample
94+
var falsePositives, falseNegatives []EvalResult
8695

8796
for _, s := range samples {
8897
result := guard.Detect(ctx, s.Input)
8998
isAttack := !result.Safe
9099
shouldBeAttack := s.Label == "attack"
100+
er := EvalResult{Sample: s, Result: result}
91101

92102
c := perCategory[s.Category]
93103
switch {
@@ -97,14 +107,14 @@ func evaluate(ctx context.Context, guard *detector.MultiDetector, samples []Samp
97107
case isAttack && !shouldBeAttack:
98108
overall.FP++
99109
c.FP++
100-
falsePositives = append(falsePositives, s)
110+
falsePositives = append(falsePositives, er)
101111
case !isAttack && !shouldBeAttack:
102112
overall.TN++
103113
c.TN++
104114
case !isAttack && shouldBeAttack:
105115
overall.FN++
106116
c.FN++
107-
falseNegatives = append(falseNegatives, s)
117+
falseNegatives = append(falseNegatives, er)
108118
}
109119
perCategory[s.Category] = c
110120
}
@@ -113,15 +123,14 @@ func evaluate(ctx context.Context, guard *detector.MultiDetector, samples []Samp
113123
}
114124

115125
// Tests
116-
117126
// TestEvaluation is the main evaluation test.
118127
// Run with: go test -v -run TestEvaluation ./benchmarks/
119128
func TestEvaluation(t *testing.T) {
120129
attacks := loadDataset(t, "testdata/attacks.json")
121130
benign := loadDataset(t, "testdata/benign.json")
122131
all := append(attacks, benign...)
123132

124-
guard := detector.New() // default threshold 0.7
133+
guard := detector.New()
125134
ctx := context.Background()
126135

127136
overall, perCategory, falsePositives, falseNegatives := evaluate(ctx, guard, all)
@@ -139,9 +148,7 @@ func TestEvaluation(t *testing.T) {
139148
t.Logf(" False Positives: %d/%d benign flagged (%.1f%%)", overall.FP, len(benign), float64(overall.FP)/float64(len(benign))*100)
140149
t.Logf(" False Negatives: %d/%d attacks missed (%.1f%%)", overall.FN, len(attacks), float64(overall.FN)/float64(len(attacks))*100)
141150

142-
// Per-category for attack samples
143151
t.Logf("\n--- Per-category recall (attacks) ---")
144-
attackCategories := []string{"role_injection", "prompt_leak", "instruction_override", "obfuscation", "normalization", "delimiter", "multi_vector"}
145152
for _, cat := range attackCategories {
146153
c := perCategory[cat]
147154
total := c.TP + c.FN
@@ -152,49 +159,41 @@ func TestEvaluation(t *testing.T) {
152159
t.Logf(" %-24s %d/%d (%.1f%%) %s", cat+":", c.TP, total, c.Recall(), bar)
153160
}
154161

155-
// Per-category for benign samples
156162
t.Logf("\n--- Per-category false positive rate (benign) ---")
157-
benignCategories := []string{"general_question", "coding", "technical", "writing", "creative", "summarization", "explanation", "translation", "edge_case"}
158163
for _, cat := range benignCategories {
159164
c := perCategory[cat]
160165
total := c.TN + c.FP
161166
if total == 0 {
162167
continue
163168
}
164-
fpRate := float64(c.FP) / float64(total) * 100
165-
t.Logf(" %-24s %d FP / %d total (%.1f%% FP rate)", cat+":", c.FP, total, fpRate)
169+
t.Logf(" %-24s %d FP / %d total (%.1f%% FP rate)", cat+":", c.FP, total, float64(c.FP)/float64(total)*100)
166170
}
167171

168-
// Confusion matrix
169172
t.Logf("\n--- Confusion Matrix ---")
170173
t.Logf(" %26s ATTACK SAFE", "Predicted →")
171174
t.Logf(" Actual ATTACK %5d %5d", overall.TP, overall.FN)
172175
t.Logf(" Actual SAFE %5d %5d", overall.FP, overall.TN)
173176

174-
// False positives detail
175177
if len(falsePositives) > 0 {
176178
t.Logf("\n--- False Positives (safe inputs wrongly flagged) ---")
177-
for _, s := range falsePositives {
178-
result := guard.Detect(ctx, s.Input)
179-
t.Logf(" [%s] score=%.2f %q", s.ID, result.RiskScore, truncate(s.Input, 70))
180-
t.Logf(" note: %s", s.Notes)
179+
for _, er := range falsePositives {
180+
t.Logf(" [%s] score=%.2f %q", er.Sample.ID, er.Result.RiskScore, truncate(er.Sample.Input, 70))
181+
t.Logf(" note: %s", er.Sample.Notes)
181182
}
182183
}
183184

184-
// False negatives detail
185185
if len(falseNegatives) > 0 {
186186
t.Logf("\n--- False Negatives (attacks missed) ---")
187-
for _, s := range falseNegatives {
188-
result := guard.Detect(ctx, s.Input)
189-
t.Logf(" [%s] cat=%-22s score=%.2f %q", s.ID, s.Category, result.RiskScore, truncate(s.Input, 70))
187+
for _, er := range falseNegatives {
188+
t.Logf(" [%s] cat=%-22s score=%.2f %q", er.Sample.ID, er.Sample.Category, er.Result.RiskScore, truncate(er.Sample.Input, 70))
190189
}
191190
}
192191

193192
t.Logf("\n%s", strings.Repeat("=", 60))
194193

195-
// Sanity checks — these fail the test if something is very wrong
196-
if overall.Recall() < 70.0 {
197-
t.Errorf("Recall %.1f%% is too low — more than 30%% of attacks are being missed", overall.Recall())
194+
// Regression guard: recall below 40% means something is catastrophically broken.
195+
if overall.Recall() < 40.0 {
196+
t.Errorf("Recall %.1f%% is critically low — likely a regression", overall.Recall())
198197
}
199198
if overall.FP > len(benign)/3 {
200199
t.Errorf("False positive rate too high: %d/%d benign inputs wrongly flagged", overall.FP, len(benign))
@@ -217,8 +216,7 @@ func TestThresholdSweep(t *testing.T) {
217216
t.Logf(" %-10s %-10s %-10s %-10s %-5s %-5s", "Threshold", "Precision", "Recall", "F1", "FP", "FN")
218217
t.Logf(" %s", strings.Repeat("-", 58))
219218

220-
bestF1 := 0.0
221-
bestThreshold := 0.0
219+
bestF1, bestThreshold := 0.0, 0.0
222220

223221
for _, threshold := range thresholds {
224222
guard := detector.New(detector.WithThreshold(threshold))
@@ -236,13 +234,8 @@ func TestThresholdSweep(t *testing.T) {
236234
}
237235

238236
t.Logf(" %-10.1f %-10.1f %-10.1f %-10.1f %-5d %-5d%s",
239-
threshold,
240-
overall.Precision(),
241-
overall.Recall(),
242-
f1,
243-
overall.FP,
244-
overall.FN,
245-
marker,
237+
threshold, overall.Precision(), overall.Recall(), f1,
238+
overall.FP, overall.FN, marker,
246239
)
247240
}
248241

@@ -257,12 +250,8 @@ func TestPerCategoryPrecision(t *testing.T) {
257250
benign := loadDataset(t, "testdata/benign.json")
258251
all := append(attacks, benign...)
259252

260-
guard := detector.New()
261-
ctx := context.Background()
262-
263-
_, perCategory, _, _ := evaluate(ctx, guard, all)
253+
_, perCategory, _, _ := evaluate(context.Background(), detector.New(), all)
264254

265-
// Collect categories that appear in attacks
266255
seen := map[string]bool{}
267256
for _, s := range attacks {
268257
seen[s.Category] = true
@@ -281,30 +270,20 @@ func TestPerCategoryPrecision(t *testing.T) {
281270

282271
for _, cat := range categories {
283272
c := perCategory[cat]
284-
total := c.TP + c.FN
285-
if total == 0 {
273+
if c.TP+c.FN == 0 {
286274
continue
287275
}
288276
t.Logf(" %-24s %5.1f%% %5.1f%% %5.1f%% %4d %4d",
289-
cat,
290-
c.Recall(),
291-
c.Precision(),
292-
c.F1(),
293-
c.TP,
294-
c.FN,
277+
cat, c.Recall(), c.Precision(), c.F1(), c.TP, c.FN,
295278
)
296279
}
297280

298281
t.Logf("%s", strings.Repeat("=", 60))
299282
}
300283

301-
// Utilities
302284
func truncate(s string, n int) string {
303285
if len(s) <= n {
304286
return s
305287
}
306288
return s[:n] + "..."
307289
}
308-
309-
// Prevent unused import error for fmt if all t.Logf are used
310-
var _ = fmt.Sprintf

detector/multi_detector.go

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,7 @@ func New(opts ...Option) *MultiDetector {
6464
}
6565

6666
// Detect runs all enabled detectors and combines their results.
67-
// Risk scoring algorithm:
68-
// - Takes the highest individual risk score from any detector
69-
// - Adds a 0.1 bonus for each additional pattern detected (capped at 1.0)
70-
// - Confidence represents certainty of classification:
71-
// - When detectors find patterns: max confidence + 0.05 bonus if multiple detectors agree
72-
// - When no patterns found: high confidence (~0.85-0.90) it's safe
73-
//
67+
// Risk score is computed by computeWeightedScore (see scoring.go).
7468
// The input is considered unsafe if the final risk score >= threshold.
7569
func (md *MultiDetector) Detect(ctx context.Context, input string) Result {
7670
if md.config.MaxInputLength > 0 && len(input) > md.config.MaxInputLength {
@@ -117,13 +111,7 @@ func (md *MultiDetector) Detect(ctx context.Context, input string) Result {
117111
}
118112
}
119113

120-
// Calculate final risk score using our algorithm:
121-
// final_score = max(individual_scores) + 0.1 × (num_additional_patterns - 1)
122-
finalScore := maxScore
123-
if len(allPatterns) > 1 {
124-
bonus := 0.1 * float64(len(allPatterns)-1)
125-
finalScore = min(finalScore+bonus, 1.0)
126-
}
114+
finalScore := computeWeightedScore(allPatterns)
127115

128116
finalConfidence := 0.0
129117
if detectorsTriggered > 0 {
@@ -175,11 +163,7 @@ func (md *MultiDetector) Detect(ctx context.Context, input string) Result {
175163
maxScore = llmResult.RiskScore
176164
}
177165

178-
finalScore = maxScore
179-
if len(allPatterns) > 1 {
180-
bonus := 0.1 * float64(len(allPatterns)-1)
181-
finalScore = min(finalScore+bonus, 1.0)
182-
}
166+
finalScore = computeWeightedScore(allPatterns)
183167

184168
// Recalculate confidence including LLM result
185169
if llmResult.RiskScore > 0 {

0 commit comments

Comments
 (0)