@@ -6,6 +6,7 @@ use std::sync::Arc;
66use  std:: sync:: Mutex ; 
77
88pub  mod  modernbert; 
9+ pub  mod  unified_classifier; 
910
1011// Re-export ModernBERT functions and structures 
1112pub  use  modernbert:: { 
@@ -14,6 +15,12 @@ pub use modernbert::{
1415    init_modernbert_pii_classifier,  ModernBertClassificationResult , 
1516} ; 
1617
18+ // Re-export unified classifier functions and structures 
19+ pub  use  unified_classifier:: { 
20+     get_unified_classifier,  BatchClassificationResult ,  IntentResult ,  PIIResult ,  SecurityResult , 
21+     UnifiedClassificationResult ,  UnifiedClassifier ,  UNIFIED_CLASSIFIER , 
22+ } ; 
23+ 
1724use  anyhow:: { Error  as  E ,  Result } ; 
1825use  candle_core:: { DType ,  Device ,  Tensor } ; 
1926use  candle_nn:: { Linear ,  VarBuilder } ; 
@@ -1312,3 +1319,347 @@ pub extern "C" fn classify_jailbreak_text(text: *const c_char) -> Classification
13121319        } 
13131320    } 
13141321} 
1322+ 
1323+ // ================================================================================================ 
1324+ // UNIFIED CLASSIFIER C INTERFACE 
1325+ // ================================================================================================ 
1326+ 
1327+ /// C-compatible structure for unified batch results 
1328+ #[ repr( C ) ]  
1329+ pub  struct  UnifiedBatchResult  { 
1330+     pub  intent_results :  * mut  CIntentResult , 
1331+     pub  pii_results :  * mut  CPIIResult , 
1332+     pub  security_results :  * mut  CSecurityResult , 
1333+     pub  batch_size :  i32 , 
1334+     pub  error :  bool , 
1335+     pub  error_message :  * mut  c_char , 
1336+ } 
1337+ 
1338+ /// C-compatible intent result 
1339+ #[ repr( C ) ]  
1340+ pub  struct  CIntentResult  { 
1341+     pub  category :  * mut  c_char , 
1342+     pub  confidence :  f32 , 
1343+     pub  probabilities :  * mut  f32 , 
1344+     pub  num_probabilities :  i32 , 
1345+ } 
1346+ 
1347+ /// C-compatible PII result 
1348+ #[ repr( C ) ]  
1349+ pub  struct  CPIIResult  { 
1350+     pub  has_pii :  bool , 
1351+     pub  pii_types :  * mut  * mut  c_char , 
1352+     pub  num_pii_types :  i32 , 
1353+     pub  confidence :  f32 , 
1354+ } 
1355+ 
1356+ /// C-compatible security result 
1357+ #[ repr( C ) ]  
1358+ pub  struct  CSecurityResult  { 
1359+     pub  is_jailbreak :  bool , 
1360+     pub  threat_type :  * mut  c_char , 
1361+     pub  confidence :  f32 , 
1362+ } 
1363+ 
1364+ impl  UnifiedBatchResult  { 
1365+     /// Create an error result 
1366+      fn  error ( message :  & str )  -> Self  { 
1367+         let  error_msg =
1368+             CString :: new ( message) . unwrap_or_else ( |_| CString :: new ( "Unknown error" ) . unwrap ( ) ) ; 
1369+         Self  { 
1370+             intent_results :  std:: ptr:: null_mut ( ) , 
1371+             pii_results :  std:: ptr:: null_mut ( ) , 
1372+             security_results :  std:: ptr:: null_mut ( ) , 
1373+             batch_size :  0 , 
1374+             error :  true , 
1375+             error_message :  error_msg. into_raw ( ) , 
1376+         } 
1377+     } 
1378+ 
1379+     /// Convert from Rust BatchClassificationResult to C-compatible structure 
1380+      fn  from_batch_result ( result :  BatchClassificationResult )  -> Self  { 
1381+         let  batch_size = result. batch_size  as  i32 ; 
1382+ 
1383+         // Convert intent results 
1384+         let  intent_results = result
1385+             . intent_results 
1386+             . into_iter ( ) 
1387+             . map ( |r| { 
1388+                 let  probs_len = r. probabilities . len ( ) ; 
1389+                 CIntentResult  { 
1390+                     category :  CString :: new ( r. category ) . unwrap ( ) . into_raw ( ) , 
1391+                     confidence :  r. confidence , 
1392+                     probabilities :  { 
1393+                         let  mut  probs = r. probabilities . into_boxed_slice ( ) ; 
1394+                         let  ptr = probs. as_mut_ptr ( ) ; 
1395+                         std:: mem:: forget ( probs) ; 
1396+                         ptr
1397+                     } , 
1398+                     num_probabilities :  probs_len as  i32 , 
1399+                 } 
1400+             } ) 
1401+             . collect :: < Vec < _ > > ( ) 
1402+             . into_boxed_slice ( ) ; 
1403+         let  intent_ptr = Box :: into_raw ( intent_results)  as  * mut  CIntentResult ; 
1404+ 
1405+         // Convert PII results 
1406+         let  pii_results = result
1407+             . pii_results 
1408+             . into_iter ( ) 
1409+             . map ( |r| { 
1410+                 let  types_len = r. pii_types . len ( ) ; 
1411+                 CPIIResult  { 
1412+                     has_pii :  r. has_pii , 
1413+                     pii_types :  { 
1414+                         let  types:  Vec < * mut  c_char >  = r
1415+                             . pii_types 
1416+                             . into_iter ( ) 
1417+                             . map ( |t| CString :: new ( t) . unwrap ( ) . into_raw ( ) ) 
1418+                             . collect ( ) ; 
1419+                         let  mut  types_box = types. into_boxed_slice ( ) ; 
1420+                         let  ptr = types_box. as_mut_ptr ( ) ; 
1421+                         std:: mem:: forget ( types_box) ; 
1422+                         ptr
1423+                     } , 
1424+                     num_pii_types :  types_len as  i32 , 
1425+                     confidence :  r. confidence , 
1426+                 } 
1427+             } ) 
1428+             . collect :: < Vec < _ > > ( ) 
1429+             . into_boxed_slice ( ) ; 
1430+         let  pii_ptr = Box :: into_raw ( pii_results)  as  * mut  CPIIResult ; 
1431+ 
1432+         // Convert security results 
1433+         let  security_results = result
1434+             . security_results 
1435+             . into_iter ( ) 
1436+             . map ( |r| CSecurityResult  { 
1437+                 is_jailbreak :  r. is_jailbreak , 
1438+                 threat_type :  CString :: new ( r. threat_type ) . unwrap ( ) . into_raw ( ) , 
1439+                 confidence :  r. confidence , 
1440+             } ) 
1441+             . collect :: < Vec < _ > > ( ) 
1442+             . into_boxed_slice ( ) ; 
1443+         let  security_ptr = Box :: into_raw ( security_results)  as  * mut  CSecurityResult ; 
1444+ 
1445+         Self  { 
1446+             intent_results :  intent_ptr, 
1447+             pii_results :  pii_ptr, 
1448+             security_results :  security_ptr, 
1449+             batch_size, 
1450+             error :  false , 
1451+             error_message :  std:: ptr:: null_mut ( ) , 
1452+         } 
1453+     } 
1454+ } 
1455+ 
1456+ /// Initialize unified classifier (called from Go) 
1457+ #[ no_mangle]  
1458+ pub  extern  "C"  fn  init_unified_classifier_c ( 
1459+     modernbert_path :  * const  c_char , 
1460+     intent_head_path :  * const  c_char , 
1461+     pii_head_path :  * const  c_char , 
1462+     security_head_path :  * const  c_char , 
1463+     intent_labels :  * const  * const  c_char , 
1464+     intent_labels_count :  usize , 
1465+     pii_labels :  * const  * const  c_char , 
1466+     pii_labels_count :  usize , 
1467+     security_labels :  * const  * const  c_char , 
1468+     security_labels_count :  usize , 
1469+     use_cpu :  bool , 
1470+ )  -> bool  { 
1471+     let  modernbert_path = unsafe  { 
1472+         match  CStr :: from_ptr ( modernbert_path) . to_str ( )  { 
1473+             Ok ( s)  => s, 
1474+             Err ( _)  => return  false , 
1475+         } 
1476+     } ; 
1477+ 
1478+     let  intent_head_path = unsafe  { 
1479+         match  CStr :: from_ptr ( intent_head_path) . to_str ( )  { 
1480+             Ok ( s)  => s, 
1481+             Err ( _)  => return  false , 
1482+         } 
1483+     } ; 
1484+ 
1485+     let  pii_head_path = unsafe  { 
1486+         match  CStr :: from_ptr ( pii_head_path) . to_str ( )  { 
1487+             Ok ( s)  => s, 
1488+             Err ( _)  => return  false , 
1489+         } 
1490+     } ; 
1491+ 
1492+     let  security_head_path = unsafe  { 
1493+         match  CStr :: from_ptr ( security_head_path) . to_str ( )  { 
1494+             Ok ( s)  => s, 
1495+             Err ( _)  => return  false , 
1496+         } 
1497+     } ; 
1498+ 
1499+     // Convert C string arrays to Rust Vec<String> 
1500+     let  intent_labels_vec = unsafe  { 
1501+         std:: slice:: from_raw_parts ( intent_labels,  intent_labels_count) 
1502+             . iter ( ) 
1503+             . map ( |& ptr| CStr :: from_ptr ( ptr) . to_str ( ) . unwrap_or ( "" ) . to_string ( ) ) 
1504+             . collect :: < Vec < String > > ( ) 
1505+     } ; 
1506+ 
1507+     let  pii_labels_vec = unsafe  { 
1508+         std:: slice:: from_raw_parts ( pii_labels,  pii_labels_count) 
1509+             . iter ( ) 
1510+             . map ( |& ptr| CStr :: from_ptr ( ptr) . to_str ( ) . unwrap_or ( "" ) . to_string ( ) ) 
1511+             . collect :: < Vec < String > > ( ) 
1512+     } ; 
1513+ 
1514+     let  security_labels_vec = unsafe  { 
1515+         std:: slice:: from_raw_parts ( security_labels,  security_labels_count) 
1516+             . iter ( ) 
1517+             . map ( |& ptr| CStr :: from_ptr ( ptr) . to_str ( ) . unwrap_or ( "" ) . to_string ( ) ) 
1518+             . collect :: < Vec < String > > ( ) 
1519+     } ; 
1520+ 
1521+     match  UnifiedClassifier :: new ( 
1522+         modernbert_path, 
1523+         intent_head_path, 
1524+         pii_head_path, 
1525+         security_head_path, 
1526+         intent_labels_vec, 
1527+         pii_labels_vec, 
1528+         security_labels_vec, 
1529+         use_cpu, 
1530+     )  { 
1531+         Ok ( classifier)  => { 
1532+             let  mut  global_classifier = UNIFIED_CLASSIFIER . lock ( ) . unwrap ( ) ; 
1533+             * global_classifier = Some ( classifier) ; 
1534+             true 
1535+         } 
1536+         Err ( e)  => { 
1537+             eprintln ! ( "Failed to initialize unified classifier: {e}" ) ; 
1538+             false 
1539+         } 
1540+     } 
1541+ } 
1542+ 
1543+ /// Classify batch of texts using unified classifier (called from Go) 
1544+ #[ no_mangle]  
1545+ pub  extern  "C"  fn  classify_unified_batch ( 
1546+     texts_ptr :  * const  * const  c_char , 
1547+     num_texts :  i32 , 
1548+ )  -> UnifiedBatchResult  { 
1549+     if  texts_ptr. is_null ( )  || num_texts <= 0  { 
1550+         return  UnifiedBatchResult :: error ( "Invalid input parameters" ) ; 
1551+     } 
1552+ 
1553+     // Convert C strings to Rust strings 
1554+     let  texts = unsafe  { 
1555+         std:: slice:: from_raw_parts ( texts_ptr,  num_texts as  usize ) 
1556+             . iter ( ) 
1557+             . map ( |& ptr| { 
1558+                 if  ptr. is_null ( )  { 
1559+                     Err ( "Null text pointer" ) 
1560+                 }  else  { 
1561+                     CStr :: from_ptr ( ptr) . to_str ( ) . map_err ( |_| "Invalid UTF-8" ) 
1562+                 } 
1563+             } ) 
1564+             . collect :: < Result < Vec < _ > ,  _ > > ( ) 
1565+     } ; 
1566+ 
1567+     let  texts = match  texts { 
1568+         Ok ( t)  => t, 
1569+         Err ( e)  => return  UnifiedBatchResult :: error ( e) , 
1570+     } ; 
1571+ 
1572+     // Get unified classifier and perform batch classification 
1573+     match  get_unified_classifier ( )  { 
1574+         Ok ( classifier_guard)  => match  classifier_guard. as_ref ( )  { 
1575+             Some ( classifier)  => match  classifier. classify_batch ( & texts)  { 
1576+                 Ok ( result)  => UnifiedBatchResult :: from_batch_result ( result) , 
1577+                 Err ( e)  => UnifiedBatchResult :: error ( & format ! ( "Classification failed: {}" ,  e) ) , 
1578+             } , 
1579+             None  => UnifiedBatchResult :: error ( "Unified classifier not initialized" ) , 
1580+         } , 
1581+         Err ( e)  => UnifiedBatchResult :: error ( & format ! ( "Failed to get classifier: {}" ,  e) ) , 
1582+     } 
1583+ } 
1584+ 
1585+ /// Free unified batch result memory (called from Go) 
1586+ #[ no_mangle]  
1587+ pub  extern  "C"  fn  free_unified_batch_result ( result :  UnifiedBatchResult )  { 
1588+     if  result. error  { 
1589+         if  !result. error_message . is_null ( )  { 
1590+             unsafe  { 
1591+                 let  _ = CString :: from_raw ( result. error_message ) ; 
1592+             } 
1593+         } 
1594+         return ; 
1595+     } 
1596+ 
1597+     let  batch_size = result. batch_size  as  usize ; 
1598+ 
1599+     // Free intent results 
1600+     if  !result. intent_results . is_null ( )  { 
1601+         unsafe  { 
1602+             let  intent_slice = std:: slice:: from_raw_parts_mut ( result. intent_results ,  batch_size) ; 
1603+             for  intent in  intent_slice { 
1604+                 if  !intent. category . is_null ( )  { 
1605+                     let  _ = CString :: from_raw ( intent. category ) ; 
1606+                 } 
1607+                 if  !intent. probabilities . is_null ( )  { 
1608+                     let  _ = Vec :: from_raw_parts ( 
1609+                         intent. probabilities , 
1610+                         intent. num_probabilities  as  usize , 
1611+                         intent. num_probabilities  as  usize , 
1612+                     ) ; 
1613+                 } 
1614+             } 
1615+             let  _ = Box :: from_raw ( std:: slice:: from_raw_parts_mut ( 
1616+                 result. intent_results , 
1617+                 batch_size, 
1618+             ) ) ; 
1619+         } 
1620+     } 
1621+ 
1622+     // Free PII results 
1623+     if  !result. pii_results . is_null ( )  { 
1624+         unsafe  { 
1625+             let  pii_slice = std:: slice:: from_raw_parts_mut ( result. pii_results ,  batch_size) ; 
1626+             for  pii in  pii_slice { 
1627+                 if  !pii. pii_types . is_null ( )  { 
1628+                     let  types_slice =
1629+                         std:: slice:: from_raw_parts_mut ( pii. pii_types ,  pii. num_pii_types  as  usize ) ; 
1630+                     for  & mut  type_ptr in  types_slice { 
1631+                         if  !type_ptr. is_null ( )  { 
1632+                             let  _ = CString :: from_raw ( type_ptr) ; 
1633+                         } 
1634+                     } 
1635+                     let  _ = Vec :: from_raw_parts ( 
1636+                         pii. pii_types , 
1637+                         pii. num_pii_types  as  usize , 
1638+                         pii. num_pii_types  as  usize , 
1639+                     ) ; 
1640+                 } 
1641+             } 
1642+             let  _ = Box :: from_raw ( std:: slice:: from_raw_parts_mut ( 
1643+                 result. pii_results , 
1644+                 batch_size, 
1645+             ) ) ; 
1646+         } 
1647+     } 
1648+ 
1649+     // Free security results 
1650+     if  !result. security_results . is_null ( )  { 
1651+         unsafe  { 
1652+             let  security_slice =
1653+                 std:: slice:: from_raw_parts_mut ( result. security_results ,  batch_size) ; 
1654+             for  security in  security_slice { 
1655+                 if  !security. threat_type . is_null ( )  { 
1656+                     let  _ = CString :: from_raw ( security. threat_type ) ; 
1657+                 } 
1658+             } 
1659+             let  _ = Box :: from_raw ( std:: slice:: from_raw_parts_mut ( 
1660+                 result. security_results , 
1661+                 batch_size, 
1662+             ) ) ; 
1663+         } 
1664+     } 
1665+ } 
0 commit comments