forked from ibatra/BERT-Keyword-Extractor
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkeyword-extractor.py
More file actions
40 lines (33 loc) · 1.58 KB
/
keyword-extractor.py
File metadata and controls
40 lines (33 loc) · 1.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from pytorch_pretrained_bert import BertTokenizer, BertConfig
from pytorch_pretrained_bert import BertForTokenClassification, BertAdam
import torch
import argparse
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser(description='BERT Keyword Extractor')
parser.add_argument('--sentence', type=str, default=' ',
help='sentence to get keywords')
parser.add_argument('--path', type=str, default='model.pt',
help='path to load model')
args = parser.parse_args()
tag2idx = {'B': 0, 'I': 1, 'O': 2}
tags_vals = ['B', 'I', 'O']
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
model = BertForTokenClassification.from_pretrained("bert-base-uncased", num_labels=len(tag2idx))
def keywordextract(sentence, path):
text = sentence
tkns = tokenizer.tokenize(text)
indexed_tokens = tokenizer.convert_tokens_to_ids(tkns)
segments_ids = [0] * len(tkns)
tokens_tensor = torch.tensor([indexed_tokens]).to(device)
segments_tensors = torch.tensor([segments_ids]).to(device)
model = torch.load(path)
model.eval()
prediction = []
logit = model(tokens_tensor, token_type_ids=None,
attention_mask=segments_tensors)
logit = logit.detach().cpu().numpy()
prediction.extend([list(p) for p in np.argmax(logit, axis=2)])
for k, j in enumerate(prediction[0]):
if j==1 or j==0:
print(tokenizer.convert_ids_to_tokens(tokens_tensor[0].to('cpu').numpy())[k], j)
keywordextract(args.sentence, args.path)