-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
199 lines (166 loc) · 7.12 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import os
import re
import glob
import random
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from dataclasses import dataclass
from typing import Dict, List, Tuple, Any, Union
import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
from torch.utils.data import Dataset
# import wandb
import evaluate
import transformers
from transformers import Trainer, Seq2SeqTrainingArguments
# from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration
from banglanlptoolkit import BnNLPNormalizer
normalizer = BnNLPNormalizer()
def get_latest_checkpoint(model_folder):
number = []
for all in os.listdir(model_folder):
if all.startswith('checkpoint'):
number.append(all.split('-')[-1])
return f'checkpoint-{max(number)}'
train_root_dir = 'ben10/ben10/16_kHz_train_audio'
valid_root_dir = 'ben10/ben10/16_kHz_valid_audio'
max_input_length = 30*16000
def get_path(name):
if name.startswith('train'):
return os.path.join(train_root_dir, name)
elif name.startswith('valid'):
return os.path.joint(valid_root_dir, name)
class SprintDataset(Dataset):
def __init__(self, df, processor, audioConverter, feature_extractor, loopDataset=1):
self.df = df
self.paths = df['file_name']
self.sentences = df['transcripts']
self.len = len(self.df) * loopDataset
self.processor = processor
self.ac = audioConverter
self.feature_extractor = feature_extractor
def __len__(self):
return self.len
def loadSample(self, idx):
idx %= len(self.df)
audio_path = self.paths[idx]
sentence = self.sentences[idx]
waves = [torch.from_numpy(self.feature_extractor(self.ac.getAudio(get_path(audio_path))[0], sampling_rate=16000, max_length=max_input_length).input_features[0])]
input_values = torch.cat(waves, axis = 0) #[0]
input_length = len(input_values)
labels = self.processor.tokenizer(normalizer.normalize_bn([sentence], punct_replacement_token='')[0].replace('<>', ''), max_length=448).input_ids
return {
'input_features':input_values,
'input_length':input_length,
'labels':labels
}
def __getitem__(self, idx):
if idx >= self.len:
raise IndexError('index out of range')
return self.loadSample(idx)
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
STANDARDIZE_ZW = re.compile(r'(?<=\u09b0)[\u200c\u200d]+(?=\u09cd\u09af)')
DELETE_ZW = re.compile(r'(?<!\u09b0)[\u200c\u200d](?!\u09cd\u09af)')
PUNC = re.compile(r'([\?\.।;:,!])')
def removeOptionalZW(text):
"""
Removes all optional occurrences of ZWNJ or ZWJ from Bangla text.
"""
text = STANDARDIZE_ZW.sub('\u200D', text)
text = DELETE_ZW.sub('', text)
return text
def separatePunc(text):
"""
Checks for punctuation puts a space between the punctuation
and the adjacent word.
"""
text = PUNC.sub(r" \1 ", text)
text = " ".join(text.split())
return text
def removePunc(text):
"""
Remove for punctuations from text.
"""
text = PUNC.sub(r"", text)
return text
def remove_multiple_strings(cur_string):
for cur_word in ['"', "'", '”', '\u200d']:
cur_string = cur_string.replace(cur_word, '')
for cur_word in ['-', '—']:
cur_string = cur_string.replace(cur_word, ' ')
return cur_string
def normalizeUnicodextra(text):
"""
Normalizes unicode strings using the Normalization Form Canonical
Composition (NFC) scheme where we first decompose all characters and then
re-compose combining sequences in a specific order as defined by the
standard in unicodedata module. Finally all zero-width joiners are
removed.
"""
text = text.replace(u"\u098c", u"\u09ef")
text = remove_multiple_strings(text)
text = removeOptionalZW(text)
text = removePunc(text)
return text
def normalizeUnicode(text):
"""
Normalizes unicode strings using the Normalization Form Canonical
Composition (NFC) scheme where we first decompose all characters and then
re-compose combining sequences in a specific order as defined by the
standard in unicodedata module. Finally all zero-width joiners are
removed.
"""
text = text.replace(u"\u098c", u"\u09ef")
text = removeOptionalZW(text)
text = removePunc(text)
return text
wer_metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")
def compute_metrics_wav2vec2(processor, pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
# We do not want to group tokens when computing the metrics
pred_str = processor.batch_decode(pred_ids)
print(f'This is prediction: {pred_str}')
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
# print(f'This is prediction: {label_str}')
wer = wer_metric.compute(predictions=pred_str, references=label_str)
cer = cer_metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer, "cer": cer}
def compute_metrics_whisper(tokenizer, pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
# print(f'This is prediction: {pred_str}')
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = wer_metric.compute(predictions=pred_str, references=label_str)
cer = cer_metric.compute(predictions=pred_str, references=label_str)
print(f"At this evaluation, WER is: {wer} and CER is: {cer}")
return {"wer": wer, "cer": cer}