Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Cleanup some CString logic #83

Merged
merged 2 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions crates/sherpa-rs/src/audio_tag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use eyre::{bail, Result};

use crate::{
get_default_provider,
utils::{cstr_to_string, RawCStr},
utils::{cstr_to_string, cstring_from_str},
};

#[derive(Debug, Default, Clone)]
Expand All @@ -25,10 +25,10 @@ impl AudioTag {
pub fn new(config: AudioTagConfig) -> Result<Self> {
let config_clone = config.clone();

let model = RawCStr::new(&config.model);
let ced = RawCStr::new(&config.ced.unwrap_or_default());
let labels = RawCStr::new(&config.labels);
let provider = RawCStr::new(&config.provider.unwrap_or(get_default_provider()));
let model = cstring_from_str(&config.model);
let ced = cstring_from_str(&config.ced.unwrap_or_default());
let labels = cstring_from_str(&config.labels);
let provider = cstring_from_str(&config.provider.unwrap_or(get_default_provider()));

let sherpa_config = sherpa_rs_sys::SherpaOnnxAudioTaggingConfig {
model: sherpa_rs_sys::SherpaOnnxAudioTaggingModelConfig {
Expand Down
8 changes: 4 additions & 4 deletions crates/sherpa-rs/src/diarize.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{get_default_provider, utils::RawCStr};
use crate::{get_default_provider, utils::cstring_from_str};
use eyre::{bail, Result};
use std::{path::Path, ptr::null_mut};

Expand Down Expand Up @@ -58,9 +58,9 @@ impl Diarize {
threshold: config.threshold.unwrap_or(0.5),
};

let embedding_model = RawCStr::new(embedding_model);
let provider = RawCStr::new(&provider.clone());
let segmentation_model = RawCStr::new(segmentation_model);
let embedding_model = cstring_from_str(embedding_model);
let provider = cstring_from_str(&provider.clone());
let segmentation_model = cstring_from_str(segmentation_model);

let config = sherpa_rs_sys::SherpaOnnxOfflineSpeakerDiarizationConfig {
embedding: sherpa_rs_sys::SherpaOnnxSpeakerEmbeddingExtractorConfig {
Expand Down
4 changes: 2 additions & 2 deletions crates/sherpa-rs/src/embedding_manager.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::utils::{cstr_to_string, RawCStr};
use crate::utils::{cstr_to_string, cstring_from_str};
use eyre::{bail, Result};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -67,7 +67,7 @@ impl EmbeddingManager {
}

pub fn add(&mut self, name: String, embedding: &mut [f32]) -> Result<()> {
let name_c = RawCStr::new(&name.clone());
let name_c = cstring_from_str(&name.clone());
unsafe {
let status = sherpa_rs_sys::SherpaOnnxSpeakerEmbeddingManagerAdd(
self.manager,
Expand Down
14 changes: 7 additions & 7 deletions crates/sherpa-rs/src/keyword_spot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::ptr::null;

use crate::{
get_default_provider,
utils::{cstr_to_string, RawCStr},
utils::{cstr_to_string, cstring_from_str},
};
use eyre::{bail, Result};

Expand Down Expand Up @@ -61,14 +61,14 @@ impl KeywordSpot {
// Create new keyboard spotter along with stream
// Ready for streaming or regular use
pub fn new(config: KeywordSpotConfig) -> Result<Self> {
let provider = RawCStr::new(&config.provider.unwrap_or(get_default_provider()));
let provider = cstring_from_str(&config.provider.unwrap_or(get_default_provider()));

let zipformer_encoder = RawCStr::new(&config.zipformer_encoder);
let zipformer_decoder = RawCStr::new(&config.zipformer_decoder);
let zipformer_joiner = RawCStr::new(&config.zipformer_joiner);
let zipformer_encoder = cstring_from_str(&config.zipformer_encoder);
let zipformer_decoder = cstring_from_str(&config.zipformer_decoder);
let zipformer_joiner = cstring_from_str(&config.zipformer_joiner);

let tokens = RawCStr::new(&config.tokens);
let keywords = RawCStr::new(&config.keywords);
let tokens = cstring_from_str(&config.tokens);
let keywords = cstring_from_str(&config.keywords);

let sherpa_config = sherpa_rs_sys::SherpaOnnxKeywordSpotterConfig {
feat_config: sherpa_rs_sys::SherpaOnnxFeatureConfig {
Expand Down
8 changes: 4 additions & 4 deletions crates/sherpa-rs/src/language_id.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
get_default_provider,
utils::{cstr_to_string, RawCStr},
utils::{cstr_to_string, cstring_from_str},
};
use eyre::{bail, Result};

Expand All @@ -22,9 +22,9 @@ impl SpokenLanguageId {
pub fn new(config: SpokenLanguageIdConfig) -> Self {
let debug = config.debug.into();

let decoder = RawCStr::new(&config.decoder);
let encoder = RawCStr::new(&config.encoder);
let provider = RawCStr::new(&config.provider.unwrap_or(get_default_provider()));
let decoder = cstring_from_str(&config.decoder);
let encoder = cstring_from_str(&config.encoder);
let provider = cstring_from_str(&config.provider.unwrap_or(get_default_provider()));

let whisper = sherpa_rs_sys::SherpaOnnxSpokenLanguageIdentificationWhisperConfig {
decoder: decoder.as_ptr(),
Expand Down
4 changes: 2 additions & 2 deletions crates/sherpa-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ pub struct OfflineRecognizerResult {

impl OfflineRecognizerResult {
fn new(result: &sherpa_rs_sys::SherpaOnnxOfflineRecognizerResult) -> Self {
let lang = cstr_to_string(result.lang);
let text = cstr_to_string(result.text);
let lang = unsafe { cstr_to_string(result.lang) };
let text = unsafe { cstr_to_string(result.text) };
let count = result.count.try_into().unwrap();
let timestamps = if result.timestamps.is_null() {
Vec::new()
Expand Down
14 changes: 7 additions & 7 deletions crates/sherpa-rs/src/moonshine.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{get_default_provider, utils::RawCStr};
use crate::{get_default_provider, utils::cstring_from_str};
use eyre::{bail, Result};
use std::ptr::null;

Expand Down Expand Up @@ -46,15 +46,15 @@ impl MoonshineRecognizer {
let provider = config.provider.unwrap_or(get_default_provider());

// Onnx
let provider_ptr = RawCStr::new(&provider);
let provider_ptr = cstring_from_str(&provider);
let num_threads = config.num_threads.unwrap_or(2);

// Moonshine
let preprocessor_ptr = RawCStr::new(&config.preprocessor);
let encoder_ptr = RawCStr::new(&config.encoder);
let cached_decoder_ptr = RawCStr::new(&config.cached_decoder);
let uncached_decoder_ptr = RawCStr::new(&config.uncached_decoder);
let tokens_ptr = RawCStr::new(&config.tokens);
let preprocessor_ptr = cstring_from_str(&config.preprocessor);
let encoder_ptr = cstring_from_str(&config.encoder);
let cached_decoder_ptr = cstring_from_str(&config.cached_decoder);
let uncached_decoder_ptr = cstring_from_str(&config.uncached_decoder);
let tokens_ptr = cstring_from_str(&config.tokens);

let model_config = sherpa_rs_sys::SherpaOnnxOfflineModelConfig {
bpe_vocab: null(),
Expand Down
8 changes: 4 additions & 4 deletions crates/sherpa-rs/src/punctuate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use eyre::{bail, Result};

use crate::{
get_default_provider,
utils::{cstr_to_string, RawCStr},
utils::{cstr_to_string, cstring_from_str},
};

#[derive(Debug, Default, Clone)]
Expand All @@ -19,8 +19,8 @@ pub struct Punctuation {

impl Punctuation {
pub fn new(config: PunctuationConfig) -> Result<Self> {
let model = RawCStr::new(&config.model);
let provider = RawCStr::new(&config.provider.unwrap_or(if cfg!(target_os = "macos") {
let model = cstring_from_str(&config.model);
let provider = cstring_from_str(&config.provider.unwrap_or(if cfg!(target_os = "macos") {
// TODO: sherpa-onnx/issues/1448
"cpu".into()
} else {
Expand All @@ -45,7 +45,7 @@ impl Punctuation {
}

pub fn add_punctuation(&mut self, text: &str) -> String {
let text = RawCStr::new(text);
let text = cstring_from_str(text);
unsafe {
let text_with_punct_ptr = sherpa_rs_sys::SherpaOfflinePunctuationAddPunct(
self.audio_punctuation,
Expand Down
6 changes: 3 additions & 3 deletions crates/sherpa-rs/src/speaker_id.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use eyre::{bail, Result};
use std::path::PathBuf;

use crate::{get_default_provider, utils::RawCStr};
use crate::{get_default_provider, utils::cstring_from_str};

/// If similarity is greater or equal to thresold than it's a match!
pub const DEFAULT_SIMILARITY_THRESHOLD: f32 = 0.5;
Expand Down Expand Up @@ -31,8 +31,8 @@ impl EmbeddingExtractor {
if !model_path.exists() {
bail!("model not found at {}", model_path.display())
}
let model = RawCStr::new(&config.model);
let provider = RawCStr::new(&provider);
let model = cstring_from_str(&config.model);
let provider = cstring_from_str(&provider);

let extractor_config = sherpa_rs_sys::SherpaOnnxSpeakerEmbeddingExtractorConfig {
debug,
Expand Down
12 changes: 6 additions & 6 deletions crates/sherpa-rs/src/tts/kokoro.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{mem, ptr::null};

use crate::{utils::RawCStr, OnnxConfig};
use crate::{utils::cstring_from_str, OnnxConfig};
use eyre::Result;
use sherpa_rs_sys;

Expand All @@ -24,12 +24,12 @@ pub struct KokoroTtsConfig {
impl KokoroTts {
pub fn new(config: KokoroTtsConfig) -> Self {
let tts = unsafe {
let model = RawCStr::new(&config.model);
let voices = RawCStr::new(&config.voices);
let tokens = RawCStr::new(&config.tokens);
let data_dir = RawCStr::new(&config.data_dir);
let model = cstring_from_str(&config.model);
let voices = cstring_from_str(&config.voices);
let tokens = cstring_from_str(&config.tokens);
let data_dir = cstring_from_str(&config.data_dir);

let provider = RawCStr::new(&config.onnx_config.provider);
let provider = cstring_from_str(&config.onnx_config.provider);

let tts_config = config.common_config.to_raw();

Expand Down
16 changes: 8 additions & 8 deletions crates/sherpa-rs/src/tts/matcha.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{mem, ptr::null};

use crate::{utils::RawCStr, OnnxConfig};
use crate::{utils::cstring_from_str, OnnxConfig};
use eyre::Result;
use sherpa_rs_sys;

Expand Down Expand Up @@ -30,15 +30,15 @@ pub struct MatchaTtsConfig {
impl MatchaTts {
pub fn new(config: MatchaTtsConfig) -> Self {
let tts = unsafe {
let tokens = RawCStr::new(&config.tokens);
let data_dir = RawCStr::new(&config.data_dir);
let lexicon = RawCStr::new(&config.lexicon);
let dict_dir = RawCStr::new(&config.dict_dir);
let tokens = cstring_from_str(&config.tokens);
let data_dir = cstring_from_str(&config.data_dir);
let lexicon = cstring_from_str(&config.lexicon);
let dict_dir = cstring_from_str(&config.dict_dir);

let vocoder = RawCStr::new(&config.vocoder);
let acoustic_model = RawCStr::new(&config.acoustic_model);
let vocoder = cstring_from_str(&config.vocoder);
let acoustic_model = cstring_from_str(&config.acoustic_model);

let provider = RawCStr::new(&config.onnx_config.provider);
let provider = cstring_from_str(&config.onnx_config.provider);

let tts_config = config.common_config.to_raw();

Expand Down
14 changes: 8 additions & 6 deletions crates/sherpa-rs/src/tts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ mod kokoro;
mod matcha;
mod vits;

use std::ffi::CString;

use eyre::{bail, Result};

pub use kokoro::{KokoroTts, KokoroTtsConfig};
pub use matcha::{MatchaTts, MatchaTtsConfig};
pub use vits::{VitsTts, VitsTtsConfig};

use crate::utils::RawCStr;
use crate::utils::cstring_from_str;

#[derive(Debug)]
pub struct TtsAudio {
Expand All @@ -25,8 +27,8 @@ pub struct CommonTtsConfig {
}

pub struct CommonTtsRaw {
pub rule_fars: Option<RawCStr>,
pub rule_fsts: Option<RawCStr>,
pub rule_fars: Option<CString>,
pub rule_fsts: Option<CString>,
pub max_num_sentences: i32,
}

Expand All @@ -35,13 +37,13 @@ impl CommonTtsConfig {
let rule_fars = if self.rule_fars.is_empty() {
None
} else {
Some(RawCStr::new(&self.rule_fars))
Some(cstring_from_str(&self.rule_fars))
};

let rule_fsts = if self.rule_fsts.is_empty() {
None
} else {
Some(RawCStr::new(&self.rule_fsts))
Some(cstring_from_str(&self.rule_fsts))
};

CommonTtsRaw {
Expand All @@ -61,7 +63,7 @@ pub unsafe fn create(
sid: i32,
speed: f32,
) -> Result<TtsAudio> {
let text = RawCStr::new(text);
let text = cstring_from_str(text);
let audio_ptr = sherpa_rs_sys::SherpaOnnxOfflineTtsGenerate(tts, text.as_ptr(), sid, speed);

if audio_ptr.is_null() {
Expand Down
14 changes: 7 additions & 7 deletions crates/sherpa-rs/src/tts/vits.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{mem, ptr::null};

use crate::{utils::RawCStr, OnnxConfig};
use crate::{utils::cstring_from_str, OnnxConfig};
use eyre::Result;
use sherpa_rs_sys;

Expand Down Expand Up @@ -28,13 +28,13 @@ pub struct VitsTtsConfig {
impl VitsTts {
pub fn new(config: VitsTtsConfig) -> Self {
let tts = unsafe {
let model = RawCStr::new(&config.model);
let tokens = RawCStr::new(&config.tokens);
let data_dir = RawCStr::new(&config.data_dir);
let lexicon = RawCStr::new(&config.lexicon);
let dict_dir = RawCStr::new(&config.dict_dir);
let model = cstring_from_str(&config.model);
let tokens = cstring_from_str(&config.tokens);
let data_dir = cstring_from_str(&config.data_dir);
let lexicon = cstring_from_str(&config.lexicon);
let dict_dir = cstring_from_str(&config.dict_dir);

let provider = RawCStr::new(&config.onnx_config.provider);
let provider = cstring_from_str(&config.onnx_config.provider);

let tts_config = config.tts_config.to_raw();

Expand Down
46 changes: 7 additions & 39 deletions crates/sherpa-rs/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,45 +1,13 @@
use std::ffi::{c_char, CString};

// Smart pointer for CString
pub struct RawCStr {
ptr: *mut std::ffi::c_char,
pub fn cstring_from_str(s: &str) -> CString {
return CString::new(s).expect("CString::new failed");
}

impl RawCStr {
/// Creates a new `CStr` from a given Rust string slice.
pub fn new(s: &str) -> Self {
let cstr = CString::new(s).expect("CString::new failed");
let ptr = cstr.into_raw();
Self { ptr }
}

/// Returns the raw pointer to the internal C string.
///
/// # Safety
/// This function only returns the raw pointer and does not transfer ownership.
/// The pointer remains valid as long as the `CStr` instance exists.
/// Be cautious not to deallocate or modify the pointer after using `CStr::new`.
pub fn as_ptr(&self) -> *const c_char {
self.ptr
}
}

impl Drop for RawCStr {
fn drop(&mut self) {
if !self.ptr.is_null() {
unsafe {
let _ = CString::from_raw(self.ptr);
}
}
}
}

pub fn cstr_to_string(ptr: *const c_char) -> String {
unsafe {
if ptr.is_null() {
String::new()
} else {
std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned()
}
pub unsafe fn cstr_to_string(ptr: *const c_char) -> String {
if ptr.is_null() {
String::new()
} else {
std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned()
}
}
Loading