Skip to content

Commit

Permalink
Merge pull request #319 from makaveli10/upgrade_silero_vad_v5
Browse files Browse the repository at this point in the history
Upgrade silero vad v5
  • Loading branch information
makaveli10 authored Jan 13, 2025
2 parents 953a88c + 5e4589c commit c1b249a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
2 changes: 1 addition & 1 deletion whisper_live/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ def update_timestamp_offset(self, last_segment, duration):
elif self.transcript[-1]["text"].strip() != last_segment:
self.transcript.append({"text": last_segment + " "})

with self.lock():
with self.lock:
self.timestamp_offset += duration

def speech_to_text(self):
Expand Down
45 changes: 30 additions & 15 deletions whisper_live/vad.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# original: https://github.com/snakers4/silero-vad/blob/master/utils_vad.py

import os
import subprocess
import torch
import numpy as np
import onnxruntime
import warnings


class VoiceActivityDetection():
Expand All @@ -24,7 +23,11 @@ def __init__(self, force_onnx_cpu=True):
self.session = onnxruntime.InferenceSession(path, providers=['CUDAExecutionProvider'], sess_options=opts)

self.reset_states()
self.sample_rates = [8000, 16000]
if '16k' in path:
warnings.warn('This model support only 16000 sampling rate!')
self.sample_rates = [16000]
else:
self.sample_rates = [8000, 16000]

def _validate_input(self, x, sr: int):
if x.dim() == 1:
Expand All @@ -34,27 +37,32 @@ def _validate_input(self, x, sr: int):

if sr != 16000 and (sr % 16000 == 0):
step = sr // 16000
x = x[:, ::step]
x = x[:,::step]
sr = 16000

if sr not in self.sample_rates:
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")

if sr / x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")

return x, sr

def reset_states(self, batch_size=1):
self._h = np.zeros((2, batch_size, 64)).astype('float32')
self._c = np.zeros((2, batch_size, 64)).astype('float32')
self._state = torch.zeros((2, batch_size, 128)).float()
self._context = torch.zeros(0)
self._last_sr = 0
self._last_batch_size = 0

def __call__(self, x, sr: int):

x, sr = self._validate_input(x, sr)
num_samples = 512 if sr == 16000 else 256

if x.shape[-1] != num_samples:
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")

batch_size = x.shape[0]
context_size = 64 if sr == 16000 else 32

if not self._last_batch_size:
self.reset_states(batch_size)
Expand All @@ -63,28 +71,35 @@ def __call__(self, x, sr: int):
if (self._last_batch_size) and (self._last_batch_size != batch_size):
self.reset_states(batch_size)

if not len(self._context):
self._context = torch.zeros(batch_size, context_size)

x = torch.cat([self._context, x], dim=1)
if sr in [8000, 16000]:
ort_inputs = {'input': x.numpy(), 'h': self._h, 'c': self._c, 'sr': np.array(sr, dtype='int64')}
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr, dtype='int64')}
ort_outs = self.session.run(None, ort_inputs)
out, self._h, self._c = ort_outs
out, state = ort_outs
self._state = torch.from_numpy(state)
else:
raise ValueError()

self._context = x[..., -context_size:]
self._last_sr = sr
self._last_batch_size = batch_size

out = torch.tensor(out)
out = torch.from_numpy(out)
return out

def audio_forward(self, x, sr: int, num_samples: int = 512):
def audio_forward(self, x, sr: int):
outs = []
x, sr = self._validate_input(x, sr)
self.reset_states()
num_samples = 512 if sr == 16000 else 256

if x.shape[1] % num_samples:
pad_num = num_samples - (x.shape[1] % num_samples)
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)

self.reset_states(x.shape[0])
for i in range(0, x.shape[1], num_samples):
wavs_batch = x[:, i:i+num_samples]
out_chunk = self.__call__(wavs_batch, sr)
Expand All @@ -94,7 +109,7 @@ def audio_forward(self, x, sr: int, num_samples: int = 512):
return stacked.cpu()

@staticmethod
def download(model_url="https://github.com/snakers4/silero-vad/raw/v4.0/files/silero_vad.onnx"):
def download(model_url="https://github.com/snakers4/silero-vad/raw/v5.0/files/silero_vad.onnx"):
target_dir = os.path.expanduser("~/.cache/whisper-live/")

# Ensure the target directory exists
Expand Down Expand Up @@ -138,5 +153,5 @@ def __call__(self, audio_frame):
bool: True if the speech probability exceeds the threshold, indicating the presence of voice activity;
False otherwise.
"""
speech_prob = self.model(torch.from_numpy(audio_frame), self.frame_rate).item()
return speech_prob > self.threshold
speech_probs = self.model.audio_forward(torch.from_numpy(audio_frame.copy()), self.frame_rate)[0]
return torch.any(speech_probs > self.threshold).item()

0 comments on commit c1b249a

Please # to comment.