Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 50 additions & 9 deletions candle-binding/src/ffi/classify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::ffi::memory::{
allocate_lora_intent_array, allocate_lora_pii_array, allocate_lora_security_array,
allocate_modernbert_token_entity_array,
};
use crate::ffi::types::BertTokenEntity;
use crate::ffi::types::*;
use crate::model_architectures::traditional::bert::{
TRADITIONAL_BERT_CLASSIFIER, TRADITIONAL_BERT_TOKEN_CLASSIFIER,
Expand Down Expand Up @@ -654,26 +655,66 @@ pub extern "C" fn classify_candle_bert_tokens(

let entities_ptr = unsafe { allocate_bert_token_entity_array(&token_entities) };

BertTokenClassificationResult {
return BertTokenClassificationResult {
entities: entities_ptr,
num_entities: token_entities.len() as i32,
}
};
}
Err(e) => {
println!("Candle BERT token classification failed: {}", e);
BertTokenClassificationResult {
return BertTokenClassificationResult {
entities: std::ptr::null_mut(),
num_entities: 0,
}
};
}
}
} else {
println!("TraditionalBertTokenClassifier not initialized - call init function first");
BertTokenClassificationResult {
entities: std::ptr::null_mut(),
num_entities: 0,
}

// Fallback to ModernBERT token classifier (for PII detection with ModernBERT models)
if let Some(classifier) = TRADITIONAL_MODERNBERT_TOKEN_CLASSIFIER.get() {
let classifier = classifier.clone();
match classifier.classify_tokens(text) {
Ok(token_results) => {
// Filter non-background classes; Go layer applies confidence threshold
// Keep real positions (start, end) for accurate entity extraction
let token_entities: Vec<(String, String, f32, usize, usize)> = token_results
.iter()
.filter(|(_, class_idx, _, _, _)| *class_idx > 0)
.map(|(token, class_idx, confidence, start, end)| {
(
token.clone(),
format!("class_{}", class_idx),
*confidence,
*start,
*end,
)
})
.collect();

let entities_ptr =
unsafe { allocate_modernbert_token_entity_array(&token_entities) };

return BertTokenClassificationResult {
entities: entities_ptr as *mut BertTokenEntity,
num_entities: token_entities.len() as i32,
};
}
Err(e) => {
println!("ModernBERT token classification failed: {}", e);
return BertTokenClassificationResult {
entities: std::ptr::null_mut(),
num_entities: 0,
};
}
}
}

// No classifier available
println!("No token classifier initialized (Traditional BERT, ModernBERT, or LoRA) - call init function first");
BertTokenClassificationResult {
entities: std::ptr::null_mut(),
num_entities: 0,
}
}

/// Classify text using Candle BERT
Expand Down
6 changes: 4 additions & 2 deletions deploy/kubernetes/aibrix/semantic-router-values/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ config:
- type: "pii"
configuration:
enabled: true
pii_types_allowed: []
pii_types_allowed:
- "ORGANIZATION" # Allow - scientific terms like "photosynthesis" falsely detected as ORG
- type: "system_prompt"
configuration:
enabled: true
Expand Down Expand Up @@ -189,7 +190,8 @@ config:
- type: "pii"
configuration:
enabled: true
pii_types_allowed: []
pii_types_allowed:
- "GPE" # Allow - country/city names in general knowledge questions
- type: "semantic-cache"
configuration:
enabled: true
Expand Down
2 changes: 1 addition & 1 deletion e2e/profiles/ai-gateway/profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ func (p *Profile) kubectlApply(ctx context.Context, kubeConfig, manifest string)
}

func (p *Profile) kubectlDelete(ctx context.Context, kubeConfig, manifest string) error {
return p.runKubectl(ctx, kubeConfig, "delete", "-f", manifest)
return p.runKubectl(ctx, kubeConfig, "delete", "--ignore-not-found", "-f", manifest)
}

func (p *Profile) runKubectl(ctx context.Context, kubeConfig string, args ...string) error {
Expand Down
12 changes: 9 additions & 3 deletions e2e/profiles/ai-gateway/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ config:
- type: "pii"
configuration:
enabled: true
pii_types_allowed: []
pii_types_allowed:
- "ORGANIZATION" # Allow - scientific terms like "photosynthesis" falsely detected as ORG
- type: "system_prompt"
configuration:
enabled: true
Expand Down Expand Up @@ -396,6 +397,10 @@ config:
lora_name: general-expert
use_reasoning: false
plugins:
- type: "pii"
configuration:
enabled: true
pii_types_allowed: []
- type: "system_prompt"
configuration:
enabled: true
Expand Down Expand Up @@ -441,7 +446,8 @@ config:
- type: "pii"
configuration:
enabled: true
pii_types_allowed: []
pii_types_allowed:
- "GPE" # Allow - country/city names in general knowledge questions
- type: "semantic-cache"
configuration:
enabled: true
Expand Down Expand Up @@ -529,7 +535,7 @@ config:
case_sensitive: false

- name: "sensitive_keywords"
operator: "AND"
operator: "OR"
keywords: ["SSN", "credit card"]
case_sensitive: false

Expand Down
2 changes: 1 addition & 1 deletion e2e/profiles/aibrix/profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ func (p *Profile) kubectlApply(ctx context.Context, kubeConfig, manifest string)
}

