Skip to content

Commit 61be8cd

Browse files
committed
Feature Enhancement: Batch Inference Support in candle-binding
Feature Enhancement: Batch Inference Support in candle-binding Signed-off-by: OneZero-Y <[email protected]> fix: unified_classifier_test Signed-off-by: OneZero-Y <[email protected]> fix: unified_classifier_test Signed-off-by: OneZero-Y <[email protected]> fix: unit_test Signed-off-by: OneZero-Y <[email protected]>
1 parent ea63386 commit 61be8cd

File tree

14 files changed

+3004
-361
lines changed

14 files changed

+3004
-361
lines changed

candle-binding/src/lib.rs

Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::sync::Arc;
66
use std::sync::Mutex;
77

88
pub mod modernbert;
9+
pub mod unified_classifier;
910

1011
// Re-export ModernBERT functions and structures
1112
pub use modernbert::{
@@ -14,6 +15,12 @@ pub use modernbert::{
1415
init_modernbert_pii_classifier, ModernBertClassificationResult,
1516
};
1617

18+
// Re-export unified classifier functions and structures
19+
pub use unified_classifier::{
20+
get_unified_classifier, BatchClassificationResult, IntentResult, PIIResult, SecurityResult,
21+
UnifiedClassificationResult, UnifiedClassifier, UNIFIED_CLASSIFIER,
22+
};
23+
1724
use anyhow::{Error as E, Result};
1825
use candle_core::{DType, Device, Tensor};
1926
use candle_nn::{Linear, VarBuilder};
@@ -1312,3 +1319,347 @@ pub extern "C" fn classify_jailbreak_text(text: *const c_char) -> Classification
13121319
}
13131320
}
13141321
}
1322+
1323+
// ================================================================================================
1324+
// UNIFIED CLASSIFIER C INTERFACE
1325+
// ================================================================================================
1326+
1327+
/// C-compatible structure for unified batch results
1328+
#[repr(C)]
1329+
pub struct UnifiedBatchResult {
1330+
pub intent_results: *mut CIntentResult,
1331+
pub pii_results: *mut CPIIResult,
1332+
pub security_results: *mut CSecurityResult,
1333+
pub batch_size: i32,
1334+
pub error: bool,
1335+
pub error_message: *mut c_char,
1336+
}
1337+
1338+
/// C-compatible intent result
1339+
#[repr(C)]
1340+
pub struct CIntentResult {
1341+
pub category: *mut c_char,
1342+
pub confidence: f32,
1343+
pub probabilities: *mut f32,
1344+
pub num_probabilities: i32,
1345+
}
1346+
1347+
/// C-compatible PII result
1348+
#[repr(C)]
1349+
pub struct CPIIResult {
1350+
pub has_pii: bool,
1351+
pub pii_types: *mut *mut c_char,
1352+
pub num_pii_types: i32,
1353+
pub confidence: f32,
1354+
}
1355+
1356+
/// C-compatible security result
1357+
#[repr(C)]
1358+
pub struct CSecurityResult {
1359+
pub is_jailbreak: bool,
1360+
pub threat_type: *mut c_char,
1361+
pub confidence: f32,
1362+
}
1363+
1364+
impl UnifiedBatchResult {
1365+
/// Create an error result
1366+
fn error(message: &str) -> Self {
1367+
let error_msg =
1368+
CString::new(message).unwrap_or_else(|_| CString::new("Unknown error").unwrap());
1369+
Self {
1370+
intent_results: std::ptr::null_mut(),
1371+
pii_results: std::ptr::null_mut(),
1372+
security_results: std::ptr::null_mut(),
1373+
batch_size: 0,
1374+
error: true,
1375+
error_message: error_msg.into_raw(),
1376+
}
1377+
}
1378+
1379+
/// Convert from Rust BatchClassificationResult to C-compatible structure
1380+
fn from_batch_result(result: BatchClassificationResult) -> Self {
1381+
let batch_size = result.batch_size as i32;
1382+
1383+
// Convert intent results
1384+
let intent_results = result
1385+
.intent_results
1386+
.into_iter()
1387+
.map(|r| {
1388+
let probs_len = r.probabilities.len();
1389+
CIntentResult {
1390+
category: CString::new(r.category).unwrap().into_raw(),
1391+
confidence: r.confidence,
1392+
probabilities: {
1393+
let mut probs = r.probabilities.into_boxed_slice();
1394+
let ptr = probs.as_mut_ptr();
1395+
std::mem::forget(probs);
1396+
ptr
1397+
},
1398+
num_probabilities: probs_len as i32,
1399+
}
1400+
})
1401+
.collect::<Vec<_>>()
1402+
.into_boxed_slice();
1403+
let intent_ptr = Box::into_raw(intent_results) as *mut CIntentResult;
1404+
1405+
// Convert PII results
1406+
let pii_results = result
1407+
.pii_results
1408+
.into_iter()
1409+
.map(|r| {
1410+
let types_len = r.pii_types.len();
1411+
CPIIResult {
1412+
has_pii: r.has_pii,
1413+
pii_types: {
1414+
let types: Vec<*mut c_char> = r
1415+
.pii_types
1416+
.into_iter()
1417+
.map(|t| CString::new(t).unwrap().into_raw())
1418+
.collect();
1419+
let mut types_box = types.into_boxed_slice();
1420+
let ptr = types_box.as_mut_ptr();
1421+
std::mem::forget(types_box);
1422+
ptr
1423+
},
1424+
num_pii_types: types_len as i32,
1425+
confidence: r.confidence,
1426+
}
1427+
})
1428+
.collect::<Vec<_>>()
1429+
.into_boxed_slice();
1430+
let pii_ptr = Box::into_raw(pii_results) as *mut CPIIResult;
1431+
1432+
// Convert security results
1433+
let security_results = result
1434+
.security_results
1435+
.into_iter()
1436+
.map(|r| CSecurityResult {
1437+
is_jailbreak: r.is_jailbreak,
1438+
threat_type: CString::new(r.threat_type).unwrap().into_raw(),
1439+
confidence: r.confidence,
1440+
})
1441+
.collect::<Vec<_>>()
1442+
.into_boxed_slice();
1443+
let security_ptr = Box::into_raw(security_results) as *mut CSecurityResult;
1444+
1445+
Self {
1446+
intent_results: intent_ptr,
1447+
pii_results: pii_ptr,
1448+
security_results: security_ptr,
1449+
batch_size,
1450+
error: false,
1451+
error_message: std::ptr::null_mut(),
1452+
}
1453+
}
1454+
}
1455+
1456+
/// Initialize unified classifier (called from Go)
1457+
#[no_mangle]
1458+
pub extern "C" fn init_unified_classifier_c(
1459+
modernbert_path: *const c_char,
1460+
intent_head_path: *const c_char,
1461+
pii_head_path: *const c_char,
1462+
security_head_path: *const c_char,
1463+
intent_labels: *const *const c_char,
1464+
intent_labels_count: usize,
1465+
pii_labels: *const *const c_char,
1466+
pii_labels_count: usize,
1467+
security_labels: *const *const c_char,
1468+
security_labels_count: usize,
1469+
use_cpu: bool,
1470+
) -> bool {
1471+
let modernbert_path = unsafe {
1472+
match CStr::from_ptr(modernbert_path).to_str() {
1473+
Ok(s) => s,
1474+
Err(_) => return false,
1475+
}
1476+
};
1477+
1478+
let intent_head_path = unsafe {
1479+
match CStr::from_ptr(intent_head_path).to_str() {
1480+
Ok(s) => s,
1481+
Err(_) => return false,
1482+
}
1483+
};
1484+
1485+
let pii_head_path = unsafe {
1486+
match CStr::from_ptr(pii_head_path).to_str() {
1487+
Ok(s) => s,
1488+
Err(_) => return false,
1489+
}
1490+
};
1491+
1492+
let security_head_path = unsafe {
1493+
match CStr::from_ptr(security_head_path).to_str() {
1494+
Ok(s) => s,
1495+
Err(_) => return false,
1496+
}
1497+
};
1498+
1499+
// Convert C string arrays to Rust Vec<String>
1500+
let intent_labels_vec = unsafe {
1501+
std::slice::from_raw_parts(intent_labels, intent_labels_count)
1502+
.iter()
1503+
.map(|&ptr| CStr::from_ptr(ptr).to_str().unwrap_or("").to_string())
1504+
.collect::<Vec<String>>()
1505+
};
1506+
1507+
let pii_labels_vec = unsafe {
1508+
std::slice::from_raw_parts(pii_labels, pii_labels_count)
1509+
.iter()
1510+
.map(|&ptr| CStr::from_ptr(ptr).to_str().unwrap_or("").to_string())
1511+
.collect::<Vec<String>>()
1512+
};
1513+
1514+
let security_labels_vec = unsafe {
1515+
std::slice::from_raw_parts(security_labels, security_labels_count)
1516+
.iter()
1517+
.map(|&ptr| CStr::from_ptr(ptr).to_str().unwrap_or("").to_string())
1518+
.collect::<Vec<String>>()
1519+
};
1520+
1521+
match UnifiedClassifier::new(
1522+
modernbert_path,
1523+
intent_head_path,
1524+
pii_head_path,
1525+
security_head_path,
1526+
intent_labels_vec,
1527+
pii_labels_vec,
1528+
security_labels_vec,
1529+
use_cpu,
1530+
) {
1531+
Ok(classifier) => {
1532+
let mut global_classifier = UNIFIED_CLASSIFIER.lock().unwrap();
1533+
*global_classifier = Some(classifier);
1534+
true
1535+
}
1536+
Err(e) => {
1537+
eprintln!("Failed to initialize unified classifier: {e}");
1538+
false
1539+
}
1540+
}
1541+
}
1542+
1543+
/// Classify batch of texts using unified classifier (called from Go)
1544+
#[no_mangle]
1545+
pub extern "C" fn classify_unified_batch(
1546+
texts_ptr: *const *const c_char,
1547+
num_texts: i32,
1548+
) -> UnifiedBatchResult {
1549+
if texts_ptr.is_null() || num_texts <= 0 {
1550+
return UnifiedBatchResult::error("Invalid input parameters");
1551+
}
1552+
1553+
// Convert C strings to Rust strings
1554+
let texts = unsafe {
1555+
std::slice::from_raw_parts(texts_ptr, num_texts as usize)
1556+
.iter()
1557+
.map(|&ptr| {
1558+
if ptr.is_null() {
1559+
Err("Null text pointer")
1560+
} else {
1561+
CStr::from_ptr(ptr).to_str().map_err(|_| "Invalid UTF-8")
1562+
}
1563+
})
1564+
.collect::<Result<Vec<_>, _>>()
1565+
};
1566+
1567+
let texts = match texts {
1568+
Ok(t) => t,
1569+
Err(e) => return UnifiedBatchResult::error(e),
1570+
};
1571+
1572+
// Get unified classifier and perform batch classification
1573+
match get_unified_classifier() {
1574+
Ok(classifier_guard) => match classifier_guard.as_ref() {
1575+
Some(classifier) => match classifier.classify_batch(&texts) {
1576+
Ok(result) => UnifiedBatchResult::from_batch_result(result),
1577+
Err(e) => UnifiedBatchResult::error(&format!("Classification failed: {}", e)),
1578+
},
1579+
None => UnifiedBatchResult::error("Unified classifier not initialized"),
1580+
},
1581+
Err(e) => UnifiedBatchResult::error(&format!("Failed to get classifier: {}", e)),
1582+
}
1583+
}
1584+
1585+
/// Free unified batch result memory (called from Go)
1586+
#[no_mangle]
1587+
pub extern "C" fn free_unified_batch_result(result: UnifiedBatchResult) {
1588+
if result.error {
1589+
if !result.error_message.is_null() {
1590+
unsafe {
1591+
let _ = CString::from_raw(result.error_message);
1592+
}
1593+
}
1594+
return;
1595+
}
1596+
1597+
let batch_size = result.batch_size as usize;
1598+
1599+
// Free intent results
1600+
if !result.intent_results.is_null() {
1601+
unsafe {
1602+
let intent_slice = std::slice::from_raw_parts_mut(result.intent_results, batch_size);
1603+
for intent in intent_slice {
1604+
if !intent.category.is_null() {
1605+
let _ = CString::from_raw(intent.category);
1606+
}
1607+
if !intent.probabilities.is_null() {
1608+
let _ = Vec::from_raw_parts(
1609+
intent.probabilities,
1610+
intent.num_probabilities as usize,
1611+
intent.num_probabilities as usize,
1612+
);
1613+
}
1614+
}
1615+
let _ = Box::from_raw(std::slice::from_raw_parts_mut(
1616+
result.intent_results,
1617+
batch_size,
1618+
));
1619+
}
1620+
}
1621+
1622+
// Free PII results
1623+
if !result.pii_results.is_null() {
1624+
unsafe {
1625+
let pii_slice = std::slice::from_raw_parts_mut(result.pii_results, batch_size);
1626+
for pii in pii_slice {
1627+
if !pii.pii_types.is_null() {
1628+
let types_slice =
1629+
std::slice::from_raw_parts_mut(pii.pii_types, pii.num_pii_types as usize);
1630+
for &mut type_ptr in types_slice {
1631+
if !type_ptr.is_null() {
1632+
let _ = CString::from_raw(type_ptr);
1633+
}
1634+
}
1635+
let _ = Vec::from_raw_parts(
1636+
pii.pii_types,
1637+
pii.num_pii_types as usize,
1638+
pii.num_pii_types as usize,
1639+
);
1640+
}
1641+
}
1642+
let _ = Box::from_raw(std::slice::from_raw_parts_mut(
1643+
result.pii_results,
1644+
batch_size,
1645+
));
1646+
}
1647+
}
1648+
1649+
// Free security results
1650+
if !result.security_results.is_null() {
1651+
unsafe {
1652+
let security_slice =
1653+
std::slice::from_raw_parts_mut(result.security_results, batch_size);
1654+
for security in security_slice {
1655+
if !security.threat_type.is_null() {
1656+
let _ = CString::from_raw(security.threat_type);
1657+
}
1658+
}
1659+
let _ = Box::from_raw(std::slice::from_raw_parts_mut(
1660+
result.security_results,
1661+
batch_size,
1662+
));
1663+
}
1664+
}
1665+
}

0 commit comments

Comments
 (0)