-
-
Notifications
You must be signed in to change notification settings - Fork 25
Expand file tree
/
Copy pathwsd.py
More file actions
84 lines (70 loc) · 2.75 KB
/
wsd.py
File metadata and controls
84 lines (70 loc) · 2.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import Any
def load_wsd_model(device: str):
from transformers import AutoModel, AutoTokenizer
model_name = "google-bert/bert-base-multilingual-cased"
model = AutoModel.from_pretrained(
model_name,
output_hidden_states=True,
dtype="auto",
device_map=None if device == "cpu" else "auto",
)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer
def sentence_embedding(model, tokenizer, sentences: list[str]):
import torch
encodings = tokenizer(
sentences,
return_offsets_mapping=True,
return_tensors="pt",
padding=True,
truncation=True,
).to(model.device)
with torch.no_grad():
hidden_states = model(
input_ids=encodings["input_ids"], attention_mask=encodings["attention_mask"]
).hidden_states
# remove special tokens and padding
masks = []
for sen_offsets in encodings["offset_mapping"]:
masks.append(
[0 if t_offset.tolist() == [0, 0] else 1 for t_offset in sen_offsets]
)
masks = torch.tensor(masks, device=model.device)
# https://github.com/danlou/LMMS/blob/master/data/weights/lmms-sp-wsd.bert-large-cased.weights.txt
weights = torch.tensor(
[0.01473, 0.05975, 0.36144, 0.53920], dtype=torch.float32, device=model.device
).view(4, 1, 1, 1)
weight_sum_layers = (weights * torch.stack(hidden_states[-5:-1])).sum(dim=1)
sent_embeds = []
filtered_offsets = []
for sent_embed, offset, mask in zip(
weight_sum_layers, encodings["offset_mapping"], masks
):
sent_embeds.append(sent_embed[mask.bool()])
filtered_offsets.append(offset[mask.bool()])
return sent_embeds, filtered_offsets
EMBED_CACHE: dict[str, tuple[Any, Any]] = {}
def wsd(model, tokenizer, sent: str, word_offset: tuple[int, int], sense_embeds) -> int:
import numpy as np
if sent in EMBED_CACHE:
batch_embeds, batch_offsets = EMBED_CACHE[sent]
else:
batch_embeds, batch_offsets = sentence_embedding(model, tokenizer, [sent])
EMBED_CACHE.clear()
EMBED_CACHE[sent] = (batch_embeds, batch_offsets)
word_start, word_end = word_offset
vec = []
for embed, (token_start, token_end) in zip(batch_embeds[0], batch_offsets[0]):
if token_start < word_end and token_end > word_start:
vec.append(embed.cpu().numpy())
if len(vec) == 0:
return 0
target_embedding = np.array(vec).mean(axis=0)
target_embedding /= np.linalg.norm(target_embedding)
sense_embeds = [
np.array(list(map(float, sense_embed.split())), dtype=np.float32)
for sense_embed in sense_embeds
]
sims = np.dot(sense_embeds, target_embedding)
return sims.argmax()