Skip to content

Commit 7128765

Browse files
authored
feat: add comprehensive unit tests for entropy-based routing. Tests c… (#112)
* feat: add comprehensive unit tests for entropy-based routing. Tests cover all 5 key scenarios and probability distribution validation. Signed-off-by: Huamin Chen <[email protected]> * review feedback Signed-off-by: Huamin Chen <[email protected]> * review feedback Signed-off-by: Huamin Chen <[email protected]> * review feedback Signed-off-by: Huamin Chen <[email protected]> * review feedback Signed-off-by: Huamin Chen <[email protected]> --------- Signed-off-by: Huamin Chen <[email protected]>
1 parent 35f0d70 commit 7128765

File tree

11 files changed

+1879
-52
lines changed

11 files changed

+1879
-52
lines changed

candle-binding/semantic-router.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,42 @@ typedef struct {
7878
float confidence;
7979
} ClassificationResult;
8080
81+
// Classification result with full probability distribution structure
82+
typedef struct {
83+
int class;
84+
float confidence;
85+
float* probabilities;
86+
int num_classes;
87+
} ClassificationResultWithProbs;
88+
8189
// ModernBERT Classification result structure
8290
typedef struct {
8391
int class;
8492
float confidence;
8593
} ModernBertClassificationResult;
8694
95+
// ModernBERT Classification result with full probability distribution structure
96+
typedef struct {
97+
int class;
98+
float confidence;
99+
float* probabilities;
100+
int num_classes;
101+
} ModernBertClassificationResultWithProbs;
102+
87103
extern SimilarityResult find_most_similar(const char* query, const char** candidates, int num_candidates, int max_length);
88104
extern EmbeddingResult get_text_embedding(const char* text, int max_length);
89105
extern TokenizationResult tokenize_text(const char* text, int max_length);
90106
extern void free_cstring(char* s);
91107
extern void free_embedding(float* data, int length);
92108
extern void free_tokenization_result(TokenizationResult result);
93109
extern ClassificationResult classify_text(const char* text);
110+
extern ClassificationResultWithProbs classify_text_with_probabilities(const char* text);
111+
extern void free_probabilities(float* probabilities, int num_classes);
94112
extern ClassificationResult classify_pii_text(const char* text);
95113
extern ClassificationResult classify_jailbreak_text(const char* text);
96114
extern ModernBertClassificationResult classify_modernbert_text(const char* text);
115+
extern ModernBertClassificationResultWithProbs classify_modernbert_text_with_probabilities(const char* text);
116+
extern void free_modernbert_probabilities(float* probabilities, int num_classes);
97117
extern ModernBertClassificationResult classify_modernbert_pii_text(const char* text);
98118
extern ModernBertClassificationResult classify_modernbert_jailbreak_text(const char* text);
99119
*/
@@ -137,6 +157,14 @@ type ClassResult struct {
137157
Confidence float32 // Confidence score
138158
}
139159

160+
// ClassResultWithProbs represents the result of a text classification with full probability distribution
161+
type ClassResultWithProbs struct {
162+
Class int // Class index
163+
Confidence float32 // Confidence score
164+
Probabilities []float32 // Full probability distribution
165+
NumClasses int // Number of classes
166+
}
167+
140168
// TokenEntity represents a single detected entity in token classification
141169
type TokenEntity struct {
142170
EntityType string // Type of entity (e.g., "PERSON", "EMAIL", "PHONE")
@@ -452,6 +480,36 @@ func ClassifyText(text string) (ClassResult, error) {
452480
}, nil
453481
}
454482

483+
// ClassifyTextWithProbabilities classifies the provided text and returns the predicted class, confidence, and full probability distribution
484+
func ClassifyTextWithProbabilities(text string) (ClassResultWithProbs, error) {
485+
cText := C.CString(text)
486+
defer C.free(unsafe.Pointer(cText))
487+
488+
result := C.classify_text_with_probabilities(cText)
489+
490+
if result.class < 0 {
491+
return ClassResultWithProbs{}, fmt.Errorf("failed to classify text with probabilities")
492+
}
493+
494+
// Convert C array to Go slice
495+
probabilities := make([]float32, int(result.num_classes))
496+
if result.probabilities != nil && result.num_classes > 0 {
497+
probsSlice := (*[1 << 30]C.float)(unsafe.Pointer(result.probabilities))[:result.num_classes:result.num_classes]
498+
for i, prob := range probsSlice {
499+
probabilities[i] = float32(prob)
500+
}
501+
// Free the C-allocated memory
502+
C.free_probabilities(result.probabilities, result.num_classes)
503+
}
504+
505+
return ClassResultWithProbs{
506+
Class: int(result.class),
507+
Confidence: float32(result.confidence),
508+
Probabilities: probabilities,
509+
NumClasses: int(result.num_classes),
510+
}, nil
511+
}
512+
455513
// ClassifyPIIText classifies the provided text for PII detection and returns the predicted class and confidence
456514
func ClassifyPIIText(text string) (ClassResult, error) {
457515
cText := C.CString(text)
@@ -599,6 +657,36 @@ func ClassifyModernBertText(text string) (ClassResult, error) {
599657
}, nil
600658
}
601659

660+
// ClassifyModernBertTextWithProbabilities classifies the provided text using ModernBERT and returns the predicted class, confidence, and full probability distribution
661+
func ClassifyModernBertTextWithProbabilities(text string) (ClassResultWithProbs, error) {
662+
cText := C.CString(text)
663+
defer C.free(unsafe.Pointer(cText))
664+
665+
result := C.classify_modernbert_text_with_probabilities(cText)
666+
667+
if result.class < 0 {
668+
return ClassResultWithProbs{}, fmt.Errorf("failed to classify text with probabilities using ModernBERT")
669+
}
670+
671+
// Convert C array to Go slice
672+
probabilities := make([]float32, int(result.num_classes))
673+
if result.probabilities != nil && result.num_classes > 0 {
674+
probsSlice := (*[1 << 30]C.float)(unsafe.Pointer(result.probabilities))[:result.num_classes:result.num_classes]
675+
for i, prob := range probsSlice {
676+
probabilities[i] = float32(prob)
677+
}
678+
// Free the C-allocated memory
679+
C.free_modernbert_probabilities(result.probabilities, result.num_classes)
680+
}
681+
682+
return ClassResultWithProbs{
683+
Class: int(result.class),
684+
Confidence: float32(result.confidence),
685+
Probabilities: probabilities,
686+
NumClasses: int(result.num_classes),
687+
}, nil
688+
}
689+
602690
// ClassifyModernBertPIIText classifies the provided text for PII detection using ModernBERT and returns the predicted class and confidence
603691
func ClassifyModernBertPIIText(text string) (ClassResult, error) {
604692
cText := C.CString(text)

0 commit comments

Comments
 (0)