Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Support tokio and async-std runtimes #368

Merged
merged 6 commits into from
Aug 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ script:
- cargo check --no-default-features
- cargo check --no-default-features --features http
- cargo check --no-default-features --features http-tls
- cargo check --no-default-features --features ws
- cargo check --no-default-features --features ws-tls
- cargo check --no-default-features --features ws-tokio
- cargo check --no-default-features --features ws-tls-tokio
- cargo check --no-default-features --features ws-async-std
- cargo check --no-default-features --features ws-tls-async-std

after_success: |
[ $TRAVIS_BRANCH = master ] &&
Expand Down
22 changes: 12 additions & 10 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,31 @@ serde_json = "1.0.39"
tiny-keccak = { version = "2.0.1", features = ["keccak"] }
# Optional deps
## HTTP
base64 = { version = "0.12.0", optional = true }
base64 = { version = "0.12", optional = true }
hyper = { version = "0.13", optional = true, default-features = false, features = ["stream", "tcp"] }
hyper-tls = { version = "0.4", optional = true }
## WS
async-native-tls = { version = "0.3", optional = true }
async-std = { version = "1.5.0", optional = true }
async-native-tls = { version = "0.3", optional = true, default-features = false }
async-std = { version = "1.6", optional = true }
tokio = { version = "0.2", optional = true, features = ["full"] }
tokio-util = { version = "0.3", optional = true, features = ["compat"] }
soketto = { version = "0.4.1", optional = true }
## Shared (WS, HTTP)
native-tls = { version = "0.2", optional = true }
url = { version = "2.1.0", optional = true }
url = { version = "2.1", optional = true }

[dev-dependencies]
# For examples
env_logger = "0.7.0"
env_logger = "0.7"
tokio = { version = "0.2", features = ["full"] }
# WS test
async-std = { version = "1.5.0", features = ["attributes"] }

[features]
default = ["http", "ws", "http-tls", "ws-tls"]
default = ["http-tls", "ws-tls-tokio"]
http = ["hyper", "url", "base64"]
http-tls = ["hyper-tls", "native-tls", "http"]
ws = ["soketto", "async-std", "url"]
ws-tls = ["async-native-tls", "native-tls", "ws"]
ws-tokio = ["soketto", "url", "tokio", "tokio-util"]
ws-async-std = ["soketto", "url", "async-std"]
ws-tls-tokio = ["async-native-tls", "native-tls", "async-native-tls/runtime-tokio", "ws-tokio"]
ws-tls-async-std = ["async-native-tls", "native-tls", "async-native-tls/runtime-async-std", "ws-async-std"]

[workspace]
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ For more see [examples folder](./examples).
- [ ] Consider getting rid of `Unpin` requirements. (#361)
- [x] WebSockets: TLS support (#360)
- [ ] WebSockets: Reconnecting & Pings
- [ ] Consider using `tokio` instead of `async-std` for `ws.rs` transport (issue with test).
- [x] Consider using `tokio` instead of `async-std` for `ws.rs` transport (issue with test).
- [ ] Restore IPC Transport

## General
Expand Down Expand Up @@ -110,3 +110,15 @@ To complile, you need to disable the IPC feature:
```
web3 = { version = "0.11.0", default-features = false, features = ["http"] }
```

# Cargo Features

The library supports following features:
- `http` - Enables `http` transport.
- `http-tls` - Enables `http` over TLS (`https`) transport support. Implies `http`.
- `ws-tokio` - Enables `ws` tranport (`tokio` runtime).
- `ws-tls-tokio` - Enables `wss` tranport (`tokio` runtime).
- `ws-async-std` - Enables `ws` tranport (`async-std` runtime).
- `ws-tls-async-std` - Enables `wss` tranport (`async-std` runtime).

By default `http-tls` and `ws-tls-tokio` are enabled.
4 changes: 2 additions & 2 deletions src/transports/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ pub mod http;
#[cfg(feature = "http")]
pub use self::http::Http;

#[cfg(feature = "ws")]
#[cfg(any(feature = "ws-tokio", feature = "ws-async-std"))]
pub mod ws;
#[cfg(feature = "ws")]
#[cfg(any(feature = "ws-tokio", feature = "ws-async-std"))]
pub use self::ws::WebSocket;

#[cfg(feature = "url")]
Expand Down
162 changes: 133 additions & 29 deletions src/transports/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::marker::Unpin;
use std::sync::{atomic, Arc};
use std::{fmt, pin::Pin};

use self::compat::{TcpStream, TlsStream};
use crate::api::SubscriptionId;
use crate::error;
use crate::helpers;
Expand All @@ -17,9 +18,6 @@ use futures::{
};
use futures::{AsyncRead, AsyncWrite};

#[cfg(feature = "ws-tls")]
use async_native_tls::TlsStream;
use async_std::net::TcpStream;
use soketto::connection;
use soketto::handshake::{Client, ServerResponse};
use url::Url;
Expand All @@ -42,51 +40,49 @@ type Pending = oneshot::Sender<BatchResult>;
type Subscription = mpsc::UnboundedSender<rpc::Value>;

/// Stream, either plain TCP or TLS.
enum MaybeTlsStream<S> {
enum MaybeTlsStream<P, T> {
/// Unencrypted socket stream.
Plain(S),
Plain(P),
/// Encrypted socket stream.
#[cfg(feature = "ws-tls")]
Tls(TlsStream<S>),
#[allow(dead_code)]
Tls(T),
}

impl<S> AsyncRead for MaybeTlsStream<S>
impl<P, T> AsyncRead for MaybeTlsStream<P, T>
where
S: AsyncRead + AsyncWrite + Unpin,
P: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "ws-tls")]
MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_read(cx, buf),
}
}
}

