Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 71 additions & 16 deletions src/semantic-router/pkg/utils/classification/classifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,48 @@ import (
"github.com/vllm-project/semantic-router/semantic-router/pkg/metrics"
)

type CategoryInference interface {
Classify(text string) (candle_binding.ClassResult, error)
}

type LinearCategoryInference struct{}

func (c *LinearCategoryInference) Classify(text string) (candle_binding.ClassResult, error) {
return candle_binding.ClassifyText(text)
}

type ModernBertCategoryInference struct{}

func (c *ModernBertCategoryInference) Classify(text string) (candle_binding.ClassResult, error) {
return candle_binding.ClassifyModernBertText(text)
}

type JailbreakInference interface {
Classify(text string) (candle_binding.ClassResult, error)
}

type LinearJailbreakInference struct{}

func (c *LinearJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) {
return candle_binding.ClassifyJailbreakText(text)
}

type ModernBertJailbreakInference struct{}

func (c *ModernBertJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) {
return candle_binding.ClassifyModernBertJailbreakText(text)
}

type PIIInference interface {
ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error)
}

type ModernBertPIIInference struct{}

func (c *ModernBertPIIInference) ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error) {
return candle_binding.ClassifyModernBertPIITokens(text, configPath)
}

// JailbreakDetection represents the result of jailbreak analysis for a piece of content
type JailbreakDetection struct {
Content string `json:"content"`
Expand Down Expand Up @@ -40,6 +82,11 @@ type PIIAnalysisResult struct {

// Classifier handles text classification, model selection, and jailbreak detection functionality
type Classifier struct {
// Dependencies
categoryInference CategoryInference
jailbreakInference JailbreakInference
piiInference PIIInference

Config *config.RouterConfig
CategoryMapping *CategoryMapping
PIIMapping *PIIMapping
Expand All @@ -54,7 +101,27 @@ type Classifier struct {

// NewClassifier creates a new classifier with model selection and jailbreak detection capabilities
func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping, modelTTFT map[string]float64) *Classifier {
var categoryInference CategoryInference
if cfg.Classifier.CategoryModel.UseModernBERT {
categoryInference = &ModernBertCategoryInference{}
} else {
categoryInference = &LinearCategoryInference{}
}

var jailbreakInference JailbreakInference
if cfg.PromptGuard.UseModernBERT {
jailbreakInference = &ModernBertJailbreakInference{}
} else {
jailbreakInference = &LinearJailbreakInference{}
}

piiInference := &ModernBertPIIInference{}

return &Classifier{
categoryInference: categoryInference,
jailbreakInference: jailbreakInference,
piiInference: piiInference,

Config: cfg,
CategoryMapping: categoryMapping,
PIIMapping: piiMapping,
Expand Down Expand Up @@ -117,13 +184,7 @@ func (c *Classifier) CheckForJailbreak(text string) (bool, string, float32, erro
var err error

start := time.Now()
if c.Config.PromptGuard.UseModernBERT {
// Use ModernBERT jailbreak classifier
result, err = candle_binding.ClassifyModernBertJailbreakText(text)
} else {
// Use linear jailbreak classifier
result, err = candle_binding.ClassifyJailbreakText(text)
}
result, err = c.jailbreakInference.Classify(text)
metrics.RecordClassifierLatency("jailbreak", time.Since(start).Seconds())

if err != nil {
Expand Down Expand Up @@ -200,13 +261,7 @@ func (c *Classifier) ClassifyCategory(text string) (string, float64, error) {
var err error

start := time.Now()
if c.Config.Classifier.CategoryModel.UseModernBERT {
// Use ModernBERT classifier
result, err = candle_binding.ClassifyModernBertText(text)
} else {
// Use linear classifier
result, err = candle_binding.ClassifyText(text)
}
result, err = c.categoryInference.Classify(text)
metrics.RecordClassifierLatency("category", time.Since(start).Seconds())

if err != nil {
Expand Down Expand Up @@ -249,7 +304,7 @@ func (c *Classifier) ClassifyPII(text string) ([]string, error) {
// Use ModernBERT PII token classifier for entity detection
configPath := fmt.Sprintf("%s/config.json", c.Config.Classifier.PIIModel.ModelID)
start := time.Now()
tokenResult, err := candle_binding.ClassifyModernBertPIITokens(text, configPath)
tokenResult, err := c.piiInference.ClassifyTokens(text, configPath)
metrics.RecordClassifierLatency("pii", time.Since(start).Seconds())
if err != nil {
return nil, fmt.Errorf("PII token classification error: %w", err)
Expand Down Expand Up @@ -331,7 +386,7 @@ func (c *Classifier) AnalyzeContentForPII(contentList []string) (bool, []PIIAnal
// Use ModernBERT PII token classifier for detailed analysis
configPath := fmt.Sprintf("%s/config.json", c.Config.Classifier.PIIModel.ModelID)
start := time.Now()
tokenResult, err := candle_binding.ClassifyModernBertPIITokens(content, configPath)
tokenResult, err := c.piiInference.ClassifyTokens(content, configPath)
metrics.RecordClassifierLatency("pii", time.Since(start).Seconds())
if err != nil {
log.Printf("Error analyzing content %d: %v", i, err)
Expand Down
Loading
Loading