Skip to content

Commit 4585d11

Browse files
committed
Reapply "fix(lora): add explicit tokenizer truncation to handle inputs >512 tokens"
This reverts commit c1d68b6.
1 parent c1d68b6 commit 4585d11

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

candle-binding/src/core/tokenization.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,19 @@ impl DualPathTokenizer for UnifiedTokenizer {
387387
let encoding = tokenizer
388388
.encode(text, self.config.add_special_tokens)
389389
.map_err(E::msg)?;
390-
Ok(self.encoding_to_result(&encoding))
390+
391+
// Explicitly enforce max_length truncation for LoRA models
392+
// This is a safety check to ensure we never exceed the model's position embedding size
393+
let mut result = self.encoding_to_result(&encoding);
394+
let max_len = self.config.max_length;
395+
if result.token_ids.len() > max_len {
396+
result.token_ids.truncate(max_len);
397+
result.token_ids_u32.truncate(max_len);
398+
result.attention_mask.truncate(max_len);
399+
result.tokens.truncate(max_len);
400+
}
401+
402+
Ok(result)
391403
}
392404

393405
fn tokenize_batch_smart(

candle-binding/src/model_architectures/lora/bert_lora.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,9 +499,18 @@ impl HighPerformanceBertClassifier {
499499

500500
// Load tokenizer
501501
let tokenizer_path = Path::new(model_path).join("tokenizer.json");
502-
let tokenizer = Tokenizer::from_file(&tokenizer_path)
502+
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
503503
.map_err(|e| E::msg(format!("Failed to load tokenizer: {}", e)))?;
504504

505+
// Configure truncation to max 512 tokens (BERT's position embedding limit)
506+
use tokenizers::TruncationParams;
507+
tokenizer
508+
.with_truncation(Some(TruncationParams {
509+
max_length: 512,
510+
..Default::default()
511+
}))
512+
.map_err(E::msg)?;
513+
505514
// Load model weights
506515
let weights_path = if Path::new(model_path).join("model.safetensors").exists() {
507516
Path::new(model_path).join("model.safetensors")
@@ -690,9 +699,18 @@ impl HighPerformanceBertTokenClassifier {
690699

691700
// Load tokenizer
692701
let tokenizer_path = Path::new(model_path).join("tokenizer.json");
693-
let tokenizer = Tokenizer::from_file(&tokenizer_path)
702+
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
694703
.map_err(|e| E::msg(format!("Failed to load tokenizer: {}", e)))?;
695704

705+
// Configure truncation to max 512 tokens (BERT's position embedding limit)
706+
use tokenizers::TruncationParams;
707+
tokenizer
708+
.with_truncation(Some(TruncationParams {
709+
max_length: 512,
710+
..Default::default()
711+
}))
712+
.map_err(E::msg)?;
713+
696714
// Load model weights
697715
let weights_path = if Path::new(model_path).join("model.safetensors").exists() {
698716
Path::new(model_path).join("model.safetensors")

0 commit comments

Comments
 (0)