-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathWordAlignment.py
76 lines (63 loc) · 3.69 KB
/
WordAlignment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from typing import List, Tuple
import torch
from torch.nn.functional import cosine_similarity
from torch_scatter import scatter_mean
from transformers import AutoModel, AutoTokenizer
class WordAlignment:
def __init__(self, model_name: str, tokenizer_name: str, device: str, fp16: bool):
assert (device == 'cpu' and not fp16) or (device != 'cpu'), "You can't use fp16 with CPU device."
self.model: AutoModel = AutoModel.from_pretrained(model_name, output_hidden_states=True).to(device)
self.tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.model.eval()
if fp16:
self.model = self.model.half()
def bert_tokenizer(self, sentence: List[str]):
sentence_tokenized = self.tokenizer(" ".join(sentence), return_tensors="pt").to(self.model.device)
indices: List[int] = self.indices_word_pieces(sentence)
return sentence_tokenized, indices
def get_sentence_representation(self, sentence: List[str]) -> torch.Tensor:
encoded_input, indices = self.bert_tokenizer(sentence)
out_bert: torch.Tensor = self.bert_forward(encoded_input)
return scatter_mean(out_bert, index=torch.LongTensor(indices).to(self.model.device), dim=1)
@staticmethod
def obtain_cosine_similarity_matrix(source, target):
return cosine_similarity(source[..., None, :, :], target[..., :, None, :], dim=-1)[0]
def indices_word_pieces(self, sentence: List[str]) -> List[int]:
indices = []
for idx_word, word in enumerate(sentence):
word_tokenized = self.tokenizer.tokenize(word)
for _ in range(len(word_tokenized)):
indices.append(idx_word)
return indices
def __repr__(self):
return f"Bert Model: {self.model.name_or_path} Device: {self.model.device}"
@staticmethod
def mean_pooling_strategy(bert_output: Tuple[torch.Tensor], dimension: int = 4):
return torch.mean(torch.stack(bert_output[:-dimension], dim=-1), dim=-1)[:, 1:-1, :]
def bert_forward(self, encoded_sentence) -> torch.Tensor:
with torch.no_grad():
bert_output: Tuple[torch.Tensor] = self.model(**encoded_sentence)["hidden_states"]
return self.mean_pooling_strategy(bert_output)
@staticmethod
def pad_sentence(source: List[str], target: List[str]) -> None:
diff = abs(len(source) - len(target))
if diff == 0:
return
pad_vector: List[str] = ["[PAD]" for _ in range(diff)]
if len(source) > len(target):
target.extend(pad_vector)
if len(target) > len(source):
source.extend(pad_vector)
def decode(self, indices_align: torch.Tensor, sentence1: List[str], sentence2: List[str]) -> List[List[str]]:
return [[word, sentence2[idx]] for idx, word in zip(indices_align, sentence1)]
def get_alignment(self, first_sentence: List[str], second_sentence: List[str], calculate_decode: bool = True) -> Tuple[List[int], List[List[str]]]:
len_sentence1 = len(first_sentence)
sentence1 = first_sentence.copy()
sentence2 = second_sentence.copy()
self.pad_sentence(sentence1, sentence2)
sentence1_vector: torch.Tensor = self.get_sentence_representation(sentence1)
sentence2_vector: torch.Tensor = self.get_sentence_representation(sentence2)
cosine_similarity_matrix: torch.Tensor = self.obtain_cosine_similarity_matrix(sentence1_vector, sentence2_vector)
indices_align: torch.Tensor = torch.argmax(cosine_similarity_matrix.T, dim=-1)
decoded: List[List[str]] = self.decode(indices_align, sentence1, sentence2)[:len_sentence1] if calculate_decode else None
return indices_align[:len_sentence1], decoded