impl<S> AsyncWrite for MaybeTlsStream<S>
impl<P, T> AsyncWrite for MaybeTlsStream<P, T>
where
S: AsyncRead + AsyncWrite + Unpin,
P: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf),
#[cfg(feature = "ws-tls")]
MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_write(cx, buf),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
#[cfg(feature = "ws-tls")]
MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_flush(cx),
}
}

fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), std::io::Error>> {
match self.get_mut() {
MaybeTlsStream::Plain(ref mut s) => Pin::new(s).poll_close(cx),
#[cfg(feature = "ws-tls")]
MaybeTlsStream::Tls(ref mut s) => Pin::new(s).poll_close(cx),
}
}
Expand All @@ -95,8 +91,8 @@ where
struct WsServerTask {
pending: BTreeMap<RequestId, Pending>,
subscriptions: BTreeMap<SubscriptionId, Subscription>,
sender: connection::Sender<MaybeTlsStream<TcpStream>>,
receiver: connection::Receiver<MaybeTlsStream<TcpStream>>,
sender: connection::Sender<MaybeTlsStream<TcpStream, TlsStream>>,
receiver: connection::Receiver<MaybeTlsStream<TcpStream, TlsStream>>,
}

impl WsServerTask {
Expand All @@ -115,17 +111,17 @@ impl WsServerTask {
let port = url.port().unwrap_or(if scheme == "ws" { 80 } else { 443 });
let addrs = format!("{}:{}", host, port);

let stream = TcpStream::connect(addrs).await?;

let stream = compat::raw_tcp_stream(addrs).await?;
let socket = if scheme == "wss" {
#[cfg(feature = "ws-tls")]
#[cfg(any(feature = "ws-tls-tokio", feature = "ws-tls-async-std"))]
{
let stream = async_native_tls::connect(host, stream).await?;
MaybeTlsStream::Tls(stream)
MaybeTlsStream::Tls(compat::compat(stream))
}
#[cfg(not(feature = "ws-tls"))]
panic!("The library was compiled without TLS support. Enable ws-tls feature.");
#[cfg(not(any(feature = "ws-tls-tokio", feature = "ws-tls-async-std")))]
panic!("The library was compiled without TLS support. Enable ws-tls-tokio or ws-tls-async-std feature.");
} else {
let stream = compat::compat(stream);
MaybeTlsStream::Plain(stream)
};

Expand Down Expand Up @@ -311,6 +307,9 @@ impl WebSocket {
// TODO [ToDr] Not unbounded?
let (sink, stream) = mpsc::unbounded();
// Spawn background task for the transport.
#[cfg(feature = "ws-tokio")]
tokio::spawn(task.into_task(stream));
#[cfg(feature = "ws-async-std")]
async_std::task::spawn(task.into_task(stream));

Ok(Self { id, requests: sink })
Expand Down Expand Up @@ -434,21 +433,126 @@ impl DuplexTransport for WebSocket {
}
}

/// Compatibility layer between async-std and tokio
#[cfg(feature = "ws-async-std")]
#[doc(hidden)]
pub mod compat {
pub use async_std::net::TcpListener;
pub use async_std::net::TcpStream;
/// TLS stream type for async-std runtime.
#[cfg(feature = "ws-tls-async-std")]
pub type TlsStream = async_native_tls::TlsStream<TcpStream>;
/// Dummy TLS stream type.
#[cfg(not(feature = "ws-tls-async-std"))]
pub type TlsStream = TcpStream;

/// Create new TcpStream object.
pub async fn raw_tcp_stream(addrs: String) -> std::io::Result<TcpStream> {
TcpStream::connect(addrs).await
}

/// Wrap given argument into compatibility layer.
#[inline(always)]
pub fn compat<T>(t: T) -> T {
t
}
}

/// Compatibility layer between async-std and tokio
#[cfg(feature = "ws-tokio")]
pub mod compat {
/// async-std compatible TcpStream.
pub type TcpStream = Compat<tokio::net::TcpStream>;
/// async-std compatible TcpListener.
pub type TcpListener = tokio::net::TcpListener;
/// TLS stream type for tokio runtime.
#[cfg(feature = "ws-tls-tokio")]
pub type TlsStream = Compat<async_native_tls::TlsStream<tokio::net::TcpStream>>;
/// Dummy TLS stream type.
#[cfg(not(feature = "ws-tls-tokio"))]
pub type TlsStream = TcpStream;

use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};

/// Create new TcpStream object.
pub async fn raw_tcp_stream(addrs: String) -> io::Result<tokio::net::TcpStream> {
Ok(tokio::net::TcpStream::connect(addrs).await?)
}

/// Wrap given argument into compatibility layer.
pub fn compat<T>(t: T) -> Compat<T> {
Compat(t)
}

/// Compatibility layer.
pub struct Compat<T>(T);
impl<T: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for Compat<T> {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
}
}

impl<T: tokio::io::AsyncWrite + Unpin> futures::AsyncWrite for Compat<T> {
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
tokio::io::AsyncWrite::poll_write(Pin::new(&mut self.0), cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
tokio::io::AsyncWrite::poll_flush(Pin::new(&mut self.0), cx)
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
tokio::io::AsyncWrite::poll_shutdown(Pin::new(&mut self.0), cx)
}
}

impl<T: tokio::io::AsyncRead + Unpin> futures::AsyncRead for Compat<T> {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
tokio::io::AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
}
}

