Skip to content

Commit 71fd1b3

Browse files
author
Yehudit Kerido
committed
pii detection fix
Signed-off-by: Yehudit Kerido <[email protected]>
1 parent 30801fa commit 71fd1b3

File tree

11 files changed

+363
-60
lines changed

11 files changed

+363
-60
lines changed

candle-binding/src/ffi/classify.rs

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -654,26 +654,58 @@ pub extern "C" fn classify_candle_bert_tokens(
654654

655655
let entities_ptr = unsafe { allocate_bert_token_entity_array(&token_entities) };
656656

657-
BertTokenClassificationResult {
657+
return BertTokenClassificationResult {
658658
entities: entities_ptr,
659659
num_entities: token_entities.len() as i32,
660-
}
660+
};
661661
}
662662
Err(e) => {
663663
println!("Candle BERT token classification failed: {}", e);
664-
BertTokenClassificationResult {
664+
return BertTokenClassificationResult {
665665
entities: std::ptr::null_mut(),
666666
num_entities: 0,
667-
}
667+
};
668668
}
669669
}
670-
} else {
671-
println!("TraditionalBertTokenClassifier not initialized - call init function first");
672-
BertTokenClassificationResult {
673-
entities: std::ptr::null_mut(),
674-
num_entities: 0,
670+
}
671+
672+
// Fallback to ModernBERT token classifier (for PII detection with ModernBERT models)
673+
if let Some(classifier) = TRADITIONAL_MODERNBERT_TOKEN_CLASSIFIER.get() {
674+
let classifier = classifier.clone();
675+
match classifier.classify_tokens(text) {
676+
Ok(token_results) => {
677+
// Filter non-background classes; Go layer applies confidence threshold
678+
let token_entities: Vec<(String, String, f32)> = token_results
679+
.iter()
680+
.filter(|(_, class_idx, _, _, _)| *class_idx > 0)
681+
.map(|(token, class_idx, confidence, _, _)| {
682+
(token.clone(), format!("class_{}", class_idx), *confidence)
683+
})
684+
.collect();
685+
686+
let entities_ptr = unsafe { allocate_bert_token_entity_array(&token_entities) };
687+
688+
return BertTokenClassificationResult {
689+
entities: entities_ptr,
690+
num_entities: token_entities.len() as i32,
691+
};
692+
}
693+
Err(e) => {
694+
println!("ModernBERT token classification failed: {}", e);
695+
return BertTokenClassificationResult {
696+
entities: std::ptr::null_mut(),
697+
num_entities: 0,
698+
};
699+
}
675700
}
676701
}
702+
703+
// No classifier available
704+
println!("No token classifier initialized (Traditional BERT, ModernBERT, or LoRA) - call init function first");
705+
BertTokenClassificationResult {
706+
entities: std::ptr::null_mut(),
707+
num_entities: 0,
708+
}
677709
}
678710

679711
/// Classify text using Candle BERT

deploy/kubernetes/aibrix/semantic-router-values/values.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ config:
123123
- type: "pii"
124124
configuration:
125125
enabled: true
126-
pii_types_allowed: []
126+
pii_types_allowed:
127+
- "ORGANIZATION" # Allow - scientific terms like "photosynthesis" falsely detected as ORG
127128
- type: "system_prompt"
128129
configuration:
129130
enabled: true
@@ -189,7 +190,8 @@ config:
189190
- type: "pii"
190191
configuration:
191192
enabled: true
192-
pii_types_allowed: []
193+
pii_types_allowed:
194+
- "GPE" # Allow - country/city names in general knowledge questions
193195
- type: "semantic-cache"
194196
configuration:
195197
enabled: true

e2e/profiles/ai-gateway/profile.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ func (p *Profile) kubectlApply(ctx context.Context, kubeConfig, manifest string)
325325
}
326326

327327
func (p *Profile) kubectlDelete(ctx context.Context, kubeConfig, manifest string) error {
328-
return p.runKubectl(ctx, kubeConfig, "delete", "-f", manifest)
328+
return p.runKubectl(ctx, kubeConfig, "delete", "--ignore-not-found", "-f", manifest)
329329
}
330330

