diff --git a/src/client.rs b/src/client.rs index 3369c1e..48255db 100644 --- a/src/client.rs +++ b/src/client.rs @@ -12,8 +12,8 @@ use bevy::prelude::*; use bytes::Bytes; use futures::sink::SinkExt; use futures_util::StreamExt; -use quinn::{ClientConfig, Connection as QuinnConnection, Endpoint}; -use quinn_proto::ConnectionStats; +use quinn::{ClientConfig, Connection as QuinnConnection, Endpoint, SendStream}; +use quinn_proto::{ConnectionStats, VarInt}; use serde::Deserialize; use tokio::{ runtime::{self}, @@ -105,8 +105,9 @@ type InternalConnectionRef = QuinnConnection; /// Current state of a client connection #[derive(Debug)] enum ConnectionState { - Disconnected, + Connecting, Connected(InternalConnectionRef), + Disconnected, } #[derive(Debug)] @@ -171,9 +172,12 @@ impl Connection { } pub fn send_message(&self, message: T) -> Result<(), QuinnetError> { - match bincode::serialize(&message) { - Ok(payload) => self.send_payload(payload), - Err(_) => Err(QuinnetError::Serialization), + match &self.state { + ConnectionState::Disconnected => Err(QuinnetError::ConnectionClosed), + _ => match bincode::serialize(&message) { + Ok(payload) => self.send_payload(payload), + Err(_) => Err(QuinnetError::Serialization), + }, } } @@ -186,11 +190,14 @@ impl Connection { } pub fn send_payload>(&self, payload: T) -> Result<(), QuinnetError> { - match self.sender.try_send(payload.into()) { - Ok(_) => Ok(()), - Err(err) => match err { - TrySendError::Full(_) => Err(QuinnetError::FullQueue), - TrySendError::Closed(_) => Err(QuinnetError::ChannelClosed), + match &self.state { + ConnectionState::Disconnected => Err(QuinnetError::ConnectionClosed), + _ => match self.sender.try_send(payload.into()) { + Ok(_) => Ok(()), + Err(err) => match err { + TrySendError::Full(_) => Err(QuinnetError::FullQueue), + TrySendError::Closed(_) => Err(QuinnetError::ChannelClosed), + }, }, } } @@ -204,11 +211,14 @@ impl Connection { } pub fn receive_payload(&mut self) -> Result, QuinnetError> { - match self.receiver.try_recv() { - Ok(msg_payload) => Ok(Some(msg_payload)), - Err(err) => match err { - TryRecvError::Empty => Ok(None), - TryRecvError::Disconnected => Err(QuinnetError::ChannelClosed), + match &self.state { + ConnectionState::Disconnected => Err(QuinnetError::ConnectionClosed), + _ => match self.receiver.try_recv() { + Ok(msg_payload) => Ok(Some(msg_payload)), + Err(err) => match err { + TryRecvError::Empty => Ok(None), + TryRecvError::Disconnected => Err(QuinnetError::ChannelClosed), + }, }, } } @@ -226,27 +236,30 @@ impl Connection { /// Disconnect from the server on this connection. This does not send any message to the server, and simply closes all the connection's tasks locally. fn disconnect(&mut self) -> Result<(), QuinnetError> { - if self.is_connected() { - if let Err(_) = self.close_sender.send(()) { - return Err(QuinnetError::ChannelClosed); + match &self.state { + ConnectionState::Disconnected => Ok(()), + _ => { + self.state = ConnectionState::Disconnected; + match self.close_sender.send(()) { + Ok(_) => Ok(()), + Err(_) => Err(QuinnetError::ChannelClosed), + } } } - self.state = ConnectionState::Disconnected; - Ok(()) } pub fn is_connected(&self) -> bool { match self.state { - ConnectionState::Disconnected => false, ConnectionState::Connected(_) => true, + _ => false, } } /// Returns statistics about the current connection if connected. pub fn stats(&self) -> Option { match &self.state { - ConnectionState::Disconnected => None, ConnectionState::Connected(connection) => Some(connection.stats()), + _ => None, } } } @@ -331,7 +344,7 @@ impl Client { ) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE); let connection = Connection { - state: ConnectionState::Disconnected, + state: ConnectionState::Connecting, sender: to_server_sender, receiver: from_server_receiver, close_sender: close_sender.clone(), @@ -474,57 +487,100 @@ async fn connection_task(mut spawn_config: ConnectionSpawnConfig) { .expect("Failed to open send stream"); let mut frame_send = FramedWrite::new(send, LengthDelimitedCodec::new()); - let close_sender_clone = spawn_config.close_sender.clone(); - let _network_sends = tokio::spawn(async move { - tokio::select! { - _ = spawn_config.close_receiver.recv() => { - trace!("Unidirectional send Stream forced to disconnected") - } - _ = async { - while let Some(msg_bytes) = spawn_config.to_server_receiver.recv().await { - if let Err(err) = frame_send.send(msg_bytes).await { - error!("Error while sending, {}", err); // TODO Clean: error handling - error!("Client seems disconnected, closing resources"); - if let Err(_) = close_sender_clone.send(()) { - error!("Failed to close all client streams & resources") - } - spawn_config.to_sync_client.send( - InternalAsyncMessage::LostConnection) - .await - .expect("Failed to signal connection lost to sync client"); + let _close_waiter = { + let conn = connection.clone(); + let to_sync_client = spawn_config.to_sync_client.clone(); + let close_sender = spawn_config.close_sender.clone(); + tokio::spawn(async move { + let conn_err = conn.closed().await; + info!("Disconnected: {}", conn_err); + close_sender.send(()).ok(); + to_sync_client + .send(InternalAsyncMessage::LostConnection) + .await + .expect("Failed to signal connection lost to sync client"); + }) + }; + + let _network_sends = { + let close_sender_clone = spawn_config.close_sender.clone(); + let conn = connection.clone(); + tokio::spawn(async move { + tokio::select! { + _ = spawn_config.close_receiver.recv() => { + trace!("Unidirectional send Stream forced to disconnected") + } + _ = async { + while let Some(msg_bytes) = spawn_config.to_server_receiver.recv().await { + send_msg(&close_sender_clone, &spawn_config.to_sync_client, &mut frame_send, msg_bytes).await; } + } => { + trace!("Unidirectional send Stream ended") } - } => { - trace!("Unidirectional send Stream ended") } - } - }); - - let mut uni_receivers: JoinSet<()> = JoinSet::new(); - let mut close_receiver = spawn_config.close_sender.subscribe(); - let _network_reads = tokio::spawn(async move { - tokio::select! { - _ = close_receiver.recv() => { - trace!("New Stream listener forced to disconnected") + while let Ok(msg_bytes) = spawn_config.to_server_receiver.try_recv() { + if let Err(err) = frame_send.send(msg_bytes).await { + error!("Error while sending, {}", err); + } } - _ = async { - while let Ok(recv)= connection.accept_uni().await { - let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new()); - let from_server_sender = spawn_config.from_server_sender.clone(); - - uni_receivers.spawn(async move { - while let Some(Ok(msg_bytes)) = frame_recv.next().await { - from_server_sender.send(msg_bytes.into()).await.unwrap(); // TODO Clean: error handling - } - }); + if let Err(err) = frame_send.flush().await { + error!("Error while flushing stream: {}", err); + } + if let Err(err) = frame_send.into_inner().finish().await { + error!("Failed to shutdown stream gracefully: {}", err); + } + conn.close(VarInt::from_u32(0), "closed".as_bytes()); + }) + }; + + let _network_reads = { + let mut uni_receivers: JoinSet<()> = JoinSet::new(); + let mut close_receiver = spawn_config.close_sender.subscribe(); + tokio::spawn(async move { + tokio::select! { + _ = close_receiver.recv() => { + trace!("New Stream listener forced to disconnected") + } + _ = async { + while let Ok(recv)= connection.accept_uni().await { + let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new()); + let from_server_sender = spawn_config.from_server_sender.clone(); + + uni_receivers.spawn(async move { + while let Some(Ok(msg_bytes)) = frame_recv.next().await { + from_server_sender.send(msg_bytes.into()).await.unwrap(); // TODO Clean: error handling + } + }); + } + } => { + trace!("New Stream listener ended ") } - } => { - trace!("New Stream listener ended ") } - } - uni_receivers.shutdown().await; - trace!("All unidirectional stream receivers cleaned"); - }); + uni_receivers.shutdown().await; + trace!("All unidirectional stream receivers cleaned"); + }) + }; + } + } +} + +async fn send_msg( + close_sender: &tokio::sync::broadcast::Sender<()>, + to_sync_client: &mpsc::Sender, + frame_send: &mut FramedWrite, + msg_bytes: Bytes, +) { + if let Err(err) = frame_send.send(msg_bytes).await { + error!("Error while sending, {}", err); + error!("Client seems disconnected, closing resources"); + // Emit LostConnection to properly update the connection about its state. + // Raise LostConnection event before emitting a close signal because we have no guarantee to continue this async execution after the close signal has been processed. + to_sync_client + .send(InternalAsyncMessage::LostConnection) + .await + .expect("Failed to signal connection lost to sync client"); + if let Err(_) = close_sender.send(()) { + error!("Failed to close all client streams & resources") } } } @@ -545,10 +601,13 @@ fn update_sync_client( connection.state = ConnectionState::Connected(internal_connection); connection_events.send(ConnectionEvent { id: *connection_id }); } - InternalAsyncMessage::LostConnection => { - connection.state = ConnectionState::Disconnected; - connection_lost_events.send(ConnectionLostEvent { id: *connection_id }); - } + InternalAsyncMessage::LostConnection => match connection.state { + ConnectionState::Disconnected => (), + _ => { + connection.state = ConnectionState::Disconnected; + connection_lost_events.send(ConnectionLostEvent { id: *connection_id }); + } + }, InternalAsyncMessage::CertificateInteractionRequest { status, info, diff --git a/src/server.rs b/src/server.rs index bde4419..f89ad58 100644 --- a/src/server.rs +++ b/src/server.rs @@ -4,7 +4,8 @@ use bevy::prelude::*; use bytes::Bytes; use futures::sink::SinkExt; use futures_util::StreamExt; -use quinn::{Endpoint as QuinnEndpoint, ServerConfig}; +use quinn::{Endpoint as QuinnEndpoint, SendStream, ServerConfig}; +use quinn_proto::VarInt; use serde::Deserialize; use tokio::{ runtime, @@ -451,20 +452,22 @@ async fn handle_client_connection( // Create an ordered reliable send channel for this client let (to_client_sender, to_client_receiver) = mpsc::channel::(DEFAULT_MESSAGE_QUEUE_SIZE); - let to_sync_server_clone = to_sync_server.clone(); - let close_sender_clone = client_close_sender.clone(); - let connection_clone = connection.clone(); - tokio::spawn(async move { - client_sender_task( - client_id, - connection_clone, - to_client_receiver, - client_close_receiver, - close_sender_clone, - to_sync_server_clone, - ) - .await - }); + let _client_sender = { + let to_sync_server_clone = to_sync_server.clone(); + let close_sender_clone = client_close_sender.clone(); + let connection_clone = connection.clone(); + tokio::spawn(async move { + client_sender_task( + client_id, + connection_clone, + to_client_receiver, + client_close_receiver, + close_sender_clone, + to_sync_server_clone, + ) + .await + }) + }; // Signal the sync server of this new connection to_sync_server @@ -483,6 +486,21 @@ async fn handle_client_connection( } } + let _client_close_wait = { + let conn = connection.clone(); + let close_sender = client_close_sender.clone(); + let to_sync_server = to_sync_server.clone(); + tokio::spawn(async move { + let conn_err = conn.closed().await; + info!("Client {} disconnected: {}", client_id, conn_err); + close_sender.send(()).ok(); + to_sync_server + .send(InternalAsyncMessage::ClientLostConnection(client_id)) + .await + .expect("Failed to signal connection lost to sync server"); + }); + }; + // Spawn a task to listen for streams opened by this client let _client_receiver = tokio::spawn(async move { client_receiver_task( @@ -519,22 +537,61 @@ async fn client_sender_task( } _ = async { while let Some(msg_bytes) = to_client_receiver.recv().await { - // TODO Perf: Batch frames for a send_all - // TODO Clean: Error handling - if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await { - error!("Error while sending to client {}: {}", client_id, err); - error!("Client {} seems disconnected, closing resources", client_id); - if let Err(_) = close_sender.send(()) { - error!("Failed to close all client streams & resources for client {}", client_id) - } - to_sync_server.send( - InternalAsyncMessage::ClientLostConnection(client_id)) - .await - .expect("Failed to signal connection lost to sync server"); - }; + send_msg( + client_id, + &close_sender, + &to_sync_server, + &mut framed_send_stream, + msg_bytes, + ) + .await } } => {} } + while let Ok(msg_bytes) = to_client_receiver.try_recv() { + if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await { + error!("Error while sending to client {}: {}", client_id, err); + }; + } + if let Err(err) = framed_send_stream.flush().await { + error!( + "Error while flushing stream to client {}: {}", + client_id, err + ); + } + if let Err(err) = framed_send_stream.into_inner().finish().await { + error!( + "Failed to shutdown stream gracefully for client {}: {}", + client_id, err + ); + } + connection.close(VarInt::from_u32(0), "closed".as_bytes()); +} + +async fn send_msg( + client_id: ClientId, + close_sender: &tokio::sync::broadcast::Sender<()>, + to_sync_server: &mpsc::Sender, + framed_send_stream: &mut FramedWrite, + msg_bytes: Bytes, +) { + // TODO Perf: Batch frames for a send_all + if let Err(err) = framed_send_stream.send(msg_bytes.clone()).await { + error!("Error while sending to client {}: {}", client_id, err); + error!("Client {} seems disconnected, closing resources", client_id); + // Emit ClientLostConnection to properly update the server about this client state. + // Raise ClientLostConnection event before emitting a close signal because we have no guarantee to continue this async execution after the close signal has been processed. + to_sync_server + .send(InternalAsyncMessage::ClientLostConnection(client_id)) + .await + .expect("Failed to signal connection lost to sync server"); + if let Err(_) = close_sender.send(()) { + error!( + "Failed to close all client streams & resources for client {}", + client_id + ) + } + }; } async fn client_receiver_task( @@ -605,8 +662,12 @@ fn update_sync_server( connection_events.send(ConnectionEvent { id: id }); } InternalAsyncMessage::ClientLostConnection(client_id) => { - endpoint.clients.remove(&client_id); - connection_lost_events.send(ConnectionLostEvent { id: client_id }); + match endpoint.clients.remove(&client_id) { + Some(_) => { + connection_lost_events.send(ConnectionLostEvent { id: client_id }) + } + None => (), + } } } } diff --git a/src/shared.rs b/src/shared.rs index 247841f..608bb0c 100644 --- a/src/shared.rs +++ b/src/shared.rs @@ -25,6 +25,8 @@ pub enum QuinnetError { UnknownClient(ClientId), #[error("Connection with id `{0}` is unknown")] UnknownConnection(ConnectionId), + #[error("Connection is closed")] + ConnectionClosed, #[error("Endpoint is already closed")] EndpointAlreadyClosed, #[error("Failed serialization")]