@@ -49,9 +49,15 @@ impl SecurityLoRAClassifier {
49
49
candle_core:: Error :: from ( unified_err)
50
50
} ) ?;
51
51
52
+ // Load threshold from global config instead of hardcoding
53
+ let confidence_threshold = {
54
+ use crate :: core:: config_loader:: GlobalConfigLoader ;
55
+ GlobalConfigLoader :: load_security_threshold ( ) . unwrap_or ( 0.7 ) // Default from config.yaml prompt_guard.threshold
56
+ } ;
57
+
52
58
Ok ( Self {
53
59
bert_classifier,
54
- confidence_threshold : 0.5 ,
60
+ confidence_threshold,
55
61
threat_types,
56
62
model_path : model_path. to_string ( ) ,
57
63
} )
@@ -83,22 +89,38 @@ impl SecurityLoRAClassifier {
83
89
candle_core:: Error :: from ( unified_err)
84
90
} ) ?;
85
91
86
- // Determine if threat is detected based on predicted class
87
- let is_threat = predicted_class > 0 ; // Assuming class 0 is "benign" or "safe"
92
+ // Map class index to threat type label - fail if class not found
93
+ let threat_type = if predicted_class < self . threat_types . len ( ) {
94
+ self . threat_types [ predicted_class] . clone ( )
95
+ } else {
96
+ let unified_err = model_error ! (
97
+ ModelErrorType :: LoRA ,
98
+ "security classification" ,
99
+ format!(
100
+ "Invalid class index {} not found in labels (max: {})" ,
101
+ predicted_class,
102
+ self . threat_types. len( )
103
+ ) ,
104
+ text
105
+ ) ;
106
+ return Err ( candle_core:: Error :: from ( unified_err) ) ;
107
+ } ;
88
108
89
- // Get detected threat types
90
- let mut detected_threats = Vec :: new ( ) ;
91
- if is_threat && predicted_class < self . threat_types . len ( ) {
92
- detected_threats. push ( self . threat_types [ predicted_class] . clone ( ) ) ;
93
- }
109
+ // Determine if threat is detected based on class label (instead of hardcoded index)
110
+ let is_threat = !threat_type. to_lowercase ( ) . contains ( "safe" )
111
+ && !threat_type. to_lowercase ( ) . contains ( "benign" )
112
+ && !threat_type. to_lowercase ( ) . contains ( "no_threat" ) ;
94
113
95
- // Calculate severity score based on confidence and threat type
96
- let severity_score = if is_threat {
97
- confidence * 0.9 // High severity for detected threats
114
+ // Get detected threat types
115
+ let detected_threats = if is_threat {
116
+ vec ! [ threat_type ]
98
117
} else {
99
- 0.0 // No severity for safe content
118
+ Vec :: new ( )
100
119
} ;
101
120
121
+ // Use confidence as severity score (no artificial scaling)
122
+ let severity_score = if is_threat { confidence } else { 0.0 } ;
123
+
102
124
let processing_time = start_time. elapsed ( ) . as_millis ( ) as u64 ;
103
125
104
126
Ok ( SecurityResult {
@@ -129,24 +151,41 @@ impl SecurityLoRAClassifier {
129
151
let processing_time = start_time. elapsed ( ) . as_millis ( ) as u64 ;
130
152
131
153
let mut results = Vec :: new ( ) ;
132
- for ( predicted_class, confidence) in batch_results {
133
- // Determine if threat is detected
134
- let is_threat = predicted_class > 0 ; // Assuming class 0 is "benign"
154
+ for ( i, ( predicted_class, confidence) ) in batch_results. iter ( ) . enumerate ( ) {
155
+ // Map class index to threat type label - fail if class not found
156
+ let threat_type = if * predicted_class < self . threat_types . len ( ) {
157
+ self . threat_types [ * predicted_class] . clone ( )
158
+ } else {
159
+ let unified_err = model_error ! (
160
+ ModelErrorType :: LoRA ,
161
+ "batch security classification" ,
162
+ format!( "Invalid class index {} not found in labels (max: {}) for text at position {}" ,
163
+ predicted_class, self . threat_types. len( ) , i) ,
164
+ & format!( "batch[{}]" , i)
165
+ ) ;
166
+ return Err ( candle_core:: Error :: from ( unified_err) ) ;
167
+ } ;
168
+
169
+ // Determine if threat is detected based on class label
170
+ let is_threat = !threat_type. to_lowercase ( ) . contains ( "safe" )
171
+ && !threat_type. to_lowercase ( ) . contains ( "benign" )
172
+ && !threat_type. to_lowercase ( ) . contains ( "no_threat" ) ;
135
173
136
174
// Get detected threat types
137
- let mut detected_threats = Vec :: new ( ) ;
138
- if is_threat && predicted_class < self . threat_types . len ( ) {
139
- detected_threats. push ( self . threat_types [ predicted_class] . clone ( ) ) ;
140
- }
175
+ let detected_threats = if is_threat {
176
+ vec ! [ threat_type]
177
+ } else {
178
+ Vec :: new ( )
179
+ } ;
141
180
142
- // Calculate severity score
143
- let severity_score = if is_threat { confidence * 0.9 } else { 0.0 } ;
181
+ // Use confidence as severity score (no artificial scaling)
182
+ let severity_score = if is_threat { * confidence } else { 0.0 } ;
144
183
145
184
results. push ( SecurityResult {
146
185
is_threat,
147
186
threat_types : detected_threats,
148
187
severity_score,
149
- confidence,
188
+ confidence : * confidence ,
150
189
processing_time_ms : processing_time,
151
190
} ) ;
152
191
}
0 commit comments