From d0439d1cb95674fae19a746e4ba7140e5b6bdbba Mon Sep 17 00:00:00 2001
From: Stefan <stefan@nimiq.com>
Date: Tue, 21 Jan 2025 11:05:11 +0100
Subject: [PATCH] Add support for external termination notifications for
 streams

By adding a specialized unsubsribe request handler that lets the server terminates specifc streams
---
 derive/src/lib.rs     |   8 ++-
 derive/src/service.rs |   2 +-
 server/src/lib.rs     | 147 +++++++++++++++++++++++++++++++++---------
 3 files changed, 124 insertions(+), 33 deletions(-)

diff --git a/derive/src/lib.rs b/derive/src/lib.rs
index 5dae988..2e5d647 100644
--- a/derive/src/lib.rs
+++ b/derive/src/lib.rs
@@ -149,10 +149,12 @@ impl<'a> RpcMethod<'a> {
                             request,
                             move |params: #args_struct_ident| async move {
                                 let stream = self.#method_ident(#(#method_args),*).await?;
+                                let notifier = ::std::sync::Arc::new(::nimiq_jsonrpc_server::Notify::new());
+                                let listener = notifier.clone();
 
-                                let subscription = ::nimiq_jsonrpc_server::connect_stream(stream, tx, stream_id, #method_name.to_owned());
+                                let subscription = ::nimiq_jsonrpc_server::connect_stream(stream, tx, stream_id, #method_name.to_owned(), listener);
 
-                                Ok::<_, ::nimiq_jsonrpc_core::RpcError>(subscription)
+                                Ok::<_, ::nimiq_jsonrpc_core::RpcError>((subscription, Some(notifier)))
                             }
                         ).await
                     }
