Skip to content

Commit

Permalink
can be initialize (register) audio_filter plugin inside Rust
Browse files Browse the repository at this point in the history
  • Loading branch information
typester committed Feb 24, 2025
1 parent db5709e commit 8799552
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 85 deletions.
5 changes: 3 additions & 2 deletions livekit-ffi/protocol/audio_frame.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ message NewAudioStreamRequest {
required AudioStreamType type = 2;
optional uint32 sample_rate = 3;
optional uint32 num_channels = 4;
optional string audio_filter_module_id = 5;
optional string audio_filter_module_id = 5; // Unique identifier passed in LoadAudioFilterPluginRequest
optional string audio_filter_options = 6;
}
message NewAudioStreamResponse { required OwnedAudioStream stream = 1; }
Expand Down Expand Up @@ -258,8 +258,9 @@ message OwnedSoxResampler {
message LoadAudioFilterPluginRequest {
required string plugin_path = 1; // path for ffi audio filter plugin
repeated string dependencies = 2; // Optional: paths for dependency dylibs
required string module_id = 3; // Unique identifier of the plugin
}

message LoadAudioFilterPluginResponse {
required FfiOwnedHandle handle = 1;
optional string error = 1;
}
1 change: 0 additions & 1 deletion livekit-ffi/protocol/room.proto
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ message RoomOptions {
optional E2eeOptions e2ee = 4;
optional RtcConfig rtc_config = 5; // allow to setup a custom RtcConfiguration
optional uint32 join_retries = 6;
repeated AudioFilterModule audio_filter_handles = 7;
}

//
Expand Down
10 changes: 6 additions & 4 deletions livekit-ffi/src/livekit.proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2634,8 +2634,6 @@ pub struct RoomOptions {
pub rtc_config: ::core::option::Option<RtcConfig>,
#[prost(uint32, optional, tag="6")]
pub join_retries: ::core::option::Option<u32>,
#[prost(message, repeated, tag="7")]
pub audio_filter_handles: ::prost::alloc::vec::Vec<AudioFilterModule>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
Expand Down Expand Up @@ -3442,6 +3440,7 @@ pub struct NewAudioStreamRequest {
pub sample_rate: ::core::option::Option<u32>,
#[prost(uint32, optional, tag="4")]
pub num_channels: ::core::option::Option<u32>,
/// Unique identifier passed in LoadAudioFilterPluginRequest
#[prost(string, optional, tag="5")]
pub audio_filter_module_id: ::core::option::Option<::prost::alloc::string::String>,
#[prost(string, optional, tag="6")]
Expand Down Expand Up @@ -3779,12 +3778,15 @@ pub struct LoadAudioFilterPluginRequest {
/// Optional: paths for dependency dylibs
#[prost(string, repeated, tag="2")]
pub dependencies: ::prost::alloc::vec::Vec<::prost::alloc::string::String>,
/// Unique identifier of the plugin
#[prost(string, required, tag="3")]
pub module_id: ::prost::alloc::string::String,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct LoadAudioFilterPluginResponse {
#[prost(message, required, tag="1")]
pub handle: FfiOwnedHandle,
#[prost(string, optional, tag="1")]
pub error: ::core::option::Option<::prost::alloc::string::String>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]
#[repr(i32)]
Expand Down
14 changes: 1 addition & 13 deletions livekit-ffi/src/server/audio_plugin.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,14 @@
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};

use futures_util::Stream;
use livekit::{
webrtc::{audio_stream::native::NativeAudioStream, prelude::AudioFrame},
AudioFilterAudioStream, AudioFilterPlugin,
AudioFilterAudioStream,
};

use super::FfiHandle;
use crate::FfiHandleId;

#[derive(Clone)]
pub struct FfiAudioFilterPlugin {
pub handle_id: FfiHandleId,
pub plugin: Arc<AudioFilterPlugin>,
}

impl FfiHandle for FfiAudioFilterPlugin {}

pub trait AudioStream: Stream<Item = AudioFrame<'static>> + Send + Sync + Unpin {
fn close(&mut self);
}
Expand Down
46 changes: 24 additions & 22 deletions livekit-ffi/src/server/audio_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use std::time::Duration;
use futures_util::StreamExt;
use livekit::track::Track;
use livekit::webrtc::{audio_stream::native::NativeAudioStream, prelude::*};
use livekit::{AudioFilterAudioStream, AudioFilterStreamInfo};
use livekit::{registered_audio_filter_plugin, AudioFilterAudioStream, AudioFilterStreamInfo};
use tokio::sync::{broadcast, mpsc, oneshot};

use super::audio_plugin::AudioStreamKind;
Expand Down Expand Up @@ -66,15 +66,20 @@ impl FfiAudioStream {
));
};
let room = server.retrieve_handle::<FfiRoom>(room_handle)?.clone();
let Some(filter) = room.inner.audio_filter_handle(server, module_id) else {
let Some(filter) = registered_audio_filter_plugin(module_id) else {
return Err(FfiError::InvalidRequest(
"the audio filter wasn't associated with the room".into(),
));
};

let stream_info = AudioFilterStreamInfo {
url: room.inner.url(),
room_id: room.inner.room.maybe_sid().map(|sid| sid.to_string()).unwrap_or("".into()),
room_id: room
.inner
.room
.maybe_sid()
.map(|sid| sid.to_string())
.unwrap_or("".into()),
room_name: room.inner.room.name(),
participant_identity: room.inner.room.local_participant().identity().into(),
participant_id: room.inner.room.local_participant().name(),
Expand All @@ -99,7 +104,7 @@ impl FfiAudioStream {
NativeAudioStream::new(rtc_track, sample_rate as i32, num_channels as i32);

let stream = if let Some(audio_filter) = &audio_filter {
let Some(session) = audio_filter.plugin.clone().new_session(
let Some(session) = audio_filter.clone().new_session(
sample_rate,
new_stream.audio_filter_options.unwrap_or("".into()),
stream_info,
Expand Down Expand Up @@ -206,7 +211,7 @@ impl FfiAudioStream {
let participant_identity = ffi_participant.participant.identity();
let participant_id = ffi_participant.participant.sid();
let filter = match &request.audio_filter_module_id {
Some(module_id) => ffi_participant.room.audio_filter_handle(server, module_id),
Some(module_id) => registered_audio_filter_plugin(module_id),
None => None,
};

Expand Down Expand Up @@ -241,27 +246,24 @@ impl FfiAudioStream {
});

let mut audio_filter_session = match &filter {
Some(filter) => {
match &request.audio_filter_options {
Some(options) => {
let stream_info = AudioFilterStreamInfo {
url: url.clone(),
room_id: room_sid.clone().into(),
room_name: room_name.clone(),
participant_identity: participant_identity.clone().into(),
participant_id: participant_id.clone().into(),
track_id: track.sid().into(),
};

filter.plugin.clone().new_session(sample_rate as u32, &options, stream_info)
},
None => None,
Some(filter) => match &request.audio_filter_options {
Some(options) => {
let stream_info = AudioFilterStreamInfo {
url: url.clone(),
room_id: room_sid.clone().into(),
room_name: room_name.clone(),
participant_identity: participant_identity.clone().into(),
participant_id: participant_id.clone().into(),
track_id: track.sid().into(),
};

filter.clone().new_session(sample_rate as u32, &options, stream_info)
}
}
None => None,
},
None => None,
};


let native_stream = NativeAudioStream::new(rtc_track, sample_rate, num_channels);

let stream = if let Some(session) = audio_filter_session.take() {
Expand Down
10 changes: 4 additions & 6 deletions livekit-ffi/src/server/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ use std::{slice, sync::Arc};
use colorcvt::cvtimpl;
use livekit::{
prelude::*,
register_audio_filter_plugin,
webrtc::{native::audio_resampler, prelude::*},
AudioFilterPlugin,
};
use parking_lot::Mutex;

use super::{
audio_plugin::FfiAudioFilterPlugin,
audio_source, audio_stream, colorcvt,
participant::FfiParticipant,
resampler,
Expand Down Expand Up @@ -924,18 +924,16 @@ fn on_set_data_channel_buffered_amount_low_threshold(
}

fn on_load_audio_filter_plugin(
server: &'static FfiServer,
_server: &'static FfiServer,
request: proto::LoadAudioFilterPluginRequest,
) -> FfiResult<proto::LoadAudioFilterPluginResponse> {
let deps: Vec<_> = request.dependencies.iter().map(|d| d).collect();
let plugin = AudioFilterPlugin::new_with_dependencies(&request.plugin_path, deps)
.map_err(|e| FfiError::InvalidRequest(format!("plugin error: {}", e).into()))?;

let handle_id = server.next_id();
let ffi_plugin = FfiAudioFilterPlugin { handle_id, plugin };
server.store_handle(handle_id, ffi_plugin);
register_audio_filter_plugin(request.module_id, plugin);

Ok(proto::LoadAudioFilterPluginResponse { handle: proto::FfiOwnedHandle { id: handle_id } })
Ok(proto::LoadAudioFilterPluginResponse { error: None })
}

fn on_set_track_subscription_permissions(
Expand Down
46 changes: 12 additions & 34 deletions livekit-ffi/src/server/room.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@ use std::collections::HashMap;
use std::time::Duration;
use std::{collections::HashSet, slice, sync::Arc};

use livekit::prelude::*;
use livekit::ChatMessage;
use livekit::{prelude::*, registered_audio_filter_plugins};
use livekit_protocol as lk_proto;
use parking_lot::Mutex;
use tokio::sync::{broadcast, mpsc, oneshot, Mutex as AsyncMutex};
use tokio::task::JoinHandle;

use super::audio_plugin::FfiAudioFilterPlugin;
use super::FfiDataBuffer;
use crate::{
proto,
Expand Down Expand Up @@ -73,7 +72,7 @@ pub struct RoomInner {
// Used to forward RPC method invocation to the FfiClient and collect their results
rpc_method_invocation_waiters: Mutex<HashMap<u64, oneshot::Sender<Result<String, RpcError>>>>,

audio_filter_handles: Arc<HashMap<String, FfiHandleId>>,
// ws url associated with this room
url: String,
}

Expand Down Expand Up @@ -140,38 +139,32 @@ impl FfiRoom {
unreachable!("Connected event should always be the first event");
};

// initialize audio filter
let audio_filter_handles = server
// initialize audio filters
let result = server
.async_runtime
.spawn_blocking(move || {
let mut handles = HashMap::new();
for h in req.options.audio_filter_handles.iter() {
let filter = server
.retrieve_handle::<FfiAudioFilterPlugin>(h.handle_id)
.map_err(|e| e.to_string())?
.clone();
filter.plugin.on_load(&req.url, &req.token).map_err(|e| e.to_string())?;

handles.insert(h.module_id.clone(), h.handle_id);
for filter in registered_audio_filter_plugins().into_iter() {
filter.on_load(&req.url, &req.token).map_err(|e| e.to_string())?;
}
Ok::<HashMap<String, FfiHandleId>, String>(handles)
Ok::<(), String>(())
})
.await
.map_err(|e| e.to_string());

let audio_filter_handles = match audio_filter_handles {
match result {
Err(e) | Ok(Err(e)) => {
log::error!("error while initializing audio filter: {}", e);
let _ = server.send_event(proto::ffi_event::Message::Connect(
proto::ConnectCallback {
async_id,
message: Some(proto::connect_callback::Message::Error(e.to_string())),
message: Some(proto::connect_callback::Message::Error(
e.to_string(),
)),
..Default::default()
},
));
return;
}
Ok(Ok(handles)) => Arc::new(handles),
Ok(Ok(_)) => (),
};

let (data_tx, data_rx) = mpsc::unbounded_channel();
Expand All @@ -190,7 +183,6 @@ impl FfiRoom {
pending_unpublished_tracks: Default::default(),
track_handle_lookup: Default::default(),
rpc_method_invocation_waiters: Default::default(),
audio_filter_handles,
url: connect.url,
});

Expand Down Expand Up @@ -846,20 +838,6 @@ impl RoomInner {
proto::SetTrackSubscriptionPermissionsResponse {}
}

pub fn audio_filter_handle<S: AsRef<str>>(
&self,
server: &'static FfiServer,
module_id: S,
) -> Option<FfiAudioFilterPlugin> {
let Some(&handle_id) = self.audio_filter_handles.get(module_id.as_ref()) else {
return None;
};
let Some(plugin) = server.retrieve_handle::<FfiAudioFilterPlugin>(handle_id).ok() else {
return None;
};
Some(plugin.clone())
}

pub fn url(&self) -> String {
self.url.clone()
}
Expand Down
27 changes: 24 additions & 3 deletions livekit/src/plugin.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use std::{
collections::HashMap,
ffi::{c_char, c_void, CString},
pin::Pin,
sync::Arc,
sync::{Arc, LazyLock},
task::{Context, Poll},
time::Duration,
};

use futures_util::Stream;
use libloading::{Library, Symbol};
use libwebrtc::{audio_stream::native::NativeAudioStream, prelude::AudioFrame};
use parking_lot::RwLock;
use serde::Serialize;
use serde_json::json;

Expand All @@ -23,11 +25,30 @@ pub enum PluginError {
}

type OnLoadFn = unsafe extern "C" fn(options: *const c_char) -> i32;
type CreateFn = unsafe extern "C" fn(sampling_rate: u32, options: *const c_char, stream_info: *const c_char) -> *mut c_void;
type CreateFn = unsafe extern "C" fn(
sampling_rate: u32,
options: *const c_char,
stream_info: *const c_char,
) -> *mut c_void;
type DestroyFn = unsafe extern "C" fn(*const c_void);
type ProcessI16Fn = unsafe extern "C" fn(*const c_void, usize, *const i16, *mut i16);
type ProcessF32Fn = unsafe extern "C" fn(*const c_void, usize, *const f32, *mut f32);

static REGISTERED_PLUGINS: LazyLock<RwLock<HashMap<String, Arc<AudioFilterPlugin>>>> =
LazyLock::new(|| RwLock::new(HashMap::new()));

pub fn register_audio_filter_plugin(id: String, plugin: Arc<AudioFilterPlugin>) {
REGISTERED_PLUGINS.write().insert(id, plugin);
}

pub fn registered_audio_filter_plugin(id: &str) -> Option<Arc<AudioFilterPlugin>> {
REGISTERED_PLUGINS.read().get(id).cloned()
}

pub fn registered_audio_filter_plugins() -> Vec<Arc<AudioFilterPlugin>> {
REGISTERED_PLUGINS.read().values().map(|v| v.clone()).collect()
}

pub struct AudioFilterPlugin {
lib: Library,
dependencies: Vec<Library>,
Expand Down Expand Up @@ -143,7 +164,7 @@ impl AudioFilterPlugin {

let options = CString::new(options.as_ref()).unwrap_or(CString::new("").unwrap());

let stream_info = serde_json::to_string( &stream_info).unwrap();
let stream_info = serde_json::to_string(&stream_info).unwrap();
let stream_info = CString::new(stream_info).unwrap_or(CString::new("").unwrap());

let ptr = unsafe { create_fn(sampling_rate, options.as_ptr(), stream_info.as_ptr()) };
Expand Down

0 comments on commit 8799552

Please # to comment.