331331
func (p *Profile) runKubectl(ctx context.Context, kubeConfig string, args ...string) error {

e2e/profiles/ai-gateway/values.yaml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ config:
142142
- type: "pii"
143143
configuration:
144144
enabled: true
145-
pii_types_allowed: []
145+
pii_types_allowed:
146+
- "ORGANIZATION" # Allow - scientific terms like "photosynthesis" falsely detected as ORG
146147
- type: "system_prompt"
147148
configuration:
148149
enabled: true
@@ -396,6 +397,10 @@ config:
396397
lora_name: general-expert
397398
use_reasoning: false
398399
plugins:
400+
- type: "pii"
401+
configuration:
402+
enabled: true
403+
pii_types_allowed: []
399404
- type: "system_prompt"
400405
configuration:
401406
enabled: true
@@ -441,7 +446,8 @@ config:
441446
- type: "pii"
442447
configuration:
443448
enabled: true
444-
pii_types_allowed: []
449+
pii_types_allowed:
450+
- "GPE" # Allow - country/city names in general knowledge questions
445451
- type: "semantic-cache"
446452
configuration:
447453
enabled: true
@@ -529,7 +535,7 @@ config:
529535
case_sensitive: false
530536

531537
- name: "sensitive_keywords"
532-
operator: "AND"
538+
operator: "OR"
533539
keywords: ["SSN", "credit card"]
534540
case_sensitive: false
535541

e2e/profiles/aibrix/profile.go

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -290,14 +290,47 @@ func (p *Profile) deployAIBrixCore(ctx context.Context, opts *framework.SetupOpt
290290
}
291291

292292
func (p *Profile) deployGatewayResources(ctx context.Context, opts *framework.SetupOptions) error {
293-
// Apply base model (Demo LLM)
294-
if err := p.kubectlApply(ctx, opts.KubeConfig, "deploy/kubernetes/aibrix/aigw-resources/base-model.yaml"); err != nil {
295-
return fmt.Errorf("failed to apply base model: %w", err)
293+
// Apply base model (Demo LLM) with retry logic
294+
// This deployment can fail randomly in CI due to resource constraints
295+
maxRetries := 3
296+
var lastErr error
297+
298+
for attempt := 1; attempt <= maxRetries; attempt++ {
299+
if attempt > 1 {
300+
p.log("Retrying Demo LLM deployment (attempt %d/%d)...", attempt, maxRetries)
301+
// Exponential backoff: 10s, 20s, 40s
302+
backoff := time.Duration(10*(1<<(attempt-1))) * time.Second
303+
select {
304+
case <-ctx.Done():
305+
return ctx.Err()
306+
case <-time.After(backoff):
307+
}
308+
309+
// Clean up any partial deployment before retrying
310+
p.kubectlDelete(ctx, opts.KubeConfig, "deploy/kubernetes/aibrix/aigw-resources/base-model.yaml")
311+
}
312+
313+
// Apply base model (Demo LLM)
314+
if err := p.kubectlApply(ctx, opts.KubeConfig, "deploy/kubernetes/aibrix/aigw-resources/base-model.yaml"); err != nil {
315+
lastErr = fmt.Errorf("failed to apply base model: %w", err)
316+
p.log("Warning: Demo LLM apply failed (attempt %d/%d): %v", attempt, maxRetries, err)
317+
continue
318+
}
319+
320+
// Wait for Demo LLM deployment
321+
if err := p.waitForDeployment(ctx, opts, "default", deploymentDemoLLM, timeoutComponentDeploy); err != nil {
322+
lastErr = fmt.Errorf("demo LLM deployment not ready: %w", err)
323+
p.log("Warning: Demo LLM deployment not ready (attempt %d/%d): %v", attempt, maxRetries, err)
324+
continue
325+
}
326+
327+
// Success - break out of retry loop
328+
lastErr = nil
329+
break
296330
}
297331

298-
// Wait for Demo LLM deployment
299-
if err := p.waitForDeployment(ctx, opts, "default", deploymentDemoLLM, timeoutComponentDeploy); err != nil {
300-
return fmt.Errorf("demo LLM deployment not ready: %w", err)
332+
if lastErr != nil {
333+
return lastErr
301334
}
302335

303336
// Apply gateway API resources
@@ -468,7 +501,7 @@ func (p *Profile) kubectlApply(ctx context.Context, kubeConfig, manifest string)
468501
}
469502

470503
func (p *Profile) kubectlDelete(ctx context.Context, kubeConfig, manifest string) error {
471-
return p.runKubectl(ctx, kubeConfig, "delete", "-f", manifest)
504+
return p.runKubectl(ctx, kubeConfig, "delete", "--ignore-not-found", "-f", manifest)
472505
}
473506

474507
func (p *Profile) runKubectl(ctx context.Context, kubeConfig string, args ...string) error {

e2e/profiles/dynamic-config/crds/intelligentroute.yaml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ spec:
100100
loraName: "general-expert"
101101
useReasoning: false
102102
plugins:
103+
- type: "pii"
104+
configuration:
105+
enabled: true
106+
pii_types_allowed: []
103107
- type: "header_mutation"
104108
configuration:
105109
add:
@@ -269,7 +273,8 @@ spec:
269273
- type: "pii"
270274
configuration:
271275
enabled: true
272-
pii_types_allowed: []
276+
pii_types_allowed:
277+
- "ORGANIZATION" # Allow - scientific terms like "photosynthesis" falsely detected as ORG
273278
- type: "system_prompt"
274279
configuration:
275280
enabled: true
@@ -503,7 +508,8 @@ spec:
503508
- type: "pii"
504509
configuration:
505510
enabled: true
506-
pii_types_allowed: []
511+
pii_types_allowed:
512+
- "GPE" # Allow - country/city names like "France" in general knowledge questions
507513
- type: "semantic-cache"
508514
configuration:
509515
enabled: true

e2e/testcases/plugin_chain_execution.go

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package testcases
33
import (
44
"bytes"
55
"context"
6+
"embed"
67
"encoding/json"
78
"fmt"
89
"io"
@@ -13,6 +14,9 @@ import (
1314
"k8s.io/client-go/kubernetes"
1415
)
1516

17+
//go:embed testdata/plugin_chain_cases.json
18+
var pluginChainTestData embed.FS
19+
1620
func init() {
1721
pkgtestcases.Register("plugin-chain-execution", pkgtestcases.TestCase{
1822
Description: "Test plugin chain execution order and blocking behavior",
@@ -23,14 +27,29 @@ func init() {
2327

2428
// PluginChainCase represents a test case for plugin chain execution
2529
type PluginChainCase struct {
30+
Description string `json:"description"`
2631
Query string `json:"query"`
2732
ExpectPIIBlock bool `json:"expect_pii_block"`
2833
ExpectCacheUsed bool `json:"expect_cache_used"`
2934
ExpectPromptApplied bool `json:"expect_prompt_applied"`
30-
Description string `json:"description"`
3135
PIITypes []string `json:"pii_types"` // Expected PII types detected
3236
}
3337

38+
// loadPluginChainCases loads test cases from the embedded JSON file
39+
func loadPluginChainCases() ([]PluginChainCase, error) {
40+
data, err := pluginChainTestData.ReadFile("testdata/plugin_chain_cases.json")
41+
if err != nil {
42+
return nil, fmt.Errorf("failed to read plugin chain cases: %w", err)
43+
}
44+
45+
var cases []PluginChainCase
46+
if err := json.Unmarshal(data, &cases); err != nil {
47+
return nil, fmt.Errorf("failed to parse plugin chain cases: %w", err)
48+
}
49+
50+
return cases, nil
51+
}
52+
3453
// PluginChainResult tracks the result of a single plugin chain test
3554
type PluginChainResult struct {
3655
Query string
@@ -58,38 +77,14 @@ func testPluginChainExecution(ctx context.Context, client *kubernetes.Clientset,
5877
}
5978
defer stopPortForward()
6079

61-
// Define test cases
62-
testCases := []PluginChainCase{
63-
{
64-
Query: "My social security number is 123-45-6789",
65-
ExpectPIIBlock: true,
66-
ExpectCacheUsed: false, // PII blocks before cache
67-
ExpectPromptApplied: false, // PII blocks before prompt
68-
Description: "PII (SSN) should block entire plugin chain",
69-
PIITypes: []string{"US_SSN"},
70-
},
71-
{
72-
Query: "Contact me at [email protected]",
73-
ExpectPIIBlock: true,
74-
ExpectCacheUsed: false,
75-
ExpectPromptApplied: false,
76-
Description: "PII (EMAIL) should block entire plugin chain",
77-
PIITypes: []string{"EMAIL"},
78-
},
79-
{
80-
Query: "What is 5 + 7?",
81-
ExpectPIIBlock: false,
82-
ExpectCacheUsed: false, // First request, cache miss
83-
ExpectPromptApplied: true, // Should apply math expert prompt
84-
Description: "Clean query should pass PII and apply prompt",
85-
},
86-
{
87-
Query: "Tell me about photosynthesis",
88-
ExpectPIIBlock: false,
89-
ExpectCacheUsed: false,
90-
ExpectPromptApplied: true,
91-
Description: "Biology query should pass PII plugin",
92-
},
80+
// Load test cases from JSON file
81+
testCases, err := loadPluginChainCases()
82+
if err != nil {
83+
return fmt.Errorf("failed to load test cases: %w", err)
84+
}
85+
86+
if opts.Verbose {
87+
fmt.Printf("[Test] Loaded %d test cases from testdata/plugin_chain_cases.json\n", len(testCases))
9388
}
9489

9590
// Run plugin chain tests
@@ -181,7 +176,12 @@ func testSinglePluginChain(ctx context.Context, testCase PluginChainCase, localP
181176

182177
// Extract plugin execution headers
183178
piiViolationHeader := resp.Header.Get("x-vsr-pii-violation")
184-
result.PIIDetected = piiViolationHeader // Store for display purposes
179+
piiTypesHeader := resp.Header.Get("x-vsr-pii-types")
180+
if piiTypesHeader != "" {
181+
result.PIIDetected = piiTypesHeader // Store detected PII types for display
182+
} else {
183+
result.PIIDetected = piiViolationHeader // Fallback to boolean
184+
}
185185
result.PIIBlocked = (resp.StatusCode == http.StatusForbidden || piiViolationHeader == "true")
186186

187187
// Check cache headers (x-vsr-cache-hit or similar)

0 commit comments

Comments
 (0)