@@ -6,6 +6,7 @@ use std::sync::Arc;
66use std:: sync:: Mutex ;
77
88pub mod modernbert;
9+ pub mod unified_classifier;
910
1011// Re-export ModernBERT functions and structures
1112pub use modernbert:: {
@@ -14,6 +15,12 @@ pub use modernbert::{
1415 init_modernbert_pii_classifier, ModernBertClassificationResult ,
1516} ;
1617
18+ // Re-export unified classifier functions and structures
19+ pub use unified_classifier:: {
20+ get_unified_classifier, BatchClassificationResult , IntentResult , PIIResult , SecurityResult ,
21+ UnifiedClassificationResult , UnifiedClassifier , UNIFIED_CLASSIFIER ,
22+ } ;
23+
1724use anyhow:: { Error as E , Result } ;
1825use candle_core:: { DType , Device , Tensor } ;
1926use candle_nn:: { Linear , VarBuilder } ;
@@ -1177,3 +1184,347 @@ pub extern "C" fn classify_jailbreak_text(text: *const c_char) -> Classification
11771184 }
11781185 }
11791186}
1187+
1188+ // ================================================================================================
1189+ // UNIFIED CLASSIFIER C INTERFACE
1190+ // ================================================================================================
1191+
1192+ /// C-compatible structure for unified batch results
1193+ #[ repr( C ) ]
1194+ pub struct UnifiedBatchResult {
1195+ pub intent_results : * mut CIntentResult ,
1196+ pub pii_results : * mut CPIIResult ,
1197+ pub security_results : * mut CSecurityResult ,
1198+ pub batch_size : i32 ,
1199+ pub error : bool ,
1200+ pub error_message : * mut c_char ,
1201+ }
1202+
1203+ /// C-compatible intent result
1204+ #[ repr( C ) ]
1205+ pub struct CIntentResult {
1206+ pub category : * mut c_char ,
1207+ pub confidence : f32 ,
1208+ pub probabilities : * mut f32 ,
1209+ pub num_probabilities : i32 ,
1210+ }
1211+
1212+ /// C-compatible PII result
1213+ #[ repr( C ) ]
1214+ pub struct CPIIResult {
1215+ pub has_pii : bool ,
1216+ pub pii_types : * mut * mut c_char ,
1217+ pub num_pii_types : i32 ,
1218+ pub confidence : f32 ,
1219+ }
1220+
1221+ /// C-compatible security result
1222+ #[ repr( C ) ]
1223+ pub struct CSecurityResult {
1224+ pub is_jailbreak : bool ,
1225+ pub threat_type : * mut c_char ,
1226+ pub confidence : f32 ,
1227+ }
1228+
1229+ impl UnifiedBatchResult {
1230+ /// Create an error result
1231+ fn error ( message : & str ) -> Self {
1232+ let error_msg =
1233+ CString :: new ( message) . unwrap_or_else ( |_| CString :: new ( "Unknown error" ) . unwrap ( ) ) ;
1234+ Self {
1235+ intent_results : std:: ptr:: null_mut ( ) ,
1236+ pii_results : std:: ptr:: null_mut ( ) ,
1237+ security_results : std:: ptr:: null_mut ( ) ,
1238+ batch_size : 0 ,
1239+ error : true ,
1240+ error_message : error_msg. into_raw ( ) ,
1241+ }
1242+ }
1243+
1244+ /// Convert from Rust BatchClassificationResult to C-compatible structure
1245+ fn from_batch_result ( result : BatchClassificationResult ) -> Self {
1246+ let batch_size = result. batch_size as i32 ;
1247+
1248+ // Convert intent results
1249+ let intent_results = result
1250+ . intent_results
1251+ . into_iter ( )
1252+ . map ( |r| {
1253+ let probs_len = r. probabilities . len ( ) ;
1254+ CIntentResult {
1255+ category : CString :: new ( r. category ) . unwrap ( ) . into_raw ( ) ,
1256+ confidence : r. confidence ,
1257+ probabilities : {
1258+ let mut probs = r. probabilities . into_boxed_slice ( ) ;
1259+ let ptr = probs. as_mut_ptr ( ) ;
1260+ std:: mem:: forget ( probs) ;
1261+ ptr
1262+ } ,
1263+ num_probabilities : probs_len as i32 ,
1264+ }
1265+ } )
1266+ . collect :: < Vec < _ > > ( )
1267+ . into_boxed_slice ( ) ;
1268+ let intent_ptr = Box :: into_raw ( intent_results) as * mut CIntentResult ;
1269+
1270+ // Convert PII results
1271+ let pii_results = result
1272+ . pii_results
1273+ . into_iter ( )
1274+ . map ( |r| {
1275+ let types_len = r. pii_types . len ( ) ;
1276+ CPIIResult {
1277+ has_pii : r. has_pii ,
1278+ pii_types : {
1279+ let types: Vec < * mut c_char > = r
1280+ . pii_types
1281+ . into_iter ( )
1282+ . map ( |t| CString :: new ( t) . unwrap ( ) . into_raw ( ) )
1283+ . collect ( ) ;
1284+ let mut types_box = types. into_boxed_slice ( ) ;
1285+ let ptr = types_box. as_mut_ptr ( ) ;
1286+ std:: mem:: forget ( types_box) ;
1287+ ptr
1288+ } ,
1289+ num_pii_types : types_len as i32 ,
1290+ confidence : r. confidence ,
1291+ }
1292+ } )
1293+ . collect :: < Vec < _ > > ( )
1294+ . into_boxed_slice ( ) ;
1295+ let pii_ptr = Box :: into_raw ( pii_results) as * mut CPIIResult ;
1296+
1297+ // Convert security results
1298+ let security_results = result
1299+ . security_results
1300+ . into_iter ( )
1301+ . map ( |r| CSecurityResult {
1302+ is_jailbreak : r. is_jailbreak ,
1303+ threat_type : CString :: new ( r. threat_type ) . unwrap ( ) . into_raw ( ) ,
1304+ confidence : r. confidence ,
1305+ } )
1306+ . collect :: < Vec < _ > > ( )
1307+ . into_boxed_slice ( ) ;
1308+ let security_ptr = Box :: into_raw ( security_results) as * mut CSecurityResult ;
1309+
1310+ Self {
1311+ intent_results : intent_ptr,
1312+ pii_results : pii_ptr,
1313+ security_results : security_ptr,
1314+ batch_size,
1315+ error : false ,
1316+ error_message : std:: ptr:: null_mut ( ) ,
1317+ }
1318+ }
1319+ }
1320+
1321+ /// Initialize unified classifier (called from Go)
1322+ #[ no_mangle]
1323+ pub extern "C" fn init_unified_classifier_c (
1324+ modernbert_path : * const c_char ,
1325+ intent_head_path : * const c_char ,
1326+ pii_head_path : * const c_char ,
1327+ security_head_path : * const c_char ,
1328+ intent_labels : * const * const c_char ,
1329+ intent_labels_count : usize ,
1330+ pii_labels : * const * const c_char ,
1331+ pii_labels_count : usize ,
1332+ security_labels : * const * const c_char ,
1333+ security_labels_count : usize ,
1334+ use_cpu : bool ,
1335+ ) -> bool {
1336+ let modernbert_path = unsafe {
1337+ match CStr :: from_ptr ( modernbert_path) . to_str ( ) {
1338+ Ok ( s) => s,
1339+ Err ( _) => return false ,
1340+ }
1341+ } ;
1342+
1343+ let intent_head_path = unsafe {
1344+ match CStr :: from_ptr ( intent_head_path) . to_str ( ) {
1345+ Ok ( s) => s,
1346+ Err ( _) => return false ,
1347+ }
1348+ } ;
1349+
1350+ let pii_head_path = unsafe {
1351+ match CStr :: from_ptr ( pii_head_path) . to_str ( ) {
1352+ Ok ( s) => s,
1353+ Err ( _) => return false ,
1354+ }
1355+ } ;
1356+
1357+ let security_head_path = unsafe {
1358+ match CStr :: from_ptr ( security_head_path) . to_str ( ) {
1359+ Ok ( s) => s,
1360+ Err ( _) => return false ,
1361+ }
1362+ } ;
1363+
1364+ // Convert C string arrays to Rust Vec<String>
1365+ let intent_labels_vec = unsafe {
1366+ std:: slice:: from_raw_parts ( intent_labels, intent_labels_count)
1367+ . iter ( )
1368+ . map ( |& ptr| CStr :: from_ptr ( ptr) . to_str ( ) . unwrap_or ( "" ) . to_string ( ) )
1369+ . collect :: < Vec < String > > ( )
1370+ } ;
1371+
1372+ let pii_labels_vec = unsafe {
1373+ std:: slice:: from_raw_parts ( pii_labels, pii_labels_count)
1374+ . iter ( )
1375+ . map ( |& ptr| CStr :: from_ptr ( ptr) . to_str ( ) . unwrap_or ( "" ) . to_string ( ) )
1376+ . collect :: < Vec < String > > ( )
1377+ } ;
1378+
1379+ let security_labels_vec = unsafe {
1380+ std:: slice:: from_raw_parts ( security_labels, security_labels_count)
1381+ . iter ( )
1382+ . map ( |& ptr| CStr :: from_ptr ( ptr) . to_str ( ) . unwrap_or ( "" ) . to_string ( ) )
1383+ . collect :: < Vec < String > > ( )
1384+ } ;
1385+
1386+ match UnifiedClassifier :: new (
1387+ modernbert_path,
1388+ intent_head_path,
1389+ pii_head_path,
1390+ security_head_path,
1391+ intent_labels_vec,
1392+ pii_labels_vec,
1393+ security_labels_vec,
1394+ use_cpu,
1395+ ) {
1396+ Ok ( classifier) => {
1397+ let mut global_classifier = UNIFIED_CLASSIFIER . lock ( ) . unwrap ( ) ;
1398+ * global_classifier = Some ( classifier) ;
1399+ true
1400+ }
1401+ Err ( e) => {
1402+ eprintln ! ( "Failed to initialize unified classifier: {e}" ) ;
1403+ false
1404+ }
1405+ }
1406+ }
1407+
1408+ /// Classify batch of texts using unified classifier (called from Go)
1409+ #[ no_mangle]
1410+ pub extern "C" fn classify_unified_batch (
1411+ texts_ptr : * const * const c_char ,
1412+ num_texts : i32 ,
1413+ ) -> UnifiedBatchResult {
1414+ if texts_ptr. is_null ( ) || num_texts <= 0 {
1415+ return UnifiedBatchResult :: error ( "Invalid input parameters" ) ;
1416+ }
1417+
1418+ // Convert C strings to Rust strings
1419+ let texts = unsafe {
1420+ std:: slice:: from_raw_parts ( texts_ptr, num_texts as usize )
1421+ . iter ( )
1422+ . map ( |& ptr| {
1423+ if ptr. is_null ( ) {
1424+ Err ( "Null text pointer" )
1425+ } else {
1426+ CStr :: from_ptr ( ptr) . to_str ( ) . map_err ( |_| "Invalid UTF-8" )
1427+ }
1428+ } )
1429+ . collect :: < Result < Vec < _ > , _ > > ( )
1430+ } ;
1431+
1432+ let texts = match texts {
1433+ Ok ( t) => t,
1434+ Err ( e) => return UnifiedBatchResult :: error ( e) ,
1435+ } ;
1436+
1437+ // Get unified classifier and perform batch classification
1438+ match get_unified_classifier ( ) {
1439+ Ok ( classifier_guard) => match classifier_guard. as_ref ( ) {
1440+ Some ( classifier) => match classifier. classify_batch ( & texts) {
1441+ Ok ( result) => UnifiedBatchResult :: from_batch_result ( result) ,
1442+ Err ( e) => UnifiedBatchResult :: error ( & format ! ( "Classification failed: {}" , e) ) ,
1443+ } ,
1444+ None => UnifiedBatchResult :: error ( "Unified classifier not initialized" ) ,
1445+ } ,
1446+ Err ( e) => UnifiedBatchResult :: error ( & format ! ( "Failed to get classifier: {}" , e) ) ,
1447+ }
1448+ }
1449+
1450+ /// Free unified batch result memory (called from Go)
1451+ #[ no_mangle]
1452+ pub extern "C" fn free_unified_batch_result ( result : UnifiedBatchResult ) {
1453+ if result. error {
1454+ if !result. error_message . is_null ( ) {
1455+ unsafe {
1456+ let _ = CString :: from_raw ( result. error_message ) ;
1457+ }
1458+ }
1459+ return ;
1460+ }
1461+
1462+ let batch_size = result. batch_size as usize ;
1463+
1464+ // Free intent results
1465+ if !result. intent_results . is_null ( ) {
1466+ unsafe {
1467+ let intent_slice = std:: slice:: from_raw_parts_mut ( result. intent_results , batch_size) ;
1468+ for intent in intent_slice {
1469+ if !intent. category . is_null ( ) {
1470+ let _ = CString :: from_raw ( intent. category ) ;
1471+ }
1472+ if !intent. probabilities . is_null ( ) {
1473+ let _ = Vec :: from_raw_parts (
1474+ intent. probabilities ,
1475+ intent. num_probabilities as usize ,
1476+ intent. num_probabilities as usize ,
1477+ ) ;
1478+ }
1479+ }
1480+ let _ = Box :: from_raw ( std:: slice:: from_raw_parts_mut (
1481+ result. intent_results ,
1482+ batch_size,
1483+ ) ) ;
1484+ }
1485+ }
1486+
1487+ // Free PII results
1488+ if !result. pii_results . is_null ( ) {
1489+ unsafe {
1490+ let pii_slice = std:: slice:: from_raw_parts_mut ( result. pii_results , batch_size) ;
1491+ for pii in pii_slice {
1492+ if !pii. pii_types . is_null ( ) {
1493+ let types_slice =
1494+ std:: slice:: from_raw_parts_mut ( pii. pii_types , pii. num_pii_types as usize ) ;
1495+ for & mut type_ptr in types_slice {
1496+ if !type_ptr. is_null ( ) {
1497+ let _ = CString :: from_raw ( type_ptr) ;
1498+ }
1499+ }
1500+ let _ = Vec :: from_raw_parts (
1501+ pii. pii_types ,
1502+ pii. num_pii_types as usize ,
1503+ pii. num_pii_types as usize ,
1504+ ) ;
1505+ }
1506+ }
1507+ let _ = Box :: from_raw ( std:: slice:: from_raw_parts_mut (
1508+ result. pii_results ,
1509+ batch_size,
1510+ ) ) ;
1511+ }
1512+ }
1513+
1514+ // Free security results
1515+ if !result. security_results . is_null ( ) {
1516+ unsafe {
1517+ let security_slice =
1518+ std:: slice:: from_raw_parts_mut ( result. security_results , batch_size) ;
1519+ for security in security_slice {
1520+ if !security. threat_type . is_null ( ) {
1521+ let _ = CString :: from_raw ( security. threat_type ) ;
1522+ }
1523+ }
1524+ let _ = Box :: from_raw ( std:: slice:: from_raw_parts_mut (
1525+ result. security_results ,
1526+ batch_size,
1527+ ) ) ;
1528+ }
1529+ }
1530+ }
0 commit comments