func (p *Profile) kubectlDelete(ctx context.Context, kubeConfig, manifest string) error {
return p.runKubectl(ctx, kubeConfig, "delete", "-f", manifest)
return p.runKubectl(ctx, kubeConfig, "delete", "--ignore-not-found", "-f", manifest)
}

func (p *Profile) runKubectl(ctx context.Context, kubeConfig string, args ...string) error {
Expand Down
10 changes: 8 additions & 2 deletions e2e/profiles/dynamic-config/crds/intelligentroute.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ spec:
loraName: "general-expert"
useReasoning: false
plugins:
- type: "pii"
configuration:
enabled: true
pii_types_allowed: []
- type: "header_mutation"
configuration:
add:
Expand Down Expand Up @@ -269,7 +273,8 @@ spec:
- type: "pii"
configuration:
enabled: true
pii_types_allowed: []
pii_types_allowed:
- "ORGANIZATION" # Allow - scientific terms like "photosynthesis" falsely detected as ORG
- type: "system_prompt"
configuration:
enabled: true
Expand Down Expand Up @@ -503,7 +508,8 @@ spec:
- type: "pii"
configuration:
enabled: true
pii_types_allowed: []
pii_types_allowed:
- "GPE" # Allow - country/city names like "France" in general knowledge questions
- type: "semantic-cache"
configuration:
enabled: true
Expand Down
68 changes: 34 additions & 34 deletions e2e/testcases/plugin_chain_execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package testcases
import (
"bytes"
"context"
"embed"
"encoding/json"
"fmt"
"io"
Expand All @@ -13,6 +14,9 @@ import (
"k8s.io/client-go/kubernetes"
)

//go:embed testdata/plugin_chain_cases.json
var pluginChainTestData embed.FS

func init() {
pkgtestcases.Register("plugin-chain-execution", pkgtestcases.TestCase{
Description: "Test plugin chain execution order and blocking behavior",
Expand All @@ -23,14 +27,29 @@ func init() {

// PluginChainCase represents a test case for plugin chain execution
type PluginChainCase struct {
Description string `json:"description"`
Query string `json:"query"`
ExpectPIIBlock bool `json:"expect_pii_block"`
ExpectCacheUsed bool `json:"expect_cache_used"`
ExpectPromptApplied bool `json:"expect_prompt_applied"`
Description string `json:"description"`
PIITypes []string `json:"pii_types"` // Expected PII types detected
}

// loadPluginChainCases loads test cases from the embedded JSON file
func loadPluginChainCases() ([]PluginChainCase, error) {
data, err := pluginChainTestData.ReadFile("testdata/plugin_chain_cases.json")
if err != nil {
return nil, fmt.Errorf("failed to read plugin chain cases: %w", err)
}

var cases []PluginChainCase
if err := json.Unmarshal(data, &cases); err != nil {
return nil, fmt.Errorf("failed to parse plugin chain cases: %w", err)
}

return cases, nil
}

// PluginChainResult tracks the result of a single plugin chain test
type PluginChainResult struct {
Query string
Expand Down Expand Up @@ -58,38 +77,14 @@ func testPluginChainExecution(ctx context.Context, client *kubernetes.Clientset,
}
defer stopPortForward()

// Define test cases
testCases := []PluginChainCase{
{
Query: "My social security number is 123-45-6789",
ExpectPIIBlock: true,
ExpectCacheUsed: false, // PII blocks before cache
ExpectPromptApplied: false, // PII blocks before prompt
Description: "PII (SSN) should block entire plugin chain",
PIITypes: []string{"US_SSN"},
},
{
Query: "Contact me at [email protected]",
ExpectPIIBlock: true,
ExpectCacheUsed: false,
ExpectPromptApplied: false,
Description: "PII (EMAIL) should block entire plugin chain",
PIITypes: []string{"EMAIL"},
},
{
Query: "What is 5 + 7?",
ExpectPIIBlock: false,
ExpectCacheUsed: false, // First request, cache miss
ExpectPromptApplied: true, // Should apply math expert prompt
Description: "Clean query should pass PII and apply prompt",
},
{
Query: "Tell me about photosynthesis",
ExpectPIIBlock: false,
ExpectCacheUsed: false,
ExpectPromptApplied: true,
Description: "Biology query should pass PII plugin",
},
// Load test cases from JSON file
testCases, err := loadPluginChainCases()
if err != nil {
return fmt.Errorf("failed to load test cases: %w", err)
}

if opts.Verbose {
fmt.Printf("[Test] Loaded %d test cases from testdata/plugin_chain_cases.json\n", len(testCases))
}

// Run plugin chain tests
Expand Down Expand Up @@ -181,7 +176,12 @@ func testSinglePluginChain(ctx context.Context, testCase PluginChainCase, localP

// Extract plugin execution headers
piiViolationHeader := resp.Header.Get("x-vsr-pii-violation")
result.PIIDetected = piiViolationHeader // Store for display purposes
piiTypesHeader := resp.Header.Get("x-vsr-pii-types")
if piiTypesHeader != "" {
result.PIIDetected = piiTypesHeader // Store detected PII types for display
} else {
result.PIIDetected = piiViolationHeader // Fallback to boolean
}
result.PIIBlocked = (resp.StatusCode == http.StatusForbidden || piiViolationHeader == "true")

// Check cache headers (x-vsr-cache-hit or similar)
Expand Down
Loading
Loading