diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 5dae988..2e5d647 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -149,10 +149,12 @@ impl<'a> RpcMethod<'a> { request, move |params: #args_struct_ident| async move { let stream = self.#method_ident(#(#method_args),*).await?; + let notifier = ::std::sync::Arc::new(::nimiq_jsonrpc_server::Notify::new()); + let listener = notifier.clone(); - let subscription = ::nimiq_jsonrpc_server::connect_stream(stream, tx, stream_id, #method_name.to_owned()); + let subscription = ::nimiq_jsonrpc_server::connect_stream(stream, tx, stream_id, #method_name.to_owned(), listener); - Ok::<_, ::nimiq_jsonrpc_core::RpcError>(subscription) + Ok::<_, ::nimiq_jsonrpc_core::RpcError>((subscription, Some(notifier))) } ).await } @@ -171,7 +173,7 @@ impl<'a> RpcMethod<'a> { return ::nimiq_jsonrpc_server::dispatch_method_with_args( request, move |params: #args_struct_ident| async move { - Ok::<_, ::nimiq_jsonrpc_core::RpcError>(self.#method_ident(#(#method_args),*).await?) + Ok::<(_, Option<::std::sync::Arc<::nimiq_jsonrpc_server::Notify>>), ::nimiq_jsonrpc_core::RpcError>((self.#method_ident(#(#method_args),*).await?, None)) } ).await } diff --git a/derive/src/service.rs b/derive/src/service.rs index 1815c29..3dfb8e1 100644 --- a/derive/src/service.rs +++ b/derive/src/service.rs @@ -92,7 +92,7 @@ fn impl_service(im: &mut ItemImpl, args: &ServiceMeta) -> TokenStream { request: ::nimiq_jsonrpc_core::Request, tx: Option<&::tokio::sync::mpsc::Sender<::nimiq_jsonrpc_server::Message>>, stream_id: u64, - ) -> Option<::nimiq_jsonrpc_core::Response> { + ) -> Option<::nimiq_jsonrpc_server::ResponseAndSubScriptionNotifier> { match request.method.as_str() { #(#match_arms)* _ => ::nimiq_jsonrpc_server::method_not_found(request), diff --git a/server/src/lib.rs b/server/src/lib.rs index dd8bad5..60e1ffe 100644 --- a/server/src/lib.rs +++ b/server/src/lib.rs @@ -5,7 +5,7 @@ #![warn(rustdoc::missing_doc_code_examples)] use std::{ - collections::HashSet, + collections::{HashMap, HashSet}, error, fmt::{self, Debug}, future::Future, @@ -51,8 +51,12 @@ use nimiq_jsonrpc_core::{ }; pub use axum::extract::ws::Message; +pub use tokio::sync::Notify; use tower_http::cors::{Any, CorsLayer}; +/// Type defining a response and a possible notify handle used to terminate a subscription stream +pub type ResponseAndSubScriptionNotifier = (Response, Option>); + /// A server error. #[derive(Debug, Error)] pub enum Error { @@ -245,6 +249,7 @@ struct Inner { config: Config, dispatcher: RwLock, next_id: AtomicU64, + subscription_notifiers: RwLock>>, } /// A JSON-RPC server. @@ -266,6 +271,7 @@ impl Server { config, dispatcher: RwLock::new(dispatcher), next_id: AtomicU64::new(1), + subscription_notifiers: RwLock::new(HashMap::new()), }), } } @@ -450,7 +456,7 @@ impl Server { match request { SingleOrBatch::Single(request) => Self::handle_single_request(inner, request, tx) .await - .map(SingleOrBatch::Single), + .map(|(response, _)| SingleOrBatch::Single(response)), SingleOrBatch::Batch(requests) => { let futures = requests @@ -459,7 +465,7 @@ impl Server { .collect::>(); let responses = futures - .filter_map(|response_opt| async { response_opt }) + .filter_map(|response_opt| async { response_opt.map(|(response, _)| response) }) .collect::>() .await; @@ -469,15 +475,15 @@ impl Server { } /// Handles a single JSON RPC request - /// - /// # TODO - /// - /// - Handle subscriptions async fn handle_single_request( inner: Arc>, request: Request, tx: Option<&mpsc::Sender>, - ) -> Option { + ) -> Option { + if request.method == "unsubscribe" { + return Self::handle_unsubscribe_stream(request, inner).await; + } + let mut dispatcher = inner.dispatcher.write().await; // This ID is only used for streams let id = inner.next_id.fetch_add(1, Ordering::SeqCst); @@ -488,8 +494,65 @@ impl Server { log::debug!("response: {:#?}", response); + if let Some((_, Some(ref handler))) = response { + inner + .subscription_notifiers + .write() + .await + .insert(SubscriptionId::Number(id), handler.clone()); + } + response } + + async fn handle_unsubscribe_stream( + request: Request, + inner: Arc>, + ) -> Option { + let params = if let Some(params) = request.params { + params + } else { + return error_response(request.id, || { + RpcError::invalid_request(Some( + "Missing request parameter containing a list of subscription ids".to_owned(), + )) + }); + }; + + let subscription_ids = + if let Ok(ids) = serde_json::from_value::>(params) { + ids + } else { + return error_response(request.id, || { + RpcError::invalid_params(Some( + "A list of subscription ids is not provided".to_owned(), + )) + }); + }; + + if subscription_ids.is_empty() { + return error_response(request.id, || { + RpcError::invalid_params(Some("Empty list of subscription ids provided".to_owned())) + }); + } + + let mut terminated_streams = vec![]; + let mut subscription_notifiers = inner.subscription_notifiers.write().await; + for id in subscription_ids.iter() { + if let Some(notifier) = subscription_notifiers.remove(id) { + notifier.notify_one(); + terminated_streams.push(id); + } + } + + Some(( + Response::new_success( + serde_json::to_value(request.id.unwrap_or_default()).unwrap(), + serde_json::to_value(terminated_streams).unwrap(), + ), + None, + )) + } } /// A method dispatcher. These take a request and handle the method execution. Can be generated from an `impl` block @@ -502,7 +565,7 @@ pub trait Dispatcher: Send + Sync + 'static { request: Request, tx: Option<&mpsc::Sender>, id: u64, - ) -> Option; + ) -> Option; /// Returns whether a method should be dispatched with this dispatcher. /// @@ -542,7 +605,7 @@ impl Dispatcher for ModularDispatcher { request: Request, tx: Option<&mpsc::Sender>, id: u64, - ) -> Option { + ) -> Option { for dispatcher in &mut self.dispatchers { let m = dispatcher.match_method(&request.method); log::debug!("Matching '{}' against dispatcher -> {}", request.method, m); @@ -611,7 +674,7 @@ where request: Request, tx: Option<&mpsc::Sender>, id: u64, - ) -> Option { + ) -> Option { if self.is_allowed(&request.method) { log::debug!("Dispatching method: {}", request.method); self.inner.dispatch(request, tx, id).await @@ -649,13 +712,16 @@ where /// - Currently this always expects an object with named parameters. Do we want to accept a list too? /// - Merge with it's other variant, as a function call without arguments is just one with `()` as request parameter. /// -pub async fn dispatch_method_with_args(request: Request, f: F) -> Option +pub async fn dispatch_method_with_args( + request: Request, + f: F, +) -> Option where P: for<'de> Deserialize<'de> + Send, R: Serialize, RpcError: From, F: FnOnce(P) -> Fut + Send, - Fut: Future> + Send, + Fut: Future>), E>> + Send, { let params = match request.params { Some(params) => params, @@ -683,12 +749,15 @@ where /// /// This is a helper function used by implementations of `Dispatcher`. /// -pub async fn dispatch_method_without_args(request: Request, f: F) -> Option +pub async fn dispatch_method_without_args( + request: Request, + f: F, +) -> Option where R: Serialize, RpcError: From, F: FnOnce() -> Fut + Send, - Fut: Future> + Send, + Fut: Future>), E>> + Send, { let result = f().await; @@ -707,17 +776,20 @@ where } /// Constructs a [`Response`] if necessary (i.e., if the request ID was set). -fn response(id_opt: Option, result: Result) -> Option +fn response( + id_opt: Option, + result: Result<(R, Option>), E>, +) -> Option where R: Serialize, RpcError: From, { let response = match (id_opt, result) { - (Some(id), Ok(retval)) => { - let retval = serde_json::to_value(retval).expect("Failed to serialize return value"); - Some(Response::new_success(id, retval)) + (Some(id), Ok((value, subscription))) => { + let retval = serde_json::to_value(value).expect("Failed to serialize return value"); + Some((Response::new_success(id, retval), subscription)) } - (Some(id), Err(e)) => Some(Response::new_error(id, RpcError::from(e))), + (Some(id), Err(e)) => Some((Response::new_error(id, RpcError::from(e)), None)), (None, _) => None, }; @@ -733,14 +805,14 @@ where /// - `id_opt`: The ID field from the request. /// - `e`: A function that returns the error. This is only called, if we actually can respond with an error. /// -pub fn error_response(id_opt: Option, e: E) -> Option +pub fn error_response(id_opt: Option, e: E) -> Option where E: FnOnce() -> RpcError, { if let Some(id) = id_opt { let e = e(); log::error!("Error response: {:?}", e); - Some(Response::new_error(id, e)) + Some((Response::new_error(id, e), None)) } else { None } @@ -748,7 +820,7 @@ where /// Returns an error response for a method that was not found. This returns `None`, if the request doesn't expect a /// response. -pub fn method_not_found(request: Request) -> Option { +pub fn method_not_found(request: Request) -> Option { let ::nimiq_jsonrpc_core::Request { id, method, .. } = request; error_response(id, || { @@ -798,6 +870,7 @@ pub fn connect_stream( tx: &mpsc::Sender, stream_id: u64, method: String, + notify_handler: Arc, ) -> SubscriptionId where T: Serialize + Debug + Send + Sync, @@ -811,14 +884,30 @@ where tokio::spawn(async move { pin_mut!(stream); - while let Some(item) = stream.next().await { - if let Err(e) = forward_notification(item, &mut tx, &id, &method).await { - // Break the loop when the channel is closed - if let Error::Mpsc(_) = e { + let notify_future = notify_handler.notified(); + pin_mut!(notify_future); + + loop { + tokio::select! { + item = stream.next() => { + match item { + Some(notification) => { + if let Err(error) = forward_notification(notification, &mut tx, &id, &method).await { + // Break the loop when the channel is closed + if let Error::Mpsc(_) = error { + break; + } + + log::error!("{}", error); + } + }, + None => break, + } + } + _ = &mut notify_future => { + // Break the loop when an unsubscribe notification is received break; } - - log::error!("{}", e); } } });