From d630a376d5eba48cdbe3cd5e2f657b552c78b24e Mon Sep 17 00:00:00 2001 From: Kolby Moroz Liebl <31669092+KolbyML@users.noreply.github.com> Date: Mon, 2 Oct 2023 18:31:54 -0600 Subject: [PATCH] fix: handle 3rd potential id case, which prevented resets from being handled (#114) --- src/socket.rs | 65 ++++++++++++++++++++++++++++++++++----------------- 1 file changed, 44 insertions(+), 21 deletions(-) diff --git a/src/socket.rs b/src/socket.rs index 41caf95..dad22f4 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -79,19 +79,21 @@ where } }; - let init_cid = cid_from_packet(&packet, &src, false); - let acc_cid = cid_from_packet(&packet, &src, true); + let peer_init_cid = cid_from_packet(&packet, &src, IdType::SendIdPeerInitiated); + let we_init_cid = cid_from_packet(&packet, &src, IdType::SendIdWeInitiated); + let acc_cid = cid_from_packet(&packet, &src, IdType::RecvId); let mut conns = conns.write().unwrap(); let conn = conns .get(&acc_cid) - .or_else(|| conns.get(&init_cid)); + .or_else(|| conns.get(&we_init_cid)) + .or_else(|| conns.get(&peer_init_cid)); match conn { Some(conn) => { let _ = conn.send(StreamEvent::Incoming(packet)); } None => { if std::matches!(packet.packet_type(), PacketType::Syn) { - let cid = cid_from_packet(&packet, &src, true); + let cid = cid_from_packet(&packet, &src, IdType::RecvId); let mut awaiting = awaiting.write().unwrap(); // If there was an awaiting connection with the CID, then @@ -124,6 +126,9 @@ where packet = ?packet.packet_type(), seq = %packet.seq_num(), ack = %packet.ack_num(), + peer_init_cid = ?peer_init_cid, + we_init_cid = ?we_init_cid, + acc_cid = ?acc_cid, "received uTP packet for non-existing conn" ); } @@ -334,29 +339,47 @@ where } } +#[derive(Copy, Clone, Debug)] +enum IdType { + RecvId, + SendIdWeInitiated, + SendIdPeerInitiated, +} + fn cid_from_packet( packet: &Packet, src: &P, - from_initiator: bool, + id_type: IdType, ) -> ConnectionId

{ - if !from_initiator { - let (send, recv) = (packet.conn_id().wrapping_add(1), packet.conn_id()); - ConnectionId { - send, - recv, - peer: src.clone(), + match id_type { + IdType::RecvId => { + let (send, recv) = match packet.packet_type() { + PacketType::Syn => (packet.conn_id(), packet.conn_id().wrapping_add(1)), + PacketType::State | PacketType::Data | PacketType::Fin | PacketType::Reset => { + (packet.conn_id().wrapping_sub(1), packet.conn_id()) + } + }; + ConnectionId { + send, + recv, + peer: src.clone(), + } } - } else { - let (send, recv) = match packet.packet_type() { - PacketType::Syn => (packet.conn_id(), packet.conn_id().wrapping_add(1)), - PacketType::State | PacketType::Data | PacketType::Fin | PacketType::Reset => { - (packet.conn_id().wrapping_sub(1), packet.conn_id()) + IdType::SendIdWeInitiated => { + let (send, recv) = (packet.conn_id().wrapping_add(1), packet.conn_id()); + ConnectionId { + send, + recv, + peer: src.clone(), + } + } + IdType::SendIdPeerInitiated => { + let (send, recv) = (packet.conn_id(), packet.conn_id().wrapping_sub(1)); + ConnectionId { + send, + recv, + peer: src.clone(), } - }; - ConnectionId { - send, - recv, - peer: src.clone(), } } }