Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Support whisper language/task in various language bindings. #679

Merged
merged 4 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions dotnet-examples/offline-decode-files/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ class Options
[Option("whisper-decoder", Required = false, Default = "", HelpText = "Path to whisper decoder.onnx. Used only for whisper models")]
public string WhisperDecoder { get; set; }

[Option("whisper-language", Required = false, Default = "", HelpText = "Language of the input file. Can be empty")]
public string WhisperLanguage{ get; set; }

[Option("whisper-task", Required = false, Default = "transcribe", HelpText = "transcribe or translate")]
public string WhisperTask{ get; set; }

[Option("tdnn-model", Required = false, Default = "", HelpText = "Path to tdnn yesno model")]
public string TdnnModel { get; set; }

Expand Down Expand Up @@ -193,6 +199,8 @@ private static void Run(Options options)
{
config.ModelConfig.Whisper.Encoder = options.WhisperEncoder;
config.ModelConfig.Whisper.Decoder = options.WhisperDecoder;
config.ModelConfig.Whisper.Language = options.WhisperLanguage;
config.ModelConfig.Whisper.Task = options.WhisperTask;
}
else if (!String.IsNullOrEmpty(options.TdnnModel))
{
Expand Down
2 changes: 2 additions & 0 deletions go-api-examples/non-streaming-decode-files/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ func main() {

flag.StringVar(&config.ModelConfig.Whisper.Encoder, "whisper-encoder", "", "Path to the whisper encoder model")
flag.StringVar(&config.ModelConfig.Whisper.Decoder, "whisper-decoder", "", "Path to the whisper decoder model")
flag.StringVar(&config.ModelConfig.Whisper.Language, "whisper-language", "", "Language of the input wave. You can leave it empty ")
flag.StringVar(&config.ModelConfig.Whisper.Task, "whisper-task", "transcribe", "transcribe or translate")

flag.StringVar(&config.ModelConfig.Tdnn.Model, "tdnn-model", "", "Path to the tdnn model")

Expand Down
2 changes: 2 additions & 0 deletions nodejs-examples/test-offline-nemo-ctc.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ function createOfflineRecognizer() {
whisper: {
encoder: '',
decoder: '',
language: '',
task: '',
},
tdnn: {
model: '',
Expand Down
2 changes: 2 additions & 0 deletions nodejs-examples/test-offline-paraformer.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ function createOfflineRecognizer() {
whisper: {
encoder: '',
decoder: '',
language: '',
task: '',
},
tdnn: {
model: '',
Expand Down
2 changes: 2 additions & 0 deletions nodejs-examples/test-offline-transducer.js
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ function createOfflineRecognizer() {
whisper: {
encoder: '',
decoder: '',
language: '',
task: '',
},
tdnn: {
model: '',
Expand Down
2 changes: 2 additions & 0 deletions nodejs-examples/test-offline-whisper.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ function createOfflineRecognizer() {
whisper: {
encoder: './sherpa-onnx-whisper-tiny.en/tiny.en-encoder.int8.onnx',
decoder: './sherpa-onnx-whisper-tiny.en/tiny.en-decoder.int8.onnx',
language: '',
task: 'transcribe',
},
tdnn: {
model: '',
Expand Down
8 changes: 8 additions & 0 deletions scripts/dotnet/offline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,20 @@ public OfflineWhisperModelConfig()
{
Encoder = "";
Decoder = "";
Language = "";
Task = "transcribe";
}
[MarshalAs(UnmanagedType.LPStr)]
public string Encoder;

[MarshalAs(UnmanagedType.LPStr)]
public string Decoder;

[MarshalAs(UnmanagedType.LPStr)]
public string Language;

[MarshalAs(UnmanagedType.LPStr)]
public string Task;
}

[StructLayout(LayoutKind.Sequential)]
Expand Down
12 changes: 10 additions & 2 deletions scripts/go/sherpa_onnx.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,10 @@ type OfflineNemoEncDecCtcModelConfig struct {
}

type OfflineWhisperModelConfig struct {
Encoder string
Decoder string
Encoder string
Decoder string
Language string
Task string
}

type OfflineTdnnModelConfig struct {
Expand Down Expand Up @@ -423,6 +425,12 @@ func NewOfflineRecognizer(config *OfflineRecognizerConfig) *OfflineRecognizer {
c.model_config.whisper.decoder = C.CString(config.ModelConfig.Whisper.Decoder)
defer C.free(unsafe.Pointer(c.model_config.whisper.decoder))

c.model_config.whisper.language = C.CString(config.ModelConfig.Whisper.Language)
defer C.free(unsafe.Pointer(c.model_config.whisper.language))

c.model_config.whisper.task = C.CString(config.ModelConfig.Whisper.Task)
defer C.free(unsafe.Pointer(c.model_config.whisper.task))

c.model_config.tdnn.model = C.CString(config.ModelConfig.Tdnn.Model)
defer C.free(unsafe.Pointer(c.model_config.tdnn.model))

Expand Down
50 changes: 24 additions & 26 deletions sherpa-onnx/c-api/c-api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@

#include "sherpa-onnx/csrc/circular-buffer.h"
#include "sherpa-onnx/csrc/display.h"
#include "sherpa-onnx/csrc/keyword-spotter.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/offline-tts.h"
#include "sherpa-onnx/csrc/online-recognizer.h"
#include "sherpa-onnx/csrc/voice-activity-detector.h"
#include "sherpa-onnx/csrc/wave-writer.h"
#include "sherpa-onnx/csrc/keyword-spotter.h"

struct SherpaOnnxOnlineRecognizer {
std::unique_ptr<sherpa_onnx::OnlineRecognizer> impl;
Expand Down Expand Up @@ -301,6 +301,9 @@ SherpaOnnxOfflineRecognizer *CreateOfflineRecognizer(
recognizer_config.model_config.whisper.language =
SHERPA_ONNX_OR(config->model_config.whisper.language, "");

recognizer_config.model_config.whisper.task =
SHERPA_ONNX_OR(config->model_config.whisper.task, "transcribe");

recognizer_config.model_config.tdnn.model =
SHERPA_ONNX_OR(config->model_config.tdnn.model, "");

Expand Down Expand Up @@ -422,8 +425,8 @@ struct SherpaOnnxKeywordSpotter {
std::unique_ptr<sherpa_onnx::KeywordSpotter> impl;
};

SherpaOnnxKeywordSpotter* CreateKeywordSpotter(
const SherpaOnnxKeywordSpotterConfig* config) {
SherpaOnnxKeywordSpotter *CreateKeywordSpotter(
const SherpaOnnxKeywordSpotterConfig *config) {
sherpa_onnx::KeywordSpotterConfig spotter_config;

spotter_config.feat_config.sampling_rate =
Expand Down Expand Up @@ -457,20 +460,17 @@ SherpaOnnxKeywordSpotter* CreateKeywordSpotter(
spotter_config.model_config.debug =
SHERPA_ONNX_OR(config->model_config.debug, 0);

spotter_config.max_active_paths =
SHERPA_ONNX_OR(config->max_active_paths, 4);
spotter_config.max_active_paths = SHERPA_ONNX_OR(config->max_active_paths, 4);

spotter_config.num_trailing_blanks =
SHERPA_ONNX_OR(config->num_trailing_blanks , 1);
SHERPA_ONNX_OR(config->num_trailing_blanks, 1);

spotter_config.keywords_score =
SHERPA_ONNX_OR(config->keywords_score, 1.0);
spotter_config.keywords_score = SHERPA_ONNX_OR(config->keywords_score, 1.0);

spotter_config.keywords_threshold =
SHERPA_ONNX_OR(config->keywords_threshold, 0.25);

spotter_config.keywords_file =
SHERPA_ONNX_OR(config->keywords_file, "");
spotter_config.keywords_file = SHERPA_ONNX_OR(config->keywords_file, "");

if (config->model_config.debug) {
SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str());
Expand All @@ -481,39 +481,37 @@ SherpaOnnxKeywordSpotter* CreateKeywordSpotter(
return nullptr;
}

SherpaOnnxKeywordSpotter* spotter = new SherpaOnnxKeywordSpotter;
SherpaOnnxKeywordSpotter *spotter = new SherpaOnnxKeywordSpotter;

spotter->impl =
std::make_unique<sherpa_onnx::KeywordSpotter>(spotter_config);
spotter->impl = std::make_unique<sherpa_onnx::KeywordSpotter>(spotter_config);

return spotter;
}

void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter* spotter) {
void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter *spotter) {
delete spotter;
}

SherpaOnnxOnlineStream* CreateKeywordStream(
const SherpaOnnxKeywordSpotter* spotter) {
SherpaOnnxOnlineStream* stream =
SherpaOnnxOnlineStream *CreateKeywordStream(
const SherpaOnnxKeywordSpotter *spotter) {
SherpaOnnxOnlineStream *stream =
new SherpaOnnxOnlineStream(spotter->impl->CreateStream());
return stream;
}

int32_t IsKeywordStreamReady(
SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream) {
int32_t IsKeywordStreamReady(SherpaOnnxKeywordSpotter *spotter,
SherpaOnnxOnlineStream *stream) {
return spotter->impl->IsReady(stream->impl.get());
}

void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter,
SherpaOnnxOnlineStream* stream) {
void DecodeKeywordStream(SherpaOnnxKeywordSpotter *spotter,
SherpaOnnxOnlineStream *stream) {
return spotter->impl->DecodeStream(stream->impl.get());
}

void DecodeMultipleKeywordStreams(
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams,
int32_t n) {
std::vector<sherpa_onnx::OnlineStream*> ss(n);
void DecodeMultipleKeywordStreams(SherpaOnnxKeywordSpotter *spotter,
SherpaOnnxOnlineStream **streams, int32_t n) {
std::vector<sherpa_onnx::OnlineStream *> ss(n);
for (int32_t i = 0; i != n; ++i) {
ss[i] = streams[i]->impl.get();
}
Expand All @@ -522,7 +520,7 @@ void DecodeMultipleKeywordStreams(

const SherpaOnnxKeywordResult *GetKeywordResult(
SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) {
const sherpa_onnx::KeywordResult& result =
const sherpa_onnx::KeywordResult &result =
spotter->impl->GetResult(stream->impl.get());
const auto &keyword = result.keyword;

Expand Down
35 changes: 17 additions & 18 deletions sherpa-onnx/c-api/c-api.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineWhisperModelConfig {
const char *encoder;
const char *decoder;
const char *language;
const char *task;
} SherpaOnnxOfflineWhisperModelConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTdnnModelConfig {
Expand Down Expand Up @@ -483,19 +484,19 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult {
/// For Chinese, it consists of Chinese words without spaces.
/// Example 1: "hello world"
/// Example 2: "你好世界"
const char* keyword;
const char *keyword;

/// Decoded results at the token level.
/// For instance, for BPE-based models it consists of a list of BPE tokens.
const char* tokens;
const char *tokens;

const char* const* tokens_arr;
const char *const *tokens_arr;

int32_t count;

/// timestamps.size() == tokens.size()
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
float* timestamps;
float *timestamps;

/// Starting time of this segment.
/// When an endpoint is detected, it will change
Expand All @@ -511,7 +512,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordResult {
* "start_time": x,
* }
*/
const char* json;
const char *json;
} SherpaOnnxKeywordResult;

SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig {
Expand All @@ -521,7 +522,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotterConfig {
int32_t num_trailing_blanks;
float keywords_score;
float keywords_threshold;
const char* keywords_file;
const char *keywords_file;
} SherpaOnnxKeywordSpotterConfig;

SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter
Expand All @@ -530,36 +531,35 @@ SHERPA_ONNX_API typedef struct SherpaOnnxKeywordSpotter
/// @param config Config for the keyword spotter.
/// @return Return a pointer to the spotter. The user has to invoke
/// DestroyKeywordSpotter() to free it to avoid memory leak.
SHERPA_ONNX_API SherpaOnnxKeywordSpotter* CreateKeywordSpotter(
const SherpaOnnxKeywordSpotterConfig* config);
SHERPA_ONNX_API SherpaOnnxKeywordSpotter *CreateKeywordSpotter(
const SherpaOnnxKeywordSpotterConfig *config);

/// Free a pointer returned by CreateKeywordSpotter()
///
/// @param p A pointer returned by CreateKeywordSpotter()
SHERPA_ONNX_API void DestroyKeywordSpotter(
SherpaOnnxKeywordSpotter* spotter);
SHERPA_ONNX_API void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter *spotter);

/// Create an online stream for accepting wave samples.
///
/// @param spotter A pointer returned by CreateKeywordSpotter()
/// @return Return a pointer to an OnlineStream. The user has to invoke
/// DestroyOnlineStream() to free it to avoid memory leak.
SHERPA_ONNX_API SherpaOnnxOnlineStream* CreateKeywordStream(
const SherpaOnnxKeywordSpotter* spotter);
SHERPA_ONNX_API SherpaOnnxOnlineStream *CreateKeywordStream(
const SherpaOnnxKeywordSpotter *spotter);

/// Return 1 if there are enough number of feature frames for decoding.
/// Return 0 otherwise.
///
/// @param spotter A pointer returned by CreateKeywordSpotter
/// @param stream A pointer returned by CreateKeywordStream
SHERPA_ONNX_API int32_t IsKeywordStreamReady(
SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream);
SHERPA_ONNX_API int32_t IsKeywordStreamReady(SherpaOnnxKeywordSpotter *spotter,
SherpaOnnxOnlineStream *stream);

/// Call this function to run the neural network model and decoding.
//
/// Precondition for this function: IsKeywordStreamReady() MUST return 1.
SHERPA_ONNX_API void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter,
SherpaOnnxOnlineStream* stream);
SHERPA_ONNX_API void DecodeKeywordStream(SherpaOnnxKeywordSpotter *spotter,
SherpaOnnxOnlineStream *stream);

/// This function is similar to DecodeKeywordStream(). It decodes multiple
/// OnlineStream in parallel.
Expand Down Expand Up @@ -588,8 +588,7 @@ SHERPA_ONNX_API const SherpaOnnxKeywordResult *GetKeywordResult(
/// Destroy the pointer returned by GetKeywordResult().
///
/// @param r A pointer returned by GetKeywordResult()
SHERPA_ONNX_API void DestroyKeywordResult(
const SherpaOnnxKeywordResult *r);
SHERPA_ONNX_API void DestroyKeywordResult(const SherpaOnnxKeywordResult *r);

// ============================================================
// For VAD
Expand Down
3 changes: 2 additions & 1 deletion sherpa-onnx/csrc/offline-tts-vits-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ class OfflineTtsVitsModel::Impl {
inputs.push_back(std::move(length_scale_tensor));
inputs.push_back(std::move(noise_scale_w_tensor));

if (input_names_.size() == 6 && input_names_.back() == "sid") {
if (input_names_.size() == 6 &&
(input_names_.back() == "sid" || input_names_.back() == "speaker")) {
inputs.push_back(std::move(sid_tensor));
}

Expand Down
4 changes: 3 additions & 1 deletion sherpa-onnx/csrc/transducer-keyword-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
//
// Copyright (c) 2023-2024 Xiaomi Corporation

#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"

#include <algorithm>
#include <cmath>
#include <cstring>
#include <utility>
#include <vector>

#include "sherpa-onnx/csrc/log.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/transducer-keyword-decoder.h"

namespace sherpa_onnx {

Expand Down
Loading
Loading