Skip to content

Commit bfb87d9

Browse files
committed
Feature Enhancement: Batch Inference Support in candle-binding
Feature Enhancement: Batch Inference Support in candle-binding Signed-off-by: OneZero-Y <[email protected]> fix: unified_classifier_test Signed-off-by: OneZero-Y <[email protected]> fix: unified_classifier_test Signed-off-by: OneZero-Y <[email protected]> fix: unit_test Signed-off-by: OneZero-Y <[email protected]>
1 parent 734c995 commit bfb87d9

File tree

15 files changed

+3016
-354
lines changed

15 files changed

+3016
-354
lines changed

candle-binding/src/lib.rs

Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::sync::Arc;
66
use std::sync::Mutex;
77

88
pub mod modernbert;
9+
pub mod unified_classifier;
910

1011
// Re-export ModernBERT functions and structures
1112
pub use modernbert::{
@@ -14,6 +15,12 @@ pub use modernbert::{
1415
init_modernbert_pii_classifier, ModernBertClassificationResult,
1516
};
1617

18+
// Re-export unified classifier functions and structures
19+
pub use unified_classifier::{
20+
get_unified_classifier, BatchClassificationResult, IntentResult, PIIResult, SecurityResult,
21+
UnifiedClassificationResult, UnifiedClassifier, UNIFIED_CLASSIFIER,
22+
};
23+
1724
use anyhow::{Error as E, Result};
1825
use candle_core::{DType, Device, Tensor};
1926
use candle_nn::{Linear, VarBuilder};
@@ -1177,3 +1184,347 @@ pub extern "C" fn classify_jailbreak_text(text: *const c_char) -> Classification
11771184
}
11781185
}
11791186
}
1187+
1188+
// ================================================================================================
1189+
// UNIFIED CLASSIFIER C INTERFACE
1190+
// ================================================================================================
1191+
1192+
/// C-compatible structure for unified batch results
1193+
#[repr(C)]
1194+
pub struct UnifiedBatchResult {
1195+
pub intent_results: *mut CIntentResult,
1196+
pub pii_results: *mut CPIIResult,
1197+
pub security_results: *mut CSecurityResult,
1198+
pub batch_size: i32,
1199+
pub error: bool,
1200+
pub error_message: *mut c_char,
1201+
}
1202+
1203+
/// C-compatible intent result
1204+
#[repr(C)]
1205+
pub struct CIntentResult {
1206+
pub category: *mut c_char,
1207+
pub confidence: f32,
1208+
pub probabilities: *mut f32,
1209+
pub num_probabilities: i32,
1210+
}
1211+
1212+
/// C-compatible PII result
1213+
#[repr(C)]
1214+
pub struct CPIIResult {
1215+
pub has_pii: bool,
1216+
pub pii_types: *mut *mut c_char,
1217+
pub num_pii_types: i32,
1218+
pub confidence: f32,
1219+
}
1220+
1221+
/// C-compatible security result
1222+
#[repr(C)]
1223+
pub struct CSecurityResult {
1224+
pub is_jailbreak: bool,
1225+
pub threat_type: *mut c_char,
1226+
pub confidence: f32,
1227+
}
1228+
1229+
impl UnifiedBatchResult {
1230+
/// Create an error result
1231+
fn error(message: &str) -> Self {
1232+
let error_msg =
1233+
CString::new(message).unwrap_or_else(|_| CString::new("Unknown error").unwrap());
1234+
Self {
1235+
intent_results: std::ptr::null_mut(),
1236+
pii_results: std::ptr::null_mut(),
1237+
security_results: std::ptr::null_mut(),
1238+
batch_size: 0,
1239+
error: true,
1240+
error_message: error_msg.into_raw(),
1241+
}
1242+
}
1243+
1244+
/// Convert from Rust BatchClassificationResult to C-compatible structure
1245+
fn from_batch_result(result: BatchClassificationResult) -> Self {
1246+
let batch_size = result.batch_size as i32;
1247+
1248+
// Convert intent results
1249+
let intent_results = result
1250+
.intent_results
1251+
.into_iter()
1252+
.map(|r| {
1253+
let probs_len = r.probabilities.len();
1254+
CIntentResult {
1255+
category: CString::new(r.category).unwrap().into_raw(),
1256+
confidence: r.confidence,
1257+
probabilities: {
1258+
let mut probs = r.probabilities.into_boxed_slice();
1259+
let ptr = probs.as_mut_ptr();
1260+
std::mem::forget(probs);
1261+
ptr
1262+
},
1263+
num_probabilities: probs_len as i32,
1264+
}
1265+
})
1266+
.collect::<Vec<_>>()
1267+
.into_boxed_slice();
1268+
let intent_ptr = Box::into_raw(intent_results) as *mut CIntentResult;
1269+
1270+
// Convert PII results
1271+
let pii_results = result
1272+
.pii_results
1273+
.into_iter()
1274+
.map(|r| {
1275+
let types_len = r.pii_types.len();
1276+
CPIIResult {
1277+
has_pii: r.has_pii,
1278+
pii_types: {
1279+
let types: Vec<*mut c_char> = r
1280+
.pii_types
1281+
.into_iter()
1282+
.map(|t| CString::new(t).unwrap().into_raw())
1283+
.collect();
1284+
let mut types_box = types.into_boxed_slice();
1285+
let ptr = types_box.as_mut_ptr();
1286+
std::mem::forget(types_box);
1287+
ptr
1288+
},
1289+
num_pii_types: types_len as i32,
1290+
confidence: r.confidence,
1291+
}
1292+
})
1293+
.collect::<Vec<_>>()
1294+
.into_boxed_slice();
1295+
let pii_ptr = Box::into_raw(pii_results) as *mut CPIIResult;
1296+
1297+
// Convert security results
1298+
let security_results = result
1299+
.security_results
1300+
.into_iter()
1301+
.map(|r| CSecurityResult {
1302+
is_jailbreak: r.is_jailbreak,
1303+
threat_type: CString::new(r.threat_type).unwrap().into_raw(),
1304+
confidence: r.confidence,
1305+
})
1306+
.collect::<Vec<_>>()
1307+
.into_boxed_slice();
1308+
let security_ptr = Box::into_raw(security_results) as *mut CSecurityResult;
1309+
1310+
Self {
1311+
intent_results: intent_ptr,
1312+
pii_results: pii_ptr,
1313+
security_results: security_ptr,
1314+
batch_size,
1315+
error: false,
1316+
error_message: std::ptr::null_mut(),
1317+
}
1318+
}
1319+
}
1320+
1321+
/// Initialize unified classifier (called from Go)
1322+
#[no_mangle]
1323+
pub extern "C" fn init_unified_classifier_c(
1324+
modernbert_path: *const c_char,
1325+
intent_head_path: *const c_char,
1326+
pii_head_path: *const c_char,
1327+
security_head_path: *const c_char,
1328+
intent_labels: *const *const c_char,
1329+
intent_labels_count: usize,
1330+
pii_labels: *const *const c_char,
1331+
pii_labels_count: usize,
1332+
security_labels: *const *const c_char,
1333+
security_labels_count: usize,
1334+
use_cpu: bool,
1335+
) -> bool {
1336+
let modernbert_path = unsafe {
1337+
match CStr::from_ptr(modernbert_path).to_str() {
1338+
Ok(s) => s,
1339+
Err(_) => return false,
1340+
}
1341+
};
1342+
1343+
let intent_head_path = unsafe {
1344+
match CStr::from_ptr(intent_head_path).to_str() {
1345+
Ok(s) => s,
1346+
Err(_) => return false,
1347+
}
1348+
};
1349+
1350+
let pii_head_path = unsafe {
1351+
match CStr::from_ptr(pii_head_path).to_str() {
1352+
Ok(s) => s,
1353+
Err(_) => return false,
1354+
}
1355+
};
1356+
1357+
let security_head_path = unsafe {
1358+
match CStr::from_ptr(security_head_path).to_str() {
1359+
Ok(s) => s,
1360+
Err(_) => return false,
1361+
}
1362+
};
1363+
1364+
// Convert C string arrays to Rust Vec<String>
1365+
let intent_labels_vec = unsafe {
1366+
std::slice::from_raw_parts(intent_labels, intent_labels_count)
1367+
.iter()
1368+
.map(|&ptr| CStr::from_ptr(ptr).to_str().unwrap_or("").to_string())
1369+
.collect::<Vec<String>>()
1370+
};
1371+
1372+
let pii_labels_vec = unsafe {
1373+
std::slice::from_raw_parts(pii_labels, pii_labels_count)
1374+
.iter()
1375+
.map(|&ptr| CStr::from_ptr(ptr).to_str().unwrap_or("").to_string())
1376+
.collect::<Vec<String>>()
1377+
};
1378+
1379+
let security_labels_vec = unsafe {
1380+
std::slice::from_raw_parts(security_labels, security_labels_count)
1381+
.iter()
1382+
.map(|&ptr| CStr::from_ptr(ptr).to_str().unwrap_or("").to_string())
1383+
.collect::<Vec<String>>()
1384+
};
1385+
1386+
match UnifiedClassifier::new(
1387+
modernbert_path,
1388+
intent_head_path,
1389+
pii_head_path,
1390+
security_head_path,
1391+
intent_labels_vec,
1392+
pii_labels_vec,
1393+
security_labels_vec,
1394+
use_cpu,
1395+
) {
1396+
Ok(classifier) => {
1397+
let mut global_classifier = UNIFIED_CLASSIFIER.lock().unwrap();
1398+
*global_classifier = Some(classifier);
1399+
true
1400+
}
1401+
Err(e) => {
1402+
eprintln!("Failed to initialize unified classifier: {e}");
1403+
false
1404+
}
1405+
}
1406+
}
1407+
1408+
/// Classify batch of texts using unified classifier (called from Go)
1409+
#[no_mangle]
1410+
pub extern "C" fn classify_unified_batch(
1411+
texts_ptr: *const *const c_char,
1412+
num_texts: i32,
1413+
) -> UnifiedBatchResult {
1414+
if texts_ptr.is_null() || num_texts <= 0 {
1415+
return UnifiedBatchResult::error("Invalid input parameters");
1416+
}
1417+
1418+
// Convert C strings to Rust strings
1419+
let texts = unsafe {
1420+
std::slice::from_raw_parts(texts_ptr, num_texts as usize)
1421+
.iter()
1422+
.map(|&ptr| {
1423+
if ptr.is_null() {
1424+
Err("Null text pointer")
1425+
} else {
1426+
CStr::from_ptr(ptr).to_str().map_err(|_| "Invalid UTF-8")
1427+
}
1428+
})
1429+
.collect::<Result<Vec<_>, _>>()
1430+
};
1431+
1432+
let texts = match texts {
1433+
Ok(t) => t,
1434+
Err(e) => return UnifiedBatchResult::error(e),
1435+
};
1436+
1437+
// Get unified classifier and perform batch classification
1438+
match get_unified_classifier() {
1439+
Ok(classifier_guard) => match classifier_guard.as_ref() {
1440+
Some(classifier) => match classifier.classify_batch(&texts) {
1441+
Ok(result) => UnifiedBatchResult::from_batch_result(result),
1442+
Err(e) => UnifiedBatchResult::error(&format!("Classification failed: {}", e)),
1443+
},
1444+
None => UnifiedBatchResult::error("Unified classifier not initialized"),
1445+
},
1446+
Err(e) => UnifiedBatchResult::error(&format!("Failed to get classifier: {}", e)),
1447+
}
1448+
}
1449+
1450+
/// Free unified batch result memory (called from Go)
1451+
#[no_mangle]
1452+
pub extern "C" fn free_unified_batch_result(result: UnifiedBatchResult) {
1453+
if result.error {
1454+
if !result.error_message.is_null() {
1455+
unsafe {
1456+
let _ = CString::from_raw(result.error_message);
1457+
}
1458+
}
1459+
return;
1460+
}
1461+
1462+
let batch_size = result.batch_size as usize;
1463+
1464+
// Free intent results
1465+
if !result.intent_results.is_null() {
1466+
unsafe {
1467+
let intent_slice = std::slice::from_raw_parts_mut(result.intent_results, batch_size);
1468+
for intent in intent_slice {
1469+
if !intent.category.is_null() {
1470+
let _ = CString::from_raw(intent.category);
1471+
}
1472+
if !intent.probabilities.is_null() {
1473+
let _ = Vec::from_raw_parts(
1474+
intent.probabilities,
1475+
intent.num_probabilities as usize,
1476+
intent.num_probabilities as usize,
1477+
);
1478+
}
1479+
}
1480+
let _ = Box::from_raw(std::slice::from_raw_parts_mut(
1481+
result.intent_results,
1482+
batch_size,
1483+
));
1484+
}
1485+
}
1486+
1487+
// Free PII results
1488+
if !result.pii_results.is_null() {
1489+
unsafe {
1490+
let pii_slice = std::slice::from_raw_parts_mut(result.pii_results, batch_size);
1491+
for pii in pii_slice {
1492+
if !pii.pii_types.is_null() {
1493+
let types_slice =
1494+
std::slice::from_raw_parts_mut(pii.pii_types, pii.num_pii_types as usize);
1495+
for &mut type_ptr in types_slice {
1496+
if !type_ptr.is_null() {
1497+
let _ = CString::from_raw(type_ptr);
1498+
}
1499+
}
1500+
let _ = Vec::from_raw_parts(
1501+
pii.pii_types,
1502+
pii.num_pii_types as usize,
1503+
pii.num_pii_types as usize,
1504+
);
1505+
}
1506+
}
1507+
let _ = Box::from_raw(std::slice::from_raw_parts_mut(
1508+
result.pii_results,
1509+
batch_size,
1510+
));
1511+
}
1512+
}
1513+
1514+
// Free security results
1515+
if !result.security_results.is_null() {
1516+
unsafe {
1517+
let security_slice =
1518+
std::slice::from_raw_parts_mut(result.security_results, batch_size);
1519+
for security in security_slice {
1520+
if !security.threat_type.is_null() {
1521+
let _ = CString::from_raw(security.threat_type);
1522+
}
1523+
}
1524+
let _ = Box::from_raw(std::slice::from_raw_parts_mut(
1525+
result.security_results,
1526+
batch_size,
1527+
));
1528+
}
1529+
}
1530+
}

0 commit comments

Comments
 (0)