Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add support for Useful Sensors Moonshine model. #1808

Open
wants to merge 11 commits into
base: master
Choose a base branch
from

Conversation

njeffrie
Copy link

For context on the moonshine model please see the Useful Sensors Moonshine repo

Adds the following:

  • c++ moonshine model
  • pybind for python moonshine model
  • moonshine model spec
  • support for multi-dimensional layernorm on CPU.
  • support for broadcasting layernorm weights for multi-dimensional layernorm on CPU.

For now the moonshine converter (safetensor -> ctranslate2 binary) will live in the moonshine repo. Planning to add a transformers converter once Moonshine is part of the transformers library.

@BBC-Esq
Copy link

BBC-Esq commented Oct 26, 2024

I checked out your repo but didn't see anywhere to actually download the moonshine models. How is Ctranslate2 supposed to evaluate whether to incorporate this pull request if the models' can't be tested?

@njeffrie
Copy link
Author

njeffrie commented Oct 26, 2024

Thanks for taking a look - I've uploaded CTranslate2 models for moonshine base and tiny to UsefulSensors huggingface hub. In case it's helpful for testing, the following is a minimal python script to transcribe a wav file with CTranslate2 moonshine base (assuming the model was downloaded to ./ctranslate2/base):

from ctranslate2.models import Moonshine
from ctranslate2 import StorageView
import torchaudio
import tokenizers

tokenizer = tokenizers.Tokenizer.from_file("ctranslate2/base/tokenizer.json")
model = Moonshine('ctranslate2/base', device='cpu')

audio, sr = torchaudio.load('foo.wav')
if sr != 16000:
    audio = torchaudio.functional.resample(audio, sr, 16000)
audio_sv = StorageView.from_array(audio.numpy())

result = model.generate(audio_sv, [[1]], beam_size=5)[0]
tokens = result.sequences_ids[0]
text = tokenizer.decode(tokens).strip()
print(text)

@BBC-Esq
Copy link

BBC-Esq commented Oct 26, 2024

Thanks for the info, but unfortunately I'm not knowledgeable enough to know how to use .h5 files, but I think this was what I was asking about that you did link to...

image

Also, unfortunately, I have no decision-making power regarding Ctranslate2 either so...But I would recommend that if you can't get a response from the Ctranslate2 people relatively quickly that you reach out to a guy named @MahmoudAshraf97 because, although he's not officially with "Systran," he's also interested in all-things Ctranslate2/TTS and is pretty good about responding and has a good repoir with them.

As I said, I'm just one annoying fan of this technology so...Good luck!

@njeffrie
Copy link
Author

Just uploaded CT2 models for moonshine tiny and base: https://huggingface.co/UsefulSensors/moonshine/tree/main/ctranslate2

@njeffrie
Copy link
Author

@minhthuc2502 perhaps you could take a look or assign somebody to review? Landing this in CTranslate2 is currently blocking us from releasing a faster-whisper style model as part of usefulsensors/moonshine.

Thanks!

@minhthuc2502
Copy link
Collaborator

minhthuc2502 commented Nov 25, 2024

Could you add CUDA support by implementing it in layer_norm_gpu.cu? Additionally, I noticed there isn't a converter to transform the original model into CTranslate2's format, apart from the added spec.

And try to fix the pipeline please.

Thank you.

@njeffrie
Copy link
Author

Thanks for taking a look. I'll cuda support for the layernorm changes, add our safetensors -> CTranslate converter and look into what's going on with the presubmit pipeline.

Additionally, I've added a fix to support batching to address this issue.

Adds the following:
- c++ moonshine model
- pybind for python moonshine model
- moonshine model spec
- safetensor moonshine model converter
- support for GroupNorm-style weights for LayerNorm
- support for multi-axis cuda layernorm
- Add a define to prevent quantizing the first conv layers in the
  Moonshine preprocessor
