Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions candle-binding/src/ffi/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,11 @@ pub extern "C" fn init_lora_unified_classifier(
}
};

// Check if already initialized - return success if so
if PARALLEL_LORA_ENGINE.get().is_some() {
return true;
}

// Load labels dynamically from model configurations
let _intent_labels_vec = load_labels_from_model_config(intent_path).unwrap_or_else(|e| {
eprintln!(
Expand Down Expand Up @@ -723,7 +728,9 @@ pub extern "C" fn init_lora_unified_classifier(
) {
Ok(engine) => {
// Store in global static variable (Arc for efficient cloning during concurrent access)
// Return true even if already set (race condition)
PARALLEL_LORA_ENGINE.set(Arc::new(engine)).is_ok()
|| PARALLEL_LORA_ENGINE.get().is_some()
}
Err(e) => {
eprintln!(
Expand Down
24 changes: 13 additions & 11 deletions src/semantic-router/pkg/classification/classifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
mcpclient "github.com/vllm-project/semantic-router/src/semantic-router/pkg/mcp"
)

const testModelsDir = "../../../../models"

func TestClassifier(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "Classifier Suite")
Expand Down Expand Up @@ -2909,13 +2911,11 @@ func createMockModelFile(t *testing.T, dir, filename string) {

func TestAutoDiscoverModels_RealModels(t *testing.T) {
// Test with real models directory
modelsDir := "../../../../../models"
modelsDir := testModelsDir

paths, err := AutoDiscoverModels(modelsDir)
if err != nil {
// Skip this test in environments without the real models directory
t.Logf("AutoDiscoverModels() failed in real-models test: %v", err)
t.Skip("Skipping real-models discovery test because models directory is unavailable")
t.Fatalf("AutoDiscoverModels() failed: %v (models directory should exist at %s)", err, modelsDir)
}

t.Logf("Discovered paths:")
Expand Down Expand Up @@ -2964,10 +2964,9 @@ func TestAutoDiscoverModels_RealModels(t *testing.T) {
// TestAutoInitializeUnifiedClassifier tests the full initialization process
func TestAutoInitializeUnifiedClassifier(t *testing.T) {
// Test with real models directory
classifier, err := AutoInitializeUnifiedClassifier("../../../../../models")
classifier, err := AutoInitializeUnifiedClassifier(testModelsDir)
if err != nil {
t.Logf("AutoInitializeUnifiedClassifier() failed in real-models test: %v", err)
t.Skip("Skipping unified classifier init test because real models are unavailable")
t.Fatalf("AutoInitializeUnifiedClassifier() failed: %v (models directory should exist at %s)", err, testModelsDir)
}

if classifier == nil {
Expand Down Expand Up @@ -3305,7 +3304,7 @@ var (
// getTestClassifier returns a shared classifier instance for all integration tests
func getTestClassifier(t *testing.T) *UnifiedClassifier {
globalTestClassifierOnce.Do(func() {
classifier, err := AutoInitializeUnifiedClassifier("../../../../../models")
classifier, err := AutoInitializeUnifiedClassifier(testModelsDir)
if err != nil {
t.Logf("Failed to initialize classifier: %v", err)
return
Expand All @@ -3323,8 +3322,11 @@ func TestUnifiedClassifier_Integration(t *testing.T) {
// Get shared classifier instance
classifier := getTestClassifier(t)
if classifier == nil {
t.Skip("Skipping integration tests - classifier not available")
return
t.Fatal("Classifier initialization failed")
}

if !classifier.useLoRA {
t.Fatal("LoRA models not detected")
}

t.Run("RealBatchClassification", func(t *testing.T) {
Expand Down Expand Up @@ -3472,7 +3474,7 @@ func TestUnifiedClassifier_Integration(t *testing.T) {
func getBenchmarkClassifier(b *testing.B) *UnifiedClassifier {
// Reuse the global test classifier for benchmarks
globalTestClassifierOnce.Do(func() {
classifier, err := AutoInitializeUnifiedClassifier("../../../../../models")
classifier, err := AutoInitializeUnifiedClassifier(testModelsDir)
if err != nil {
b.Logf("Failed to initialize classifier: %v", err)
return
Expand Down
11 changes: 11 additions & 0 deletions tools/make/models.mk
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ download-models: ## Download models (full or minimal set depending on CI_MINIMAL
# - PII token classifier (ModernBERT Presidio)
# - Jailbreak classifier (ModernBERT)
# - Optional plain PII classifier mapping (small)
# - LoRA models (BERT architecture) for unified classifier tests

download-models-minimal:
download-models-minimal: ## Pre-download minimal set of models for CI tests
Expand All @@ -47,6 +48,16 @@ download-models-minimal: ## Pre-download minimal set of models for CI tests
@if [ ! -f "models/pii_classifier_modernbert-base_model/.downloaded" ] || [ ! -d "models/pii_classifier_modernbert-base_model" ]; then \
hf download LLM-Semantic-Router/pii_classifier_modernbert-base_model --local-dir models/pii_classifier_modernbert-base_model && printf '%s\n' "$$(date -u +%Y-%m-%dT%H:%M:%SZ)" > models/pii_classifier_modernbert-base_model/.downloaded; \
fi
# Download LoRA models for unified classifier integration tests
@if [ ! -f "models/lora_intent_classifier_bert-base-uncased_model/.downloaded" ] || [ ! -d "models/lora_intent_classifier_bert-base-uncased_model" ]; then \
hf download LLM-Semantic-Router/lora_intent_classifier_bert-base-uncased_model --local-dir models/lora_intent_classifier_bert-base-uncased_model && printf '%s\n' "$$(date -u +%Y-%m-%dT%H:%M:%SZ)" > models/lora_intent_classifier_bert-base-uncased_model/.downloaded; \
fi
@if [ ! -f "models/lora_pii_detector_bert-base-uncased_model/.downloaded" ] || [ ! -d "models/lora_pii_detector_bert-base-uncased_model" ]; then \
hf download LLM-Semantic-Router/lora_pii_detector_bert-base-uncased_model --local-dir models/lora_pii_detector_bert-base-uncased_model && printf '%s\n' "$$(date -u +%Y-%m-%dT%H:%M:%SZ)" > models/lora_pii_detector_bert-base-uncased_model/.downloaded; \
fi
@if [ ! -f "models/lora_jailbreak_classifier_bert-base-uncased_model/.downloaded" ] || [ ! -d "models/lora_jailbreak_classifier_bert-base-uncased_model" ]; then \
hf download LLM-Semantic-Router/lora_jailbreak_classifier_bert-base-uncased_model --local-dir models/lora_jailbreak_classifier_bert-base-uncased_model && printf '%s\n' "$$(date -u +%Y-%m-%dT%H:%M:%SZ)" > models/lora_jailbreak_classifier_bert-base-uncased_model/.downloaded; \
fi

# Full model set for local development and docs

Expand Down
Loading