@@ -171,7 +173,7 @@ impl<'a> RpcMethod<'a> {
                     return ::nimiq_jsonrpc_server::dispatch_method_with_args(
                         request,
                         move |params: #args_struct_ident| async move {
-                            Ok::<_, ::nimiq_jsonrpc_core::RpcError>(self.#method_ident(#(#method_args),*).await?)
+                            Ok::<(_, Option<::std::sync::Arc<::nimiq_jsonrpc_server::Notify>>), ::nimiq_jsonrpc_core::RpcError>((self.#method_ident(#(#method_args),*).await?, None))
                         }
                     ).await
                 }
diff --git a/derive/src/service.rs b/derive/src/service.rs
index 1815c29..3dfb8e1 100644
--- a/derive/src/service.rs
+++ b/derive/src/service.rs
@@ -92,7 +92,7 @@ fn impl_service(im: &mut ItemImpl, args: &ServiceMeta) -> TokenStream {
                 request: ::nimiq_jsonrpc_core::Request,
                 tx: Option<&::tokio::sync::mpsc::Sender<::nimiq_jsonrpc_server::Message>>,
                 stream_id: u64,
-            ) -> Option<::nimiq_jsonrpc_core::Response> {
+            ) -> Option<::nimiq_jsonrpc_server::ResponseAndSubScriptionNotifier> {
                 match request.method.as_str() {
                     #(#match_arms)*
                     _ => ::nimiq_jsonrpc_server::method_not_found(request),
diff --git a/server/src/lib.rs b/server/src/lib.rs
index dd8bad5..60e1ffe 100644
--- a/server/src/lib.rs
+++ b/server/src/lib.rs
@@ -5,7 +5,7 @@
 #![warn(rustdoc::missing_doc_code_examples)]
 
 use std::{
-    collections::HashSet,
+    collections::{HashMap, HashSet},
     error,
     fmt::{self, Debug},
     future::Future,
@@ -51,8 +51,12 @@ use nimiq_jsonrpc_core::{
 };
 
 pub use axum::extract::ws::Message;
+pub use tokio::sync::Notify;
 use tower_http::cors::{Any, CorsLayer};
 
+/// Type defining a response and a possible notify handle used to terminate a subscription stream
+pub type ResponseAndSubScriptionNotifier = (Response, Option<Arc<Notify>>);
+
 /// A server error.
 #[derive(Debug, Error)]
 pub enum Error {
@@ -245,6 +249,7 @@ struct Inner<D: Dispatcher> {
     config: Config,
     dispatcher: RwLock<D>,
     next_id: AtomicU64,
+    subscription_notifiers: RwLock<HashMap<SubscriptionId, Arc<Notify>>>,
 }
 
 /// A JSON-RPC server.
@@ -266,6 +271,7 @@ impl<D: Dispatcher> Server<D> {
                 config,
                 dispatcher: RwLock::new(dispatcher),
                 next_id: AtomicU64::new(1),
+                subscription_notifiers: RwLock::new(HashMap::new()),
             }),
         }
     }
@@ -450,7 +456,7 @@ impl<D: Dispatcher> Server<D> {
         match request {
             SingleOrBatch::Single(request) => Self::handle_single_request(inner, request, tx)
                 .await
-                .map(SingleOrBatch::Single),
+                .map(|(response, _)| SingleOrBatch::Single(response)),
 
             SingleOrBatch::Batch(requests) => {
                 let futures = requests
@@ -459,7 +465,7 @@ impl<D: Dispatcher> Server<D> {
                     .collect::<FuturesUnordered<_>>();
 
                 let responses = futures
-                    .filter_map(|response_opt| async { response_opt })
+                    .filter_map(|response_opt| async { response_opt.map(|(response, _)| response) })
                     .collect::<Vec<Response>>()
                     .await;
 
@@ -469,15 +475,15 @@ impl<D: Dispatcher> Server<D> {
     }
 
     /// Handles a single JSON RPC request
-    ///
-    /// # TODO
-    ///
-    /// - Handle subscriptions
     async fn handle_single_request(
         inner: Arc<Inner<D>>,
         request: Request,
         tx: Option<&mpsc::Sender<Message>>,
-    ) -> Option<Response> {
+    ) -> Option<ResponseAndSubScriptionNotifier> {
+        if request.method == "unsubscribe" {
+            return Self::handle_unsubscribe_stream(request, inner).await;
+        }
+
         let mut dispatcher = inner.dispatcher.write().await;
         // This ID is only used for streams
         let id = inner.next_id.fetch_add(1, Ordering::SeqCst);
@@ -488,8 +494,65 @@ impl<D: Dispatcher> Server<D> {
 
         log::debug!("response: {:#?}", response);
 
+        if let Some((_, Some(ref handler))) = response {
+            inner
+                .subscription_notifiers
+                .write()
+                .await
+                .insert(SubscriptionId::Number(id), handler.clone());
+        }
+
         response
     }
+
+    async fn handle_unsubscribe_stream(
+        request: Request,
+        inner: Arc<Inner<D>>,
+    ) -> Option<ResponseAndSubScriptionNotifier> {
+        let params = if let Some(params) = request.params {
+            params
+        } else {
+            return error_response(request.id, || {
+                RpcError::invalid_request(Some(
+                    "Missing request parameter containing a list of subscription ids".to_owned(),
+                ))
+            });
+        };
+
+        let subscription_ids =
+            if let Ok(ids) = serde_json::from_value::<Vec<SubscriptionId>>(params) {
+                ids
+            } else {
+                return error_response(request.id, || {
+                    RpcError::invalid_params(Some(
+                        "A list of subscription ids is not provided".to_owned(),
+                    ))
+                });
+            };
+
+        if subscription_ids.is_empty() {
+            return error_response(request.id, || {
+                RpcError::invalid_params(Some("Empty list of subscription ids provided".to_owned()))
+            });
+        }
+
+        let mut terminated_streams = vec![];
+        let mut subscription_notifiers = inner.subscription_notifiers.write().await;
+        for id in subscription_ids.iter() {
+            if let Some(notifier) = subscription_notifiers.remove(id) {
+                notifier.notify_one();
+                terminated_streams.push(id);
+            }
+        }
+
+        Some((
+            Response::new_success(
+                serde_json::to_value(request.id.unwrap_or_default()).unwrap(),
+                serde_json::to_value(terminated_streams).unwrap(),
+            ),
+            None,
+        ))
+    }
 }
 
 /// A method dispatcher. These take a request and handle the method execution. Can be generated from an `impl` block
@@ -502,7 +565,7 @@ pub trait Dispatcher: Send + Sync + 'static {
         request: Request,
         tx: Option<&mpsc::Sender<Message>>,
         id: u64,
-    ) -> Option<Response>;
+    ) -> Option<ResponseAndSubScriptionNotifier>;
 
     /// Returns whether a method should be dispatched with this dispatcher.
     ///
@@ -542,7 +605,7 @@ impl Dispatcher for ModularDispatcher {
         request: Request,
         tx: Option<&mpsc::Sender<Message>>,
         id: u64,
-    ) -> Option<Response> {
+    ) -> Option<ResponseAndSubScriptionNotifier> {
         for dispatcher in &mut self.dispatchers {
             let m = dispatcher.match_method(&request.method);
             log::debug!("Matching '{}' against dispatcher -> {}", request.method, m);
@@ -611,7 +674,7 @@ where
         request: Request,
         tx: Option<&mpsc::Sender<Message>>,
         id: u64,
-    ) -> Option<Response> {
+    ) -> Option<ResponseAndSubScriptionNotifier> {
         if self.is_allowed(&request.method) {
             log::debug!("Dispatching method: {}", request.method);
             self.inner.dispatch(request, tx, id).await
@@ -649,13 +712,16 @@ where
 ///  - Currently this always expects an object with named parameters. Do we want to accept a list too?
 ///  - Merge with it's other variant, as a function call without arguments is just one with `()` as request parameter.
 ///
-pub async fn dispatch_method_with_args<P, R, E, F, Fut>(request: Request, f: F) -> Option<Response>
+pub async fn dispatch_method_with_args<P, R, E, F, Fut>(
+    request: Request,
+    f: F,
+) -> Option<ResponseAndSubScriptionNotifier>
 where
     P: for<'de> Deserialize<'de> + Send,
     R: Serialize,
     RpcError: From<E>,
     F: FnOnce(P) -> Fut + Send,
-    Fut: Future<Output = Result<R, E>> + Send,
+    Fut: Future<Output = Result<(R, Option<Arc<Notify>>), E>> + Send,
 {
     let params = match request.params {
         Some(params) => params,
@@ -683,12 +749,15 @@ where
 ///
 /// This is a helper function used by implementations of `Dispatcher`.
 ///
-pub async fn dispatch_method_without_args<R, E, F, Fut>(request: Request, f: F) -> Option<Response>
+pub async fn dispatch_method_without_args<R, E, F, Fut>(
+    request: Request,
+    f: F,
+) -> Option<ResponseAndSubScriptionNotifier>
 where
     R: Serialize,
     RpcError: From<E>,
     F: FnOnce() -> Fut + Send,
-    Fut: Future<Output = Result<R, E>> + Send,
+    Fut: Future<Output = Result<(R, Option<Arc<Notify>>), E>> + Send,
 {
     let result = f().await;
 
@@ -707,17 +776,20 @@ where
 }
 
 /// Constructs a [`Response`] if necessary (i.e., if the request ID was set).
-fn response<R, E>(id_opt: Option<Value>, result: Result<R, E>) -> Option<Response>
+fn response<R, E>(
+    id_opt: Option<Value>,
+    result: Result<(R, Option<Arc<Notify>>), E>,
+) -> Option<ResponseAndSubScriptionNotifier>
 where
     R: Serialize,
     RpcError: From<E>,
 {
     let response = match (id_opt, result) {
-        (Some(id), Ok(retval)) => {
-            let retval = serde_json::to_value(retval).expect("Failed to serialize return value");
-            Some(Response::new_success(id, retval))
+        (Some(id), Ok((value, subscription))) => {
+            let retval = serde_json::to_value(value).expect("Failed to serialize return value");
+            Some((Response::new_success(id, retval), subscription))
         }
-        (Some(id), Err(e)) => Some(Response::new_error(id, RpcError::from(e))),
+        (Some(id), Err(e)) => Some((Response::new_error(id, RpcError::from(e)), None)),
         (None, _) => None,
     };
 
@@ -733,14 +805,14 @@ where
 ///  - `id_opt`: The ID field from the request.
 ///  - `e`: A function that returns the error. This is only called, if we actually can respond with an error.
 ///
-pub fn error_response<E>(id_opt: Option<Value>, e: E) -> Option<Response>
+pub fn error_response<E>(id_opt: Option<Value>, e: E) -> Option<ResponseAndSubScriptionNotifier>
 where
     E: FnOnce() -> RpcError,
 {
     if let Some(id) = id_opt {
         let e = e();
         log::error!("Error response: {:?}", e);
-        Some(Response::new_error(id, e))
+        Some((Response::new_error(id, e), None))
     } else {
         None
     }
@@ -748,7 +820,7 @@ where
 
 /// Returns an error response for a method that was not found. This returns `None`, if the request doesn't expect a
 /// response.
-pub fn method_not_found(request: Request) -> Option<Response> {
+pub fn method_not_found(request: Request) -> Option<ResponseAndSubScriptionNotifier> {
     let ::nimiq_jsonrpc_core::Request { id, method, .. } = request;
 
     error_response(id, || {
@@ -798,6 +870,7 @@ pub fn connect_stream<T, S>(
     tx: &mpsc::Sender<Message>,
     stream_id: u64,
     method: String,
+    notify_handler: Arc<Notify>,
 ) -> SubscriptionId
 where
     T: Serialize + Debug + Send + Sync,
@@ -811,14 +884,30 @@ where
         tokio::spawn(async move {
             pin_mut!(stream);
 
-            while let Some(item) = stream.next().await {
-                if let Err(e) = forward_notification(item, &mut tx, &id, &method).await {
-                    // Break the loop when the channel is closed
-                    if let Error::Mpsc(_) = e {
+            let notify_future = notify_handler.notified();
+            pin_mut!(notify_future);
+
+            loop {
+                tokio::select! {
+                    item = stream.next() => {
+                        match item {
+                            Some(notification) => {
+                                if let Err(error) = forward_notification(notification, &mut tx, &id, &method).await {
+                                    // Break the loop when the channel is closed
+                                    if let Error::Mpsc(_) = error {
+                                        break;
+                                    }
+
+                                    log::error!("{}", error);
+                                }
+                            },
+                            None => break,
+                        }
+                    }
+                    _ = &mut notify_future => {
+                        // Break the loop when an unsubscribe notification is received
                         break;
                     }
-
-                    log::error!("{}", e);
                 }
             }
         });