- Add options to enable rotary positional embeddings in the Transformer
  Encoder spec.
Fixes bug when batch size > 1.
Converts safetensor model def + tokenizer_config.json to ctranslate2 model spec for Moonshine.
@BBC-Esq
Copy link

BBC-Esq commented Dec 4, 2024

Thanks for taking a look. I'll cuda support for the layernorm changes, add our safetensors -> CTranslate converter and look into what's going on with the presubmit pipeline.

Additionally, I've added a fix to support batching to address this issue.

Can you please post when it's ready to review because I'm actually kind of curious to test out these models. I won't do it until all the multiple changes are near final or what not. Thanks!

@njeffrie
Copy link
Author

njeffrie commented Dec 5, 2024

Should be ready to go @minhthuc2502, @BBC-Esq.

@broke-end-dev
Copy link

@guillaumekln @minhthuc2502 could you make some time for this? really wanna try this out on my realtime transcriber project.

@guynich
Copy link

guynich commented Mar 2, 2025

@guillaumekln @minhthuc2502 could you make some time for this? really wanna try this out on my realtime transcriber project.

+1

@BBC-Esq
Copy link

BBC-Esq commented Mar 2, 2025

If I had write approval I'd move things along, but alas...

@BBC-Esq
Copy link

BBC-Esq commented Mar 2, 2025

@njeffrie can you take a look at this script and make sure that I'm running the Moonshine model correctly? I had to implement some preprocessing of the audio using the "av" library as well as some chunking to address several error messages, but now I'm getting some pretty decent results...but I'm wondering why I had to break it into chunks?

Granted, this is using the transformers library...but I'm still waiting on the pull request like you for Ctranslate2. Hopefully that'll come soon. Also, I'm looking forward to trying out your batch processing capabilities as well. Anyways, here's the entire scrip:

Note: I'm using pip installed version of CUDA, not a systemwide installation...so you'd obviously just delete the set_cuda_paths functionality if relying on a systemwide installation. Thanks!

SCRIPT HERE
import sys
import os
from pathlib import Path

def set_cuda_paths():
    venv_base = Path(sys.executable).parent.parent
    nvidia_base_path = venv_base / 'Lib' / 'site-packages' / 'nvidia'
    cuda_path_runtime = nvidia_base_path / 'cuda_runtime' / 'bin'
    cuda_path_runtime_lib = nvidia_base_path / 'cuda_runtime' / 'bin' / 'lib' / 'x64'
    cuda_path_runtime_include = nvidia_base_path / 'cuda_runtime' / 'include'
    cublas_path = nvidia_base_path / 'cublas' / 'bin'
    cudnn_path = nvidia_base_path / 'cudnn' / 'bin'
    nvrtc_path = nvidia_base_path / 'cuda_nvrtc' / 'bin'
    nvcc_path = nvidia_base_path / 'cuda_nvcc' / 'bin'

    paths_to_add = [
        str(cuda_path_runtime),
        str(cuda_path_runtime_lib),
        str(cuda_path_runtime_include),
        str(cublas_path),
        str(cudnn_path),
        str(nvrtc_path),
        str(nvcc_path),
    ]

    current_value = os.environ.get('PATH', '')
    new_value = os.pathsep.join(paths_to_add + [current_value] if current_value else paths_to_add)
    os.environ['PATH'] = new_value

    triton_cuda_path = nvidia_base_path / 'cuda_runtime'
    os.environ['CUDA_PATH'] = str(triton_cuda_path)

set_cuda_paths()

import torch
import time
import gc
import os
import textwrap
import av
import numpy as np
import threading
from pathlib import Path
from transformers import MoonshineForConditionalGeneration, AutoProcessor

try:
    import pynvml
    pynvml_available = True
    pynvml.nvmlInit()
    handle = pynvml.nvmlDeviceGetHandleByIndex(0)
