Skip to content

Commit baa0822

Browse files
committed
fix: unit test and model download from huggingface
Signed-off-by: OneZero-Y <[email protected]>
1 parent ad44eff commit baa0822

File tree

5 files changed

+451
-216
lines changed

5 files changed

+451
-216
lines changed

Makefile

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -257,39 +257,39 @@ download-models:
257257
fi
258258

259259
@if [ ! -d "lora_intent_classifier_bert-base-uncased_model" ]; then \
260-
hf download OneZero-Y/lora_intent_classifier_bert-base-uncased_model --local-dir models/lora_intent_classifier_bert-base-uncased_model; \
260+
hf download LLM-Semantic-Router/lora_intent_classifier_bert-base-uncased_model --local-dir models/lora_intent_classifier_bert-base-uncased_model; \
261261
fi
262262

263263
@if [ ! -d "models/lora_intent_classifier_roberta-base_model" ]; then \
264-
hf download OneZero-Y/lora_intent_classifier_roberta-base_model --local-dir models/lora_intent_classifier_roberta-base_model; \
264+
hf download LLM-Semantic-Router/lora_intent_classifier_roberta-base_model --local-dir models/lora_intent_classifier_roberta-base_model; \
265265
fi
266266

267267
@if [ ! -d "models/lora_intent_classifier_modernbert-base_model" ]; then \
268-
hf download OneZero-Y/lora_intent_classifier_modernbert-base_model --local-dir models/lora_intent_classifier_modernbert-base_model; \
268+
hf download LLM-Semantic-Router/lora_intent_classifier_modernbert-base_model --local-dir models/lora_intent_classifier_modernbert-base_model; \
269269
fi
270270

271271
@if [ ! -d "models/lora_pii_detector_bert-base-uncased_model" ]; then \
272-
hf download OneZero-Y/lora_pii_detector_bert-base-uncased_model --local-dir models/lora_pii_detector_bert-base-uncased_model; \
272+
hf download LLM-Semantic-Router/lora_pii_detector_bert-base-uncased_model --local-dir models/lora_pii_detector_bert-base-uncased_model; \
273273
fi
274274

275275
@if [ ! -d "models/lora_pii_detector_roberta-base_model" ]; then \
276-
hf download OneZero-Y/lora_pii_detector_roberta-base_model --local-dir models/lora_pii_detector_roberta-base_model; \
276+
hf download LLM-Semantic-Router/lora_pii_detector_roberta-base_model --local-dir models/lora_pii_detector_roberta-base_model; \
277277
fi
278278

279279
@if [ ! -d "models/lora_pii_detector_modernbert-base_model" ]; then \
280-
hf download OneZero-Y/lora_pii_detector_modernbert-base_model --local-dir models/lora_pii_detector_modernbert-base_model; \
280+
hf download LLM-Semantic-Router/lora_pii_detector_modernbert-base_model --local-dir models/lora_pii_detector_modernbert-base_model; \
281281
fi
282282

283283
@if [ ! -d "models/lora_jailbreak_classifier_bert-base-uncased_model" ]; then \
284-
hf download OneZero-Y/lora_jailbreak_classifier_bert-base-uncased_model --local-dir models/lora_jailbreak_classifier_bert-base-uncased_model; \
284+
hf download LLM-Semantic-Router/lora_jailbreak_classifier_bert-base-uncased_model --local-dir models/lora_jailbreak_classifier_bert-base-uncased_model; \
285285
fi
286286

287287
@if [ ! -d "models/lora_jailbreak_classifier_roberta-base_model" ]; then \
288-
hf download OneZero-Y/lora_jailbreak_classifier_roberta-base_model --local-dir models/lora_jailbreak_classifier_roberta-base_model; \
288+
hf download LLM-Semantic-Router/lora_jailbreak_classifier_roberta-base_model --local-dir models/lora_jailbreak_classifier_roberta-base_model; \
289289
fi
290290

291291
@if [ ! -d "models/lora_jailbreak_classifier_modernbert-base_model" ]; then \
292-
hf download OneZero-Y/lora_jailbreak_classifier_modernbert-base_model --local-dir models/lora_jailbreak_classifier_modernbert-base_model; \
292+
hf download LLM-Semantic-Router/lora_jailbreak_classifier_modernbert-base_model --local-dir models/lora_jailbreak_classifier_modernbert-base_model; \
293293
fi
294294

295295
# Milvus container management

candle-binding/src/bert_official.rs

