Skip to content

Commit

Permalink
feat(stt): add support for local Whisper models
Browse files Browse the repository at this point in the history
- Added `use_whisper_local_stt` function to support local Whisper models via Python with reticulate.
- Added `use_mlx_whisper_local_stt` function for MLX Whisper models, optimized for Mac OS with Apple Silicon.
- Updated `perform_speech_to_text` to use `whisper_local` as the default model.
- Enhanced `speech_to_summary_workflow` to display the selected speech-to-text model.
- Updated documentation and NAMESPACE to export the new functions.
- Added `reticulate` to the Suggests field in DESCRIPTION for Python integration.
  • Loading branch information
bakaburg1 committed Oct 11, 2024
1 parent a49a408 commit 69e4f5e
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 15 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Suggests:
av (>= 0.9.0),
devtools (>= 2.4.5),
parallel (>= 4.3.2),
reticulate (>= 1.38.0),
testthat (>= 3.0.0),
text2vec (>= 0.6.4),
tictoc (>= 1.2),
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ export(speech_to_summary_workflow)
export(split_audio)
export(summarise_full_meeting)
export(summarise_transcript)
export(use_mlx_whisper_local_stt)
export(use_whisper_local_stt)
export(validate_agenda)
import(dplyr)
importFrom(stats,setNames)
Expand Down
1 change: 1 addition & 0 deletions R/data_management.R
Original file line number Diff line number Diff line change
Expand Up @@ -1239,6 +1239,7 @@ speech_to_summary_workflow <- function(
) {

message("\n### Performing speech to text...\n")
message("(stt model: ", stt_model, ")\n")

# A speech-to-text model is required
if (is.null(stt_model)) {
Expand Down
226 changes: 212 additions & 14 deletions R/speech_to_text.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
perform_speech_to_text <- function(
audio_path,
output_dir = file.path(dirname(audio_path), "transcription_output_data"),
model,
model = getOption("minutemaker_stt_model", "whisper_local"),
initial_prompt = NULL, overwrite = FALSE,
language = "en",
...
Expand Down Expand Up @@ -395,25 +395,25 @@ use_azure_whisper_stt <- function(
warning("Error ", response$status_code, " in Azure Whisper API request: ",
httr::content(response, "text"), call. = FALSE, immediate. = TRUE)

wait_for <- stringr::str_extract(
httr::content(response, "text", encoding = "UTF-8"),
"\\d+(?= seconds)") |> as.numeric()
wait_for <- stringr::str_extract(
httr::content(response, "text", encoding = "UTF-8"),
"\\d+(?= seconds)") |> as.numeric()

if (is.na(wait_for) && !interactive()) stop()
if (is.na(wait_for) && !interactive()) stop()

if (is.na(wait_for)) wait_for <- 30
if (is.na(wait_for)) wait_for <- 30

message("Retrying in ", wait_for, " seconds...")
message("Retrying in ", wait_for, " seconds...")

Sys.sleep(wait_for)
Sys.sleep(wait_for)

res <- use_azure_whisper_stt(
audio_file = audio_file,
language = language,
initial_prompt = initial_prompt,
temperature = temperature)
res <- use_azure_whisper_stt(
audio_file = audio_file,
language = language,
initial_prompt = initial_prompt,
temperature = temperature)

return(res)
return(res)
}

# Return the response
Expand Down Expand Up @@ -489,3 +489,201 @@ use_openai_whisper_stt <- function(
# Return the response
res <- httr::content(response)
}

#' Use Local Whisper Model for Speech-to-Text
#'
#' This function uses a local Whisper model via Python with reticulate to
#' transcribe audio. It can use the official OpenAI Whisper package or any
#' compatible Python package.
#'
#' @param audio_file The path to the audio file to transcribe.
#' @param language The language of the input audio. Default is "en" for English.
#' If NULL, Whisper will attempt to detect the language.
#' @param initial_prompt Text to guide the model's style or continue a previous
#' segment.
#' @param model The Whisper model to use. Default is "turbo". Check
#' https://github.com/openai/whisper for other available models.
#' @param whisper_package The Python package to use for Whisper (default:
#' "openai-whisper").
#'
#' @return A list with the full transcript and the transcription by segments.
#'
#' @export
use_whisper_local_stt <- function(
audio_file,
language = "en",
initial_prompt = "",
model = "turbo",
whisper_package = getOption(
"minutemaker_whisper_package", "openai-whisper")
) {
# Check if reticulate is installed
if (!rlang::is_installed("reticulate")) {
stop("Package 'reticulate' is required. ",
"Please install it using install.packages('reticulate')")
}

# Check if Miniconda is installed
if (length(list.files(reticulate::miniconda_path())) == 0) {
message("Miniconda not found. Installing it now...")
reticulate::install_miniconda()
}

conda_env <- "minutemaker_env"

# Check if the conda environment exists
if (!reticulate::condaenv_exists(conda_env)) {
message(
"Conda environment '", conda_env, "' does not exist. Creating it now...")

reticulate::conda_create(conda_env, python_version = "3.9")
}

# Use the conda environment
reticulate::use_miniconda(conda_env, required = TRUE)

# Check if Whisper is already installed
if (!reticulate::py_module_available("whisper")) {
message("Whisper not found. Installing dependencies...")

# Install the required packages
reticulate::conda_install(
conda_env,
c("numpy==1.23.5", "numba==0.56.4", "llvmlite==0.39.1", whisper_package),
pip = TRUE)
}

# Import the Whisper module
whisper <- reticulate::import("whisper")

# Load the Whisper model
model <- whisper$load_model(model)

# Prepare transcription options
options <- list(
language = language,
initial_prompt = initial_prompt,
fp16 = FALSE
)

# Remove NULL values from options
options <- options[!sapply(options, is.null)]

# Perform transcription
result <- do.call(model$transcribe, c(list(audio_file), options))

# Extract segments
segments <- lapply(result$segments, function(seg) {
list(
id = seg$id,
start = seg$start,
end = seg$end,
text = seg$text
)
})

# Return results in the expected format
list(
text = result$text,
segments = segments
)
}

#' Use MLX Whisper Local Model for Speech-to-Text (Mac OS only)
#'
#' This function uses a local MLX Whisper model via Python with reticulate to
#' transcribe audio. It is specifically designed to work with the MLX Whisper
#' package. MLX allows faster inference on Mac OS with Apple Silicon.
#'
#' @param audio_file The path to the audio file to transcribe.
#' @param language The language of the input audio. Default is "en" for English.
#' If NULL, Whisper will attempt to detect the language.
#' @param initial_prompt Text to guide the model's style or continue a previous
#' segment.
#' @param model The MLX Whisper model to use. Default is
#' "mlx-community/whisper-large-v3-turbo".
#' @param whisper_package The Python package to use for MLX Whisper (default:
#' "mlx_whisper").
#'
#' @return A list with the full transcript and the transcription by segments.
#'
#' @export
use_mlx_whisper_local_stt <- function(
audio_file,
language = "en",
initial_prompt = "",
model = "mlx-community/distil-whisper-large-v3",
whisper_package = getOption("minutemaker_whisper_package", "mlx_whisper")
) {
# Check if reticulate is installed
if (!rlang::is_installed("reticulate")) {
stop("Package 'reticulate' is required. ",
"Please install it using install.packages('reticulate')")
}

# Check if Miniconda is installed
if (length(list.files(reticulate::miniconda_path())) == 0) {
message("Miniconda not found. Installing it now...")
reticulate::install_miniconda()
}

conda_env <- "minutemaker_env"

# Check if the conda environment exists
if (!reticulate::condaenv_exists(conda_env)) {
message(
"Conda environment '", conda_env, "' does not exist. Creating it now...")

reticulate::conda_create(conda_env, python_version = "3.9")
}

# Use the conda environment
reticulate::use_condaenv(conda_env, required = TRUE)

# Check if Whisper is already installed
if (!reticulate::py_module_available(whisper_package)) {
message("Whisper not found. Installing dependencies...")

# Install the required packages reticulate::conda_install(conda_env,
# c("numpy==1.23.5", "numba==0.56.4", "llvmlite==0.39.1", whisper_package),
# pip = TRUE)
reticulate::conda_install(conda_env, whisper_package, pip = TRUE)
}

# Import the Whisper module
mlx_whisper <- reticulate::import(whisper_package)

# Prepare transcription options
decode_options <- list(
language = language,
initial_prompt = initial_prompt
)

# Remove NULL values from options
decode_options <- decode_options[!sapply(decode_options, is.null)]

# Perform transcription
result <- mlx_whisper$transcribe(
audio_file,
path_or_hf_repo = model,
fp16 = FALSE,
word_timestamps = TRUE,
!!!decode_options
)

# Extract segments
segments <- lapply(result$segments, function(seg) {
list(
id = seg$id,
start = seg$start,
end = seg$end,
text = seg$text
)
})

# Return results in the expected format
list(
text = result$text,
segments = segments
)
}
2 changes: 1 addition & 1 deletion man/perform_speech_to_text.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 37 additions & 0 deletions man/use_mlx_whisper_local_stt.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 37 additions & 0 deletions man/use_whisper_local_stt.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 69e4f5e

Please # to comment.