Skip to content

Commit f0c77a0

Browse files
authored
Remove warning about attn_implementation for new transformers versions (#63)
1 parent fffc18a commit f0c77a0

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

rxnmapper/core.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import logging
66
import os
7+
from contextlib import contextmanager
78
from typing import Any, Dict, Iterator, List, Optional, Tuple
89

910
import numpy as np
@@ -28,6 +29,17 @@
2829
_logger.addHandler(logging.NullHandler())
2930

3031

32+
@contextmanager
33+
def suppress_transformers_warnings():
34+
logger = logging.getLogger("transformers")
35+
previous_level = logger.level
36+
logger.setLevel(logging.ERROR)
37+
try:
38+
yield
39+
finally:
40+
logger.setLevel(previous_level)
41+
42+
3143
class RXNMapper:
3244
"""Wrap the Transformer model, corresponding tokenizer, and attention scoring algorithms.
3345
@@ -136,8 +148,11 @@ def convert_batch_to_attns(
136148
f"Reaction SMILES has {max_input_length} tokens, should be at most {max_supported_by_model}."
137149
)
138150

139-
with torch.no_grad():
140-
output = self.model(**parsed_input)
151+
# suppress warning that suggests setting "attn_implementation"; doing
152+
# so would break compatibility for old `transformers` versions.
153+
with suppress_transformers_warnings():
154+
with torch.no_grad():
155+
output = self.model(**parsed_input)
141156
attentions = output[2]
142157
selected_attns = torch.cat(
143158
[a.unsqueeze(1) for i, a in enumerate(attentions) if i in use_layers],

0 commit comments

Comments
 (0)