Skip to content

Commit cf8691b

Browse files
committed
fix(lora): add explicit tokenizer truncation to handle inputs >512 tokens
This commit fixes LoRA tokenization errors that occurred when processing inputs exceeding 512 tokens, which caused "index-select invalid index 512 with dim size 512" errors and resulted in empty predictions. Changes: - Added explicit truncation configuration to BertLoRAClassifier tokenizer - Added safety check in UnifiedTokenizer::tokenize_for_lora() - Ensures all inputs are properly truncated to BERT's 512 token limit Test results: - LoRA accuracy improved from ~40% (with empty predictions) to 80.36% - 0 tokenization errors on 280 MMLU-Pro test cases - 0 empty predictions Fixes the accuracy regression reported in vllm-project#726 Signed-off-by: Yossi Ovadia <[email protected]>
1 parent e62acbf commit cf8691b

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

candle-binding/src/core/tokenization.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,19 @@ impl DualPathTokenizer for UnifiedTokenizer {
387387
let encoding = tokenizer
388388
.encode(text, self.config.add_special_tokens)
389389
.map_err(E::msg)?;
390-
Ok(self.encoding_to_result(&encoding))
390+
391+
// Explicitly enforce max_length truncation for LoRA models
392+
// This is a safety check to ensure we never exceed the model's position embedding size
393+
let mut result = self.encoding_to_result(&encoding);
394+
let max_len = self.config.max_length;
395+
if result.token_ids.len() > max_len {
396+
result.token_ids.truncate(max_len);
397+
result.token_ids_u32.truncate(max_len);
398+
result.attention_mask.truncate(max_len);
399+
result.tokens.truncate(max_len);
400+
}
401+
402+
Ok(result)
391403
}
392404

393405
fn tokenize_batch_smart(

candle-binding/src/model_architectures/lora/bert_lora.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,9 +499,18 @@ impl HighPerformanceBertClassifier {
499499

500500
// Load tokenizer
501501
let tokenizer_path = Path::new(model_path).join("tokenizer.json");
502-
let tokenizer = Tokenizer::from_file(&tokenizer_path)
502+
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
503503
.map_err(|e| E::msg(format!("Failed to load tokenizer: {}", e)))?;
504504

505+
// Configure truncation to max 512 tokens (BERT's position embedding limit)
506+
use tokenizers::TruncationParams;
507+
tokenizer
508+
.with_truncation(Some(TruncationParams {
509+
max_length: 512,
510+
..Default::default()
511+
}))
512+
.map_err(E::msg)?;
513+
505514
// Load model weights
506515
let weights_path = if Path::new(model_path).join("model.safetensors").exists() {
507516
Path::new(model_path).join("model.safetensors")
@@ -690,9 +699,18 @@ impl HighPerformanceBertTokenClassifier {
690699

691700
// Load tokenizer
692701
let tokenizer_path = Path::new(model_path).join("tokenizer.json");
693-
let tokenizer = Tokenizer::from_file(&tokenizer_path)
702+
let mut tokenizer = Tokenizer::from_file(&tokenizer_path)
694703
.map_err(|e| E::msg(format!("Failed to load tokenizer: {}", e)))?;
695704

705+
// Configure truncation to max 512 tokens (BERT's position embedding limit)
706+
use tokenizers::TruncationParams;
707+
tokenizer
708+
.with_truncation(Some(TruncationParams {
709+
max_length: 512,
710+
..Default::default()
711+
}))
712+
.map_err(E::msg)?;
713+
696714
// Load model weights
697715
let weights_path = if Path::new(model_path).join("model.safetensors").exists() {
698716
Path::new(model_path).join("model.safetensors")

0 commit comments

Comments
 (0)