@@ -78,22 +78,42 @@ typedef struct {
7878 float confidence;
7979} ClassificationResult;
8080
81+ // Classification result with full probability distribution structure
82+ typedef struct {
83+ int class;
84+ float confidence;
85+ float* probabilities;
86+ int num_classes;
87+ } ClassificationResultWithProbs;
88+
8189// ModernBERT Classification result structure
8290typedef struct {
8391 int class;
8492 float confidence;
8593} ModernBertClassificationResult;
8694
95+ // ModernBERT Classification result with full probability distribution structure
96+ typedef struct {
97+ int class;
98+ float confidence;
99+ float* probabilities;
100+ int num_classes;
101+ } ModernBertClassificationResultWithProbs;
102+
87103extern SimilarityResult find_most_similar(const char* query, const char** candidates, int num_candidates, int max_length);
88104extern EmbeddingResult get_text_embedding(const char* text, int max_length);
89105extern TokenizationResult tokenize_text(const char* text, int max_length);
90106extern void free_cstring(char* s);
91107extern void free_embedding(float* data, int length);
92108extern void free_tokenization_result(TokenizationResult result);
93109extern ClassificationResult classify_text(const char* text);
110+ extern ClassificationResultWithProbs classify_text_with_probabilities(const char* text);
111+ extern void free_probabilities(float* probabilities, int num_classes);
94112extern ClassificationResult classify_pii_text(const char* text);
95113extern ClassificationResult classify_jailbreak_text(const char* text);
96114extern ModernBertClassificationResult classify_modernbert_text(const char* text);
115+ extern ModernBertClassificationResultWithProbs classify_modernbert_text_with_probabilities(const char* text);
116+ extern void free_modernbert_probabilities(float* probabilities, int num_classes);
97117extern ModernBertClassificationResult classify_modernbert_pii_text(const char* text);
98118extern ModernBertClassificationResult classify_modernbert_jailbreak_text(const char* text);
99119*/
@@ -137,6 +157,14 @@ type ClassResult struct {
137157 Confidence float32 // Confidence score
138158}
139159
160+ // ClassResultWithProbs represents the result of a text classification with full probability distribution
161+ type ClassResultWithProbs struct {
162+ Class int // Class index
163+ Confidence float32 // Confidence score
164+ Probabilities []float32 // Full probability distribution
165+ NumClasses int // Number of classes
166+ }
167+
140168// TokenEntity represents a single detected entity in token classification
141169type TokenEntity struct {
142170 EntityType string // Type of entity (e.g., "PERSON", "EMAIL", "PHONE")
@@ -452,6 +480,36 @@ func ClassifyText(text string) (ClassResult, error) {
452480 }, nil
453481}
454482
483+ // ClassifyTextWithProbabilities classifies the provided text and returns the predicted class, confidence, and full probability distribution
484+ func ClassifyTextWithProbabilities (text string ) (ClassResultWithProbs , error ) {
485+ cText := C .CString (text )
486+ defer C .free (unsafe .Pointer (cText ))
487+
488+ result := C .classify_text_with_probabilities (cText )
489+
490+ if result .class < 0 {
491+ return ClassResultWithProbs {}, fmt .Errorf ("failed to classify text with probabilities" )
492+ }
493+
494+ // Convert C array to Go slice
495+ probabilities := make ([]float32 , int (result .num_classes ))
496+ if result .probabilities != nil && result .num_classes > 0 {
497+ probsSlice := (* [1 << 30 ]C.float )(unsafe .Pointer (result .probabilities ))[:result .num_classes :result .num_classes ]
498+ for i , prob := range probsSlice {
499+ probabilities [i ] = float32 (prob )
500+ }
501+ // Free the C-allocated memory
502+ C .free_probabilities (result .probabilities , result .num_classes )
503+ }
504+
505+ return ClassResultWithProbs {
506+ Class : int (result .class ),
507+ Confidence : float32 (result .confidence ),
508+ Probabilities : probabilities ,
509+ NumClasses : int (result .num_classes ),
510+ }, nil
511+ }
512+
455513// ClassifyPIIText classifies the provided text for PII detection and returns the predicted class and confidence
456514func ClassifyPIIText (text string ) (ClassResult , error ) {
457515 cText := C .CString (text )
@@ -599,6 +657,36 @@ func ClassifyModernBertText(text string) (ClassResult, error) {
599657 }, nil
600658}
601659
660+ // ClassifyModernBertTextWithProbabilities classifies the provided text using ModernBERT and returns the predicted class, confidence, and full probability distribution
661+ func ClassifyModernBertTextWithProbabilities (text string ) (ClassResultWithProbs , error ) {
662+ cText := C .CString (text )
663+ defer C .free (unsafe .Pointer (cText ))
664+
665+ result := C .classify_modernbert_text_with_probabilities (cText )
666+
667+ if result .class < 0 {
668+ return ClassResultWithProbs {}, fmt .Errorf ("failed to classify text with probabilities using ModernBERT" )
669+ }
670+
671+ // Convert C array to Go slice
672+ probabilities := make ([]float32 , int (result .num_classes ))
673+ if result .probabilities != nil && result .num_classes > 0 {
674+ probsSlice := (* [1 << 30 ]C.float )(unsafe .Pointer (result .probabilities ))[:result .num_classes :result .num_classes ]
675+ for i , prob := range probsSlice {
676+ probabilities [i ] = float32 (prob )
677+ }
678+ // Free the C-allocated memory
679+ C .free_modernbert_probabilities (result .probabilities , result .num_classes )
680+ }
681+
682+ return ClassResultWithProbs {
683+ Class : int (result .class ),
684+ Confidence : float32 (result .confidence ),
685+ Probabilities : probabilities ,
686+ NumClasses : int (result .num_classes ),
687+ }, nil
688+ }
689+
602690// ClassifyModernBertPIIText classifies the provided text for PII detection using ModernBERT and returns the predicted class and confidence
603691func ClassifyModernBertPIIText (text string ) (ClassResult , error ) {
604692 cText := C .CString (text )
0 commit comments