Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add speaker diarization functionality #26

Merged
merged 1 commit into from
Nov 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 109 additions & 5 deletions whisperplus/app.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import gradio as gr

from whisperplus.pipelines.whisper import SpeechToTextPipeline
from whisperplus.pipelines.whisper_diarize import ASRDiarizationPipeline
from whisperplus.utils.download_utils import download_and_convert_to_mp3
from whisperplus.utils.text_utils import format_speech_to_dialogue


def main(url, model_id, language_choice):
def youtube_url_to_text(url, model_id, language_choice):
"""
Main function that downloads and converts a video to MP3 format, performs speech-to-text conversion using
a specified model, and returns the transcript along with the video path.
Expand All @@ -25,7 +27,37 @@ def main(url, model_id, language_choice):
return transcript, video_path


def app():
def speaker_diarization(url, model_id, device, num_speakers, min_speaker, max_speaker):
"""
Main function that downloads and converts a video to MP3 format, performs speech-to-text conversion using
a specified model, and returns the transcript along with the video path.

Args:
url (str): The URL of the video to download and convert.
model_id (str): The ID of the speech-to-text model to use.
language_choice (str): The language choice for the speech-to-text conversion.

Returns:
transcript (str): The transcript of the speech-to-text conversion.
video_path (str): The path of the downloaded video.
"""

pipeline = ASRDiarizationPipeline.from_pretrained(
asr_model=model_id,
diarizer_model="pyannote/speaker-diarization",
use_auth_token=False,
chunk_length_s=30,
device=device,
)

audio_path = download_and_convert_to_mp3(url)
output_text = pipeline(
audio_path, num_speakers=num_speakers, min_speaker=min_speaker, max_speaker=max_speaker)
dialogue = format_speech_to_dialogue(output_text)
return dialogue, audio_path


def youtube_url_to_text_app():
with gr.Blocks():
with gr.Row():
with gr.Column():
Expand Down Expand Up @@ -63,7 +95,7 @@ def app():
output_audio = gr.Audio(label="Output Audio")

whisperplus_in_predict.click(
fn=main,
fn=youtube_url_to_text,
inputs=[
youtube_url_path,
whisper_model_id,
Expand All @@ -79,7 +111,7 @@ def app():
"English",
],
],
fn=main,
fn=youtube_url_to_text,
inputs=[
youtube_url_path,
whisper_model_id,
Expand All @@ -90,6 +122,75 @@ def app():
)


def speaker_diarization_app():
with gr.Blocks():
with gr.Row():
with gr.Column():
youtube_url_path = gr.Text(placeholder="Enter Youtube URL", label="Youtube URL")

whisper_model_id = gr.Dropdown(
choices=[
"openai/whisper-large-v3",
"openai/whisper-large",
"openai/whisper-medium",
"openai/whisper-base",
"openai/whisper-small",
"openai/whisper-tiny",
],
value="openai/whisper-large-v3",
label="Whisper Model",
)
device = gr.Dropdown(
choices=["cpu", "cuda", "mps"],
value="cuda",
label="Device",
)
num_speakers = gr.Number(value=2, label="Number of Speakers")
min_speaker = gr.Number(value=1, label="Minimum Number of Speakers")
max_speaker = gr.Number(value=2, label="Maximum Number of Speakers")
whisperplus_in_predict = gr.Button(value="Generator")

with gr.Column():
output_text = gr.Textbox(label="Output Text")
output_audio = gr.Audio(label="Output Audio")

whisperplus_in_predict.click(
fn=speaker_diarization,
inputs=[
youtube_url_path,
whisper_model_id,
device,
num_speakers,
min_speaker,
max_speaker,
],
outputs=[output_text, output_audio],
)
gr.Examples(
examples=[
[
"https://www.youtube.com/shorts/o8PgLUgte2k",
"openai/whisper-large-v3",
"mps",
2,
1,
2,
],
],
fn=speaker_diarization,
inputs=[
youtube_url_path,
whisper_model_id,
device,
num_speakers,
min_speaker,
max_speaker,
],
outputs=[output_text, output_audio],
cache_examples=False,
)


gradio_app = gr.Blocks()
with gradio_app:
gr.HTML(
Expand All @@ -107,7 +208,10 @@ def app():
""")
with gr.Row():
with gr.Column():
app()
with gr.Tab(label="Youtube URL to Text"):
youtube_url_to_text_app()
with gr.Tab(label="Speaker Diarization"):
speaker_diarization_app()

gradio_app.queue()
gradio_app.launch(debug=True)