Skip to content

Commit 76b941f

Browse files
committed
refactor: Implement modular candle-binding architecture
 - Restructure codebase into modular layers (core/, ffi/, model_architectures/, classifiers/) - Add unified error handling and configuration loading systems - Implement dual-path architecture for traditional and LoRA models - Add comprehensive FFI layer with memory safety Maintains backward compatibility while enabling future model integrations. Signed-off-by: OneZero-Y <[email protected]>
1 parent c5d1425 commit 76b941f

File tree

10 files changed

+735
-174
lines changed

10 files changed

+735
-174
lines changed

candle-binding/src/classifiers/lora/intent_lora.rs

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,15 @@ impl IntentLoRAClassifier {
4747
candle_core::Error::from(unified_err)
4848
})?;
4949

50+
// Load threshold from global config instead of hardcoding
51+
let confidence_threshold = {
52+
use crate::core::config_loader::GlobalConfigLoader;
53+
GlobalConfigLoader::load_intent_threshold().unwrap_or(0.6) // Default from config.yaml classifier.category_model.threshold
54+
};
55+
5056
Ok(Self {
5157
bert_classifier: classifier,
52-
confidence_threshold: 0.7,
58+
confidence_threshold,
5359
intent_labels,
5460
model_path: model_path.to_string(),
5561
})
@@ -81,11 +87,21 @@ impl IntentLoRAClassifier {
8187
candle_core::Error::from(unified_err)
8288
})?;
8389

84-
// Map class index to intent label
90+
// Map class index to intent label - fail if class not found
8591
let intent = if predicted_class < self.intent_labels.len() {
8692
self.intent_labels[predicted_class].clone()
8793
} else {
88-
format!("UNKNOWN_{}", predicted_class)
94+
let unified_err = model_error!(
95+
ModelErrorType::LoRA,
96+
"intent classification",
97+
format!(
98+
"Invalid class index {} not found in labels (max: {})",
99+
predicted_class,
100+
self.intent_labels.len()
101+
),
102+
text
103+
);
104+
return Err(candle_core::Error::from(unified_err));
89105
};
90106