Lines changed: 205 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,48 @@ pub struct CandleBertClassifier {
1818
}
1919

2020
impl CandleBertClassifier {
21+
/// Shared helper method for efficient batch tensor creation
22+
fn create_batch_tensors(
23+
&self,
24+
texts: &[&str],
25+
) -> Result<(Tensor, Tensor, Tensor, Vec<tokenizers::Encoding>)> {
26+
let encodings = self
27+
.tokenizer
28+
.encode_batch(texts.to_vec(), true)
29+
.map_err(E::msg)?;
30+
31+
let batch_size = texts.len();
32+
let max_len = encodings
33+
.iter()
34+
.map(|enc| enc.get_ids().len())
35+
.max()
36+
.unwrap_or(0);
37+
38+
let total_elements = batch_size * max_len;
39+
let mut all_token_ids = Vec::with_capacity(total_elements);
40+
let mut all_attention_masks = Vec::with_capacity(total_elements);
41+
42+
for encoding in &encodings {
43+
let token_ids = encoding.get_ids();
44+
let attention_mask = encoding.get_attention_mask();
45+
46+
all_token_ids.extend_from_slice(token_ids);
47+
all_attention_masks.extend_from_slice(attention_mask);
48+
49+
let padding_needed = max_len - token_ids.len();
50+
all_token_ids.extend(std::iter::repeat(0).take(padding_needed));
51+
all_attention_masks.extend(std::iter::repeat(0).take(padding_needed));
52+
}
53+
54+
let token_ids =
55+
Tensor::new(all_token_ids.as_slice(), &self.device)?.reshape(&[batch_size, max_len])?;
56+
let attention_mask = Tensor::new(all_attention_masks.as_slice(), &self.device)?
57+
.reshape(&[batch_size, max_len])?;
58+
let token_type_ids = Tensor::zeros(&[batch_size, max_len], DType::U32, &self.device)?;
59+
60+
Ok((token_ids, attention_mask, token_type_ids, encodings))
61+
}
62+
2163
pub fn new(model_path: &str, num_classes: usize, use_cpu: bool) -> Result<Self> {
2264
let device = if use_cpu {
2365
Device::Cpu
@@ -137,6 +179,47 @@ impl CandleBertClassifier {
137179

138180
Ok((predicted_class, confidence))
139181
}
182+
183+
/// True batch processing for multiple texts - significant performance improvement
184+
pub fn classify_batch(&self, texts: &[&str]) -> Result<Vec<(usize, f32)>> {
185+
if texts.is_empty() {
186+
return Ok(Vec::new());
187+
}
188+
189+
// OPTIMIZATION: Use shared tensor creation method
190+
let (token_ids, attention_mask, token_type_ids, _encodings) =
191+
self.create_batch_tensors(texts)?;
192+
193+
// Batch BERT forward pass
194+
let sequence_output =
195+
self.bert
196+
.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
197+
198+
// OPTIMIZATION: Use proper CLS token pooling instead of mean pooling
199+
let cls_tokens = sequence_output.i((.., 0))?; // Extract CLS tokens for all samples
200+
let pooled_output = self.pooler.forward(&cls_tokens)?;
201+
let pooled_output = pooled_output.tanh()?;
202+
203+
let logits = self.classifier.forward(&pooled_output)?;
204+
let probabilities = candle_nn::ops::softmax(&logits, 1)?;
205+
206+
// OPTIMIZATION: Batch result extraction
207+
let probs_data = probabilities.to_vec2::<f32>()?;
208+
let mut results = Vec::with_capacity(texts.len());
209+
210+
for row in probs_data {
211+
let (predicted_class, confidence) = row
212+
.iter()
213+
.enumerate()
214+
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
215+
.map(|(idx, &conf)| (idx, conf))
216+
.unwrap_or((0, 0.0));
217+
218+
results.push((predicted_class, confidence));
219+
}
220+
221+
Ok(results)
222+
}
140223
}
141224

142225
/// BERT token classifier for PII detection
@@ -148,6 +231,48 @@ pub struct CandleBertTokenClassifier {
148231
}
149232

150233
impl CandleBertTokenClassifier {
234+
/// Shared helper method for efficient batch tensor creation
235+
fn create_batch_tensors(
236+
&self,
237+
texts: &[&str],
238+
) -> Result<(Tensor, Tensor, Tensor, Vec<tokenizers::Encoding>)> {
239+
let encodings = self
240+
.tokenizer
241+
.encode_batch(texts.to_vec(), true)
242+
.map_err(E::msg)?;
243+
244+
let batch_size = texts.len();
245+
let max_len = encodings
246+
.iter()
247+
.map(|enc| enc.get_ids().len())
248+
.max()
249+
.unwrap_or(0);
250+
251+
let total_elements = batch_size * max_len;
252+
let mut all_token_ids = Vec::with_capacity(total_elements);
253+
let mut all_attention_masks = Vec::with_capacity(total_elements);
254+
255+
for encoding in &encodings {
256+
let token_ids = encoding.get_ids();
257+
let attention_mask = encoding.get_attention_mask();
258+
259+
all_token_ids.extend_from_slice(token_ids);
260+
all_attention_masks.extend_from_slice(attention_mask);
261+
262+
let padding_needed = max_len - token_ids.len();
263+
all_token_ids.extend(std::iter::repeat(0).take(padding_needed));
264+
all_attention_masks.extend(std::iter::repeat(0).take(padding_needed));
265+
}
266+
267+
let token_ids =
268+
Tensor::new(all_token_ids.as_slice(), &self.device)?.reshape(&[batch_size, max_len])?;
269+
let attention_mask = Tensor::new(all_attention_masks.as_slice(), &self.device)?
270+
.reshape(&[batch_size, max_len])?;
271+
let token_type_ids = Tensor::zeros(&[batch_size, max_len], DType::U32, &self.device)?;
272+
273+
Ok((token_ids, attention_mask, token_type_ids, encodings))
274+
}
275+
151276
pub fn new(model_path: &str, num_classes: usize, use_cpu: bool) -> Result<Self> {
152277
let device = if use_cpu {
153278
Device::Cpu
@@ -208,95 +333,109 @@ impl CandleBertTokenClassifier {
208333
})
209334
}
210335

211-
pub fn classify_tokens(&self, text: &str) -> Result<Vec<(String, usize, f32)>> {
212-
// Tokenize
213-
let encoding = self.tokenizer.encode(text, true).map_err(E::msg)?;
214-
let token_ids = encoding.get_ids().to_vec();
215-
let attention_mask = encoding.get_attention_mask().to_vec();
216-
let tokens = encoding.get_tokens();
217-
218-
// Create tensors
219-
let token_ids = Tensor::new(&token_ids[..], &self.device)?.unsqueeze(0)?;
220-
let token_type_ids = token_ids.zeros_like()?;
221-
let attention_mask = Tensor::new(&attention_mask[..], &self.device)?.unsqueeze(0)?;
222-
223-
// Forward pass
224-
let sequence_output =
225-
self.bert
226-
.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
227-
228-
// Apply token classifier to each token
229-
let logits = self.classifier.forward(&sequence_output)?;
336+
/// Helper method to extract entities from probabilities
337+
fn extract_entities_from_probs(
338+
&self,
339+
probs: &Tensor,
340+
tokens: &[String],
341+
offsets: &[(usize, usize)],
342+
) -> Result<Vec<(String, usize, f32)>> {
343+
let probs_vec = probs.to_vec2::<f32>()?;
344+
let mut results = Vec::new();
230345

231-
// Get predictions for each token
232-
let probabilities = candle_nn::ops::softmax(&logits, 2)?;
233-
let probabilities = probabilities.squeeze(0)?;
234-
let probabilities_vec = probabilities.to_vec2::<f32>()?;
346+
for (token_idx, (token, token_probs)) in tokens.iter().zip(probs_vec.iter()).enumerate() {
347+
if token_idx >= offsets.len() {
348+
break;
349+
}
235350

236-
let mut results = Vec::new();
237-
for (token, probs) in tokens.iter().zip(probabilities_vec.iter()) {
238-
let (predicted_class, &confidence) = probs
351+
let (predicted_class, &confidence) = token_probs
239352
.iter()
240353
.enumerate()
241354
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
242-
.unwrap();
355+
.unwrap_or((0, &0.0));
356+
357+
// Skip padding tokens and special tokens
358+
if token.starts_with("[PAD]")
359+
|| token.starts_with("[CLS]")
360+
|| token.starts_with("[SEP]")
361+
{
362+
continue;
363+
}
243364

244365
results.push((token.clone(), predicted_class, confidence));
245366
}
246367

247368
Ok(results)
248369
}
249370

250-
pub fn classify_tokens_with_spans(
251-
&self,
252-
text: &str,
253-
) -> Result<Vec<(String, usize, f32, usize, usize)>> {
254-
// Tokenize with offset mapping
255-
let encoding = self.tokenizer.encode(text, true).map_err(E::msg)?;
256-
let token_ids = encoding.get_ids().to_vec();
257-
let attention_mask = encoding.get_attention_mask().to_vec();
258-
let tokens = encoding.get_tokens();
259-
let offsets = encoding.get_offsets();
371+
/// True batch processing for token classification - significant performance improvement
372+
pub fn classify_tokens_batch(&self, texts: &[&str]) -> Result<Vec<Vec<(String, usize, f32)>>> {
373+
if texts.is_empty() {
374+
return Ok(Vec::new());
375+
}
260376

261-
// Create tensors
262-
let token_ids = Tensor::new(&token_ids[..], &self.device)?.unsqueeze(0)?;
263-
let token_type_ids = token_ids.zeros_like()?;
264-
let attention_mask = Tensor::new(&attention_mask[..], &self.device)?.unsqueeze(0)?;
377+
// OPTIMIZATION: Use shared tensor creation method
378+
let (token_ids, attention_mask, token_type_ids, encodings) =
379+
self.create_batch_tensors(texts)?;
265380

266-
// Forward pass
381+
// Batch BERT forward pass
267382
let sequence_output =
268383
self.bert
269384
.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;
270385

271-
// Apply token classifier to each token
272-
let logits = self.classifier.forward(&sequence_output)?;
273-
274-
// Get predictions for each token
386+
// Batch token classification
387+
let logits = self.classifier.forward(&sequence_output)?; // (batch_size, seq_len, num_labels)
275388
let probabilities = candle_nn::ops::softmax(&logits, 2)?;
276-
let probabilities = probabilities.squeeze(0)?;
277-
let probabilities_vec = probabilities.to_vec2::<f32>()?;
389+
390+
// OPTIMIZATION: More efficient result extraction
391+
let mut batch_results = Vec::with_capacity(texts.len());
392+
for i in 0..texts.len() {
393+
let encoding = &encodings[i];
394+
let tokens = encoding.get_tokens();
395+
let offsets = encoding.get_offsets();
396+
397+
let text_probs = probabilities.get(i)?; // (seq_len, num_labels)
398+
let text_results = self.extract_entities_from_probs(&text_probs, tokens, offsets)?;
399+
batch_results.push(text_results);
400+
}
401+
402+
Ok(batch_results)
403+
}
404+
405+
/// Single text token classification with span information (for backward compatibility)
406+
pub fn classify_tokens_with_spans(
407+
&self,
408+
text: &str,
409+
) -> Result<Vec<(String, usize, f32, usize, usize)>> {
410+
// Use batch processing for single text
411+
let batch_results = self.classify_tokens_batch(&[text])?;
412+
if batch_results.is_empty() {
413+
return Ok(Vec::new());
414+
}
415+
416+
// Get tokenization info for spans
417+
let encoding = self.tokenizer.encode(text, true).map_err(E::msg)?;
418+
let offsets = encoding.get_offsets();
278419

279420
let mut results = Vec::new();
280-
for ((token, offset), probs) in tokens
281-
.iter()
282-
.zip(offsets.iter())
283-
.zip(probabilities_vec.iter())
284-
{
285-
let (predicted_class, &confidence) = probs
286-
.iter()
287-
.enumerate()
288-
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
289-
.unwrap();
290-
291-
results.push((
292-
token.clone(),
293-
predicted_class,
294-
confidence,
295-
offset.0,
296-
offset.1,
297-
));
421+
for (i, (token, class_id, confidence)) in batch_results[0].iter().enumerate() {
422+
if i < offsets.len() {
423+
let (start_char, end_char) = offsets[i];
424+
results.push((token.clone(), *class_id, *confidence, start_char, end_char));
425+
}
298426
}
299427

300428
Ok(results)
301429
}
430+
431+
/// Single text token classification (for backward compatibility)
432+
pub fn classify_tokens(&self, text: &str) -> Result<Vec<(String, usize, f32)>> {
433+
// Use batch processing for single text
434+
let batch_results = self.classify_tokens_batch(&[text])?;
435+
if batch_results.is_empty() {
436+
return Ok(Vec::new());
437+
}
438+
439+
Ok(batch_results.into_iter().next().unwrap())
440+
}
302441
}

0 commit comments

Comments
 (0)