Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Gracefully disconnect connections and trigger events #6

Merged
merged 4 commits into from
Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 132 additions & 73 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use bevy::prelude::*;
use bytes::Bytes;
use futures::sink::SinkExt;
use futures_util::StreamExt;
use quinn::{ClientConfig, Connection as QuinnConnection, Endpoint};
use quinn_proto::ConnectionStats;
use quinn::{ClientConfig, Connection as QuinnConnection, Endpoint, SendStream};
use quinn_proto::{ConnectionStats, VarInt};
use serde::Deserialize;
use tokio::{
runtime::{self},
Expand Down Expand Up @@ -105,8 +105,9 @@ type InternalConnectionRef = QuinnConnection;
/// Current state of a client connection
#[derive(Debug)]
enum ConnectionState {
Disconnected,
Connecting,
Connected(InternalConnectionRef),
Disconnected,
}

#[derive(Debug)]
Expand Down Expand Up @@ -171,9 +172,12 @@ impl Connection {
}

pub fn send_message<T: serde::Serialize>(&self, message: T) -> Result<(), QuinnetError> {
match bincode::serialize(&message) {
Ok(payload) => self.send_payload(payload),
Err(_) => Err(QuinnetError::Serialization),
match &self.state {
ConnectionState::Disconnected => Err(QuinnetError::ConnectionClosed),
_ => match bincode::serialize(&message) {
Ok(payload) => self.send_payload(payload),
Err(_) => Err(QuinnetError::Serialization),
},
}
}

Expand All @@ -186,11 +190,14 @@ impl Connection {
}

pub fn send_payload<T: Into<Bytes>>(&self, payload: T) -> Result<(), QuinnetError> {
match self.sender.try_send(payload.into()) {
Ok(_) => Ok(()),
Err(err) => match err {
TrySendError::Full(_) => Err(QuinnetError::FullQueue),
TrySendError::Closed(_) => Err(QuinnetError::ChannelClosed),
match &self.state {
ConnectionState::Disconnected => Err(QuinnetError::ConnectionClosed),
_ => match self.sender.try_send(payload.into()) {
Ok(_) => Ok(()),
Err(err) => match err {
TrySendError::Full(_) => Err(QuinnetError::FullQueue),
TrySendError::Closed(_) => Err(QuinnetError::ChannelClosed),
},
},
}
}
Expand All @@ -204,11 +211,14 @@ impl Connection {
}

pub fn receive_payload(&mut self) -> Result<Option<Bytes>, QuinnetError> {
match self.receiver.try_recv() {
Ok(msg_payload) => Ok(Some(msg_payload)),
Err(err) => match err {
TryRecvError::Empty => Ok(None),
TryRecvError::Disconnected => Err(QuinnetError::ChannelClosed),
match &self.state {
ConnectionState::Disconnected => Err(QuinnetError::ConnectionClosed),
_ => match self.receiver.try_recv() {
Ok(msg_payload) => Ok(Some(msg_payload)),
Err(err) => match err {
TryRecvError::Empty => Ok(None),
TryRecvError::Disconnected => Err(QuinnetError::ChannelClosed),
},
},
}
}
Expand All @@ -226,27 +236,30 @@ impl Connection {

/// Disconnect from the server on this connection. This does not send any message to the server, and simply closes all the connection's tasks locally.
fn disconnect(&mut self) -> Result<(), QuinnetError> {
if self.is_connected() {
if let Err(_) = self.close_sender.send(()) {
return Err(QuinnetError::ChannelClosed);
match &self.state {
ConnectionState::Disconnected => Ok(()),
_ => {
self.state = ConnectionState::Disconnected;
match self.close_sender.send(()) {
Ok(_) => Ok(()),
Err(_) => Err(QuinnetError::ChannelClosed),
}
}
}
self.state = ConnectionState::Disconnected;
Ok(())
}

pub fn is_connected(&self) -> bool {
match self.state {
ConnectionState::Disconnected => false,
ConnectionState::Connected(_) => true,
_ => false,
}
}

/// Returns statistics about the current connection if connected.
pub fn stats(&self) -> Option<ConnectionStats> {
match &self.state {
ConnectionState::Disconnected => None,
ConnectionState::Connected(connection) => Some(connection.stats()),
_ => None,
}
}
}
Expand Down Expand Up @@ -331,7 +344,7 @@ impl Client {
) = broadcast::channel(DEFAULT_KILL_MESSAGE_QUEUE_SIZE);

let connection = Connection {
state: ConnectionState::Disconnected,
state: ConnectionState::Connecting,
sender: to_server_sender,
receiver: from_server_receiver,
close_sender: close_sender.clone(),
Expand Down Expand Up @@ -474,57 +487,100 @@ async fn connection_task(mut spawn_config: ConnectionSpawnConfig) {
.expect("Failed to open send stream");
let mut frame_send = FramedWrite::new(send, LengthDelimitedCodec::new());

let close_sender_clone = spawn_config.close_sender.clone();
let _network_sends = tokio::spawn(async move {
tokio::select! {
_ = spawn_config.close_receiver.recv() => {
trace!("Unidirectional send Stream forced to disconnected")
}
_ = async {
while let Some(msg_bytes) = spawn_config.to_server_receiver.recv().await {
if let Err(err) = frame_send.send(msg_bytes).await {
error!("Error while sending, {}", err); // TODO Clean: error handling
error!("Client seems disconnected, closing resources");
if let Err(_) = close_sender_clone.send(()) {
error!("Failed to close all client streams & resources")
}
spawn_config.to_sync_client.send(
InternalAsyncMessage::LostConnection)
.await
.expect("Failed to signal connection lost to sync client");
let _close_waiter = {
let conn = connection.clone();
let to_sync_client = spawn_config.to_sync_client.clone();
let close_sender = spawn_config.close_sender.clone();
tokio::spawn(async move {
let conn_err = conn.closed().await;
info!("Disconnected: {}", conn_err);
close_sender.send(()).ok();
to_sync_client
.send(InternalAsyncMessage::LostConnection)
.await
.expect("Failed to signal connection lost to sync client");
})
};

let _network_sends = {
let close_sender_clone = spawn_config.close_sender.clone();
let conn = connection.clone();
tokio::spawn(async move {
tokio::select! {
_ = spawn_config.close_receiver.recv() => {
trace!("Unidirectional send Stream forced to disconnected")
}
_ = async {
while let Some(msg_bytes) = spawn_config.to_server_receiver.recv().await {
send_msg(&close_sender_clone, &spawn_config.to_sync_client, &mut frame_send, msg_bytes).await;
}
} => {
trace!("Unidirectional send Stream ended")
}
} => {
trace!("Unidirectional send Stream ended")
}
}
});

let mut uni_receivers: JoinSet<()> = JoinSet::new();
let mut close_receiver = spawn_config.close_sender.subscribe();
let _network_reads = tokio::spawn(async move {
tokio::select! {
_ = close_receiver.recv() => {
trace!("New Stream listener forced to disconnected")
while let Ok(msg_bytes) = spawn_config.to_server_receiver.try_recv() {
if let Err(err) = frame_send.send(msg_bytes).await {
error!("Error while sending, {}", err);
}
}
_ = async {
while let Ok(recv)= connection.accept_uni().await {
let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new());
let from_server_sender = spawn_config.from_server_sender.clone();

uni_receivers.spawn(async move {
while let Some(Ok(msg_bytes)) = frame_recv.next().await {
from_server_sender.send(msg_bytes.into()).await.unwrap(); // TODO Clean: error handling
}
});
if let Err(err) = frame_send.flush().await {
error!("Error while flushing stream: {}", err);
}
if let Err(err) = frame_send.into_inner().finish().await {
error!("Failed to shutdown stream gracefully: {}", err);
}
conn.close(VarInt::from_u32(0), "closed".as_bytes());
})
};

let _network_reads = {
let mut uni_receivers: JoinSet<()> = JoinSet::new();
let mut close_receiver = spawn_config.close_sender.subscribe();
tokio::spawn(async move {
tokio::select! {
_ = close_receiver.recv() => {
trace!("New Stream listener forced to disconnected")
}
_ = async {
while let Ok(recv)= connection.accept_uni().await {
let mut frame_recv = FramedRead::new(recv, LengthDelimitedCodec::new());
let from_server_sender = spawn_config.from_server_sender.clone();

uni_receivers.spawn(async move {
while let Some(Ok(msg_bytes)) = frame_recv.next().await {
from_server_sender.send(msg_bytes.into()).await.unwrap(); // TODO Clean: error handling
}
});
}
} => {
trace!("New Stream listener ended ")
}
} => {
trace!("New Stream listener ended ")
}
}
uni_receivers.shutdown().await;
trace!("All unidirectional stream receivers cleaned");
});
uni_receivers.shutdown().await;
trace!("All unidirectional stream receivers cleaned");
})
};
}
}
}

async fn send_msg(
close_sender: &tokio::sync::broadcast::Sender<()>,
to_sync_client: &mpsc::Sender<InternalAsyncMessage>,
frame_send: &mut FramedWrite<SendStream, LengthDelimitedCodec>,
msg_bytes: Bytes,
) {
if let Err(err) = frame_send.send(msg_bytes).await {
error!("Error while sending, {}", err);
error!("Client seems disconnected, closing resources");
// Emit LostConnection to properly update the connection about its state.
// Raise LostConnection event before emitting a close signal because we have no guarantee to continue this async execution after the close signal has been processed.
to_sync_client
.send(InternalAsyncMessage::LostConnection)
.await
.expect("Failed to signal connection lost to sync client");
if let Err(_) = close_sender.send(()) {
error!("Failed to close all client streams & resources")
}
}
}
Expand All @@ -545,10 +601,13 @@ fn update_sync_client(
connection.state = ConnectionState::Connected(internal_connection);
connection_events.send(ConnectionEvent { id: *connection_id });
}
InternalAsyncMessage::LostConnection => {
connection.state = ConnectionState::Disconnected;
connection_lost_events.send(ConnectionLostEvent { id: *connection_id });
}
InternalAsyncMessage::LostConnection => match connection.state {
ConnectionState::Disconnected => (),
_ => {
connection.state = ConnectionState::Disconnected;
connection_lost_events.send(ConnectionLostEvent { id: *connection_id });
}
},
InternalAsyncMessage::CertificateInteractionRequest {
status,
info,
Expand Down
Loading