91107
let processing_time = start_time.elapsed().as_millis() as u64;
@@ -119,16 +135,23 @@ impl IntentLoRAClassifier {
119135
let processing_time = start_time.elapsed().as_millis() as u64;
120136

121137
let mut results = Vec::new();
122-
for (predicted_class, confidence) in batch_results {
123-
let intent = if predicted_class < self.intent_labels.len() {
124-
self.intent_labels[predicted_class].clone()
138+
for (i, (predicted_class, confidence)) in batch_results.iter().enumerate() {
139+
let intent = if *predicted_class < self.intent_labels.len() {
140+
self.intent_labels[*predicted_class].clone()
125141
} else {
126-
format!("UNKNOWN_{}", predicted_class)
142+
let unified_err = model_error!(
143+
ModelErrorType::LoRA,
144+
"batch intent classification",
145+
format!("Invalid class index {} not found in labels (max: {}) for text at position {}",
146+
predicted_class, self.intent_labels.len(), i),
147+
&format!("batch[{}]", i)
148+
);
149+
return Err(candle_core::Error::from(unified_err));
127150
};
128151

129152
results.push(IntentResult {
130153
intent,
131-
confidence,
154+
confidence: *confidence,
132155
processing_time_ms: processing_time,
133156
});
134157
}

candle-binding/src/classifiers/lora/pii_lora.rs

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,23 @@ pub struct PIILoRAClassifier {
2020
model_path: String,
2121
}
2222

23-
/// PII detection result
23+
/// Individual PII occurrence with its own confidence
24+
#[derive(Debug, Clone)]
25+
pub struct PIIOccurrence {
26+
pub pii_type: String,
27+
pub confidence: f32,
28+
pub token: String,
29+
pub start_pos: usize,
30+
pub end_pos: usize,
31+
}
32+
33+
/// PII detection result with individual occurrence confidences
2434
#[derive(Debug, Clone)]
2535
pub struct PIIResult {
2636
pub has_pii: bool,
27-
pub pii_types: Vec<String>,
28-
pub confidence: f32,
37+
pub pii_types: Vec<String>, // Keep for backward compatibility
38+
pub confidence: f32, // Overall confidence (average or max)
39+
pub occurrences: Vec<PIIOccurrence>, // Individual occurrences with their own confidence
2940
pub processing_time_ms: u64,
3041
}
3142

@@ -86,12 +97,13 @@ impl PIILoRAClassifier {
8697
candle_core::Error::from(unified_err)
8798
})?;
8899

89-
// Analyze token results to determine PII presence
100+
// Create individual occurrences with their own confidence scores
101+
let mut occurrences = Vec::new();
90102
let mut detected_types = Vec::new();
91-
let mut max_confidence = 0.0f32;
103+
let mut confidence_scores = Vec::new();
92104
let mut has_pii = false;
93105

94-
// Calculate confidence for "O" class before processing
106+
// Calculate confidence for "O" class for non-PII tokens
95107
let o_confidences: Vec<f32> = token_results
96108
.iter()
97109
.filter(|(_, class_idx, _)| *class_idx == 0) // "O" class
@@ -103,25 +115,35 @@ impl PIILoRAClassifier {
103115
o_confidences.iter().sum::<f32>() / o_confidences.len() as f32
104116
};
105117

106-
for (_token, class_idx, confidence) in token_results {
118+
// Process each token with its individual confidence
119+
for (i, (token, class_idx, confidence)) in token_results.iter().enumerate() {
107120
// Skip "O" (Outside) labels - class 0 typically means no PII
108-
if class_idx > 0 && class_idx < self.pii_types.len() {
121+
if *class_idx > 0 && *class_idx < self.pii_types.len() {
109122
has_pii = true;
110-
max_confidence = max_confidence.max(confidence);
123+
confidence_scores.push(*confidence);
111124

112-
let pii_type = &self.pii_types[class_idx];
125+
let pii_type = &self.pii_types[*class_idx];
113126
if !detected_types.contains(pii_type) {
114127
detected_types.push(pii_type.clone());
115128
}
129+
130+
// Create individual occurrence with its own confidence
131+
occurrences.push(PIIOccurrence {
132+
pii_type: pii_type.clone(),
133+
confidence: *confidence, // Each occurrence keeps its individual confidence
134+
token: token.clone(),
135+
start_pos: i, // Token position in sequence
136+
end_pos: i + 1,
137+
});
116138
}
117139
}
118140

119-
// Use real confidence from model inference - no hardcoded values
141+
// Calculate overall confidence without inflating individual confidences
120142
let final_confidence = if has_pii {
121-
max_confidence
143+
// Use average confidence instead of max to avoid inflating significance
144+
confidence_scores.iter().sum::<f32>() / confidence_scores.len() as f32
122145
} else {
123146
// For no PII detected, use the confidence of the "O" (Outside) class
124-
// This comes from the actual model's softmax output for class 0
125147
avg_o_confidence
126148
};
127149

@@ -131,6 +153,7 @@ impl PIILoRAClassifier {
131153
has_pii,
132154
pii_types: detected_types,
133155
confidence: final_confidence,
156+
occurrences, // Include individual occurrences with their own confidences
134157
processing_time_ms: processing_time,
135158
})
136159
}

candle-binding/src/classifiers/lora/security_lora.rs

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,15 @@ impl SecurityLoRAClassifier {
4949
candle_core::Error::from(unified_err)
5050
})?;
5151

52+
// Load threshold from global config instead of hardcoding
53+
let confidence_threshold = {
54+
use crate::core::config_loader::GlobalConfigLoader;
55+
GlobalConfigLoader::load_security_threshold().unwrap_or(0.7) // Default from config.yaml prompt_guard.threshold
56+
};
57+
5258
Ok(Self {
5359
bert_classifier,
54-
confidence_threshold: 0.5,
60+
confidence_threshold,
5561
threat_types,
5662
model_path: model_path.to_string(),
5763
})
@@ -83,22 +89,38 @@ impl SecurityLoRAClassifier {
8389
candle_core::Error::from(unified_err)
8490
})?;
8591