except (ImportError, ModuleNotFoundError, pynvml.NVMLError):
    pynvml_available = False
    print("pynvml not available - VRAM usage tracking disabled")

AUDIO_FILE_PATH = r"D:\Scripts\test_moonshine\test_flac.flac"
output_filename = os.path.splitext(os.path.basename(AUDIO_FILE_PATH))[0] + "_transcription.txt"
OUTPUT_FILE_PATH = os.path.join(os.path.dirname(AUDIO_FILE_PATH), output_filename)

def poll_vram_usage(stop_event, vram_readings):
    while not stop_event.is_set():
        memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        vram_usage = memory_info.used / 1024**2
        vram_readings.append(vram_usage)
        time.sleep(0.1)

def load_audio(file_path):
    print(f"Loading audio from {file_path}...")
    audio_array = []

    with av.open(str(file_path)) as container:
        audio_stream = container.streams.audio[0]
        sample_rate = audio_stream.rate

        resampler = av.AudioResampler(
            format='s16',
            layout='mono',
            rate=16000
        )
        sample_rate = 16000

        for frame in container.decode(audio=0):
            frames = resampler.resample(frame)
            if frames:
                for new_frame in frames:
                    arr = new_frame.to_ndarray().flatten()
                    audio_array.append(arr)

    audio_data = np.concatenate(audio_array)
    if audio_data.dtype == np.int16:
        audio_data = audio_data.astype(np.float32) / 32768.0

    print(f"Loaded audio: {len(audio_data)} samples, {sample_rate}Hz")
    return torch.tensor(audio_data), sample_rate

def detect_speech_segments(audio_data, sample_rate, energy_threshold=0.01, min_silence_duration=0.5, min_speech_duration=1.0):
    frame_length = int(0.025 * sample_rate)  # 25ms frames
    hop_length = int(0.010 * sample_rate)    # 10ms hop
    min_silence_frames = int(min_silence_duration / (hop_length / sample_rate))
    min_speech_frames = int(min_speech_duration / (hop_length / sample_rate))

    audio_np = audio_data.numpy() if isinstance(audio_data, torch.Tensor) else audio_data

    # Calculate energy for each frame
    num_frames = 1 + (len(audio_np) - frame_length) // hop_length
    energy = np.zeros(num_frames)

    for i in range(num_frames):
        start = i * hop_length
        end = start + frame_length
        frame = audio_np[start:end]
        energy[i] = np.sqrt(np.mean(frame**2))

    # Normalize energy and apply threshold
    energy_norm = energy / np.max(energy) if np.max(energy) > 0 else energy
    is_speech = energy_norm > energy_threshold

    # Apply minimum speech and silence duration constraints
    for i in range(1, len(is_speech) - min_speech_frames):
        # Ensure minimum speech duration
        if is_speech[i] and not is_speech[i-1]:  # Start of speech
            if not np.all(is_speech[i:i+min_speech_frames]):
                is_speech[i:i+min_speech_frames] = False

        # Ensure minimum silence duration
        if not is_speech[i] and is_speech[i-1]:  # Start of silence
            if not np.all(~is_speech[i:i+min_silence_frames]):
                is_speech[i:i+min_silence_frames] = True

    # Find speech segments
    segments = []
    in_speech = False
    speech_start = 0

    for i, speech in enumerate(is_speech):
        if speech and not in_speech:
            in_speech = True
            speech_start = i * hop_length
        elif not speech and in_speech:
            in_speech = False
            speech_end = i * hop_length
            segments.append((speech_start, speech_end))

    # Add final segment if needed
    if in_speech:
        segments.append((speech_start, len(audio_np)))

    return segments

