Skip to content

Commit

Permalink
removing unused params form vad_iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
mgonzs13 committed Dec 26, 2024
1 parent dd32b5e commit 870c655
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 63 deletions.
8 changes: 2 additions & 6 deletions whisper_bringup/launch/silero-vad.launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,9 @@ def run_silero_vad(context: LaunchContext, repo, file, model_path):
"frame_size_ms": LaunchConfiguration("frame_size_ms", default=32),
"threshold": LaunchConfiguration("threshold", default=0.5),
"min_silence_ms": LaunchConfiguration(
"min_silence_ms", default=0
),
"speech_pad_ms": LaunchConfiguration("speech_pad_ms", default=32),
"min_speech_ms": LaunchConfiguration("min_speech_ms", default=32),
"max_speech_s": LaunchConfiguration(
"max_speech_s", default=float("inf")
"min_silence_ms", default=100
),
"speech_pad_ms": LaunchConfiguration("speech_pad_ms", default=30),
}
],
remappings=[("audio", "/audio/in")],
Expand Down
2 changes: 0 additions & 2 deletions whisper_ros/include/silero_vad/silero_vad_node.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ class SileroVadNode : public rclcpp_lifecycle::LifecycleNode {
float threshold_;
int min_silence_ms_;
int speech_pad_ms_;
int min_speech_ms_;
float max_speech_s_;

rclcpp::Publisher<std_msgs::msg::Float32MultiArray>::SharedPtr publisher_;
rclcpp::Subscription<audio_common_msgs::msg::AudioStamped>::SharedPtr
Expand Down
10 changes: 3 additions & 7 deletions whisper_ros/include/silero_vad/vad_iterator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ class VadIterator {
public:
VadIterator(const std::string &model_path, int sample_rate = 16000,
int frame_size_ms = 32, float threshold = 0.5f,
int min_silence_ms = 0, int speech_pad_ms = 32,
int min_speech_ms = 32,
float max_speech_s = std::numeric_limits<float>::infinity());
int min_silence_ms = 100, int speech_pad_ms = 30);

void reset_states();
Timestamp predict(const std::vector<float> &data);
Expand All @@ -58,11 +56,9 @@ class VadIterator {
int sample_rate;
int sr_per_ms;
int64_t window_size_samples;
int min_speech_samples;
int speech_pad_samples;
float max_speech_samples;
unsigned int min_silence_samples;
unsigned int min_silence_samples_at_max_speech;
int context_size;

// Model state
bool triggered = false;
Expand All @@ -75,9 +71,9 @@ class VadIterator {
std::vector<const char *> input_node_names = {"input", "state", "sr"};

std::vector<float> input;
std::vector<float> context;
std::vector<float> state;
std::vector<int64_t> sr;
std::vector<float> context;

int64_t input_node_dims[2] = {};
const int64_t state_node_dims[3] = {2, 1, 128};
Expand Down
14 changes: 2 additions & 12 deletions whisper_ros/src/silero_vad/silero_vad_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,6 @@ SileroVadNode::SileroVadNode()
this->declare_parameter<float>("threshold", 0.5f);
this->declare_parameter<int>("min_silence_ms", 100);
this->declare_parameter<int>("speech_pad_ms", 30);
this->declare_parameter<int>("min_speech_ms", 32);
this->declare_parameter<float>("max_speech_s",
std::numeric_limits<float>::infinity());
}

rclcpp_lifecycle::node_interfaces::LifecycleNodeInterface::CallbackReturn
Expand All @@ -62,8 +59,6 @@ SileroVadNode::on_configure(const rclcpp_lifecycle::State &) {
this->get_parameter("threshold", this->threshold_);
this->get_parameter("min_silence_ms", this->min_silence_ms_);
this->get_parameter("speech_pad_ms", this->speech_pad_ms_);
this->get_parameter("min_speech_ms", this->min_speech_ms_);
this->get_parameter("max_speech_s", this->max_speech_s_);

RCLCPP_INFO(get_logger(), "[%s] Configured", this->get_name());

Expand All @@ -79,8 +74,7 @@ SileroVadNode::on_activate(const rclcpp_lifecycle::State &) {
// create silero-vad
this->vad_iterator = std::make_unique<VadIterator>(
this->model_path_, this->sample_rate_, this->frame_size_ms_,
this->threshold_, this->min_silence_ms_, this->speech_pad_ms_,
this->min_speech_ms_, this->max_speech_s_);
this->threshold_, this->min_silence_ms_, this->speech_pad_ms_);

this->publisher_ =
this->create_publisher<std_msgs::msg::Float32MultiArray>("vad", 10);
Expand Down Expand Up @@ -185,8 +179,6 @@ void SileroVadNode::audio_callback(

// Predict if speech starts or ends
auto timestamp = this->vad_iterator->predict(data);
// RCLCPP_INFO(this->get_logger(), "Timestampt: %s",
// timestamp.to_string().c_str());

// Check if speech starts
if (timestamp.start != -1 && timestamp.end == -1 && !this->listening) {
Expand All @@ -209,9 +201,7 @@ void SileroVadNode::audio_callback(
if (this->data.size() / msg->audio.info.rate < 1.0) {
int pad_size =
msg->audio.info.chunk + msg->audio.info.rate - this->data.size();
for (int i = 0; i < pad_size; i++) {
this->data.push_back(0.0);
}
this->data.insert(this->data.end(), pad_size, 0.0f);
}

this->listening.store(false);
Expand Down
2 changes: 1 addition & 1 deletion whisper_ros/src/silero_vad/timestamp.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// MIT License

// Copyright (c) 2023 Miguel Ángel González Santamarta
// Copyright (c) 2024 Miguel Ángel González Santamarta

// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand Down
77 changes: 42 additions & 35 deletions whisper_ros/src/silero_vad/vad_iterator.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// MIT License

// Copyright (c) 2023 Miguel Ángel González Santamarta
// Copyright (c) 2024 Miguel Ángel González Santamarta

// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
Expand All @@ -20,6 +20,7 @@
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

#include <algorithm>
#include <limits>
#include <memory>
#include <string>
Expand All @@ -31,32 +32,39 @@ using namespace silero_vad;

VadIterator::VadIterator(const std::string &model_path, int sample_rate,
int frame_size_ms, float threshold, int min_silence_ms,
int speech_pad_ms, int min_speech_ms,
float max_speech_s)
int speech_pad_ms)
: env(ORT_LOGGING_LEVEL_WARNING, "VadIterator"), threshold(threshold),
sample_rate(sample_rate), sr_per_ms(sample_rate / 1000),
window_size_samples(frame_size_ms * sr_per_ms),
min_speech_samples(sr_per_ms * min_speech_ms),
speech_pad_samples(sr_per_ms * speech_pad_ms),
max_speech_samples(sample_rate * max_speech_s - window_size_samples -
2 * speech_pad_samples),
min_silence_samples(sr_per_ms * min_silence_ms),
min_silence_samples_at_max_speech(sr_per_ms * 98),
state(2 * 1 * 128, 0.0f), sr(1, sample_rate), context(64, 0.0f) {
context_size(sample_rate == 16000 ? 64 : 32), context(context_size, 0.0f),
state(2 * 1 * 128, 0.0f), sr(1, sample_rate) {

// this->input.resize(window_size_samples);
this->input_node_dims[0] = 1;
this->input_node_dims[1] = window_size_samples;
this->init_onnx_model(model_path);

try {
this->init_onnx_model(model_path);
} catch (const std::exception &e) {
throw std::runtime_error("Failed to initialize ONNX model: " +
std::string(e.what()));
}
}

void VadIterator::init_onnx_model(const std::string &model_path) {
this->session_options.SetIntraOpNumThreads(1);
this->session_options.SetInterOpNumThreads(1);
this->session_options.SetGraphOptimizationLevel(
GraphOptimizationLevel::ORT_ENABLE_ALL);
this->session = std::make_shared<Ort::Session>(this->env, model_path.c_str(),
this->session_options);

try {
this->session = std::make_shared<Ort::Session>(
this->env, model_path.c_str(), this->session_options);
} catch (const std::exception &e) {
throw std::runtime_error("Failed to create ONNX session: " +
std::string(e.what()));
}
}

void VadIterator::reset_states() {
Expand All @@ -68,16 +76,13 @@ void VadIterator::reset_states() {
}

Timestamp VadIterator::predict(const std::vector<float> &data) {
// Create input tensors
// Pre-fill input with context
this->input.clear();
for (auto ele : this->context) {
this->input.push_back(ele);
}

for (auto ele : data) {
this->input.push_back(ele);
}
this->input.reserve(context.size() + data.size());
this->input.insert(input.end(), context.begin(), context.end());
this->input.insert(input.end(), data.begin(), data.end());

// Create input tensors
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
this->memory_info, this->input.data(), this->input.size(),
this->input_node_dims, 2);
Expand All @@ -95,20 +100,23 @@ Timestamp VadIterator::predict(const std::vector<float> &data) {
this->ort_inputs.emplace_back(std::move(sr_tensor));

// Run inference
this->ort_outputs = this->session->Run(
Ort::RunOptions{nullptr}, this->input_node_names.data(),
this->ort_inputs.data(), this->ort_inputs.size(),
this->output_node_names.data(), this->output_node_names.size());
try {
this->ort_outputs = session->Run(
Ort::RunOptions{nullptr}, this->input_node_names.data(),
this->ort_inputs.data(), this->ort_inputs.size(),
this->output_node_names.data(), this->output_node_names.size());
} catch (const std::exception &e) {
throw std::runtime_error("ONNX inference failed: " + std::string(e.what()));
}

// Process output
float speech_prob = this->ort_outputs[0].GetTensorMutableData<float>()[0];
float *updated_state = this->ort_outputs[1].GetTensorMutableData<float>();
std::copy(updated_state, updated_state + this->state.size(),
this->state.begin());

for (int i = 64; i > 0; i--) {
this->context.push_back(data.at(data.size() - i));
}
// Update context with the last 64 samples of data
this->context.assign(data.end() - context_size, data.end());

// Handle result
this->current_sample += this->window_size_samples;
Expand All @@ -119,10 +127,10 @@ Timestamp VadIterator::predict(const std::vector<float> &data) {
}

if (!this->triggered) {
int start_timestwamp = this->current_sample - this->speech_pad_samples -
this->window_size_samples;
this->triggered = true;
return Timestamp(this->current_sample - this->speech_pad_samples -
this->window_size_samples,
-1, speech_prob);
return Timestamp(start_timestwamp, -1, speech_prob);
}

} else if (speech_prob < this->threshold - 0.15 && this->triggered) {
Expand All @@ -131,12 +139,11 @@ Timestamp VadIterator::predict(const std::vector<float> &data) {
}

if (this->current_sample - this->temp_end >= this->min_silence_samples) {
this->temp_end = 0;
int end_timestamp =
this->temp_end + this->speech_pad_samples - this->window_size_samples;
this->triggered = false;
return Timestamp(-1,
this->temp_end + this->speech_pad_samples -
this->window_size_samples,
speech_prob);
this->temp_end = 0;
return Timestamp(-1, end_timestamp, speech_prob);
}
}

Expand Down

0 comments on commit 870c655

Please # to comment.