86-
// Determine if threat is detected based on predicted class
87-
let is_threat = predicted_class > 0; // Assuming class 0 is "benign" or "safe"
92+
// Map class index to threat type label - fail if class not found
93+
let threat_type = if predicted_class < self.threat_types.len() {
94+
self.threat_types[predicted_class].clone()
95+
} else {
96+
let unified_err = model_error!(
97+
ModelErrorType::LoRA,
98+
"security classification",
99+
format!(
100+
"Invalid class index {} not found in labels (max: {})",
101+
predicted_class,
102+
self.threat_types.len()
103+
),
104+
text
105+
);
106+
return Err(candle_core::Error::from(unified_err));
107+
};
88108

89-
// Get detected threat types
90-
let mut detected_threats = Vec::new();
91-
if is_threat && predicted_class < self.threat_types.len() {
92-
detected_threats.push(self.threat_types[predicted_class].clone());
93-
}
109+
// Determine if threat is detected based on class label (instead of hardcoded index)
110+
let is_threat = !threat_type.to_lowercase().contains("safe")
111+
&& !threat_type.to_lowercase().contains("benign")
112+
&& !threat_type.to_lowercase().contains("no_threat");
94113

95-
// Calculate severity score based on confidence and threat type
96-
let severity_score = if is_threat {
97-
confidence * 0.9 // High severity for detected threats
114+
// Get detected threat types
115+
let detected_threats = if is_threat {
116+
vec![threat_type]
98117
} else {
99-
0.0 // No severity for safe content
118+
Vec::new()
100119
};
101120

121+
// Use confidence as severity score (no artificial scaling)
122+
let severity_score = if is_threat { confidence } else { 0.0 };
123+
102124
let processing_time = start_time.elapsed().as_millis() as u64;
103125

104126
Ok(SecurityResult {
@@ -129,24 +151,41 @@ impl SecurityLoRAClassifier {
129151
let processing_time = start_time.elapsed().as_millis() as u64;
130152

131153
let mut results = Vec::new();
132-
for (predicted_class, confidence) in batch_results {
133-
// Determine if threat is detected
134-
let is_threat = predicted_class > 0; // Assuming class 0 is "benign"
154+
for (i, (predicted_class, confidence)) in batch_results.iter().enumerate() {
155+
// Map class index to threat type label - fail if class not found
156+
let threat_type = if *predicted_class < self.threat_types.len() {
157+
self.threat_types[*predicted_class].clone()
158+
} else {
159+
let unified_err = model_error!(
160+
ModelErrorType::LoRA,
161+
"batch security classification",
162+
format!("Invalid class index {} not found in labels (max: {}) for text at position {}",
163+
predicted_class, self.threat_types.len(), i),
164+
&format!("batch[{}]", i)
165+
);
166+
return Err(candle_core::Error::from(unified_err));
167+
};
168+
169+
// Determine if threat is detected based on class label
170+
let is_threat = !threat_type.to_lowercase().contains("safe")
171+
&& !threat_type.to_lowercase().contains("benign")
172+
&& !threat_type.to_lowercase().contains("no_threat");
135173

136174
// Get detected threat types
137-
let mut detected_threats = Vec::new();
138-
if is_threat && predicted_class < self.threat_types.len() {
139-
detected_threats.push(self.threat_types[predicted_class].clone());
140-
}
175+
let detected_threats = if is_threat {
176+
vec![threat_type]
177+
} else {
178+
Vec::new()
179+
};
141180

142-
// Calculate severity score
143-
let severity_score = if is_threat { confidence * 0.9 } else { 0.0 };
181+
// Use confidence as severity score (no artificial scaling)
182+
let severity_score = if is_threat { *confidence } else { 0.0 };
144183

145184
results.push(SecurityResult {
146185
is_threat,
147186
threat_types: detected_threats,
148187
severity_score,
149-
confidence,
188+
confidence: *confidence,
150189
processing_time_ms: processing_time,
151190
});
152191
}

0 commit comments

Comments
 (0)