@@ -6,6 +6,7 @@ use std::sync::Arc;
66use std:: sync:: Mutex ;
77
88pub mod modernbert;
9+ pub mod unified_classifier;
910
1011// Re-export ModernBERT functions and structures
1112pub use modernbert:: {
@@ -14,6 +15,12 @@ pub use modernbert::{
1415 init_modernbert_pii_classifier, ModernBertClassificationResult ,
1516} ;
1617
18+ // Re-export unified classifier functions and structures
19+ pub use unified_classifier:: {
20+ get_unified_classifier, BatchClassificationResult , IntentResult , PIIResult , SecurityResult ,
21+ UnifiedClassificationResult , UnifiedClassifier , UNIFIED_CLASSIFIER ,
22+ } ;
23+
1724use anyhow:: { Error as E , Result } ;
1825use candle_core:: { DType , Device , Tensor } ;
1926use candle_nn:: { Linear , VarBuilder } ;
@@ -1312,3 +1319,347 @@ pub extern "C" fn classify_jailbreak_text(text: *const c_char) -> Classification
13121319 }
13131320 }
13141321}
1322+
1323+ // ================================================================================================
1324+ // UNIFIED CLASSIFIER C INTERFACE
1325+ // ================================================================================================
1326+
1327+ /// C-compatible structure for unified batch results
1328+ #[ repr( C ) ]
1329+ pub struct UnifiedBatchResult {
1330+ pub intent_results : * mut CIntentResult ,
1331+ pub pii_results : * mut CPIIResult ,
1332+ pub security_results : * mut CSecurityResult ,
1333+ pub batch_size : i32 ,
1334+ pub error : bool ,
1335+ pub error_message : * mut c_char ,
1336+ }
1337+
1338+ /// C-compatible intent result
1339+ #[ repr( C ) ]
1340+ pub struct CIntentResult {
1341+ pub category : * mut c_char ,
1342+ pub confidence : f32 ,
1343+ pub probabilities : * mut f32 ,
1344+ pub num_probabilities : i32 ,
1345+ }
1346+
1347+ /// C-compatible PII result
1348+ #[ repr( C ) ]
1349+ pub struct CPIIResult {
1350+ pub has_pii : bool ,
1351+ pub pii_types : * mut * mut c_char ,
1352+ pub num_pii_types : i32 ,
1353+ pub confidence : f32 ,
1354+ }
1355+
1356+ /// C-compatible security result
1357+ #[ repr( C ) ]
1358+ pub struct CSecurityResult {
1359+ pub is_jailbreak : bool ,
1360+ pub threat_type : * mut c_char ,
1361+ pub confidence : f32 ,
1362+ }
1363+
1364+ impl UnifiedBatchResult {
1365+ /// Create an error result
1366+ fn error ( message : & str ) -> Self {
1367+ let error_msg =
1368+ CString :: new ( message) . unwrap_or_else ( |_| CString :: new ( "Unknown error" ) . unwrap ( ) ) ;
1369+ Self {
1370+ intent_results : std:: ptr:: null_mut ( ) ,
1371+ pii_results : std:: ptr:: null_mut ( ) ,
1372+ security_results : std:: ptr:: null_mut ( ) ,
1373+ batch_size : 0 ,
1374+ error : true ,
1375+ error_message : error_msg. into_raw ( ) ,
1376+ }
1377+ }
1378+
1379+ /// Convert from Rust BatchClassificationResult to C-compatible structure
1380+ fn from_batch_result ( result : BatchClassificationResult ) -> Self {
1381+ let batch_size = result. batch_size as i32 ;
1382+
1383+ // Convert intent results
1384+ let intent_results = result
1385+ . intent_results
1386+ . into_iter ( )
1387+ . map ( |r| {
1388+ let probs_len = r. probabilities . len ( ) ;
1389+ CIntentResult {
1390+ category : CString :: new ( r. category ) . unwrap ( ) . into_raw ( ) ,
1391+ confidence : r. confidence ,
1392+ probabilities : {
1393+ let mut probs = r. probabilities . into_boxed_slice ( ) ;
1394+ let ptr = probs. as_mut_ptr ( ) ;
1395+ std:: mem:: forget ( probs) ;
1396+ ptr
1397+ } ,
1398+ num_probabilities : probs_len as i32 ,
1399+ }
1400+ } )
1401+ . collect :: < Vec < _ > > ( )
1402+ . into_boxed_slice ( ) ;
1403+ let intent_ptr = Box :: into_raw ( intent_results) as * mut CIntentResult ;
1404+
1405+ // Convert PII results
1406+ let pii_results = result
1407+ . pii_results
1408+ . into_iter ( )
1409+ . map ( |r| {
1410+ let types_len = r. pii_types . len ( ) ;
1411+ CPIIResult {
1412+ has_pii : r. has_pii ,
1413+ pii_types : {
1414+ let types: Vec < * mut c_char > = r
1415+ . pii_types
1416+ . into_iter ( )
1417+ . map ( |t| CString :: new ( t) . unwrap ( ) . into_raw ( ) )
1418+ . collect ( ) ;
1419+ let mut types_box = types. into_boxed_slice ( ) ;
1420+ let ptr = types_box. as_mut_ptr ( ) ;
1421+ std:: mem:: forget ( types_box) ;
1422+ ptr
1423+ } ,
1424+ num_pii_types : types_len as i32 ,
1425+ confidence : r. confidence ,
1426+ }
1427+ } )
1428+ . collect :: < Vec < _ > > ( )
1429+ . into_boxed_slice ( ) ;
1430+ let pii_ptr = Box :: into_raw ( pii_results) as * mut CPIIResult ;
1431+
1432+ // Convert security results
1433+ let security_results = result
1434+ . security_results
1435+ . into_iter ( )
1436+ . map ( |r| CSecurityResult {
1437+ is_jailbreak : r. is_jailbreak ,
1438+ threat_type : CString :: new ( r. threat_type ) . unwrap ( ) . into_raw ( ) ,
1439+ confidence : r. confidence ,
1440+ } )
1441+ . collect :: < Vec < _ > > ( )
1442+ . into_boxed_slice ( ) ;
1443+ let security_ptr = Box :: into_raw ( security_results) as * mut CSecurityResult ;
1444+
1445+ Self {
1446+ intent_results : intent_ptr,
1447+ pii_results : pii_ptr,
1448+ security_results : security_ptr,
1449+ batch_size,
1450+ error : false ,
1451+ error_message : std:: ptr:: null_mut ( ) ,
1452+ }
1453+ }
1454+ }
1455+
1456+ /// Initialize unified classifier (called from Go)
1457+ #[ no_mangle]
1458+ pub extern "C" fn init_unified_classifier_c (
1459+ modernbert_path : * const c_char ,
1460+ intent_head_path : * const c_char ,
1461+ pii_head_path : * const c_char ,
1462+ security_head_path : * const c_char ,
1463+ intent_labels : * const * const c_char ,
1464+ intent_labels_count : usize ,
1465+ pii_labels : * const * const c_char ,
1466+ pii_labels_count : usize ,
1467+ security_labels : * const * const c_char ,
1468+ security_labels_count : usize ,
1469+ use_cpu : bool ,
1470+ ) -> bool {
1471+ let modernbert_path = unsafe {
1472+ match CStr :: from_ptr ( modernbert_path) . to_str ( ) {
1473+ Ok ( s) => s,
1474+ Err ( _) => return false ,
1475+ }
1476+ } ;
1477+
1478+ let intent_head_path = unsafe {
1479+ match CStr :: from_ptr ( intent_head_path) . to_str ( ) {
1480+ Ok ( s) => s,
1481+ Err ( _) => return false ,
1482+ }
1483+ } ;
1484+
1485+ let pii_head_path = unsafe {
1486+ match CStr :: from_ptr ( pii_head_path) . to_str ( ) {
1487+ Ok ( s) => s,
1488+ Err ( _) => return false ,
1489+ }
1490+ } ;
1491+
1492+ let security_head_path = unsafe {
1493+ match CStr :: from_ptr ( security_head_path) . to_str ( ) {
1494+ Ok ( s) => s,
1495+ Err ( _) => return false ,
1496+ }
1497+ } ;
1498+
1499+ // Convert C string arrays to Rust Vec<String>
1500+ let intent_labels_vec = unsafe {
1501+ std:: slice:: from_raw_parts ( intent_labels, intent_labels_count)
1502+ . iter ( )
1503+ . map ( |& ptr| CStr :: from_ptr ( ptr) . to_str ( ) . unwrap_or ( "" ) . to_string ( ) )
1504+ . collect :: < Vec < String > > ( )
1505+ } ;
1506+
1507+ let pii_labels_vec = unsafe {
1508+ std:: slice:: from_raw_parts ( pii_labels, pii_labels_count)
1509+ . iter ( )
1510+ . map ( |& ptr| CStr :: from_ptr ( ptr) . to_str ( ) . unwrap_or ( "" ) . to_string ( ) )
1511+ . collect :: < Vec < String > > ( )
1512+ } ;
1513+
1514+ let security_labels_vec = unsafe {
1515+ std:: slice:: from_raw_parts ( security_labels, security_labels_count)
1516+ . iter ( )
1517+ . map ( |& ptr| CStr :: from_ptr ( ptr) . to_str ( ) . unwrap_or ( "" ) . to_string ( ) )
1518+ . collect :: < Vec < String > > ( )
1519+ } ;
1520+
1521+ match UnifiedClassifier :: new (
1522+ modernbert_path,
1523+ intent_head_path,
1524+ pii_head_path,
1525+ security_head_path,
1526+ intent_labels_vec,
1527+ pii_labels_vec,
1528+ security_labels_vec,
1529+ use_cpu,
1530+ ) {
1531+ Ok ( classifier) => {
1532+ let mut global_classifier = UNIFIED_CLASSIFIER . lock ( ) . unwrap ( ) ;
1533+ * global_classifier = Some ( classifier) ;
1534+ true
1535+ }
1536+ Err ( e) => {
1537+ eprintln ! ( "Failed to initialize unified classifier: {e}" ) ;
1538+ false
1539+ }
1540+ }
1541+ }
1542+
1543+ /// Classify batch of texts using unified classifier (called from Go)
1544+ #[ no_mangle]
1545+ pub extern "C" fn classify_unified_batch (
1546+ texts_ptr : * const * const c_char ,
1547+ num_texts : i32 ,
1548+ ) -> UnifiedBatchResult {
1549+ if texts_ptr. is_null ( ) || num_texts <= 0 {
1550+ return UnifiedBatchResult :: error ( "Invalid input parameters" ) ;
1551+ }
1552+
1553+ // Convert C strings to Rust strings
1554+ let texts = unsafe {
1555+ std:: slice:: from_raw_parts ( texts_ptr, num_texts as usize )
1556+ . iter ( )
1557+ . map ( |& ptr| {
1558+ if ptr. is_null ( ) {
1559+ Err ( "Null text pointer" )
1560+ } else {
1561+ CStr :: from_ptr ( ptr) . to_str ( ) . map_err ( |_| "Invalid UTF-8" )
1562+ }
1563+ } )
1564+ . collect :: < Result < Vec < _ > , _ > > ( )
1565+ } ;
1566+
1567+ let texts = match texts {
1568+ Ok ( t) => t,
1569+ Err ( e) => return UnifiedBatchResult :: error ( e) ,
1570+ } ;
1571+
1572+ // Get unified classifier and perform batch classification
1573+ match get_unified_classifier ( ) {
1574+ Ok ( classifier_guard) => match classifier_guard. as_ref ( ) {
1575+ Some ( classifier) => match classifier. classify_batch ( & texts) {
1576+ Ok ( result) => UnifiedBatchResult :: from_batch_result ( result) ,
1577+ Err ( e) => UnifiedBatchResult :: error ( & format ! ( "Classification failed: {}" , e) ) ,
1578+ } ,
1579+ None => UnifiedBatchResult :: error ( "Unified classifier not initialized" ) ,
1580+ } ,
1581+ Err ( e) => UnifiedBatchResult :: error ( & format ! ( "Failed to get classifier: {}" , e) ) ,
1582+ }
1583+ }
1584+
1585+ /// Free unified batch result memory (called from Go)
1586+ #[ no_mangle]
1587+ pub extern "C" fn free_unified_batch_result ( result : UnifiedBatchResult ) {
1588+ if result. error {
1589+ if !result. error_message . is_null ( ) {
1590+ unsafe {
1591+ let _ = CString :: from_raw ( result. error_message ) ;
1592+ }
1593+ }
1594+ return ;
1595+ }
1596+
1597+ let batch_size = result. batch_size as usize ;
1598+
1599+ // Free intent results
1600+ if !result. intent_results . is_null ( ) {
1601+ unsafe {
1602+ let intent_slice = std:: slice:: from_raw_parts_mut ( result. intent_results , batch_size) ;
1603+ for intent in intent_slice {
1604+ if !intent. category . is_null ( ) {
1605+ let _ = CString :: from_raw ( intent. category ) ;
1606+ }
1607+ if !intent. probabilities . is_null ( ) {
1608+ let _ = Vec :: from_raw_parts (
1609+ intent. probabilities ,
1610+ intent. num_probabilities as usize ,
1611+ intent. num_probabilities as usize ,
1612+ ) ;
1613+ }
1614+ }
1615+ let _ = Box :: from_raw ( std:: slice:: from_raw_parts_mut (
1616+ result. intent_results ,
1617+ batch_size,
1618+ ) ) ;
1619+ }
1620+ }
1621+
1622+ // Free PII results
1623+ if !result. pii_results . is_null ( ) {
1624+ unsafe {
1625+ let pii_slice = std:: slice:: from_raw_parts_mut ( result. pii_results , batch_size) ;
1626+ for pii in pii_slice {
1627+ if !pii. pii_types . is_null ( ) {
1628+ let types_slice =
1629+ std:: slice:: from_raw_parts_mut ( pii. pii_types , pii. num_pii_types as usize ) ;
1630+ for & mut type_ptr in types_slice {
1631+ if !type_ptr. is_null ( ) {
1632+ let _ = CString :: from_raw ( type_ptr) ;
1633+ }
1634+ }
1635+ let _ = Vec :: from_raw_parts (
1636+ pii. pii_types ,
1637+ pii. num_pii_types as usize ,
1638+ pii. num_pii_types as usize ,
1639+ ) ;
1640+ }
1641+ }
1642+ let _ = Box :: from_raw ( std:: slice:: from_raw_parts_mut (
1643+ result. pii_results ,
1644+ batch_size,
1645+ ) ) ;
1646+ }
1647+ }
1648+
1649+ // Free security results
1650+ if !result. security_results . is_null ( ) {
1651+ unsafe {
1652+ let security_slice =
1653+ std:: slice:: from_raw_parts_mut ( result. security_results , batch_size) ;
1654+ for security in security_slice {
1655+ if !security. threat_type . is_null ( ) {
1656+ let _ = CString :: from_raw ( security. threat_type ) ;
1657+ }
1658+ }
1659+ let _ = Box :: from_raw ( std:: slice:: from_raw_parts_mut (
1660+ result. security_results ,
1661+ batch_size,
1662+ ) ) ;
1663+ }
1664+ }
1665+ }
0 commit comments