Skip to content

Commit

Permalink
Merge branch 'main' into chxiao/native_tls_builder
Browse files Browse the repository at this point in the history
  • Loading branch information
Devdutt Shenoi authored May 30, 2024
2 parents ccbfd8d + e63bcab commit 26d7821
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 112 deletions.
9 changes: 8 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions rumqttc/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* `ConnectionAborted` variant on `StateError` type to denote abrupt end to a connection
* `set_session_expiry_interval` and `session_expiry_interval` methods on `MqttOptions`.
* `Auth` packet as per MQTT5 standards
* Allow configuring the `nodelay` property of underlying TCP client with the `tcp_nodelay` field in `NetworkOptions`

### Changed

* rename `N` as `AsyncReadWrite` to describe usage.
* use `Framed` to encode/decode MQTT packets.
* use `Login` to store credentials
* Made `DisconnectProperties` struct public.
* Replace `Vec<Option<u16>>` with `FixedBitSet` for managing packet ids of released QoS 2 publishes and incoming QoS 2 publishes in `MqttState`.
* Accept `native_tls::TlsConnectorBuilder` as input of `Transport::tls_with_config`.

### Deprecated
Expand All @@ -32,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Validate filters while creating subscription requests.
* Make v4::Connect::write return correct value
* Ordering of `State.events` related to `QoS > 0` publishes
* Filter PUBACK in pending save requests to fix unexpected PUBACK sent to reconnected broker.
* Resume session only if broker sends `CONNACK` with `session_present == 1`.

### Security
Expand Down
1 change: 1 addition & 0 deletions rumqttc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ url = { version = "2", default-features = false, optional = true }
# proxy
async-http-proxy = { version = "1.2.5", features = ["runtime-tokio", "basic-auth"], optional = true }
tokio-stream = "0.1.15"
fixedbitset = "0.5.7"

[dev-dependencies]
bincode = "1.3.3"
Expand Down
25 changes: 16 additions & 9 deletions rumqttc/src/eventloop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,15 @@ impl EventLoop {
self.pending.extend(self.state.clean());

// drain requests from channel which weren't yet received
let requests_in_channel = self.requests_rx.drain();
let mut requests_in_channel: Vec<_> = self.requests_rx.drain().collect();

requests_in_channel.retain(|request| {
match request {
Request::PubAck(_) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack
_ => true,
}
});

self.pending.extend(requests_in_channel);
}

Expand Down Expand Up @@ -323,6 +331,8 @@ pub(crate) async fn socket_connect(
SocketAddr::V6(_) => TcpSocket::new_v6()?,
};

socket.set_nodelay(network_options.tcp_nodelay)?;