def create_audio_chunks(audio_data, sample_rate, max_chunk_duration=30, overlap_duration=2):
    max_chunk_samples = int(max_chunk_duration * sample_rate)
    overlap_samples = int(overlap_duration * sample_rate)

    # Find speech segments
    speech_segments = detect_speech_segments(audio_data, sample_rate)

    # If no speech segments found or if segments are too short, default to time-based chunking
    if not speech_segments or len(speech_segments) < 3:
        print("No clear speech segments detected, using time-based chunking")
        total_samples = len(audio_data)
        step_size = max_chunk_samples - overlap_samples
        num_chunks = max(1, (total_samples - overlap_samples) // step_size + 1)

        chunks = []
        for i in range(num_chunks):
            start = i * step_size
            end = min(start + max_chunk_samples, total_samples)
            chunks.append((start, end))
        return chunks

    # Merge speech segments into optimal chunks
    chunks = []
    current_start = speech_segments[0][0]
    current_end = speech_segments[0][1]

    for start, end in speech_segments[1:]:
        # If adding this segment exceeds max duration, create a new chunk
        if end - current_start > max_chunk_samples:
            # Ensure chunks overlap at natural boundaries
            chunks.append((current_start, current_end))
            current_start = max(current_end - overlap_samples, 0)
            current_end = end
        else:
            # Otherwise, extend current chunk
            current_end = end

    # Add the final chunk
    if current_end > current_start:
        chunks.append((current_start, current_end))

    # If chunks are too small, merge adjacent chunks
    i = 0
    while i < len(chunks) - 1:
        current_chunk = chunks[i]
        next_chunk = chunks[i + 1]

        if next_chunk[1] - current_chunk[0] <= max_chunk_samples:
            chunks[i] = (current_chunk[0], next_chunk[1])
            chunks.pop(i + 1)
        else:
            i += 1

    return chunks

def transcribe_audio_chunks(model, processor, audio_data, sample_rate, device, torch_dtype, vram_readings):
    chunks = create_audio_chunks(audio_data, sample_rate)
    print(f"Audio will be processed in {len(chunks)} chunks")

    all_transcriptions = []

    for i, (start, end) in enumerate(chunks):
        chunk_data = audio_data[start:end]
        chunk_duration = len(chunk_data) / sample_rate
        print(f"Processing chunk {i+1}/{len(chunks)}: {chunk_duration:.2f} seconds ({start/sample_rate:.2f}s - {end/sample_rate:.2f}s)")

        # Save the intermediate transcriptions to disk as well
        chunk_file = f"chunk_{i+1}_of_{len(chunks)}.txt"
        chunk_path = os.path.join(os.path.dirname(OUTPUT_FILE_PATH), chunk_file)

        try:
            inputs = processor(
                chunk_data, 
                return_tensors="pt",
                sampling_rate=sample_rate
            )
            inputs = inputs.to(device, torch_dtype)

            token_limit_factor = 6.5 / processor.feature_extractor.sampling_rate
            seq_lens = inputs.attention_mask.sum(dim=-1)
            max_length = min(int((seq_lens * token_limit_factor).max().item()), 512)  # Cap at 512 tokens

            # Setup VRAM monitoring
            stop_event = threading.Event()
            if pynvml_available and torch.cuda.is_available():
                poll_thread = threading.Thread(target=poll_vram_usage, args=(stop_event, vram_readings))
                poll_thread.start()

            # Generate with deterministic settings
            with torch.no_grad():
                generated_ids = model.generate(
                    **inputs, 
                    max_length=max_length,
                    do_sample=False,
                    num_beams=1,
                    temperature=1.0
                )

            # Stop VRAM monitoring
            if pynvml_available and torch.cuda.is_available():
                stop_event.set()
                if 'poll_thread' in locals():
                    poll_thread.join()

            transcription = processor.decode(generated_ids[0], skip_special_tokens=True)

            # Write chunk transcription to file
            with open(chunk_path, 'w', encoding='utf-8') as f:
                f.write(transcription)

            all_transcriptions.append(transcription)
            print(f"Chunk {i+1} transcription: {len(transcription)} characters")

            # Clear memory
            del inputs, generated_ids
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            
        except Exception as e:
            print(f"Error processing chunk {i+1}: {str(e)}")
            with open(chunk_path, 'w', encoding='utf-8') as f:
                f.write(f"ERROR: {str(e)}")

    return " ".join(all_transcriptions)

def main():
    if torch.cuda.is_available():
        device = "cuda:0"
        compute_capability = torch.cuda.get_device_capability(0)
        major, minor = compute_capability
        
        if major >= 8:
            torch_dtype = torch.bfloat16
            print(f"Using CUDA device with compute capability {major}.{minor} - using bfloat16")
        else:
            torch_dtype = torch.float16
            print(f"Using CUDA device with compute capability {major}.{minor} - using float16")
    else:
        device = "cpu"
        torch_dtype = torch.float32
        print("CUDA not available - using CPU with float32")

    baseline_vram_usage = 0
    if pynvml_available and torch.cuda.is_available():
        memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        baseline_vram_usage = memory_info.used / 1024**2
        print(f"Baseline VRAM usage: {baseline_vram_usage:.2f} MB")

    total_start_time = time.time()

    print("Loading model and processor...")
    model_load_start = time.time()
    model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-base").to(device).to(torch_dtype)
    processor = AutoProcessor.from_pretrained("UsefulSensors/moonshine-base")
    model_load_time = time.time() - model_load_start
    print(f"Model loading time: {model_load_time:.2f} seconds")

    model_vram_usage = 0
    if pynvml_available and torch.cuda.is_available():
        memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        model_vram_usage = memory_info.used / 1024**2 - baseline_vram_usage
        print(f"Model VRAM usage: {model_vram_usage:.2f} MB")

    try:
        audio_start_time = time.time()
        audio_data, sampling_rate = load_audio(AUDIO_FILE_PATH)
        audio_load_time = time.time() - audio_start_time
        print(f"Audio loading time: {audio_load_time:.2f} seconds")
        print(f"Audio duration: {len(audio_data)/sampling_rate:.2f} seconds")

        vram_readings = []

        transcription_start_time = time.time()
        transcription = transcribe_audio_chunks(model, processor, audio_data, sampling_rate, device, torch_dtype, vram_readings)
        transcription_time = time.time() - transcription_start_time

        total_elapsed_time = time.time() - total_start_time

        max_vram_usage = 0
        if vram_readings:
            max_vram_usage = max(vram_readings) - baseline_vram_usage

        wrapped_transcription = textwrap.fill(transcription, width=100)
        with open(OUTPUT_FILE_PATH, 'w', encoding='utf-8') as f:
            f.write(wrapped_transcription)

        print("\n==== RESULTS ====")
        print(f"Transcription saved to: {OUTPUT_FILE_PATH}")
        print(f"Transcription length: {len(transcription)} characters")
        print("\n==== PERFORMANCE METRICS ====")
        print(f"Total time: {total_elapsed_time:.2f} seconds")
        print(f"Model loading time: {model_load_time:.2f} seconds")
        print(f"Audio loading time: {audio_load_time:.2f} seconds")
        print(f"Transcription time: {transcription_time:.2f} seconds")
        print(f"Transcription speed: {len(audio_data)/sampling_rate/transcription_time:.2f}x realtime")
        print(f"Characters per second: {len(transcription)/transcription_time:.2f}")

        if pynvml_available and torch.cuda.is_available():
            print("\n==== VRAM USAGE ====")
            print(f"Baseline VRAM usage: {baseline_vram_usage:.2f} MB")
            print(f"Model VRAM usage: {model_vram_usage:.2f} MB")
            print(f"Maximum net VRAM usage during inference: {max_vram_usage:.2f} MB")

    except Exception as e:
        print(f"Error processing audio file: {str(e)}")
        import traceback
        traceback.print_exc()

    finally:
        del model
        del processor
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

if __name__ == "__main__":
    main()
</details>

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants