Skip to content

Commit c8863bb

Browse files
Copilotjikk
andcommitted
Refactor: simplify model max length detection for better readability
Co-authored-by: jikk <862047+jikk@users.noreply.github.com>
1 parent 621a84d commit c8863bb

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

pylingual/models.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@ def _translate_and_decode(self, translation_requests: TrackedDataset | list[str]
5252
# Check for inputs that exceed the model's maximum input length
5353
# T5 models typically have a max input length of 512 tokens
5454
# Try multiple possible attribute names for max length across different model types
55-
model_max_length = getattr(
56-
self.translator.model.config,
57-
'n_positions',
58-
getattr(
59-
self.translator.model.config,
60-
'max_position_embeddings',
61-
getattr(self.translator.tokenizer, 'model_max_length', 512)
62-
)
63-
)
55+
model_max_length = 512 # Default fallback
56+
for attr in ['n_positions', 'max_position_embeddings']:
57+
if hasattr(self.translator.model.config, attr):
58+
model_max_length = getattr(self.translator.model.config, attr)
59+
break
60+
else:
61+
# If config doesn't have the attributes, try tokenizer
62+
if hasattr(self.translator.tokenizer, 'model_max_length'):
63+
model_max_length = self.translator.tokenizer.model_max_length
6464

6565
for i, request in enumerate(translation_requests):
6666
tokenized = self.translator.tokenizer(request, return_tensors="pt", truncation=False)

0 commit comments

Comments
 (0)