if let Some(send_buff_size) = network_options.tcp_send_buffer_size {
socket.set_send_buffer_size(send_buff_size).unwrap();
}
Expand Down Expand Up @@ -476,18 +486,15 @@ async fn mqtt_connect(
options: &MqttOptions,
network: &mut Network,
) -> Result<ConnAck, ConnectionError> {
let keep_alive = options.keep_alive().as_secs() as u16;
let clean_session = options.clean_session();
let last_will = options.last_will();

let mut connect = Connect::new(options.client_id());
connect.keep_alive = keep_alive;
connect.clean_session = clean_session;
connect.last_will = last_will;
connect.keep_alive = options.keep_alive().as_secs() as u16;
connect.clean_session = options.clean_session();
connect.last_will = options.last_will();
connect.login = options.credentials();

// send mqtt connect packet
network.connect(connect).await?;
network.write(Packet::Connect(connect)).await?;
network.flush().await?;

// validate connack
match network.read().await? {
Expand Down
6 changes: 0 additions & 6 deletions rumqttc/src/framed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,6 @@ impl Network {
.map_err(StateError::Deserialization)
}

pub async fn connect(&mut self, connect: Connect) -> Result<(), StateError> {
self.write(Packet::Connect(connect)).await?;

self.flush().await
}

pub async fn flush(&mut self) -> Result<(), crate::state::StateError> {
self.framed
.flush()
Expand Down
6 changes: 6 additions & 0 deletions rumqttc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ impl From<TlsConnectorBuilder> for TlsConfiguration {
pub struct NetworkOptions {
tcp_send_buffer_size: Option<u32>,
tcp_recv_buffer_size: Option<u32>,
tcp_nodelay: bool,
conn_timeout: u64,
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
bind_device: Option<String>,
Expand All @@ -394,12 +395,17 @@ impl NetworkOptions {
NetworkOptions {
tcp_send_buffer_size: None,
tcp_recv_buffer_size: None,
tcp_nodelay: false,
conn_timeout: 5,
#[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
bind_device: None,
}
}

pub fn set_tcp_nodelay(&mut self, nodelay: bool) {
self.tcp_nodelay = nodelay;
}

pub fn set_tcp_send_buffer_size(&mut self, size: u32) {
self.tcp_send_buffer_size = Some(size);
}
Expand Down
56 changes: 21 additions & 35 deletions rumqttc/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{Event, Incoming, Outgoing, Request};

use crate::mqttbytes::v4::*;
use crate::mqttbytes::{self, *};
use fixedbitset::FixedBitSet;
use std::collections::VecDeque;
use std::{io, time::Instant};

Expand Down Expand Up @@ -62,9 +63,9 @@ pub struct MqttState {
/// Outgoing QoS 1, 2 publishes which aren't acked yet
pub(crate) outgoing_pub: Vec<Option<Publish>>,
/// Packet ids of released QoS 2 publishes
pub(crate) outgoing_rel: Vec<Option<u16>>,
pub(crate) outgoing_rel: FixedBitSet,
/// Packet ids on incoming QoS 2 publishes
pub(crate) incoming_pub: Vec<Option<u16>>,
pub(crate) incoming_pub: FixedBitSet,
/// Last collision due to broker not acking in order
pub collision: Option<Publish>,
/// Buffered incoming packets
Expand All @@ -89,8 +90,8 @@ impl MqttState {
max_inflight,
// index 0 is wasted as 0 is not a valid packet id
outgoing_pub: vec![None; max_inflight as usize + 1],
outgoing_rel: vec![None; max_inflight as usize + 1],
incoming_pub: vec![None; u16::MAX as usize + 1],
outgoing_rel: FixedBitSet::with_capacity(max_inflight as usize + 1),
incoming_pub: FixedBitSet::with_capacity(u16::MAX as usize + 1),
collision: None,
// TODO: Optimize these sizes later
events: VecDeque::with_capacity(100),
Expand All @@ -113,17 +114,14 @@ impl MqttState {
}

// remove and collect pending releases
for rel in self.outgoing_rel.iter_mut() {
if let Some(pkid) = rel.take() {
let request = Request::PubRel(PubRel::new(pkid));
pending.push(request);
}
for pkid in self.outgoing_rel.ones() {
let request = Request::PubRel(PubRel::new(pkid as u16));
pending.push(request);
}
self.outgoing_rel.clear();

// remove packed ids of incoming qos2 publishes
for id in self.incoming_pub.iter_mut() {
id.take();
}
// remove packet ids of incoming qos2 publishes
self.incoming_pub.clear();

self.await_pingresp = false;
self.collision_ping_count = 0;
Expand Down Expand Up @@ -210,7 +208,7 @@ impl MqttState {
}
QoS::ExactlyOnce => {
let pkid = publish.pkid;
self.incoming_pub[pkid as usize] = Some(pkid);
self.incoming_pub.insert(pkid as usize);

if !self.manual_acks {
let pubrec = PubRec::new(pkid);
Expand Down Expand Up @@ -261,7 +259,7 @@ impl MqttState {
}

// NOTE: Inflight - 1 for qos2 in comp
self.outgoing_rel[pubrec.pkid as usize] = Some(pubrec.pkid);
self.outgoing_rel.insert(pubrec.pkid as usize);
let pubrel = PubRel { pkid: pubrec.pkid };
let event = Event::Outgoing(Outgoing::PubRel(pubrec.pkid));
self.events.push_back(event);
Expand All @@ -270,16 +268,12 @@ impl MqttState {
}

fn handle_incoming_pubrel(&mut self, pubrel: &PubRel) -> Result<Option<Packet>, StateError> {
let publish = self
.incoming_pub
.get_mut(pubrel.pkid as usize)
.ok_or(StateError::Unsolicited(pubrel.pkid))?;

if publish.take().is_none() {
if !self.incoming_pub.contains(pubrel.pkid as usize) {
error!("Unsolicited pubrel packet: {:?}", pubrel.pkid);
return Err(StateError::Unsolicited(pubrel.pkid));
}

self.incoming_pub.set(pubrel.pkid as usize, false);
let event = Event::Outgoing(Outgoing::PubComp(pubrel.pkid));
let pubcomp = PubComp { pkid: pubrel.pkid };
self.events.push_back(event);
Expand All @@ -288,17 +282,12 @@ impl MqttState {
}

fn handle_incoming_pubcomp(&mut self, pubcomp: &PubComp) -> Result<Option<Packet>, StateError> {
if self
.outgoing_rel
.get_mut(pubcomp.pkid as usize)
.ok_or(StateError::Unsolicited(pubcomp.pkid))?
.take()
.is_none()
{
if !self.outgoing_rel.contains(pubcomp.pkid as usize) {
error!("Unsolicited pubcomp packet: {:?}", pubcomp.pkid);
return Err(StateError::Unsolicited(pubcomp.pkid));
}

self.outgoing_rel.set(pubcomp.pkid as usize, false);
self.inflight -= 1;
let packet = self.check_collision(pubcomp.pkid).map(|publish| {
let event = Event::Outgoing(Outgoing::Publish(publish.pkid));
Expand Down Expand Up @@ -486,7 +475,7 @@ impl MqttState {
_ => pubrel,
};

self.outgoing_rel[pubrel.pkid as usize] = Some(pubrel.pkid);
self.outgoing_rel.insert(pubrel.pkid as usize);
self.inflight += 1;
Ok(pubrel)
}
Expand Down Expand Up @@ -610,10 +599,8 @@ mod test {
mqtt.handle_incoming_publish(&publish2).unwrap();
mqtt.handle_incoming_publish(&publish3).unwrap();

let pkid = mqtt.incoming_pub[3].unwrap();

// only qos2 publish should be add to queue
assert_eq!(pkid, 3);
assert!(mqtt.incoming_pub.contains(3));
}

#[test]
Expand Down Expand Up @@ -656,8 +643,7 @@ mod test {
mqtt.handle_incoming_publish(&publish2).unwrap();
mqtt.handle_incoming_publish(&publish3).unwrap();

let pkid = mqtt.incoming_pub[3].unwrap();
assert_eq!(pkid, 3);
assert!(mqtt.incoming_pub.contains(3));

assert!(mqtt.events.is_empty());
}
Expand Down Expand Up @@ -725,7 +711,7 @@ mod test {
assert_eq!(backup.unwrap().pkid, 1);

// check if the qos2 element's release pkid is 2
assert_eq!(mqtt.outgoing_rel[2].unwrap(), 2);
assert!(mqtt.outgoing_rel.contains(2));
}

#[test]
Expand Down
34 changes: 21 additions & 13 deletions rumqttc/src/v5/eventloop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,15 @@ impl EventLoop {
self.pending.extend(self.state.clean());

// drain requests from channel which weren't yet received
let requests_in_channel = self.requests_rx.drain();
let mut requests_in_channel: Vec<_> = self.requests_rx.drain().collect();

requests_in_channel.retain(|request| {
match request {
Request::PubAck(_) => false, // Wait for publish retransmission, else the broker could be confused by an unexpected ack
_ => true,
}
});

self.pending.extend(requests_in_channel);
}

Expand Down Expand Up @@ -390,20 +398,20 @@ async fn mqtt_connect(
options: &mut MqttOptions,
network: &mut Network,
) -> Result<ConnAck, ConnectionError> {
let keep_alive = options.keep_alive().as_secs() as u16;
let clean_start = options.clean_start();
let client_id = options.client_id();
let properties = options.connect_properties();

let connect = Connect {
keep_alive,
client_id,
clean_start,
properties,
};
let packet = Packet::Connect(
Connect {
client_id: options.client_id(),
keep_alive: options.keep_alive().as_secs() as u16,
clean_start: options.clean_start(),
properties: options.connect_properties(),
},
options.last_will(),
options.credentials(),
);

// send mqtt connect packet
network.connect(connect, options).await?;
network.write(packet).await?;
network.flush().await?;

// validate connack
match network.read().await? {
Expand Down
16 changes: 1 addition & 15 deletions rumqttc/src/v5/framed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ use tokio_util::codec::Framed;
use crate::framed::AsyncReadWrite;

use super::mqttbytes::v5::Packet;
use super::{mqttbytes, Codec, Connect, MqttOptions, MqttState};
use super::{Incoming, StateError};
use super::{mqttbytes, Codec, Incoming, MqttState, StateError};

/// Network transforms packets <-> frames efficiently. It takes
/// advantage of pre-allocation, buffering and vectorization when
Expand Down Expand Up @@ -86,19 +85,6 @@ impl Network {
.map_err(StateError::Deserialization)
}

pub async fn connect(
&mut self,
connect: Connect,
options: &MqttOptions,
) -> Result<(), StateError> {
let last_will = options.last_will();
let login = options.credentials();
self.write(Packet::Connect(connect, last_will, login))
.await?;

self.flush().await
}

pub async fn flush(&mut self) -> Result<(), StateError> {
self.framed
.flush()
Expand Down
Loading

0 comments on commit 26d7821

Please # to comment.