impl<T: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for Compat<T> {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
tokio::io::AsyncRead::poll_read(Pin::new(&mut self.0), cx, buf)
}
}
}

#[cfg(test)]
mod tests {
use super::WebSocket;
use super::*;
use crate::{rpc, Transport};
use async_std::net::TcpListener;
use futures::io::{BufReader, BufWriter};
use futures::StreamExt;
use soketto::handshake;

#[async_std::test]
#[test]
fn bounds_matching() {
fn async_rw<T: AsyncRead + AsyncWrite>() {}

async_rw::<TcpStream>();
async_rw::<MaybeTlsStream<TcpStream, TlsStream>>();
}

#[tokio::test]
async fn should_send_a_request() {
let _ = env_logger::try_init();
// given
let addr = "127.0.0.1:3000";
async_std::task::spawn(server(addr));
let listener = futures::executor::block_on(compat::TcpListener::bind(addr)).expect("Failed to bind");
println!("Starting the server.");
tokio::spawn(server(listener, addr));

let endpoint = "ws://127.0.0.1:3000";
let ws = WebSocket::new(endpoint).await.unwrap();
Expand All @@ -460,11 +564,11 @@ mod tests {
assert_eq!(res.await, Ok(rpc::Value::String("x".into())));
}

async fn server(addr: &str) {
let listener = futures::executor::block_on(TcpListener::bind(addr)).expect("Failed to bind");
async fn server(mut listener: compat::TcpListener, addr: &str) {
let mut incoming = listener.incoming();
println!("Listening on: {}", addr);
while let Some(Ok(socket)) = incoming.next().await {
let socket = compat::compat(socket);
let mut server = handshake::Server::new(BufReader::new(BufWriter::new(socket)));
let key = {
let req = server.receive_request().await.unwrap();
Expand Down