Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Unsubscribe streams #37

Merged
merged 1 commit into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,12 @@ impl<'a> RpcMethod<'a> {
request,
move |params: #args_struct_ident| async move {
let stream = self.#method_ident(#(#method_args),*).await?;
let notifier = ::std::sync::Arc::new(::nimiq_jsonrpc_server::Notify::new());
let listener = notifier.clone();

let subscription = ::nimiq_jsonrpc_server::connect_stream(stream, tx, stream_id, #method_name.to_owned());
let subscription = ::nimiq_jsonrpc_server::connect_stream(stream, tx, stream_id, #method_name.to_owned(), listener);

Ok::<_, ::nimiq_jsonrpc_core::RpcError>(subscription)
Ok::<_, ::nimiq_jsonrpc_core::RpcError>((subscription, Some(notifier)))
}
).await
}
Expand All @@ -171,7 +173,7 @@ impl<'a> RpcMethod<'a> {
return ::nimiq_jsonrpc_server::dispatch_method_with_args(
request,
move |params: #args_struct_ident| async move {
Ok::<_, ::nimiq_jsonrpc_core::RpcError>(self.#method_ident(#(#method_args),*).await?)
Ok::<(_, Option<::std::sync::Arc<::nimiq_jsonrpc_server::Notify>>), ::nimiq_jsonrpc_core::RpcError>((self.#method_ident(#(#method_args),*).await?, None))
}
).await
}
Expand Down
2 changes: 1 addition & 1 deletion derive/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ fn impl_service(im: &mut ItemImpl, args: &ServiceMeta) -> TokenStream {
request: ::nimiq_jsonrpc_core::Request,
tx: Option<&::tokio::sync::mpsc::Sender<::nimiq_jsonrpc_server::Message>>,
stream_id: u64,
) -> Option<::nimiq_jsonrpc_core::Response> {
) -> Option<::nimiq_jsonrpc_server::ResponseAndSubScriptionNotifier> {
match request.method.as_str() {
#(#match_arms)*
_ => ::nimiq_jsonrpc_server::method_not_found(request),
Expand Down
147 changes: 118 additions & 29 deletions server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#![warn(rustdoc::missing_doc_code_examples)]

use std::{
collections::HashSet,
collections::{HashMap, HashSet},
error,
fmt::{self, Debug},
future::Future,
Expand Down Expand Up @@ -51,8 +51,12 @@ use nimiq_jsonrpc_core::{
};

pub use axum::extract::ws::Message;
pub use tokio::sync::Notify;
use tower_http::cors::{Any, CorsLayer};

/// Type defining a response and a possible notify handle used to terminate a subscription stream
pub type ResponseAndSubScriptionNotifier = (Response, Option<Arc<Notify>>);

/// A server error.
#[derive(Debug, Error)]
pub enum Error {
Expand Down Expand Up @@ -245,6 +249,7 @@ struct Inner<D: Dispatcher> {
config: Config,
dispatcher: RwLock<D>,
next_id: AtomicU64,
subscription_notifiers: RwLock<HashMap<SubscriptionId, Arc<Notify>>>,
}

/// A JSON-RPC server.
Expand All @@ -266,6 +271,7 @@ impl<D: Dispatcher> Server<D> {
config,
dispatcher: RwLock::new(dispatcher),
next_id: AtomicU64::new(1),
subscription_notifiers: RwLock::new(HashMap::new()),
}),
}
}
Expand Down Expand Up @@ -450,7 +456,7 @@ impl<D: Dispatcher> Server<D> {
match request {
SingleOrBatch::Single(request) => Self::handle_single_request(inner, request, tx)
.await
.map(SingleOrBatch::Single),
.map(|(response, _)| SingleOrBatch::Single(response)),

SingleOrBatch::Batch(requests) => {
let futures = requests
Expand All @@ -459,7 +465,7 @@ impl<D: Dispatcher> Server<D> {
.collect::<FuturesUnordered<_>>();

let responses = futures
.filter_map(|response_opt| async { response_opt })
.filter_map(|response_opt| async { response_opt.map(|(response, _)| response) })
.collect::<Vec<Response>>()
.await;

Expand All @@ -469,15 +475,15 @@ impl<D: Dispatcher> Server<D> {
}

/// Handles a single JSON RPC request
///
/// # TODO
///
/// - Handle subscriptions
async fn handle_single_request(
inner: Arc<Inner<D>>,
request: Request,
tx: Option<&mpsc::Sender<Message>>,
) -> Option<Response> {
) -> Option<ResponseAndSubScriptionNotifier> {
if request.method == "unsubscribe" {
return Self::handle_unsubscribe_stream(request, inner).await;
}

let mut dispatcher = inner.dispatcher.write().await;
// This ID is only used for streams
let id = inner.next_id.fetch_add(1, Ordering::SeqCst);
Expand All @@ -488,8 +494,65 @@ impl<D: Dispatcher> Server<D> {

log::debug!("response: {:#?}", response);

if let Some((_, Some(ref handler))) = response {
inner
.subscription_notifiers
.write()
.await
.insert(SubscriptionId::Number(id), handler.clone());
}

response
}

async fn handle_unsubscribe_stream(
request: Request,
inner: Arc<Inner<D>>,
) -> Option<ResponseAndSubScriptionNotifier> {
let params = if let Some(params) = request.params {
params
} else {
return error_response(request.id, || {
RpcError::invalid_request(Some(
"Missing request parameter containing a list of subscription ids".to_owned(),
))
});
};

let subscription_ids =
if let Ok(ids) = serde_json::from_value::<Vec<SubscriptionId>>(params) {
ids
} else {
return error_response(request.id, || {
RpcError::invalid_params(Some(
"A list of subscription ids is not provided".to_owned(),
))
});
};

if subscription_ids.is_empty() {
return error_response(request.id, || {
RpcError::invalid_params(Some("Empty list of subscription ids provided".to_owned()))
});
}

let mut terminated_streams = vec![];
let mut subscription_notifiers = inner.subscription_notifiers.write().await;
for id in subscription_ids.iter() {
if let Some(notifier) = subscription_notifiers.remove(id) {
notifier.notify_one();
terminated_streams.push(id);
}
}

Some((
Response::new_success(
serde_json::to_value(request.id.unwrap_or_default()).unwrap(),
serde_json::to_value(terminated_streams).unwrap(),
),
None,
))
}
}

/// A method dispatcher. These take a request and handle the method execution. Can be generated from an `impl` block
Expand All @@ -502,7 +565,7 @@ pub trait Dispatcher: Send + Sync + 'static {
request: Request,
tx: Option<&mpsc::Sender<Message>>,
id: u64,
) -> Option<Response>;
) -> Option<ResponseAndSubScriptionNotifier>;

/// Returns whether a method should be dispatched with this dispatcher.
///
Expand Down Expand Up @@ -542,7 +605,7 @@ impl Dispatcher for ModularDispatcher {
request: Request,
tx: Option<&mpsc::Sender<Message>>,
id: u64,
) -> Option<Response> {
) -> Option<ResponseAndSubScriptionNotifier> {
for dispatcher in &mut self.dispatchers {
let m = dispatcher.match_method(&request.method);
log::debug!("Matching '{}' against dispatcher -> {}", request.method, m);
Expand Down Expand Up @@ -611,7 +674,7 @@ where
request: Request,
tx: Option<&mpsc::Sender<Message>>,
id: u64,
) -> Option<Response> {
) -> Option<ResponseAndSubScriptionNotifier> {
if self.is_allowed(&request.method) {
log::debug!("Dispatching method: {}", request.method);
self.inner.dispatch(request, tx, id).await
Expand Down Expand Up @@ -649,13 +712,16 @@ where
/// - Currently this always expects an object with named parameters. Do we want to accept a list too?
/// - Merge with it's other variant, as a function call without arguments is just one with `()` as request parameter.
///
pub async fn dispatch_method_with_args<P, R, E, F, Fut>(request: Request, f: F) -> Option<Response>
pub async fn dispatch_method_with_args<P, R, E, F, Fut>(
request: Request,
f: F,
) -> Option<ResponseAndSubScriptionNotifier>
where
P: for<'de> Deserialize<'de> + Send,
R: Serialize,
RpcError: From<E>,
F: FnOnce(P) -> Fut + Send,
Fut: Future<Output = Result<R, E>> + Send,
Fut: Future<Output = Result<(R, Option<Arc<Notify>>), E>> + Send,
{
let params = match request.params {
Some(params) => params,
Expand Down Expand Up @@ -683,12 +749,15 @@ where
///
/// This is a helper function used by implementations of `Dispatcher`.
///
pub async fn dispatch_method_without_args<R, E, F, Fut>(request: Request, f: F) -> Option<Response>
pub async fn dispatch_method_without_args<R, E, F, Fut>(
request: Request,
f: F,
) -> Option<ResponseAndSubScriptionNotifier>
where
R: Serialize,
RpcError: From<E>,
F: FnOnce() -> Fut + Send,
Fut: Future<Output = Result<R, E>> + Send,
Fut: Future<Output = Result<(R, Option<Arc<Notify>>), E>> + Send,
{
let result = f().await;

Expand All @@ -707,17 +776,20 @@ where
}

/// Constructs a [`Response`] if necessary (i.e., if the request ID was set).
fn response<R, E>(id_opt: Option<Value>, result: Result<R, E>) -> Option<Response>
fn response<R, E>(
id_opt: Option<Value>,
result: Result<(R, Option<Arc<Notify>>), E>,
) -> Option<ResponseAndSubScriptionNotifier>
where
R: Serialize,
RpcError: From<E>,
{
let response = match (id_opt, result) {
(Some(id), Ok(retval)) => {
let retval = serde_json::to_value(retval).expect("Failed to serialize return value");
Some(Response::new_success(id, retval))
(Some(id), Ok((value, subscription))) => {
let retval = serde_json::to_value(value).expect("Failed to serialize return value");
Some((Response::new_success(id, retval), subscription))
}
(Some(id), Err(e)) => Some(Response::new_error(id, RpcError::from(e))),
(Some(id), Err(e)) => Some((Response::new_error(id, RpcError::from(e)), None)),
(None, _) => None,
};

Expand All @@ -733,22 +805,22 @@ where
/// - `id_opt`: The ID field from the request.
/// - `e`: A function that returns the error. This is only called, if we actually can respond with an error.
///
pub fn error_response<E>(id_opt: Option<Value>, e: E) -> Option<Response>
pub fn error_response<E>(id_opt: Option<Value>, e: E) -> Option<ResponseAndSubScriptionNotifier>
where
E: FnOnce() -> RpcError,
{
if let Some(id) = id_opt {
let e = e();
log::error!("Error response: {:?}", e);
Some(Response::new_error(id, e))
Some((Response::new_error(id, e), None))
} else {
None
}
}

/// Returns an error response for a method that was not found. This returns `None`, if the request doesn't expect a
/// response.
pub fn method_not_found(request: Request) -> Option<Response> {
pub fn method_not_found(request: Request) -> Option<ResponseAndSubScriptionNotifier> {
let ::nimiq_jsonrpc_core::Request { id, method, .. } = request;

error_response(id, || {
Expand Down Expand Up @@ -798,6 +870,7 @@ pub fn connect_stream<T, S>(
tx: &mpsc::Sender<Message>,
stream_id: u64,
method: String,
notify_handler: Arc<Notify>,
) -> SubscriptionId
where
T: Serialize + Debug + Send + Sync,
Expand All @@ -811,14 +884,30 @@ where
tokio::spawn(async move {
pin_mut!(stream);

while let Some(item) = stream.next().await {
if let Err(e) = forward_notification(item, &mut tx, &id, &method).await {
// Break the loop when the channel is closed
if let Error::Mpsc(_) = e {
let notify_future = notify_handler.notified();
pin_mut!(notify_future);

loop {
tokio::select! {
item = stream.next() => {
match item {
Some(notification) => {
if let Err(error) = forward_notification(notification, &mut tx, &id, &method).await {
// Break the loop when the channel is closed
if let Error::Mpsc(_) = error {
break;
}

log::error!("{}", error);
}
},
None => break,
}
}
_ = &mut notify_future => {
// Break the loop when an unsubscribe notification is received
break;
}

log::error!("{}", e);
}
}
});
Expand Down
Loading