Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 76 additions & 16 deletions src/semantic-router/pkg/utils/classification/classifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,69 @@ import (
"github.com/vllm-project/semantic-router/semantic-router/pkg/metrics"
)

type CategoryInference interface {
Classify(text string) (candle_binding.ClassResult, error)
}

type LinearCategoryInference struct{}

func (c *LinearCategoryInference) Classify(text string) (candle_binding.ClassResult, error) {
return candle_binding.ClassifyText(text)
}

type ModernBertCategoryInference struct{}

func (c *ModernBertCategoryInference) Classify(text string) (candle_binding.ClassResult, error) {
return candle_binding.ClassifyModernBertText(text)
}

// createCategoryInference creates the appropriate category inference based on configuration
func createCategoryInference(useModernBERT bool) CategoryInference {
if useModernBERT {
return &ModernBertCategoryInference{}
}
return &LinearCategoryInference{}
}

type JailbreakInference interface {
Classify(text string) (candle_binding.ClassResult, error)
}

type LinearJailbreakInference struct{}

func (c *LinearJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) {
return candle_binding.ClassifyJailbreakText(text)
}

type ModernBertJailbreakInference struct{}

func (c *ModernBertJailbreakInference) Classify(text string) (candle_binding.ClassResult, error) {
return candle_binding.ClassifyModernBertJailbreakText(text)
}

// createJailbreakInference creates the appropriate jailbreak inference based on configuration
func createJailbreakInference(useModernBERT bool) JailbreakInference {
if useModernBERT {
return &ModernBertJailbreakInference{}
}
return &LinearJailbreakInference{}
}

type PIIInference interface {
ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error)
}

type ModernBertPIIInference struct{}

func (c *ModernBertPIIInference) ClassifyTokens(text string, configPath string) (candle_binding.TokenClassificationResult, error) {
return candle_binding.ClassifyModernBertPIITokens(text, configPath)
}

// createPIIInference creates the appropriate PII inference (currently only ModernBERT)
func createPIIInference() PIIInference {
return &ModernBertPIIInference{}
}

// JailbreakDetection represents the result of jailbreak analysis for a piece of content
type JailbreakDetection struct {
Content string `json:"content"`
Expand Down Expand Up @@ -40,6 +103,11 @@ type PIIAnalysisResult struct {

// Classifier handles text classification, model selection, and jailbreak detection functionality
type Classifier struct {
// Dependencies
categoryInference CategoryInference
jailbreakInference JailbreakInference
piiInference PIIInference

Config *config.RouterConfig
CategoryMapping *CategoryMapping
PIIMapping *PIIMapping
Expand All @@ -55,6 +123,10 @@ type Classifier struct {
// NewClassifier creates a new classifier with model selection and jailbreak detection capabilities
func NewClassifier(cfg *config.RouterConfig, categoryMapping *CategoryMapping, piiMapping *PIIMapping, jailbreakMapping *JailbreakMapping, modelTTFT map[string]float64) *Classifier {
return &Classifier{
categoryInference: createCategoryInference(cfg.Classifier.CategoryModel.UseModernBERT),
jailbreakInference: createJailbreakInference(cfg.PromptGuard.UseModernBERT),
piiInference: createPIIInference(),

Config: cfg,
CategoryMapping: categoryMapping,
PIIMapping: piiMapping,
Expand Down Expand Up @@ -117,13 +189,7 @@ func (c *Classifier) CheckForJailbreak(text string) (bool, string, float32, erro
var err error

start := time.Now()
if c.Config.PromptGuard.UseModernBERT {
// Use ModernBERT jailbreak classifier
result, err = candle_binding.ClassifyModernBertJailbreakText(text)
} else {
// Use linear jailbreak classifier
result, err = candle_binding.ClassifyJailbreakText(text)
}
result, err = c.jailbreakInference.Classify(text)
metrics.RecordClassifierLatency("jailbreak", time.Since(start).Seconds())

if err != nil {
Expand Down Expand Up @@ -200,13 +266,7 @@ func (c *Classifier) ClassifyCategory(text string) (string, float64, error) {
var err error

start := time.Now()
if c.Config.Classifier.CategoryModel.UseModernBERT {
// Use ModernBERT classifier
result, err = candle_binding.ClassifyModernBertText(text)
} else {
// Use linear classifier
result, err = candle_binding.ClassifyText(text)
}
result, err = c.categoryInference.Classify(text)
metrics.RecordClassifierLatency("category", time.Since(start).Seconds())

if err != nil {
Expand Down Expand Up @@ -249,7 +309,7 @@ func (c *Classifier) ClassifyPII(text string) ([]string, error) {
// Use ModernBERT PII token classifier for entity detection
configPath := fmt.Sprintf("%s/config.json", c.Config.Classifier.PIIModel.ModelID)
start := time.Now()
tokenResult, err := candle_binding.ClassifyModernBertPIITokens(text, configPath)
tokenResult, err := c.piiInference.ClassifyTokens(text, configPath)
metrics.RecordClassifierLatency("pii", time.Since(start).Seconds())
if err != nil {
return nil, fmt.Errorf("PII token classification error: %w", err)
Expand Down Expand Up @@ -331,7 +391,7 @@ func (c *Classifier) AnalyzeContentForPII(contentList []string) (bool, []PIIAnal
// Use ModernBERT PII token classifier for detailed analysis
configPath := fmt.Sprintf("%s/config.json", c.Config.Classifier.PIIModel.ModelID)
start := time.Now()
tokenResult, err := candle_binding.ClassifyModernBertPIITokens(content, configPath)
tokenResult, err := c.piiInference.ClassifyTokens(content, configPath)
metrics.RecordClassifierLatency("pii", time.Since(start).Seconds())
if err != nil {
log.Printf("Error analyzing content %d: %v", i, err)
Expand Down
Loading
Loading