From 430d5043ac2c2e3a75936f8cb46ce36159a2661b Mon Sep 17 00:00:00 2001 From: kadirnar Date: Sat, 25 Nov 2023 14:44:36 +0300 Subject: [PATCH] Add speaker diarization functionality --- whisperplus/app.py | 114 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 109 insertions(+), 5 deletions(-) diff --git a/whisperplus/app.py b/whisperplus/app.py index e05ccc9..34d4f9d 100644 --- a/whisperplus/app.py +++ b/whisperplus/app.py @@ -1,10 +1,12 @@ import gradio as gr from whisperplus.pipelines.whisper import SpeechToTextPipeline +from whisperplus.pipelines.whisper_diarize import ASRDiarizationPipeline from whisperplus.utils.download_utils import download_and_convert_to_mp3 +from whisperplus.utils.text_utils import format_speech_to_dialogue -def main(url, model_id, language_choice): +def youtube_url_to_text(url, model_id, language_choice): """ Main function that downloads and converts a video to MP3 format, performs speech-to-text conversion using a specified model, and returns the transcript along with the video path. @@ -25,7 +27,37 @@ def main(url, model_id, language_choice): return transcript, video_path -def app(): +def speaker_diarization(url, model_id, device, num_speakers, min_speaker, max_speaker): + """ + Main function that downloads and converts a video to MP3 format, performs speech-to-text conversion using + a specified model, and returns the transcript along with the video path. + + Args: + url (str): The URL of the video to download and convert. + model_id (str): The ID of the speech-to-text model to use. + language_choice (str): The language choice for the speech-to-text conversion. + + Returns: + transcript (str): The transcript of the speech-to-text conversion. + video_path (str): The path of the downloaded video. + """ + + pipeline = ASRDiarizationPipeline.from_pretrained( + asr_model=model_id, + diarizer_model="pyannote/speaker-diarization", + use_auth_token=False, + chunk_length_s=30, + device=device, + ) + + audio_path = download_and_convert_to_mp3(url) + output_text = pipeline( + audio_path, num_speakers=num_speakers, min_speaker=min_speaker, max_speaker=max_speaker) + dialogue = format_speech_to_dialogue(output_text) + return dialogue, audio_path + + +def youtube_url_to_text_app(): with gr.Blocks(): with gr.Row(): with gr.Column(): @@ -63,7 +95,7 @@ def app(): output_audio = gr.Audio(label="Output Audio") whisperplus_in_predict.click( - fn=main, + fn=youtube_url_to_text, inputs=[ youtube_url_path, whisper_model_id, @@ -79,7 +111,7 @@ def app(): "English", ], ], - fn=main, + fn=youtube_url_to_text, inputs=[ youtube_url_path, whisper_model_id, @@ -90,6 +122,75 @@ def app(): ) +def speaker_diarization_app(): + with gr.Blocks(): + with gr.Row(): + with gr.Column(): + youtube_url_path = gr.Text(placeholder="Enter Youtube URL", label="Youtube URL") + + whisper_model_id = gr.Dropdown( + choices=[ + "openai/whisper-large-v3", + "openai/whisper-large", + "openai/whisper-medium", + "openai/whisper-base", + "openai/whisper-small", + "openai/whisper-tiny", + ], + value="openai/whisper-large-v3", + label="Whisper Model", + ) + device = gr.Dropdown( + choices=["cpu", "cuda", "mps"], + value="cuda", + label="Device", + ) + num_speakers = gr.Number(value=2, label="Number of Speakers") + min_speaker = gr.Number(value=1, label="Minimum Number of Speakers") + max_speaker = gr.Number(value=2, label="Maximum Number of Speakers") + whisperplus_in_predict = gr.Button(value="Generator") + + with gr.Column(): + output_text = gr.Textbox(label="Output Text") + output_audio = gr.Audio(label="Output Audio") + + whisperplus_in_predict.click( + fn=speaker_diarization, + inputs=[ + youtube_url_path, + whisper_model_id, + device, + num_speakers, + min_speaker, + max_speaker, + ], + outputs=[output_text, output_audio], + ) + gr.Examples( + examples=[ + [ + "https://www.youtube.com/shorts/o8PgLUgte2k", + "openai/whisper-large-v3", + "mps", + 2, + 1, + 2, + ], + ], + fn=speaker_diarization, + inputs=[ + youtube_url_path, + whisper_model_id, + device, + num_speakers, + min_speaker, + max_speaker, + ], + outputs=[output_text, output_audio], + cache_examples=False, + ) + + gradio_app = gr.Blocks() with gradio_app: gr.HTML( @@ -107,7 +208,10 @@ def app(): """) with gr.Row(): with gr.Column(): - app() + with gr.Tab(label="Youtube URL to Text"): + youtube_url_to_text_app() + with gr.Tab(label="Speaker Diarization"): + speaker_diarization_app() gradio_app.queue() gradio_app.launch(debug=True)