@@ -18,6 +18,48 @@ pub struct CandleBertClassifier {
1818} 
1919
2020impl  CandleBertClassifier  { 
21+     /// Shared helper method for efficient batch tensor creation 
22+      fn  create_batch_tensors ( 
23+         & self , 
24+         texts :  & [ & str ] , 
25+     )  -> Result < ( Tensor ,  Tensor ,  Tensor ,  Vec < tokenizers:: Encoding > ) >  { 
26+         let  encodings = self 
27+             . tokenizer 
28+             . encode_batch ( texts. to_vec ( ) ,  true ) 
29+             . map_err ( E :: msg) ?; 
30+ 
31+         let  batch_size = texts. len ( ) ; 
32+         let  max_len = encodings
33+             . iter ( ) 
34+             . map ( |enc| enc. get_ids ( ) . len ( ) ) 
35+             . max ( ) 
36+             . unwrap_or ( 0 ) ; 
37+ 
38+         let  total_elements = batch_size *  max_len; 
39+         let  mut  all_token_ids = Vec :: with_capacity ( total_elements) ; 
40+         let  mut  all_attention_masks = Vec :: with_capacity ( total_elements) ; 
41+ 
42+         for  encoding in  & encodings { 
43+             let  token_ids = encoding. get_ids ( ) ; 
44+             let  attention_mask = encoding. get_attention_mask ( ) ; 
45+ 
46+             all_token_ids. extend_from_slice ( token_ids) ; 
47+             all_attention_masks. extend_from_slice ( attention_mask) ; 
48+ 
49+             let  padding_needed = max_len - token_ids. len ( ) ; 
50+             all_token_ids. extend ( std:: iter:: repeat ( 0 ) . take ( padding_needed) ) ; 
51+             all_attention_masks. extend ( std:: iter:: repeat ( 0 ) . take ( padding_needed) ) ; 
52+         } 
53+ 
54+         let  token_ids =
55+             Tensor :: new ( all_token_ids. as_slice ( ) ,  & self . device ) ?. reshape ( & [ batch_size,  max_len] ) ?; 
56+         let  attention_mask = Tensor :: new ( all_attention_masks. as_slice ( ) ,  & self . device ) ?
57+             . reshape ( & [ batch_size,  max_len] ) ?; 
58+         let  token_type_ids = Tensor :: zeros ( & [ batch_size,  max_len] ,  DType :: U32 ,  & self . device ) ?; 
59+ 
60+         Ok ( ( token_ids,  attention_mask,  token_type_ids,  encodings) ) 
61+     } 
62+ 
2163    pub  fn  new ( model_path :  & str ,  num_classes :  usize ,  use_cpu :  bool )  -> Result < Self >  { 
2264        let  device = if  use_cpu { 
2365            Device :: Cpu 
@@ -137,6 +179,47 @@ impl CandleBertClassifier {
137179
138180        Ok ( ( predicted_class,  confidence) ) 
139181    } 
182+ 
183+     /// True batch processing for multiple texts - significant performance improvement 
184+      pub  fn  classify_batch ( & self ,  texts :  & [ & str ] )  -> Result < Vec < ( usize ,  f32 ) > >  { 
185+         if  texts. is_empty ( )  { 
186+             return  Ok ( Vec :: new ( ) ) ; 
187+         } 
188+ 
189+         // OPTIMIZATION: Use shared tensor creation method 
190+         let  ( token_ids,  attention_mask,  token_type_ids,  _encodings)  =
191+             self . create_batch_tensors ( texts) ?; 
192+ 
193+         // Batch BERT forward pass 
194+         let  sequence_output =
195+             self . bert 
196+                 . forward ( & token_ids,  & token_type_ids,  Some ( & attention_mask) ) ?; 
197+ 
198+         // OPTIMIZATION: Use proper CLS token pooling instead of mean pooling 
199+         let  cls_tokens = sequence_output. i ( ( ..,  0 ) ) ?;  // Extract CLS tokens for all samples 
200+         let  pooled_output = self . pooler . forward ( & cls_tokens) ?; 
201+         let  pooled_output = pooled_output. tanh ( ) ?; 
202+ 
203+         let  logits = self . classifier . forward ( & pooled_output) ?; 
204+         let  probabilities = candle_nn:: ops:: softmax ( & logits,  1 ) ?; 
205+ 
206+         // OPTIMIZATION: Batch result extraction 
207+         let  probs_data = probabilities. to_vec2 :: < f32 > ( ) ?; 
208+         let  mut  results = Vec :: with_capacity ( texts. len ( ) ) ; 
209+ 
210+         for  row in  probs_data { 
211+             let  ( predicted_class,  confidence)  = row
212+                 . iter ( ) 
213+                 . enumerate ( ) 
214+                 . max_by ( |( _,  a) ,  ( _,  b) | a. partial_cmp ( b) . unwrap ( ) ) 
215+                 . map ( |( idx,  & conf) | ( idx,  conf) ) 
216+                 . unwrap_or ( ( 0 ,  0.0 ) ) ; 
217+ 
218+             results. push ( ( predicted_class,  confidence) ) ; 
219+         } 
220+ 
221+         Ok ( results) 
222+     } 
140223} 
141224
142225/// BERT token classifier for PII detection 
@@ -148,6 +231,48 @@ pub struct CandleBertTokenClassifier {
148231} 
149232
150233impl  CandleBertTokenClassifier  { 
234+     /// Shared helper method for efficient batch tensor creation 
235+      fn  create_batch_tensors ( 
236+         & self , 
237+         texts :  & [ & str ] , 
238+     )  -> Result < ( Tensor ,  Tensor ,  Tensor ,  Vec < tokenizers:: Encoding > ) >  { 
239+         let  encodings = self 
240+             . tokenizer 
241+             . encode_batch ( texts. to_vec ( ) ,  true ) 
242+             . map_err ( E :: msg) ?; 
243+ 
244+         let  batch_size = texts. len ( ) ; 
245+         let  max_len = encodings
246+             . iter ( ) 
247+             . map ( |enc| enc. get_ids ( ) . len ( ) ) 
248+             . max ( ) 
249+             . unwrap_or ( 0 ) ; 
250+ 
251+         let  total_elements = batch_size *  max_len; 
252+         let  mut  all_token_ids = Vec :: with_capacity ( total_elements) ; 
253+         let  mut  all_attention_masks = Vec :: with_capacity ( total_elements) ; 
254+ 
255+         for  encoding in  & encodings { 
256+             let  token_ids = encoding. get_ids ( ) ; 
257+             let  attention_mask = encoding. get_attention_mask ( ) ; 
258+ 
259+             all_token_ids. extend_from_slice ( token_ids) ; 
260+             all_attention_masks. extend_from_slice ( attention_mask) ; 
261+ 
262+             let  padding_needed = max_len - token_ids. len ( ) ; 
263+             all_token_ids. extend ( std:: iter:: repeat ( 0 ) . take ( padding_needed) ) ; 
264+             all_attention_masks. extend ( std:: iter:: repeat ( 0 ) . take ( padding_needed) ) ; 
265+         } 
266+ 
267+         let  token_ids =
268+             Tensor :: new ( all_token_ids. as_slice ( ) ,  & self . device ) ?. reshape ( & [ batch_size,  max_len] ) ?; 
269+         let  attention_mask = Tensor :: new ( all_attention_masks. as_slice ( ) ,  & self . device ) ?
270+             . reshape ( & [ batch_size,  max_len] ) ?; 
271+         let  token_type_ids = Tensor :: zeros ( & [ batch_size,  max_len] ,  DType :: U32 ,  & self . device ) ?; 
272+ 
273+         Ok ( ( token_ids,  attention_mask,  token_type_ids,  encodings) ) 
274+     } 
275+ 
151276    pub  fn  new ( model_path :  & str ,  num_classes :  usize ,  use_cpu :  bool )  -> Result < Self >  { 
152277        let  device = if  use_cpu { 
153278            Device :: Cpu 
@@ -208,95 +333,109 @@ impl CandleBertTokenClassifier {
208333        } ) 
209334    } 
210335
211-     pub  fn  classify_tokens ( & self ,  text :  & str )  -> Result < Vec < ( String ,  usize ,  f32 ) > >  { 
212-         // Tokenize 
213-         let  encoding = self . tokenizer . encode ( text,  true ) . map_err ( E :: msg) ?; 
214-         let  token_ids = encoding. get_ids ( ) . to_vec ( ) ; 
215-         let  attention_mask = encoding. get_attention_mask ( ) . to_vec ( ) ; 
216-         let  tokens = encoding. get_tokens ( ) ; 
217- 
218-         // Create tensors 
219-         let  token_ids = Tensor :: new ( & token_ids[ ..] ,  & self . device ) ?. unsqueeze ( 0 ) ?; 
220-         let  token_type_ids = token_ids. zeros_like ( ) ?; 
221-         let  attention_mask = Tensor :: new ( & attention_mask[ ..] ,  & self . device ) ?. unsqueeze ( 0 ) ?; 
222- 
223-         // Forward pass 
224-         let  sequence_output =
225-             self . bert 
226-                 . forward ( & token_ids,  & token_type_ids,  Some ( & attention_mask) ) ?; 
227- 
228-         // Apply token classifier to each token 
229-         let  logits = self . classifier . forward ( & sequence_output) ?; 
336+     /// Helper method to extract entities from probabilities 
337+      fn  extract_entities_from_probs ( 
338+         & self , 
339+         probs :  & Tensor , 
340+         tokens :  & [ String ] , 
341+         offsets :  & [ ( usize ,  usize ) ] , 
342+     )  -> Result < Vec < ( String ,  usize ,  f32 ) > >  { 
343+         let  probs_vec = probs. to_vec2 :: < f32 > ( ) ?; 
344+         let  mut  results = Vec :: new ( ) ; 
230345
231-         // Get predictions for each token 
232-         let  probabilities = candle_nn :: ops :: softmax ( & logits ,   2 ) ? ; 
233-         let  probabilities = probabilities . squeeze ( 0 ) ? ; 
234-         let  probabilities_vec = probabilities . to_vec2 :: < f32 > ( ) ? ; 
346+         for   ( token_idx ,   ( token ,  token_probs ) )   in  tokens . iter ( ) . zip ( probs_vec . iter ( ) ) . enumerate ( )   { 
347+              if  token_idx >= offsets . len ( )   { 
348+                  break ; 
349+              } 
235350
236-         let  mut  results = Vec :: new ( ) ; 
237-         for  ( token,  probs)  in  tokens. iter ( ) . zip ( probabilities_vec. iter ( ) )  { 
238-             let  ( predicted_class,  & confidence)  = probs
351+             let  ( predicted_class,  & confidence)  = token_probs
239352                . iter ( ) 
240353                . enumerate ( ) 
241354                . max_by ( |( _,  a) ,  ( _,  b) | a. partial_cmp ( b) . unwrap ( ) ) 
242-                 . unwrap ( ) ; 
355+                 . unwrap_or ( ( 0 ,  & 0.0 ) ) ; 
356+ 
357+             // Skip padding tokens and special tokens 
358+             if  token. starts_with ( "[PAD]" ) 
359+                 || token. starts_with ( "[CLS]" ) 
360+                 || token. starts_with ( "[SEP]" ) 
361+             { 
362+                 continue ; 
363+             } 
243364
244365            results. push ( ( token. clone ( ) ,  predicted_class,  confidence) ) ; 
245366        } 
246367
247368        Ok ( results) 
248369    } 
249370
250-     pub  fn  classify_tokens_with_spans ( 
251-         & self , 
252-         text :  & str , 
253-     )  -> Result < Vec < ( String ,  usize ,  f32 ,  usize ,  usize ) > >  { 
254-         // Tokenize with offset mapping 
255-         let  encoding = self . tokenizer . encode ( text,  true ) . map_err ( E :: msg) ?; 
256-         let  token_ids = encoding. get_ids ( ) . to_vec ( ) ; 
257-         let  attention_mask = encoding. get_attention_mask ( ) . to_vec ( ) ; 
258-         let  tokens = encoding. get_tokens ( ) ; 
259-         let  offsets = encoding. get_offsets ( ) ; 
371+     /// True batch processing for token classification - significant performance improvement 
372+      pub  fn  classify_tokens_batch ( & self ,  texts :  & [ & str ] )  -> Result < Vec < Vec < ( String ,  usize ,  f32 ) > > >  { 
373+         if  texts. is_empty ( )  { 
374+             return  Ok ( Vec :: new ( ) ) ; 
375+         } 
260376
261-         // Create tensors 
262-         let  token_ids = Tensor :: new ( & token_ids[ ..] ,  & self . device ) ?. unsqueeze ( 0 ) ?; 
263-         let  token_type_ids = token_ids. zeros_like ( ) ?; 
264-         let  attention_mask = Tensor :: new ( & attention_mask[ ..] ,  & self . device ) ?. unsqueeze ( 0 ) ?; 
377+         // OPTIMIZATION: Use shared tensor creation method 
378+         let  ( token_ids,  attention_mask,  token_type_ids,  encodings)  =
379+             self . create_batch_tensors ( texts) ?; 
265380
266-         // Forward  pass 
381+         // Batch BERT forward  pass 
267382        let  sequence_output =
268383            self . bert 
269384                . forward ( & token_ids,  & token_type_ids,  Some ( & attention_mask) ) ?; 
270385
271-         // Apply token classifier to each token 
272-         let  logits = self . classifier . forward ( & sequence_output) ?; 
273- 
274-         // Get predictions for each token 
386+         // Batch token classification 
387+         let  logits = self . classifier . forward ( & sequence_output) ?;  // (batch_size, seq_len, num_labels) 
275388        let  probabilities = candle_nn:: ops:: softmax ( & logits,  2 ) ?; 
276-         let  probabilities = probabilities. squeeze ( 0 ) ?; 
277-         let  probabilities_vec = probabilities. to_vec2 :: < f32 > ( ) ?; 
389+ 
390+         // OPTIMIZATION: More efficient result extraction 
391+         let  mut  batch_results = Vec :: with_capacity ( texts. len ( ) ) ; 
392+         for  i in  0 ..texts. len ( )  { 
393+             let  encoding = & encodings[ i] ; 
394+             let  tokens = encoding. get_tokens ( ) ; 
395+             let  offsets = encoding. get_offsets ( ) ; 
396+ 
397+             let  text_probs = probabilities. get ( i) ?;  // (seq_len, num_labels) 
398+             let  text_results = self . extract_entities_from_probs ( & text_probs,  tokens,  offsets) ?; 
399+             batch_results. push ( text_results) ; 
400+         } 
401+ 
402+         Ok ( batch_results) 
403+     } 
404+ 
405+     /// Single text token classification with span information (for backward compatibility) 
406+      pub  fn  classify_tokens_with_spans ( 
407+         & self , 
408+         text :  & str , 
409+     )  -> Result < Vec < ( String ,  usize ,  f32 ,  usize ,  usize ) > >  { 
410+         // Use batch processing for single text 
411+         let  batch_results = self . classify_tokens_batch ( & [ text] ) ?; 
412+         if  batch_results. is_empty ( )  { 
413+             return  Ok ( Vec :: new ( ) ) ; 
414+         } 
415+ 
416+         // Get tokenization info for spans 
417+         let  encoding = self . tokenizer . encode ( text,  true ) . map_err ( E :: msg) ?; 
418+         let  offsets = encoding. get_offsets ( ) ; 
278419
279420        let  mut  results = Vec :: new ( ) ; 
280-         for  ( ( token,  offset) ,  probs)  in  tokens
281-             . iter ( ) 
282-             . zip ( offsets. iter ( ) ) 
283-             . zip ( probabilities_vec. iter ( ) ) 
284-         { 
285-             let  ( predicted_class,  & confidence)  = probs
286-                 . iter ( ) 
287-                 . enumerate ( ) 
288-                 . max_by ( |( _,  a) ,  ( _,  b) | a. partial_cmp ( b) . unwrap ( ) ) 
289-                 . unwrap ( ) ; 
290- 
291-             results. push ( ( 
292-                 token. clone ( ) , 
293-                 predicted_class, 
294-                 confidence, 
295-                 offset. 0 , 
296-                 offset. 1 , 
297-             ) ) ; 
421+         for  ( i,  ( token,  class_id,  confidence) )  in  batch_results[ 0 ] . iter ( ) . enumerate ( )  { 
422+             if  i < offsets. len ( )  { 
423+                 let  ( start_char,  end_char)  = offsets[ i] ; 
424+                 results. push ( ( token. clone ( ) ,  * class_id,  * confidence,  start_char,  end_char) ) ; 
425+             } 
298426        } 
299427
300428        Ok ( results) 
301429    } 
430+ 
431+     /// Single text token classification (for backward compatibility) 
432+      pub  fn  classify_tokens ( & self ,  text :  & str )  -> Result < Vec < ( String ,  usize ,  f32 ) > >  { 
433+         // Use batch processing for single text 
434+         let  batch_results = self . classify_tokens_batch ( & [ text] ) ?; 
435+         if  batch_results. is_empty ( )  { 
436+             return  Ok ( Vec :: new ( ) ) ; 
437+         } 
438+ 
439+         Ok ( batch_results. into_iter ( ) . next ( ) . unwrap ( ) ) 
440+     } 
302441} 
0 commit comments