@@ -52,15 +52,15 @@ def _translate_and_decode(self, translation_requests: TrackedDataset | list[str]
5252 # Check for inputs that exceed the model's maximum input length
5353 # T5 models typically have a max input length of 512 tokens
5454 # Try multiple possible attribute names for max length across different model types
55- model_max_length = getattr (
56- self . translator . model . config ,
57- 'n_positions' ,
58- getattr (
59- self . translator . model . config ,
60- 'max_position_embeddings' ,
61- getattr ( self . translator . tokenizer , 'model_max_length' , 512 )
62- )
63- )
55+ model_max_length = 512 # Default fallback
56+ for attr in [ 'n_positions' , 'max_position_embeddings' ]:
57+ if hasattr ( self . translator . model . config , attr ):
58+ model_max_length = getattr (self . translator . model . config , attr )
59+ break
60+ else :
61+ # If config doesn't have the attributes, try tokenizer
62+ if hasattr ( self . translator . tokenizer , 'model_max_length' ):
63+ model_max_length = self . translator . tokenizer . model_max_length
6464
6565 for i , request in enumerate (translation_requests ):
6666 tokenized = self .translator .tokenizer (request , return_tensors = "pt" , truncation = False )
0 commit comments