diff --git a/.gitmodules b/.gitmodules index 80d7961bfecb..2f6c628a2f24 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,10 +1,10 @@ -[submodule "spec_testsuite"] +[submodule "tests/spec_testsuite"] path = tests/spec_testsuite url = https://github.com/WebAssembly/testsuite -[submodule "crates/c-api/examples/wasm-c-api"] +[submodule "crates/c-api/wasm-c-api"] path = crates/c-api/wasm-c-api url = https://github.com/WebAssembly/wasm-c-api -[submodule "WASI"] +[submodule "crates/wasi-common/WASI"] path = crates/wasi-common/WASI url = https://github.com/WebAssembly/WASI [submodule "crates/wasi-nn/spec"] @@ -13,10 +13,7 @@ [submodule "tests/wasi_testsuite/wasi-threads"] path = tests/wasi_testsuite/wasi-threads url = https://github.com/WebAssembly/wasi-threads -[submodule "crates/wasi-http/wasi-http"] - path = crates/wasi-http/wasi-http - url = https://github.com/WebAssembly/wasi-http -[submodule "tests/wasi_testsuite/wasi"] +[submodule "tests/wasi_testsuite/wasi-common"] path = tests/wasi_testsuite/wasi-common url = https://github.com/WebAssembly/wasi-testsuite.git branch = prod/testsuite-base diff --git a/Cargo.lock b/Cargo.lock index 1ffa394338aa..d4554190edc1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3769,9 +3769,11 @@ dependencies = [ "once_cell", "rustix 0.38.8", "system-interface", + "test-log", "thiserror", "tokio", "tracing", + "tracing-subscriber", "wasi-cap-std-sync", "wasi-common", "wasi-tokio", diff --git a/Cargo.toml b/Cargo.toml index a12c80ddc183..56411e399a17 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -258,6 +258,8 @@ futures = { version = "0.3.27", default-features = false } indexmap = "2.0.0" pretty_env_logger = "0.5.0" syn = "2.0.25" +test-log = { version = "0.2", default-features = false, features = ["trace"] } +tracing-subscriber = { version = "0.3.1", default-features = false, features = ['fmt', 'env-filter'] } [features] default = [ diff --git a/RELEASES.md b/RELEASES.md index 8cc2cdb13d0a..3627b0bc979e 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -87,6 +87,11 @@ Unreleased. be turned off. [#6547](https://github.com/bytecodealliance/wasmtime/pull/6547) +* WASI Preview 2 output-stream has been redesigned with changes to + backpressure and flushing. The `HostOutputStream` trait has changed + substantially. + [#6877](https://github.com/bytecodealliance/wasmtime/pull/6877) + ### Removed * Wasmtime's experimental implementation of wasi-crypto has been removed. More diff --git a/ci/build-test-matrix.js b/ci/build-test-matrix.js index ce14be4f5aad..35c2c00efd57 100644 --- a/ci/build-test-matrix.js +++ b/ci/build-test-matrix.js @@ -95,7 +95,13 @@ const array = [ "qemu_target": "riscv64-linux-user", "name": "Test Linux riscv64", "filter": "linux-riscv64", - "isa": "riscv64" + "isa": "riscv64", + // There appears to be a miscompile in Rust 1.72 for riscv64 where + // wasmtime-wasi tests are segfaulting in CI with the stack pointing in + // Tokio. Updating rustc seems to do the trick, so without doing a full + // rigorous investigation this uses beta for now but Rust 1.73 should be + // good to go for this. + "rust": "beta-2023-09-10", } ]; diff --git a/crates/test-programs/Cargo.toml b/crates/test-programs/Cargo.toml index f64082e77370..4808b8e462e7 100644 --- a/crates/test-programs/Cargo.toml +++ b/crates/test-programs/Cargo.toml @@ -26,12 +26,9 @@ tracing = { workspace = true } [dev-dependencies] anyhow = { workspace = true } tempfile = { workspace = true } -test-log = { version = "0.2", default-features = false, features = ["trace"] } +test-log = { workspace = true } tracing = { workspace = true } -tracing-subscriber = { version = "0.3.1", default-features = false, features = [ - 'fmt', - 'env-filter', -] } +tracing-subscriber = { workspace = true } lazy_static = "1" wasmtime = { workspace = true, features = ['cranelift', 'component-model'] } diff --git a/crates/test-programs/reactor-tests/src/lib.rs b/crates/test-programs/reactor-tests/src/lib.rs index 532dda77f4e1..c0f9f2dcc1b1 100644 --- a/crates/test-programs/reactor-tests/src/lib.rs +++ b/crates/test-programs/reactor-tests/src/lib.rs @@ -3,9 +3,21 @@ wit_bindgen::generate!("test-reactor" in "../../wasi/wit"); export_test_reactor!(T); struct T; +use wasi::io::streams; +use wasi::poll::poll; static mut STATE: Vec = Vec::new(); +struct DropPollable { + pollable: poll::Pollable, +} + +impl Drop for DropPollable { + fn drop(&mut self) { + poll::drop_pollable(self.pollable); + } +} + impl TestReactor for T { fn add_strings(ss: Vec) -> u32 { for s in ss { @@ -24,10 +36,38 @@ impl TestReactor for T { } fn write_strings_to(o: OutputStream) -> Result<(), ()> { + let sub = DropPollable { + pollable: streams::subscribe_to_output_stream(o), + }; unsafe { for s in STATE.iter() { - wasi::io::streams::write(o, s.as_bytes()).map_err(|_| ())?; + let mut out = s.as_bytes(); + while !out.is_empty() { + poll::poll_oneoff(&[sub.pollable]); + let n = match streams::check_write(o) { + Ok(n) => n, + Err(_) => return Err(()), + }; + + let len = (n as usize).min(out.len()); + match streams::write(o, &out[..len]) { + Ok(_) => out = &out[len..], + Err(_) => return Err(()), + } + } } + + match streams::flush(o) { + Ok(_) => {} + Err(_) => return Err(()), + } + + poll::poll_oneoff(&[sub.pollable]); + match streams::check_write(o) { + Ok(_) => {} + Err(_) => return Err(()), + } + Ok(()) } } diff --git a/crates/test-programs/tests/reactor.rs b/crates/test-programs/tests/reactor.rs index 521c67af8c9e..555c756d086d 100644 --- a/crates/test-programs/tests/reactor.rs +++ b/crates/test-programs/tests/reactor.rs @@ -103,7 +103,7 @@ async fn reactor_tests() -> Result<()> { // `host` and `wasi-common` crate. // Note, this works because of the add_to_linker invocations using the // `host` crate for `streams`, not because of `with` in the bindgen macro. - let writepipe = preview2::pipe::MemoryOutputPipe::new(); + let writepipe = preview2::pipe::MemoryOutputPipe::new(4096); let table_ix = preview2::TableStreamExt::push_output_stream( store.data_mut().table_mut(), Box::new(writepipe.clone()), diff --git a/crates/test-programs/tests/wasi-http-components-sync.rs b/crates/test-programs/tests/wasi-http-components-sync.rs index cd3705d685d1..bb4975cfe96e 100644 --- a/crates/test-programs/tests/wasi-http-components-sync.rs +++ b/crates/test-programs/tests/wasi-http-components-sync.rs @@ -69,8 +69,8 @@ fn instantiate_component( } fn run(name: &str) -> anyhow::Result<()> { - let stdout = MemoryOutputPipe::new(); - let stderr = MemoryOutputPipe::new(); + let stdout = MemoryOutputPipe::new(4096); + let stderr = MemoryOutputPipe::new(4096); let r = { let mut table = Table::new(); let component = get_component(name); diff --git a/crates/test-programs/tests/wasi-http-components.rs b/crates/test-programs/tests/wasi-http-components.rs index 811b23151b24..86c45350779b 100644 --- a/crates/test-programs/tests/wasi-http-components.rs +++ b/crates/test-programs/tests/wasi-http-components.rs @@ -70,8 +70,8 @@ async fn instantiate_component( } async fn run(name: &str) -> anyhow::Result<()> { - let stdout = MemoryOutputPipe::new(); - let stderr = MemoryOutputPipe::new(); + let stdout = MemoryOutputPipe::new(4096); + let stderr = MemoryOutputPipe::new(4096); let r = { let mut table = Table::new(); let component = get_component(name); diff --git a/crates/test-programs/tests/wasi-http-modules.rs b/crates/test-programs/tests/wasi-http-modules.rs index 38dd25ffba54..8bc0a677c347 100644 --- a/crates/test-programs/tests/wasi-http-modules.rs +++ b/crates/test-programs/tests/wasi-http-modules.rs @@ -74,8 +74,8 @@ async fn instantiate_module(module: Module, ctx: Ctx) -> Result<(Store, Fun } async fn run(name: &str) -> anyhow::Result<()> { - let stdout = MemoryOutputPipe::new(); - let stderr = MemoryOutputPipe::new(); + let stdout = MemoryOutputPipe::new(4096); + let stderr = MemoryOutputPipe::new(4096); let r = { let mut table = Table::new(); let module = get_module(name); diff --git a/crates/test-programs/tests/wasi-preview1-host-in-preview2.rs b/crates/test-programs/tests/wasi-preview1-host-in-preview2.rs index 05dbd32717d9..0daa796179d5 100644 --- a/crates/test-programs/tests/wasi-preview1-host-in-preview2.rs +++ b/crates/test-programs/tests/wasi-preview1-host-in-preview2.rs @@ -30,8 +30,8 @@ pub fn prepare_workspace(exe_name: &str) -> Result { async fn run(name: &str, inherit_stdio: bool) -> Result<()> { let workspace = prepare_workspace(name)?; - let stdout = MemoryOutputPipe::new(); - let stderr = MemoryOutputPipe::new(); + let stdout = MemoryOutputPipe::new(4096); + let stderr = MemoryOutputPipe::new(4096); let r = { let mut linker = Linker::new(&ENGINE); add_to_linker_async(&mut linker)?; diff --git a/crates/test-programs/tests/wasi-preview2-components-sync.rs b/crates/test-programs/tests/wasi-preview2-components-sync.rs index 08b60217be62..10c0bd8df0b3 100644 --- a/crates/test-programs/tests/wasi-preview2-components-sync.rs +++ b/crates/test-programs/tests/wasi-preview2-components-sync.rs @@ -30,8 +30,8 @@ pub fn prepare_workspace(exe_name: &str) -> Result { fn run(name: &str, inherit_stdio: bool) -> Result<()> { let workspace = prepare_workspace(name)?; - let stdout = MemoryOutputPipe::new(); - let stderr = MemoryOutputPipe::new(); + let stdout = MemoryOutputPipe::new(4096); + let stderr = MemoryOutputPipe::new(4096); let r = { let mut linker = Linker::new(&ENGINE); add_to_linker(&mut linker)?; diff --git a/crates/test-programs/tests/wasi-preview2-components.rs b/crates/test-programs/tests/wasi-preview2-components.rs index 77a3e71ff6d8..b7bbdfbc164f 100644 --- a/crates/test-programs/tests/wasi-preview2-components.rs +++ b/crates/test-programs/tests/wasi-preview2-components.rs @@ -30,8 +30,8 @@ pub fn prepare_workspace(exe_name: &str) -> Result { async fn run(name: &str, inherit_stdio: bool) -> Result<()> { let workspace = prepare_workspace(name)?; - let stdout = MemoryOutputPipe::new(); - let stderr = MemoryOutputPipe::new(); + let stdout = MemoryOutputPipe::new(4096); + let stderr = MemoryOutputPipe::new(4096); let r = { let mut linker = Linker::new(&ENGINE); add_to_linker(&mut linker)?; diff --git a/crates/test-programs/wasi-http-tests/src/lib.rs b/crates/test-programs/wasi-http-tests/src/lib.rs index 6de349b42dc1..808f30830d21 100644 --- a/crates/test-programs/wasi-http-tests/src/lib.rs +++ b/crates/test-programs/wasi-http-tests/src/lib.rs @@ -7,7 +7,7 @@ pub mod bindings { }); } -use anyhow::{anyhow, Context, Result}; +use anyhow::{anyhow, Result}; use std::fmt; use std::sync::OnceLock; @@ -42,6 +42,16 @@ impl Response { } } +struct DropPollable { + pollable: poll::Pollable, +} + +impl Drop for DropPollable { + fn drop(&mut self) { + poll::drop_pollable(self.pollable); + } +} + pub async fn request( method: http_types::Method, scheme: http_types::Scheme, @@ -72,27 +82,39 @@ pub async fn request( let request_body = http_types::outgoing_request_write(request) .map_err(|_| anyhow!("outgoing request write failed"))?; - if let Some(body) = body { - let output_stream_pollable = streams::subscribe_to_output_stream(request_body); - let len = body.len(); - if len == 0 { - let (_written, _status) = streams::write(request_body, &[]) - .map_err(|_| anyhow!("request_body stream write failed")) - .context("writing empty request body")?; - } else { - let mut body_cursor = 0; - while body_cursor < body.len() { - let (written, _status) = streams::write(request_body, &body[body_cursor..]) - .map_err(|_| anyhow!("request_body stream write failed")) - .context("writing request body")?; - body_cursor += written as usize; + if let Some(mut buf) = body { + let sub = DropPollable { + pollable: streams::subscribe_to_output_stream(request_body), + }; + while !buf.is_empty() { + poll::poll_oneoff(&[sub.pollable]); + + let permit = match streams::check_write(request_body) { + Ok(n) => usize::try_from(n)?, + Err(_) => anyhow::bail!("output stream error"), + }; + + let len = buf.len().min(permit); + let (chunk, rest) = buf.split_at(len); + buf = rest; + + match streams::write(request_body, chunk) { + Err(_) => anyhow::bail!("output stream error"), + _ => {} } } - // TODO: enable when working as expected - // let _ = poll::poll_oneoff(&[output_stream_pollable]); + match streams::flush(request_body) { + Err(_) => anyhow::bail!("output stream error"), + _ => {} + } + + poll::poll_oneoff(&[sub.pollable]); - poll::drop_pollable(output_stream_pollable); + match streams::check_write(request_body) { + Ok(_) => {} + Err(_) => anyhow::bail!("output stream error"), + }; } let future_response = outgoing_handler::handle(request, None); diff --git a/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v4.rs b/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v4.rs index c0eb0afdc473..2388cb074afe 100644 --- a/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v4.rs +++ b/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v4.rs @@ -1,24 +1,10 @@ //! A simple TCP testcase, using IPv4. -use wasi::io::streams; -use wasi::poll::poll; use wasi::sockets::network::{IpAddressFamily, IpSocketAddress, Ipv4SocketAddress}; -use wasi::sockets::{instance_network, network, tcp, tcp_create_socket}; +use wasi::sockets::{instance_network, tcp, tcp_create_socket}; use wasi_sockets_tests::*; -fn wait(sub: poll::Pollable) { - loop { - let wait = poll::poll_oneoff(&[sub]); - if wait[0] { - break; - } - } -} - fn main() { - let first_message = b"Hello, world!"; - let second_message = b"Greetings, planet!"; - let net = instance_network::instance_network(); let sock = tcp_create_socket::create_tcp_socket(IpAddressFamily::Ipv4).unwrap(); @@ -31,82 +17,11 @@ fn main() { let sub = tcp::subscribe(sock); tcp::start_bind(sock, net, addr).unwrap(); - wait(sub); - tcp::finish_bind(sock).unwrap(); - - tcp::start_listen(sock).unwrap(); - wait(sub); - tcp::finish_listen(sock).unwrap(); - - let addr = tcp::local_address(sock).unwrap(); - - let client = tcp_create_socket::create_tcp_socket(IpAddressFamily::Ipv4).unwrap(); - let client_sub = tcp::subscribe(client); - - tcp::start_connect(client, net, addr).unwrap(); - wait(client_sub); - let (client_input, client_output) = tcp::finish_connect(client).unwrap(); - - let (n, status) = streams::write(client_output, &[]).unwrap(); - assert_eq!(n, 0); - assert_eq!(status, streams::StreamStatus::Open); - - let (n, status) = streams::write(client_output, first_message).unwrap(); - assert_eq!(n, first_message.len() as u64); // Not guaranteed to work but should work in practice. - assert_eq!(status, streams::StreamStatus::Open); - - streams::drop_input_stream(client_input); - streams::drop_output_stream(client_output); - poll::drop_pollable(client_sub); - tcp::drop_tcp_socket(client); wait(sub); - let (accepted, input, output) = tcp::accept(sock).unwrap(); + wasi::poll::poll::drop_pollable(sub); - let (empty_data, status) = streams::read(input, 0).unwrap(); - assert!(empty_data.is_empty()); - assert_eq!(status, streams::StreamStatus::Open); - - let (data, status) = streams::blocking_read(input, first_message.len() as u64).unwrap(); - assert_eq!(status, streams::StreamStatus::Open); - - tcp::drop_tcp_socket(accepted); - streams::drop_input_stream(input); - streams::drop_output_stream(output); - - // Check that we sent and recieved our message! - assert_eq!(data, first_message); // Not guaranteed to work but should work in practice. - - // Another client - let client = tcp_create_socket::create_tcp_socket(IpAddressFamily::Ipv4).unwrap(); - let client_sub = tcp::subscribe(client); - - tcp::start_connect(client, net, addr).unwrap(); - wait(client_sub); - let (client_input, client_output) = tcp::finish_connect(client).unwrap(); - - let (n, status) = streams::write(client_output, second_message).unwrap(); - assert_eq!(n, second_message.len() as u64); // Not guaranteed to work but should work in practice. - assert_eq!(status, streams::StreamStatus::Open); - - streams::drop_input_stream(client_input); - streams::drop_output_stream(client_output); - poll::drop_pollable(client_sub); - tcp::drop_tcp_socket(client); - - wait(sub); - let (accepted, input, output) = tcp::accept(sock).unwrap(); - let (data, status) = streams::blocking_read(input, second_message.len() as u64).unwrap(); - assert_eq!(status, streams::StreamStatus::Open); - - streams::drop_input_stream(input); - streams::drop_output_stream(output); - tcp::drop_tcp_socket(accepted); - - // Check that we sent and recieved our message! - assert_eq!(data, second_message); // Not guaranteed to work but should work in practice. + tcp::finish_bind(sock).unwrap(); - poll::drop_pollable(sub); - tcp::drop_tcp_socket(sock); - network::drop_network(net); + example_body(net, sock, IpAddressFamily::Ipv4) } diff --git a/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v6.rs b/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v6.rs index 36e465f0fac7..b5ff8358cc09 100644 --- a/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v6.rs +++ b/crates/test-programs/wasi-sockets-tests/src/bin/tcp_v6.rs @@ -1,24 +1,10 @@ //! Like v4.rs, but with IPv6. -use wasi::io::streams; -use wasi::poll::poll; use wasi::sockets::network::{IpAddressFamily, IpSocketAddress, Ipv6SocketAddress}; -use wasi::sockets::{instance_network, network, tcp, tcp_create_socket}; +use wasi::sockets::{instance_network, tcp, tcp_create_socket}; use wasi_sockets_tests::*; -fn wait(sub: poll::Pollable) { - loop { - let wait = poll::poll_oneoff(&[sub]); - if wait[0] { - break; - } - } -} - fn main() { - let first_message = b"Hello, world!"; - let second_message = b"Greetings, planet!"; - let net = instance_network::instance_network(); let sock = tcp_create_socket::create_tcp_socket(IpAddressFamily::Ipv6).unwrap(); @@ -33,82 +19,11 @@ fn main() { let sub = tcp::subscribe(sock); tcp::start_bind(sock, net, addr).unwrap(); - wait(sub); - tcp::finish_bind(sock).unwrap(); - - tcp::start_listen(sock).unwrap(); - wait(sub); - tcp::finish_listen(sock).unwrap(); - - let addr = tcp::local_address(sock).unwrap(); - - let client = tcp_create_socket::create_tcp_socket(IpAddressFamily::Ipv6).unwrap(); - let client_sub = tcp::subscribe(client); - - tcp::start_connect(client, net, addr).unwrap(); - wait(client_sub); - let (client_input, client_output) = tcp::finish_connect(client).unwrap(); - - let (n, status) = streams::write(client_output, &[]).unwrap(); - assert_eq!(n, 0); - assert_eq!(status, streams::StreamStatus::Open); - - let (n, status) = streams::write(client_output, first_message).unwrap(); - assert_eq!(n, first_message.len() as u64); // Not guaranteed to work but should work in practice. - assert_eq!(status, streams::StreamStatus::Open); - - streams::drop_input_stream(client_input); - streams::drop_output_stream(client_output); - poll::drop_pollable(client_sub); - tcp::drop_tcp_socket(client); wait(sub); - let (accepted, input, output) = tcp::accept(sock).unwrap(); + wasi::poll::poll::drop_pollable(sub); - let (empty_data, status) = streams::read(input, 0).unwrap(); - assert!(empty_data.is_empty()); - assert_eq!(status, streams::StreamStatus::Open); - - let (data, status) = streams::blocking_read(input, first_message.len() as u64).unwrap(); - assert_eq!(status, streams::StreamStatus::Open); - - tcp::drop_tcp_socket(accepted); - streams::drop_input_stream(input); - streams::drop_output_stream(output); - - // Check that we sent and recieved our message! - assert_eq!(data, first_message); // Not guaranteed to work but should work in practice. - - // Another client - let client = tcp_create_socket::create_tcp_socket(IpAddressFamily::Ipv6).unwrap(); - let client_sub = tcp::subscribe(client); - - tcp::start_connect(client, net, addr).unwrap(); - wait(client_sub); - let (client_input, client_output) = tcp::finish_connect(client).unwrap(); - - let (n, status) = streams::write(client_output, second_message).unwrap(); - assert_eq!(n, second_message.len() as u64); // Not guaranteed to work but should work in practice. - assert_eq!(status, streams::StreamStatus::Open); - - streams::drop_input_stream(client_input); - streams::drop_output_stream(client_output); - poll::drop_pollable(client_sub); - tcp::drop_tcp_socket(client); - - wait(sub); - let (accepted, input, output) = tcp::accept(sock).unwrap(); - let (data, status) = streams::blocking_read(input, second_message.len() as u64).unwrap(); - assert_eq!(status, streams::StreamStatus::Open); - - streams::drop_input_stream(input); - streams::drop_output_stream(output); - tcp::drop_tcp_socket(accepted); - - // Check that we sent and recieved our message! - assert_eq!(data, second_message); // Not guaranteed to work but should work in practice. + tcp::finish_bind(sock).unwrap(); - poll::drop_pollable(sub); - tcp::drop_tcp_socket(sock); - network::drop_network(net); + example_body(net, sock, IpAddressFamily::Ipv6) } diff --git a/crates/test-programs/wasi-sockets-tests/src/lib.rs b/crates/test-programs/wasi-sockets-tests/src/lib.rs index cf3ecf02f82c..fb06654c313a 100644 --- a/crates/test-programs/wasi-sockets-tests/src/lib.rs +++ b/crates/test-programs/wasi-sockets-tests/src/lib.rs @@ -1 +1,143 @@ wit_bindgen::generate!("test-command-with-sockets" in "../../wasi/wit"); + +use wasi::io::streams; +use wasi::poll::poll; +use wasi::sockets::{network, tcp, tcp_create_socket}; + +pub fn wait(sub: poll::Pollable) { + loop { + let wait = poll::poll_oneoff(&[sub]); + if wait[0] { + break; + } + } +} + +pub struct DropPollable { + pub pollable: poll::Pollable, +} + +impl Drop for DropPollable { + fn drop(&mut self) { + poll::drop_pollable(self.pollable); + } +} + +pub fn write(output: streams::OutputStream, mut bytes: &[u8]) -> (usize, streams::StreamStatus) { + let total = bytes.len(); + let mut written = 0; + + let s = DropPollable { + pollable: streams::subscribe_to_output_stream(output), + }; + + while !bytes.is_empty() { + poll::poll_oneoff(&[s.pollable]); + + let permit = match streams::check_write(output) { + Ok(n) => n, + Err(_) => return (written, streams::StreamStatus::Ended), + }; + + let len = bytes.len().min(permit as usize); + let (chunk, rest) = bytes.split_at(len); + + match streams::write(output, chunk) { + Ok(()) => {} + Err(_) => return (written, streams::StreamStatus::Ended), + } + + match streams::blocking_flush(output) { + Ok(()) => {} + Err(_) => return (written, streams::StreamStatus::Ended), + } + + bytes = rest; + written += len; + } + + (total, streams::StreamStatus::Open) +} + +pub fn example_body(net: tcp::Network, sock: tcp::TcpSocket, family: network::IpAddressFamily) { + let first_message = b"Hello, world!"; + let second_message = b"Greetings, planet!"; + + let sub = tcp::subscribe(sock); + + tcp::start_listen(sock).unwrap(); + wait(sub); + tcp::finish_listen(sock).unwrap(); + + let addr = tcp::local_address(sock).unwrap(); + + let client = tcp_create_socket::create_tcp_socket(family).unwrap(); + let client_sub = tcp::subscribe(client); + + tcp::start_connect(client, net, addr).unwrap(); + wait(client_sub); + let (client_input, client_output) = tcp::finish_connect(client).unwrap(); + + let (n, status) = write(client_output, &[]); + assert_eq!(n, 0); + assert_eq!(status, streams::StreamStatus::Open); + + let (n, status) = write(client_output, first_message); + assert_eq!(n, first_message.len()); + assert_eq!(status, streams::StreamStatus::Open); + + streams::drop_input_stream(client_input); + streams::drop_output_stream(client_output); + poll::drop_pollable(client_sub); + tcp::drop_tcp_socket(client); + + wait(sub); + let (accepted, input, output) = tcp::accept(sock).unwrap(); + + let (empty_data, status) = streams::read(input, 0).unwrap(); + assert!(empty_data.is_empty()); + assert_eq!(status, streams::StreamStatus::Open); + + let (data, status) = streams::blocking_read(input, first_message.len() as u64).unwrap(); + assert_eq!(status, streams::StreamStatus::Open); + + streams::drop_input_stream(input); + streams::drop_output_stream(output); + tcp::drop_tcp_socket(accepted); + + // Check that we sent and recieved our message! + assert_eq!(data, first_message); // Not guaranteed to work but should work in practice. + + // Another client + let client = tcp_create_socket::create_tcp_socket(family).unwrap(); + let client_sub = tcp::subscribe(client); + + tcp::start_connect(client, net, addr).unwrap(); + wait(client_sub); + let (client_input, client_output) = tcp::finish_connect(client).unwrap(); + + let (n, status) = write(client_output, second_message); + assert_eq!(n, second_message.len()); + assert_eq!(status, streams::StreamStatus::Open); + + streams::drop_input_stream(client_input); + streams::drop_output_stream(client_output); + poll::drop_pollable(client_sub); + tcp::drop_tcp_socket(client); + + wait(sub); + let (accepted, input, output) = tcp::accept(sock).unwrap(); + let (data, status) = streams::blocking_read(input, second_message.len() as u64).unwrap(); + assert_eq!(status, streams::StreamStatus::Open); + + streams::drop_input_stream(input); + streams::drop_output_stream(output); + tcp::drop_tcp_socket(accepted); + + // Check that we sent and recieved our message! + assert_eq!(data, second_message); // Not guaranteed to work but should work in practice. + + poll::drop_pollable(sub); + tcp::drop_tcp_socket(sock); + network::drop_network(net); +} diff --git a/crates/wasi-http/src/component_impl.rs b/crates/wasi-http/src/component_impl.rs index 3ffc0e0f2036..98f6c57579b7 100644 --- a/crates/wasi-http/src/component_impl.rs +++ b/crates/wasi-http/src/component_impl.rs @@ -609,14 +609,90 @@ pub fn add_component_to_linker( }) }, )?; + linker.func_wrap2_async( + "wasi:io/streams", + "check-write", + move |mut caller: Caller<'_, T>, stream: u32, ptr: u32| { + Box::new(async move { + let memory: Memory = memory_get(&mut caller)?; + let ctx = get_cx(caller.data_mut()); + + let result = match io::streams::Host::check_write(ctx, stream).await { + // 0 == outer result tag (success) + // 1 == result value (u64 lower 32 bits) + // 2 == result value (u64 upper 32 bits) + Ok(len) => [0, len as u32, (len >> 32) as u32], + + // 0 == outer result tag (failure) + // 1 == result value (u64 lower 32 bits) + // 2 == result value (unused) + Err(_) => todo!("how do we extract runtime error cases?"), + }; + + let raw = u32_array_to_u8(&result); + memory.write(caller.as_context_mut(), ptr as _, &raw)?; + + Ok(()) + }) + }, + )?; + linker.func_wrap2_async( + "wasi:io/streams", + "flush", + move |mut caller: Caller<'_, T>, stream: u32, ptr: u32| { + Box::new(async move { + let ctx = get_cx(caller.data_mut()); + + let result: [u32; 2] = match io::streams::Host::flush(ctx, stream).await { + // 0 == outer result tag + // 1 == unused + Ok(_) => [0, 0], + + // 0 == outer result tag + // 1 == inner result tag + Err(_) => todo!("how do we extract runtime error cases?"), + }; + + let raw = u32_array_to_u8(&result); + let memory: Memory = memory_get(&mut caller)?; + memory.write(caller.as_context_mut(), ptr as _, &raw)?; + + Ok(()) + }) + }, + )?; + linker.func_wrap2_async( + "wasi:io/streams", + "blocking-flush", + move |mut caller: Caller<'_, T>, stream: u32, ptr: u32| { + Box::new(async move { + let ctx = get_cx(caller.data_mut()); + + let result: [u32; 2] = match io::streams::Host::blocking_flush(ctx, stream).await { + // 0 == outer result tag + // 1 == unused + Ok(_) => [0, 0], + + // 0 == outer result tag + // 1 == inner result tag + Err(_) => todo!("how do we extract runtime error cases?"), + }; + + let raw = u32_array_to_u8(&result); + let memory: Memory = memory_get(&mut caller)?; + memory.write(caller.as_context_mut(), ptr as _, &raw)?; + + Ok(()) + }) + }, + )?; linker.func_wrap4_async( "wasi:io/streams", "write", move |mut caller: Caller<'_, T>, stream: u32, body_ptr: u32, body_len: u32, ptr: u32| { Box::new(async move { - let memory = memory_get(&mut caller)?; - let body = - string_from_memory(&memory, caller.as_context_mut(), body_ptr, body_len)?; + let memory: Memory = memory_get(&mut caller)?; + let body = slice_from_memory(&memory, caller.as_context_mut(), body_ptr, body_len)?; let ctx = get_cx(caller.data_mut()); tracing::trace!( @@ -629,19 +705,11 @@ pub fn add_component_to_linker( "[module='wasi:io/streams' function='write'] return result={:?}", result ); - let (len, status) = result?.map_err(|_| anyhow!("write failed"))?; - - let written: u32 = len.try_into()?; - let done: u32 = match status { - io::streams::StreamStatus::Open => 0, - io::streams::StreamStatus::Ended => 1, - }; + result?; // First == is_err // Second == {ok: is_err = false, tag: is_err = true} - // Third == amount of bytes written - // Fifth == enum status - let result: [u32; 5] = [0, 0, written, 0, done]; + let result: [u32; 2] = [0, 0]; let raw = u32_array_to_u8(&result); memory.write(caller.as_context_mut(), ptr as _, &raw)?; @@ -652,37 +720,28 @@ pub fn add_component_to_linker( )?; linker.func_wrap4_async( "wasi:io/streams", - "blocking-write", + "blocking-write-and-flush", move |mut caller: Caller<'_, T>, stream: u32, body_ptr: u32, body_len: u32, ptr: u32| { Box::new(async move { - let memory = memory_get(&mut caller)?; - let body = - string_from_memory(&memory, caller.as_context_mut(), body_ptr, body_len)?; + let memory: Memory = memory_get(&mut caller)?; + let body = slice_from_memory(&memory, caller.as_context_mut(), body_ptr, body_len)?; let ctx = get_cx(caller.data_mut()); tracing::trace!( - "[module='wasi:io/streams' function='blocking-write'] call stream={:?} body={:?}", + "[module='wasi:io/streams' function='blocking-write-and-flush'] call stream={:?} body={:?}", stream, body ); - let result = io::streams::Host::blocking_write(ctx, stream, body.into()).await; + let result = io::streams::Host::blocking_write_and_flush(ctx, stream, body.into()).await; tracing::trace!( - "[module='wasi:io/streams' function='blocking-write'] return result={:?}", + "[module='wasi:io/streams' function='blocking-write-and-flush'] return result={:?}", result ); - let (len, status) = result?.map_err(|_| anyhow!("write failed"))?; - - let written: u32 = len.try_into()?; - let done: u32 = match status { - io::streams::StreamStatus::Open => 0, - io::streams::StreamStatus::Ended => 1, - }; + result?; // First == is_err // Second == {ok: is_err = false, tag: is_err = true} - // Third == amount of bytes written - // Fifth == enum status - let result: [u32; 5] = [0, 0, written, 0, done]; + let result: [u32; 2] = [0, 0]; let raw = u32_array_to_u8(&result); memory.write(caller.as_context_mut(), ptr as _, &raw)?; @@ -1437,19 +1496,11 @@ pub mod sync { "[module='wasi:io/streams' function='write'] return result={:?}", result ); - let (len, status) = result?.map_err(|_| anyhow!("write failed"))?; - - let written: u32 = len.try_into()?; - let done: u32 = match status { - io::streams::StreamStatus::Open => 0, - io::streams::StreamStatus::Ended => 1, - }; + result?; // First == is_err // Second == {ok: is_err = false, tag: is_err = true} - // Third == amount of bytes written - // Fifth == enum status - let result: [u32; 5] = [0, 0, written, 0, done]; + let result: [u32; 2] = [0, 0]; let raw = u32_array_to_u8(&result); memory.write(caller.as_context_mut(), ptr as _, &raw)?; @@ -1459,7 +1510,7 @@ pub mod sync { )?; linker.func_wrap( "wasi:io/streams", - "blocking-write", + "blocking-write-and-flush", move |mut caller: Caller<'_, T>, stream: u32, body_ptr: u32, @@ -1472,28 +1523,20 @@ pub mod sync { let ctx = get_cx(caller.data_mut()); tracing::trace!( - "[module='wasi:io/streams' function='blocking-write'] call stream={:?} body={:?}", + "[module='wasi:io/streams' function='blocking-write-and-flush'] call stream={:?} body={:?}", stream, body ); - let result = io::streams::Host::blocking_write(ctx, stream, body.into()); + let result = io::streams::Host::blocking_write_and_flush(ctx, stream, body.into()); tracing::trace!( - "[module='wasi:io/streams' function='blocking-write'] return result={:?}", + "[module='wasi:io/streams' function='blocking-write-and-flush'] return result={:?}", result ); - let (len, status) = result?.map_err(|_| anyhow!("write failed"))?; - - let written: u32 = len.try_into()?; - let done: u32 = match status { - io::streams::StreamStatus::Open => 0, - io::streams::StreamStatus::Ended => 1, - }; + result?; // First == is_err // Second == {ok: is_err = false, tag: is_err = true} - // Third == amount of bytes written - // Fifth == enum status - let result: [u32; 5] = [0, 0, written, 0, done]; + let result: [u32; 2] = [0, 0]; let raw = u32_array_to_u8(&result); memory.write(caller.as_context_mut(), ptr as _, &raw)?; @@ -1501,6 +1544,31 @@ pub mod sync { Ok(()) }, )?; + linker.func_wrap( + "wasi:io/streams", + "check-write", + move |mut caller: Caller<'_, T>, stream: u32, ptr: u32| { + let memory = memory_get(&mut caller)?; + let ctx = get_cx(caller.data_mut()); + + let result = match io::streams::Host::check_write(ctx, stream) { + // 0 == outer result tag (success) + // 1 == result value (u64 lower 32 bits) + // 2 == result value (u64 upper 32 bits) + Ok(len) => [0, len as u32, (len >> 32) as u32], + + // 0 == outer result tag (failure) + // 1 == result value (u64 lower 32 bits) + // 2 == result value (unused) + Err(_) => todo!("how do we extract runtime error cases?"), + }; + + let raw = u32_array_to_u8(&result); + memory.write(caller.as_context_mut(), ptr as _, &raw)?; + + Ok(()) + }, + )?; linker.func_wrap( "wasi:http/types", "drop-fields", diff --git a/crates/wasi-http/src/http_impl.rs b/crates/wasi-http/src/http_impl.rs index cd0cdfdf4f4d..e181ef143e0d 100644 --- a/crates/wasi-http/src/http_impl.rs +++ b/crates/wasi-http/src/http_impl.rs @@ -113,8 +113,7 @@ impl WasiHttpViewExt for T { let request = self .table() .get_request(request_id) - .context("[handle_async] getting request")? - .clone(); + .context("[handle_async] getting request")?; tracing::debug!("http request retrieved from table"); let method = match request.method() { @@ -324,6 +323,7 @@ impl WasiHttpViewExt for T { let (stream_id, stream) = self .table_mut() .push_stream(Bytes::from(buf), response_id) + .await .context("[handle_async] pushing stream")?; let response = self .table_mut() diff --git a/crates/wasi-http/src/types.rs b/crates/wasi-http/src/types.rs index e611bc4ac858..ff4c56f79fda 100644 --- a/crates/wasi-http/src/types.rs +++ b/crates/wasi-http/src/types.rs @@ -287,6 +287,7 @@ impl Stream { } } +#[async_trait::async_trait] pub trait TableHttpExt { fn push_request(&mut self, request: Box) -> Result; fn get_request(&self, id: u32) -> Result<&(dyn HttpRequest), TableError>; @@ -308,12 +309,17 @@ pub trait TableHttpExt { fn get_fields_mut(&mut self, id: u32) -> Result<&mut Box, TableError>; fn delete_fields(&mut self, id: u32) -> Result<(), TableError>; - fn push_stream(&mut self, content: Bytes, parent: u32) -> Result<(u32, Stream), TableError>; + async fn push_stream( + &mut self, + content: Bytes, + parent: u32, + ) -> Result<(u32, Stream), TableError>; fn get_stream(&self, id: u32) -> Result<&Stream, TableError>; fn get_stream_mut(&mut self, id: u32) -> Result<&mut Box, TableError>; fn delete_stream(&mut self, id: u32) -> Result<(), TableError>; } +#[async_trait::async_trait] impl TableHttpExt for Table { fn push_request(&mut self, request: Box) -> Result { self.push(Box::new(request)) @@ -367,20 +373,30 @@ impl TableHttpExt for Table { self.delete::>(id).map(|_old| ()) } - fn push_stream(&mut self, content: Bytes, parent: u32) -> Result<(u32, Stream), TableError> { + async fn push_stream( + &mut self, + mut content: Bytes, + parent: u32, + ) -> Result<(u32, Stream), TableError> { tracing::debug!("preparing http body stream"); let (a, b) = tokio::io::duplex(MAX_BUF_SIZE); let (_, write_stream) = tokio::io::split(a); let (read_stream, _) = tokio::io::split(b); let input_stream = AsyncReadStream::new(read_stream); - let mut output_stream = AsyncWriteStream::new(write_stream); + // TODO: more informed budget here + let mut output_stream = AsyncWriteStream::new(4096, write_stream); - let mut cursor = 0; - while cursor < content.len() { - let (written, _) = output_stream - .write(content.slice(cursor..content.len())) + while !content.is_empty() { + let permit = output_stream + .write_ready() + .await + .map_err(|_| TableError::NotPresent)?; + + let chunk = content.split_to(permit as usize); + + output_stream + .write(chunk) .map_err(|_| TableError::NotPresent)?; - cursor += written; } let input_stream = Box::new(input_stream); diff --git a/crates/wasi-http/src/types_impl.rs b/crates/wasi-http/src/types_impl.rs index 8cf3a7e3d6cb..c044cd1bdde9 100644 --- a/crates/wasi-http/src/types_impl.rs +++ b/crates/wasi-http/src/types_impl.rs @@ -213,10 +213,13 @@ impl crate::bindings::http::types::Host for T .table() .get_request(request) .context("[outgoing_request_write] getting request")?; - let stream_id = req.body().unwrap_or_else(|| { + let stream_id = if let Some(stream_id) = req.body() { + stream_id + } else { let (new, stream) = self .table_mut() .push_stream(Bytes::new(), request) + .await .expect("[outgoing_request_write] valid output stream"); self.http_ctx_mut().streams.insert(new, stream); let req = self @@ -225,7 +228,7 @@ impl crate::bindings::http::types::Host for T .expect("[outgoing_request_write] request to be found"); req.set_body(new); new - }); + }; let stream = self .table() .get_stream(stream_id) diff --git a/crates/wasi-http/wit/deps/io/streams.wit b/crates/wasi-http/wit/deps/io/streams.wit index 98df181c1ea4..e2631f66a569 100644 --- a/crates/wasi-http/wit/deps/io/streams.wit +++ b/crates/wasi-http/wit/deps/io/streams.wit @@ -134,58 +134,115 @@ interface streams { /// This [represents a resource](https://github.com/WebAssembly/WASI/blob/main/docs/WitInWasi.md#Resources). type output-stream = u32 - /// Perform a non-blocking write of bytes to a stream. + /// An error for output-stream operations. /// - /// This function returns a `u64` and a `stream-status`. The `u64` indicates - /// the number of bytes from `buf` that were written, which may be less than - /// the length of `buf`. The `stream-status` indicates if further writes to - /// the stream are expected to be read. + /// Contrary to input-streams, a closed output-stream is reported using + /// an error. + enum write-error { + /// The last operation (a write or flush) failed before completion. + last-operation-failed, + /// The stream is closed: no more input will be accepted by the + /// stream. A closed output-stream will return this error on all + /// future operations. + closed + } + /// Check readiness for writing. This function never blocks. + /// + /// Returns the number of bytes permitted for the next call to `write`, + /// or an error. Calling `write` with more bytes than this function has + /// permitted will trap. + /// + /// When this function returns 0 bytes, the `subscribe-to-output-stream` + /// pollable will become ready when this function will report at least + /// 1 byte, or an error. + check-write: func( + this: output-stream + ) -> result + + /// Perform a write. This function never blocks. /// - /// When the returned `stream-status` is `open`, the `u64` return value may - /// be less than the length of `buf`. This indicates that no more bytes may - /// be written to the stream promptly. In that case the - /// `subscribe-to-output-stream` pollable will indicate when additional bytes - /// may be promptly written. + /// Precondition: check-write gave permit of Ok(n) and contents has a + /// length of less than or equal to n. Otherwise, this function will trap. /// - /// Writing an empty list must return a non-error result with `0` for the - /// `u64` return value, and the current `stream-status`. + /// returns Err(closed) without writing if the stream has closed since + /// the last call to check-write provided a permit. write: func( this: output-stream, - /// Data to write - buf: list - ) -> result> + contents: list + ) -> result<_, write-error> - /// Blocking write of bytes to a stream. + /// Perform a write of up to 4096 bytes, and then flush the stream. Block + /// until all of these operations are complete, or an error occurs. /// - /// This is similar to `write`, except that it blocks until at least one - /// byte can be written. - blocking-write: func( - this: output-stream, - /// Data to write - buf: list - ) -> result> + /// This is a convenience wrapper around the use of `check-write`, + /// `subscribe-to-output-stream`, `write`, and `flush`, and is implemented + /// with the following pseudo-code: + /// + /// ```text + /// let pollable = subscribe-to-output-stream(this); + /// while !contents.is_empty() { + /// // Wait for the stream to become writable + /// poll-oneoff(pollable); + /// let Ok(n) = check-write(this); // eliding error handling + /// let len = min(n, contents.len()); + /// let (chunk, rest) = contents.split_at(len); + /// write(this, chunk); // eliding error handling + /// contents = rest; + /// } + /// flush(this); + /// // Wait for completion of `flush` + /// poll-oneoff(pollable); + /// // Check for any errors that arose during `flush` + /// let _ = check-write(this); // eliding error handling + /// ``` + blocking-write-and-flush: func( + this: output-stream, + contents: list + ) -> result<_, write-error> - /// Write multiple zero-bytes to a stream. + /// Request to flush buffered output. This function never blocks. /// - /// This function returns a `u64` indicating the number of zero-bytes - /// that were written; it may be less than `len`. Equivelant to a call to - /// `write` with a list of zeroes of the given length. - write-zeroes: func( + /// This tells the output-stream that the caller intends any buffered + /// output to be flushed. the output which is expected to be flushed + /// is all that has been passed to `write` prior to this call. + /// + /// Upon calling this function, the `output-stream` will not accept any + /// writes (`check-write` will return `ok(0)`) until the flush has + /// completed. The `subscribe-to-output-stream` pollable will become ready + /// when the flush has completed and the stream can accept more writes. + flush: func( this: output-stream, - /// The number of zero-bytes to write - len: u64 - ) -> result> + ) -> result<_, write-error> + + /// Request to flush buffered output, and block until flush completes + /// and stream is ready for writing again. + blocking-flush: func( + this: output-stream, + ) -> result<_, write-error> + + /// Create a `pollable` which will resolve once the output-stream + /// is ready for more writing, or an error has occured. When this + /// pollable is ready, `check-write` will return `ok(n)` with n>0, or an + /// error. + /// + /// If the stream is closed, this pollable is always ready immediately. + /// + /// The created `pollable` is a child resource of the `output-stream`. + /// Implementations may trap if the `output-stream` is dropped before + /// all derived `pollable`s created with this function are dropped. + subscribe-to-output-stream: func(this: output-stream) -> pollable - /// Write multiple zero bytes to a stream, with blocking. + /// Write zeroes to a stream. /// - /// This is similar to `write-zeroes`, except that it blocks until at least - /// one byte can be written. Equivelant to a call to `blocking-write` with - /// a list of zeroes of the given length. - blocking-write-zeroes: func( + /// this should be used precisely like `write` with the exact same + /// preconditions (must use check-write first), but instead of + /// passing a list of bytes, you simply pass the number of zero-bytes + /// that should be written. + write-zeroes: func( this: output-stream, - /// The number of zero bytes to write + /// The number of zero-bytes to write len: u64 - ) -> result> + ) -> result<_, write-error> /// Read from one stream and write to another. /// @@ -232,16 +289,6 @@ interface streams { src: input-stream ) -> result> - /// Create a `pollable` which will resolve once either the specified stream - /// is ready to accept bytes or the `stream-state` has become closed. - /// - /// Once the stream-state is closed, this pollable is always ready - /// immediately. - /// - /// The created `pollable` is a child resource of the `output-stream`. - /// Implementations may trap if the `output-stream` is dropped before - /// all derived `pollable`s created with this function are dropped. - subscribe-to-output-stream: func(this: output-stream) -> pollable /// Dispose of the specified `output-stream`, after which it may no longer /// be used. diff --git a/crates/wasi-http/wit/deps/sockets/tcp.wit b/crates/wasi-http/wit/deps/sockets/tcp.wit index 4edb1db7f0b1..3922769b308e 100644 --- a/crates/wasi-http/wit/deps/sockets/tcp.wit +++ b/crates/wasi-http/wit/deps/sockets/tcp.wit @@ -81,6 +81,9 @@ interface tcp { /// - /// - start-connect: func(this: tcp-socket, network: network, remote-address: ip-socket-address) -> result<_, error-code> + /// Note: the returned `input-stream` and `output-stream` are child + /// resources of the `tcp-socket`. Implementations may trap if the + /// `tcp-socket` is dropped before both of these streams are dropped. finish-connect: func(this: tcp-socket) -> result, error-code> /// Start listening for new connections. @@ -116,6 +119,10 @@ interface tcp { /// /// On success, this function returns the newly accepted client socket along with /// a pair of streams that can be used to read & write to the connection. + /// + /// Note: the returned `input-stream` and `output-stream` are child + /// resources of the returned `tcp-socket`. Implementations may trap if the + /// `tcp-socket` is dropped before its child streams are dropped. /// /// # Typical errors /// - `not-listening`: Socket is not in the Listener state. (EINVAL) @@ -223,6 +230,10 @@ interface tcp { /// Create a `pollable` which will resolve once the socket is ready for I/O. /// + /// The created `pollable` is a child resource of the `tcp-socket`. + /// Implementations may trap if the `tcp-socket` is dropped before all + /// derived `pollable`s created with this function are dropped. + /// /// Note: this function is here for WASI Preview2 only. /// It's planned to be removed when `future` is natively supported in Preview3. subscribe: func(this: tcp-socket) -> pollable diff --git a/crates/wasi-preview1-component-adapter/src/descriptors.rs b/crates/wasi-preview1-component-adapter/src/descriptors.rs index bcc1ba7cf71b..6045358bde36 100644 --- a/crates/wasi-preview1-component-adapter/src/descriptors.rs +++ b/crates/wasi-preview1-component-adapter/src/descriptors.rs @@ -5,7 +5,9 @@ use crate::bindings::wasi::cli::{ use crate::bindings::wasi::filesystem::types as filesystem; use crate::bindings::wasi::io::streams::{self, InputStream, OutputStream}; use crate::bindings::wasi::sockets::tcp; -use crate::{set_stderr_stream, BumpArena, File, ImportAlloc, TrappingUnwrap, WasmStr}; +use crate::{ + set_stderr_stream, BlockingMode, BumpArena, File, ImportAlloc, TrappingUnwrap, WasmStr, +}; use core::cell::{Cell, UnsafeCell}; use core::mem::MaybeUninit; use wasi::{Errno, Fd}; @@ -232,7 +234,7 @@ impl Descriptors { descriptor_type: filesystem::get_type(preopen.descriptor).trapping_unwrap(), position: Cell::new(0), append: false, - blocking: false, + blocking_mode: BlockingMode::Blocking, }), })) .trapping_unwrap(); diff --git a/crates/wasi-preview1-component-adapter/src/lib.rs b/crates/wasi-preview1-component-adapter/src/lib.rs index 595ffe615657..09ae02facf6a 100644 --- a/crates/wasi-preview1-component-adapter/src/lib.rs +++ b/crates/wasi-preview1-component-adapter/src/lib.rs @@ -546,7 +546,7 @@ pub unsafe extern "C" fn fd_fdstat_get(fd: Fd, stat: *mut Fdstat) -> Errno { if file.append { fs_flags |= FDFLAGS_APPEND; } - if !file.blocking { + if matches!(file.blocking_mode, BlockingMode::NonBlocking) { fs_flags |= FDFLAGS_NONBLOCK; } let fs_rights_inheriting = fs_rights_base; @@ -611,7 +611,11 @@ pub unsafe extern "C" fn fd_fdstat_set_flags(fd: Fd, flags: Fdflags) -> Errno { _ => Err(wasi::ERRNO_BADF)?, }; file.append = flags & FDFLAGS_APPEND == FDFLAGS_APPEND; - file.blocking = !(flags & FDFLAGS_NONBLOCK == FDFLAGS_NONBLOCK); + file.blocking_mode = if flags & FDFLAGS_NONBLOCK == FDFLAGS_NONBLOCK { + BlockingMode::NonBlocking + } else { + BlockingMode::Blocking + }; Ok(()) }) } @@ -875,23 +879,17 @@ pub unsafe extern "C" fn fd_read( State::with(|state| { match state.descriptors().get(fd)? { Descriptor::Streams(streams) => { - let blocking = if let StreamType::File(file) = &streams.type_ { - file.blocking + let blocking_mode = if let StreamType::File(file) = &streams.type_ { + file.blocking_mode } else { - false + BlockingMode::Blocking }; let read_len = u64::try_from(len).trapping_unwrap(); let wasi_stream = streams.get_read_stream()?; let (data, stream_stat) = state .import_alloc - .with_buffer(ptr, len, || { - if blocking { - streams::blocking_read(wasi_stream, read_len) - } else { - streams::read(wasi_stream, read_len) - } - }) + .with_buffer(ptr, len, || blocking_mode.read(wasi_stream, read_len)) .map_err(|_| ERRNO_IO)?; assert_eq!(data.as_ptr(), ptr); @@ -1268,18 +1266,13 @@ pub unsafe extern "C" fn fd_write( Descriptor::Streams(streams) => { let wasi_stream = streams.get_write_stream()?; - let (bytes, _stream_stat) = if let StreamType::File(file) = &streams.type_ { - if file.blocking { - streams::blocking_write(wasi_stream, bytes) - } else { - streams::write(wasi_stream, bytes) - } + let nbytes = if let StreamType::File(file) = &streams.type_ { + file.blocking_mode.write(wasi_stream, bytes)? } else { // Use blocking writes on non-file streams (stdout, stderr, as sockets // aren't currently used). - streams::blocking_write(wasi_stream, bytes) - } - .map_err(|_| ERRNO_IO)?; + BlockingMode::Blocking.write(wasi_stream, bytes)? + }; // If this is a file, keep the current-position pointer up to date. if let StreamType::File(file) = &streams.type_ { @@ -1287,12 +1280,11 @@ pub unsafe extern "C" fn fd_write( // we should set the position to the new end of the file, but // we don't have an API to do that atomically. if !file.append { - file.position - .set(file.position.get() + filesystem::Filesize::from(bytes)); + file.position.set(file.position.get() + nbytes as u64); } } - *nwritten = bytes as usize; + *nwritten = nbytes; Ok(()) } Descriptor::Closed(_) => Err(ERRNO_BADF), @@ -1454,7 +1446,11 @@ pub unsafe extern "C" fn path_open( descriptor_type, position: Cell::new(0), append, - blocking: (fdflags & wasi::FDFLAGS_NONBLOCK) == 0, + blocking_mode: if fdflags & wasi::FDFLAGS_NONBLOCK == 0 { + BlockingMode::Blocking + } else { + BlockingMode::NonBlocking + }, }), }); @@ -2127,6 +2123,72 @@ impl From for wasi::Filetype { } } +#[derive(Clone, Copy)] +pub enum BlockingMode { + NonBlocking, + Blocking, +} + +impl BlockingMode { + // note: these methods must take self, not &self, to avoid rustc creating a constant + // out of a BlockingMode literal that it places in .romem, creating a data section and + // breaking our fragile linking scheme + fn read( + self, + input_stream: streams::InputStream, + read_len: u64, + ) -> Result<(Vec, streams::StreamStatus), ()> { + match self { + BlockingMode::NonBlocking => streams::read(input_stream, read_len), + BlockingMode::Blocking => streams::blocking_read(input_stream, read_len), + } + } + fn write(self, output_stream: streams::OutputStream, mut bytes: &[u8]) -> Result { + match self { + BlockingMode::Blocking => { + let total = bytes.len(); + while !bytes.is_empty() { + let len = bytes.len().min(4096); + let (chunk, rest) = bytes.split_at(len); + bytes = rest; + match streams::blocking_write_and_flush(output_stream, chunk) { + Ok(()) => {} + Err(_) => return Err(ERRNO_IO), + } + } + Ok(total) + } + + BlockingMode::NonBlocking => { + let permit = match streams::check_write(output_stream) { + Ok(n) => n, + Err(streams::WriteError::Closed) => 0, + Err(streams::WriteError::LastOperationFailed) => return Err(ERRNO_IO), + }; + + let len = bytes.len().min(permit as usize); + if len == 0 { + return Ok(0); + } + + match streams::write(output_stream, &bytes[..len]) { + Ok(_) => {} + Err(streams::WriteError::Closed) => return Ok(0), + Err(streams::WriteError::LastOperationFailed) => return Err(ERRNO_IO), + } + + match streams::blocking_flush(output_stream) { + Ok(_) => {} + Err(streams::WriteError::Closed) => return Ok(0), + Err(streams::WriteError::LastOperationFailed) => return Err(ERRNO_IO), + } + + Ok(len) + } + } + } +} + #[repr(C)] pub struct File { /// The handle to the preview2 descriptor that this file is referencing. @@ -2142,9 +2204,9 @@ pub struct File { append: bool, /// In blocking mode, read and write calls dispatch to blocking_read and - /// blocking_write on the underlying streams. When false, read and write - /// dispatch to stream's plain read and write. - blocking: bool, + /// blocking_check_write on the underlying streams. When false, read and write + /// dispatch to stream's plain read and check_write. + blocking_mode: BlockingMode, } impl File { diff --git a/crates/wasi/Cargo.toml b/crates/wasi/Cargo.toml index 6817470e9eda..5bc3b63cf8b0 100644 --- a/crates/wasi/Cargo.toml +++ b/crates/wasi/Cargo.toml @@ -41,6 +41,8 @@ futures = { workspace = true, optional = true } [dev-dependencies] tokio = { workspace = true, features = ["time", "sync", "io-std", "io-util", "rt", "rt-multi-thread", "net", "macros"] } +test-log = { workspace = true } +tracing-subscriber = { workspace = true } [target.'cfg(unix)'.dependencies] rustix = { workspace = true, features = ["event", "fs", "net"], optional = true } diff --git a/crates/wasi/src/preview2/filesystem.rs b/crates/wasi/src/preview2/filesystem.rs index 53d8362a3178..f99b6ab2d9af 100644 --- a/crates/wasi/src/preview2/filesystem.rs +++ b/crates/wasi/src/preview2/filesystem.rs @@ -1,5 +1,10 @@ -use crate::preview2::{StreamRuntimeError, StreamState, Table, TableError}; +use crate::preview2::{ + AbortOnDropJoinHandle, HostOutputStream, OutputStreamError, StreamRuntimeError, StreamState, + Table, TableError, +}; +use anyhow::anyhow; use bytes::{Bytes, BytesMut}; +use futures::future::{maybe_done, MaybeDone}; use std::sync::Arc; bitflags::bitflags! { @@ -159,15 +164,7 @@ fn read_result(r: Result) -> Result<(usize, StreamState), Ok(0) => Ok((0, StreamState::Closed)), Ok(n) => Ok((n, StreamState::Open)), Err(e) if e.kind() == std::io::ErrorKind::Interrupted => Ok((0, StreamState::Open)), - Err(e) => Err(StreamRuntimeError::from(anyhow::anyhow!(e)).into()), - } -} - -fn write_result(r: Result) -> Result<(usize, StreamState), anyhow::Error> { - match r { - Ok(0) => Ok((0, StreamState::Closed)), - Ok(n) => Ok((n, StreamState::Open)), - Err(e) => Err(StreamRuntimeError::from(anyhow::anyhow!(e)).into()), + Err(e) => Err(StreamRuntimeError::from(anyhow!(e)).into()), } } @@ -180,35 +177,101 @@ pub(crate) enum FileOutputMode { pub(crate) struct FileOutputStream { file: Arc, mode: FileOutputMode, + // Allows join future to be awaited in a cancellable manner. Gone variant indicates + // no task is currently outstanding. + task: MaybeDone>>, + closed: bool, } impl FileOutputStream { pub fn write_at(file: Arc, position: u64) -> Self { Self { file, mode: FileOutputMode::Position(position), + task: MaybeDone::Gone, + closed: false, } } pub fn append(file: Arc) -> Self { Self { file, mode: FileOutputMode::Append, + task: MaybeDone::Gone, + closed: false, } } - /// Write bytes. On success, returns the number of bytes written. - pub async fn write(&mut self, buf: Bytes) -> anyhow::Result<(usize, StreamState)> { +} + +// FIXME: configurable? determine from how much space left in file? +const FILE_WRITE_CAPACITY: usize = 1024 * 1024; + +#[async_trait::async_trait] +impl HostOutputStream for FileOutputStream { + fn write(&mut self, buf: Bytes) -> Result<(), OutputStreamError> { use system_interface::fs::FileIoExt; + + if self.closed { + return Err(OutputStreamError::Closed); + } + if !matches!(self.task, MaybeDone::Gone) { + // a write is pending - this call was not permitted + return Err(OutputStreamError::Trap(anyhow!( + "write not permitted: FileOutputStream write pending" + ))); + } let f = Arc::clone(&self.file); let m = self.mode; - let r = tokio::task::spawn_blocking(move || match m { - FileOutputMode::Position(p) => f.write_at(buf.as_ref(), p), - FileOutputMode::Append => f.append(buf.as_ref()), - }) - .await - .unwrap(); - let (n, state) = write_result(r)?; - if let FileOutputMode::Position(ref mut position) = self.mode { - *position += n as u64; + self.task = maybe_done(AbortOnDropJoinHandle::from(tokio::task::spawn_blocking( + move || match m { + FileOutputMode::Position(mut p) => { + let mut buf = buf; + while !buf.is_empty() { + let nwritten = f.write_at(buf.as_ref(), p)?; + // afterwards buf contains [nwritten, len): + let _ = buf.split_to(nwritten); + p += nwritten as u64; + } + Ok(()) + } + FileOutputMode::Append => { + let mut buf = buf; + while !buf.is_empty() { + let nwritten = f.append(buf.as_ref())?; + let _ = buf.split_to(nwritten); + } + Ok(()) + } + }, + ))); + Ok(()) + } + fn flush(&mut self) -> Result<(), OutputStreamError> { + if self.closed { + return Err(OutputStreamError::Closed); + } + // Only userland buffering of file writes is in the blocking task. + Ok(()) + } + async fn write_ready(&mut self) -> Result { + if self.closed { + return Err(OutputStreamError::Closed); + } + // If there is no outstanding task, accept more input: + if matches!(self.task, MaybeDone::Gone) { + return Ok(FILE_WRITE_CAPACITY); + } + // Wait for outstanding task: + std::pin::Pin::new(&mut self.task).await; + + // Mark task as finished, and handle output: + match std::pin::Pin::new(&mut self.task) + .take_output() + .expect("just awaited for MaybeDone completion") + { + Ok(()) => Ok(FILE_WRITE_CAPACITY), + Err(e) => { + self.closed = true; + Err(OutputStreamError::LastOperationFailed(e.into())) + } } - Ok((n, state)) } } diff --git a/crates/wasi/src/preview2/host/filesystem.rs b/crates/wasi/src/preview2/host/filesystem.rs index 189f5ca34aea..d286c7bf5e31 100644 --- a/crates/wasi/src/preview2/host/filesystem.rs +++ b/crates/wasi/src/preview2/host/filesystem.rs @@ -777,10 +777,7 @@ impl types::Host for T { fd: types::Descriptor, offset: types::Filesize, ) -> Result { - use crate::preview2::{ - filesystem::FileOutputStream, - stream::{InternalOutputStream, InternalTableStreamExt}, - }; + use crate::preview2::{filesystem::FileOutputStream, TableStreamExt}; // Trap if fd lookup fails: let f = self.table().get_file(fd)?; @@ -796,9 +793,7 @@ impl types::Host for T { let writer = FileOutputStream::write_at(clone, offset); // Insert the stream view into the table. Trap if the table is full. - let index = self - .table_mut() - .push_internal_output_stream(InternalOutputStream::File(writer))?; + let index = self.table_mut().push_output_stream(Box::new(writer))?; Ok(index) } @@ -807,10 +802,7 @@ impl types::Host for T { &mut self, fd: types::Descriptor, ) -> Result { - use crate::preview2::{ - filesystem::FileOutputStream, - stream::{InternalOutputStream, InternalTableStreamExt}, - }; + use crate::preview2::{filesystem::FileOutputStream, TableStreamExt}; // Trap if fd lookup fails: let f = self.table().get_file(fd)?; @@ -825,9 +817,7 @@ impl types::Host for T { let appender = FileOutputStream::append(clone); // Insert the stream view into the table. Trap if the table is full. - let index = self - .table_mut() - .push_internal_output_stream(InternalOutputStream::File(appender))?; + let index = self.table_mut().push_output_stream(Box::new(appender))?; Ok(index) } diff --git a/crates/wasi/src/preview2/host/io.rs b/crates/wasi/src/preview2/host/io.rs index 7dc52a027e94..ce431af4f542 100644 --- a/crates/wasi/src/preview2/host/io.rs +++ b/crates/wasi/src/preview2/host/io.rs @@ -1,13 +1,13 @@ use crate::preview2::{ bindings::io::streams::{self, InputStream, OutputStream}, bindings::poll::poll::Pollable, - filesystem::{FileInputStream, FileOutputStream}, + filesystem::FileInputStream, poll::PollableFuture, stream::{ - HostInputStream, HostOutputStream, InternalInputStream, InternalOutputStream, - InternalTableStreamExt, StreamRuntimeError, StreamState, + HostInputStream, HostOutputStream, InternalInputStream, InternalTableStreamExt, + OutputStreamError, StreamRuntimeError, StreamState, TableStreamExt, }, - HostPollable, TablePollableExt, WasiView, + HostPollable, TableError, TablePollableExt, WasiView, }; use std::any::Any; @@ -20,7 +20,23 @@ impl From for streams::StreamStatus { } } -const ZEROS: &[u8] = &[0; 4 * 1024 * 1024]; +impl From for streams::Error { + fn from(e: TableError) -> streams::Error { + streams::Error::trap(e.into()) + } +} +impl From for streams::Error { + fn from(e: OutputStreamError) -> streams::Error { + match e { + OutputStreamError::Closed => streams::WriteError::Closed.into(), + OutputStreamError::LastOperationFailed(e) => { + tracing::debug!("streams::WriteError::LastOperationFailed: {e:?}"); + streams::WriteError::LastOperationFailed.into() + } + OutputStreamError::Trap(e) => streams::Error::trap(e), + } + } +} #[async_trait::async_trait] impl streams::Host for T { @@ -30,7 +46,7 @@ impl streams::Host for T { } async fn drop_output_stream(&mut self, stream: OutputStream) -> anyhow::Result<()> { - self.table_mut().delete_internal_output_stream(stream)?; + self.table_mut().delete_output_stream(stream)?; Ok(()) } @@ -112,71 +128,6 @@ impl streams::Host for T { } } - async fn write( - &mut self, - stream: OutputStream, - bytes: Vec, - ) -> anyhow::Result> { - match self.table_mut().get_internal_output_stream_mut(stream)? { - InternalOutputStream::Host(s) => { - let (bytes_written, status) = - match HostOutputStream::write(s.as_mut(), bytes.into()) { - Ok(a) => a, - Err(e) => { - if let Some(e) = e.downcast_ref::() { - tracing::debug!("stream runtime error: {e:?}"); - return Ok(Err(())); - } else { - return Err(e); - } - } - }; - Ok(Ok((u64::try_from(bytes_written).unwrap(), status.into()))) - } - InternalOutputStream::File(s) => { - let (nwritten, state) = FileOutputStream::write(s, bytes.into()).await?; - Ok(Ok((nwritten as u64, state.into()))) - } - } - } - - async fn blocking_write( - &mut self, - stream: OutputStream, - bytes: Vec, - ) -> anyhow::Result> { - match self.table_mut().get_internal_output_stream_mut(stream)? { - InternalOutputStream::Host(s) => { - let mut bytes = bytes::Bytes::from(bytes); - let mut nwritten: usize = 0; - loop { - s.ready().await?; - let (written, state) = match HostOutputStream::write(s.as_mut(), bytes.clone()) - { - Ok(a) => a, - Err(e) => { - if let Some(e) = e.downcast_ref::() { - tracing::debug!("stream runtime error: {e:?}"); - return Ok(Err(())); - } else { - return Err(e); - } - } - }; - let _ = bytes.split_to(written); - nwritten += written; - if bytes.is_empty() || state == StreamState::Closed { - return Ok(Ok((nwritten as u64, state.into()))); - } - } - } - InternalOutputStream::File(s) => { - let (written, state) = FileOutputStream::write(s, bytes.into()).await?; - Ok(Ok((written as u64, state.into()))) - } - } - } - async fn skip( &mut self, stream: InputStream, @@ -256,85 +207,129 @@ impl streams::Host for T { } } - async fn write_zeroes( - &mut self, - stream: OutputStream, - len: u64, - ) -> anyhow::Result> { - let s = self.table_mut().get_internal_output_stream_mut(stream)?; - let mut bytes = bytes::Bytes::from_static(ZEROS); - bytes.truncate((len as usize).min(bytes.len())); - let (written, state) = match s { - InternalOutputStream::Host(s) => match HostOutputStream::write(s.as_mut(), bytes) { - Ok(a) => a, - Err(e) => { - if let Some(e) = e.downcast_ref::() { - tracing::debug!("stream runtime error: {e:?}"); - return Ok(Err(())); - } else { - return Err(e); + async fn subscribe_to_input_stream(&mut self, stream: InputStream) -> anyhow::Result { + // Ensure that table element is an input-stream: + let pollable = match self.table_mut().get_internal_input_stream_mut(stream)? { + InternalInputStream::Host(_) => { + fn input_stream_ready<'a>(stream: &'a mut dyn Any) -> PollableFuture<'a> { + let stream = stream + .downcast_mut::() + .expect("downcast to InternalInputStream failed"); + match *stream { + InternalInputStream::Host(ref mut hs) => hs.ready(), + _ => unreachable!(), } } - }, - InternalOutputStream::File(s) => match FileOutputStream::write(s, bytes).await { - Ok(a) => a, - Err(e) => { - if let Some(e) = e.downcast_ref::() { - tracing::debug!("stream runtime error: {e:?}"); - return Ok(Err(())); - } else { - return Err(e); - } + + HostPollable::TableEntry { + index: stream, + make_future: input_stream_ready, } - }, + } + // Files are always "ready" immediately (because we have no way to actually wait on + // readiness in epoll) + InternalInputStream::File(_) => { + HostPollable::Closure(Box::new(|| Box::pin(futures::future::ready(Ok(()))))) + } }; - Ok(Ok((written as u64, state.into()))) + Ok(self.table_mut().push_host_pollable(pollable)?) } - async fn blocking_write_zeroes( + /* -------------------------------------------------------------- + * + * OutputStream methods + * + * -------------------------------------------------------------- */ + + async fn check_write(&mut self, stream: OutputStream) -> Result { + let s = self.table_mut().get_output_stream_mut(stream)?; + match futures::future::poll_immediate(s.write_ready()).await { + Some(Ok(permit)) => Ok(permit as u64), + Some(Err(e)) => Err(e.into()), + None => Ok(0), + } + } + async fn write(&mut self, stream: OutputStream, bytes: Vec) -> Result<(), streams::Error> { + let s = self.table_mut().get_output_stream_mut(stream)?; + HostOutputStream::write(s, bytes.into())?; + Ok(()) + } + + async fn subscribe_to_output_stream( &mut self, stream: OutputStream, - len: u64, - ) -> anyhow::Result> { - let mut remaining = len as usize; - let s = self.table_mut().get_internal_output_stream_mut(stream)?; - loop { - if let InternalOutputStream::Host(s) = s { - HostOutputStream::ready(s.as_mut()).await?; - } - let mut bytes = bytes::Bytes::from_static(ZEROS); - bytes.truncate(remaining.min(bytes.len())); - let (written, state) = match s { - InternalOutputStream::Host(s) => match HostOutputStream::write(s.as_mut(), bytes) { - Ok(a) => a, - Err(e) => { - if let Some(e) = e.downcast_ref::() { - tracing::debug!("stream runtime error: {e:?}"); - return Ok(Err(())); - } else { - return Err(e); - } - } - }, - InternalOutputStream::File(s) => match FileOutputStream::write(s, bytes).await { - Ok(a) => a, - Err(e) => { - if let Some(e) = e.downcast_ref::() { - tracing::debug!("stream runtime error: {e:?}"); - return Ok(Err(())); - } else { - return Err(e); - } - } - }, - }; - remaining -= written; - if remaining == 0 || state == StreamState::Closed { - return Ok(Ok((len - remaining as u64, state.into()))); - } + ) -> anyhow::Result { + // Ensure that table element is an output-stream: + let _ = self.table_mut().get_output_stream_mut(stream)?; + + fn output_stream_ready<'a>(stream: &'a mut dyn Any) -> PollableFuture<'a> { + let stream = stream + .downcast_mut::>() + .expect("downcast to HostOutputStream failed"); + Box::pin(async move { + let _ = stream.write_ready().await?; + Ok(()) + }) } + + Ok(self + .table_mut() + .push_host_pollable(HostPollable::TableEntry { + index: stream, + make_future: output_stream_ready, + })?) } + async fn blocking_write_and_flush( + &mut self, + stream: OutputStream, + bytes: Vec, + ) -> Result<(), streams::Error> { + let s = self.table_mut().get_output_stream_mut(stream)?; + + if bytes.len() > 4096 { + return Err(streams::Error::trap(anyhow::anyhow!( + "Buffer too large for blocking-write-and-flush (expected at most 4096)" + ))); + } + + let mut bytes = bytes::Bytes::from(bytes); + while !bytes.is_empty() { + let permit = s.write_ready().await?; + let len = bytes.len().min(permit); + let chunk = bytes.split_to(len); + HostOutputStream::write(s, chunk)?; + } + + HostOutputStream::flush(s)?; + let _ = s.write_ready().await?; + + Ok(()) + } + + async fn write_zeroes(&mut self, stream: OutputStream, len: u64) -> Result<(), streams::Error> { + let s = self.table_mut().get_output_stream_mut(stream)?; + HostOutputStream::write_zeroes(s, len as usize)?; + Ok(()) + } + + async fn flush(&mut self, stream: OutputStream) -> Result<(), streams::Error> { + let s = self.table_mut().get_output_stream_mut(stream)?; + HostOutputStream::flush(s)?; + Ok(()) + } + async fn blocking_flush(&mut self, stream: OutputStream) -> Result<(), streams::Error> { + let s = self.table_mut().get_output_stream_mut(stream)?; + HostOutputStream::flush(s)?; + let _ = s.write_ready().await?; + Ok(()) + } + + /* -------------------------------------------------------------- + * + * Aspirational methods + * + * -------------------------------------------------------------- */ async fn splice( &mut self, _src: InputStream, @@ -403,69 +398,11 @@ impl streams::Host for T { todo!("stream forward is not implemented") } - - async fn subscribe_to_input_stream(&mut self, stream: InputStream) -> anyhow::Result { - // Ensure that table element is an input-stream: - let pollable = match self.table_mut().get_internal_input_stream_mut(stream)? { - InternalInputStream::Host(_) => { - fn input_stream_ready<'a>(stream: &'a mut dyn Any) -> PollableFuture<'a> { - let stream = stream - .downcast_mut::() - .expect("downcast to InternalInputStream failed"); - match *stream { - InternalInputStream::Host(ref mut hs) => hs.ready(), - _ => unreachable!(), - } - } - - HostPollable::TableEntry { - index: stream, - make_future: input_stream_ready, - } - } - // Files are always "ready" immediately (because we have no way to actually wait on - // readiness in epoll) - InternalInputStream::File(_) => { - HostPollable::Closure(Box::new(|| Box::pin(futures::future::ready(Ok(()))))) - } - }; - Ok(self.table_mut().push_host_pollable(pollable)?) - } - - async fn subscribe_to_output_stream( - &mut self, - stream: OutputStream, - ) -> anyhow::Result { - // Ensure that table element is an output-stream: - let pollable = match self.table_mut().get_internal_output_stream_mut(stream)? { - InternalOutputStream::Host(_) => { - fn output_stream_ready<'a>(stream: &'a mut dyn Any) -> PollableFuture<'a> { - let stream = stream - .downcast_mut::() - .expect("downcast to HostOutputStream failed"); - match *stream { - InternalOutputStream::Host(ref mut hs) => hs.ready(), - _ => unreachable!(), - } - } - - HostPollable::TableEntry { - index: stream, - make_future: output_stream_ready, - } - } - InternalOutputStream::File(_) => { - HostPollable::Closure(Box::new(|| Box::pin(futures::future::ready(Ok(()))))) - } - }; - - Ok(self.table_mut().push_host_pollable(pollable)?) - } } pub mod sync { use crate::preview2::{ - bindings::io::streams::{Host as AsyncHost, StreamStatus as AsyncStreamStatus}, + bindings::io::streams::{self as async_streams, Host as AsyncHost}, bindings::sync_io::io::streams::{self, InputStream, OutputStream}, bindings::sync_io::poll::poll::Pollable, in_tokio, WasiView, @@ -473,15 +410,34 @@ pub mod sync { // same boilerplate everywhere, converting between two identical types with different // definition sites. one day wasmtime-wit-bindgen will make all this unnecessary - fn xform(r: Result<(A, AsyncStreamStatus), ()>) -> Result<(A, streams::StreamStatus), ()> { + fn xform( + r: Result<(A, async_streams::StreamStatus), ()>, + ) -> Result<(A, streams::StreamStatus), ()> { r.map(|(a, b)| (a, b.into())) } - impl From for streams::StreamStatus { - fn from(other: AsyncStreamStatus) -> Self { + impl From for streams::StreamStatus { + fn from(other: async_streams::StreamStatus) -> Self { match other { - AsyncStreamStatus::Open => Self::Open, - AsyncStreamStatus::Ended => Self::Ended, + async_streams::StreamStatus::Open => Self::Open, + async_streams::StreamStatus::Ended => Self::Ended, + } + } + } + + impl From for streams::WriteError { + fn from(other: async_streams::WriteError) -> Self { + match other { + async_streams::WriteError::LastOperationFailed => Self::LastOperationFailed, + async_streams::WriteError::Closed => Self::Closed, + } + } + } + impl From for streams::Error { + fn from(other: async_streams::Error) -> Self { + match other.downcast() { + Ok(write_error) => streams::Error::from(streams::WriteError::from(write_error)), + Err(e) => streams::Error::trap(e), } } } @@ -511,20 +467,41 @@ pub mod sync { in_tokio(async { AsyncHost::blocking_read(self, stream, len).await }).map(xform) } - fn write( + fn check_write(&mut self, stream: OutputStream) -> Result { + Ok(in_tokio(async { + AsyncHost::check_write(self, stream).await + })?) + } + fn write(&mut self, stream: OutputStream, bytes: Vec) -> Result<(), streams::Error> { + Ok(in_tokio(async { + AsyncHost::write(self, stream, bytes).await + })?) + } + fn blocking_write_and_flush( &mut self, stream: OutputStream, bytes: Vec, - ) -> anyhow::Result> { - in_tokio(async { AsyncHost::write(self, stream, bytes).await }).map(xform) + ) -> Result<(), streams::Error> { + Ok(in_tokio(async { + AsyncHost::blocking_write_and_flush(self, stream, bytes).await + })?) + } + fn subscribe_to_output_stream(&mut self, stream: OutputStream) -> anyhow::Result { + in_tokio(async { AsyncHost::subscribe_to_output_stream(self, stream).await }) + } + fn write_zeroes(&mut self, stream: OutputStream, len: u64) -> Result<(), streams::Error> { + Ok(in_tokio(async { + AsyncHost::write_zeroes(self, stream, len).await + })?) } - fn blocking_write( - &mut self, - stream: OutputStream, - bytes: Vec, - ) -> anyhow::Result> { - in_tokio(async { AsyncHost::blocking_write(self, stream, bytes).await }).map(xform) + fn flush(&mut self, stream: OutputStream) -> Result<(), streams::Error> { + Ok(in_tokio(async { AsyncHost::flush(self, stream).await })?) + } + fn blocking_flush(&mut self, stream: OutputStream) -> Result<(), streams::Error> { + Ok(in_tokio(async { + AsyncHost::blocking_flush(self, stream).await + })?) } fn skip( @@ -543,22 +520,6 @@ pub mod sync { in_tokio(async { AsyncHost::blocking_skip(self, stream, len).await }).map(xform) } - fn write_zeroes( - &mut self, - stream: OutputStream, - len: u64, - ) -> anyhow::Result> { - in_tokio(async { AsyncHost::write_zeroes(self, stream, len).await }).map(xform) - } - - fn blocking_write_zeroes( - &mut self, - stream: OutputStream, - len: u64, - ) -> anyhow::Result> { - in_tokio(async { AsyncHost::blocking_write_zeroes(self, stream, len).await }).map(xform) - } - fn splice( &mut self, src: InputStream, @@ -588,9 +549,5 @@ pub mod sync { fn subscribe_to_input_stream(&mut self, stream: InputStream) -> anyhow::Result { in_tokio(async { AsyncHost::subscribe_to_input_stream(self, stream).await }) } - - fn subscribe_to_output_stream(&mut self, stream: OutputStream) -> anyhow::Result { - in_tokio(async { AsyncHost::subscribe_to_output_stream(self, stream).await }) - } } } diff --git a/crates/wasi/src/preview2/host/tcp.rs b/crates/wasi/src/preview2/host/tcp.rs index acdca7ce018c..89f9bc240207 100644 --- a/crates/wasi/src/preview2/host/tcp.rs +++ b/crates/wasi/src/preview2/host/tcp.rs @@ -140,14 +140,9 @@ impl tcp::Host for T { }; socket.tcp_state = HostTcpState::Connected; - - let input_clone = socket.clone_inner(); - let output_clone = socket.clone_inner(); - - let input_stream = self.table_mut().push_input_stream(Box::new(input_clone))?; - let output_stream = self - .table_mut() - .push_output_stream(Box::new(output_clone))?; + let (input, output) = socket.as_split(); + let input_stream = self.table_mut().push_input_stream_child(input, this)?; + let output_stream = self.table_mut().push_output_stream_child(output, this)?; Ok((input_stream, output_stream)) } @@ -207,16 +202,20 @@ impl tcp::Host for T { .as_socketlike_view::() .accept_with(Blocking::No) })?; - let tcp_socket = HostTcpSocket::from_tcp_stream(connection)?; + let mut tcp_socket = HostTcpSocket::from_tcp_stream(connection)?; + + // Mark the socket as connected so that we can exit early from methods like `start-bind`. + tcp_socket.tcp_state = HostTcpState::Connected; - let input_clone = tcp_socket.clone_inner(); - let output_clone = tcp_socket.clone_inner(); + let (input, output) = tcp_socket.as_split(); let tcp_socket = self.table_mut().push_tcp_socket(tcp_socket)?; - let input_stream = self.table_mut().push_input_stream(Box::new(input_clone))?; + let input_stream = self + .table_mut() + .push_input_stream_child(input, tcp_socket)?; let output_stream = self .table_mut() - .push_output_stream(Box::new(output_clone))?; + .push_output_stream_child(output, tcp_socket)?; Ok((tcp_socket, input_stream, output_stream)) } @@ -431,7 +430,6 @@ impl tcp::Host for T { let join = Box::pin(async move { socket .inner - .tcp_socket .ready(Interest::READABLE | Interest::WRITABLE) .await .unwrap(); @@ -488,10 +486,7 @@ impl tcp::Host for T { | HostTcpState::ConnectReady => {} HostTcpState::Listening | HostTcpState::Connecting | HostTcpState::Connected => { - match rustix::net::shutdown( - &dropped.inner.tcp_socket, - rustix::net::Shutdown::ReadWrite, - ) { + match rustix::net::shutdown(&dropped.inner, rustix::net::Shutdown::ReadWrite) { Ok(()) | Err(Errno::NOTCONN) => {} Err(err) => Err(err).unwrap(), } diff --git a/crates/wasi/src/preview2/mod.rs b/crates/wasi/src/preview2/mod.rs index 96a738e1782c..5810a05672c5 100644 --- a/crates/wasi/src/preview2/mod.rs +++ b/crates/wasi/src/preview2/mod.rs @@ -40,7 +40,8 @@ pub use self::poll::{ClosureFuture, HostPollable, MakeFuture, PollableFuture, Ta pub use self::random::{thread_rng, Deterministic}; pub use self::stdio::{stderr, stdin, stdout, IsATTY, Stderr, Stdin, Stdout}; pub use self::stream::{ - HostInputStream, HostOutputStream, StreamRuntimeError, StreamState, TableStreamExt, + HostInputStream, HostOutputStream, OutputStreamError, StreamRuntimeError, StreamState, + TableStreamExt, }; pub use self::table::{OccupiedEntry, Table, TableError}; pub use cap_fs_ext::SystemTimeSpec; @@ -58,6 +59,7 @@ pub mod bindings { ", tracing: true, trappable_error_type: { + "wasi:io/streams"::"write-error": Error, "wasi:filesystem/types"::"error-code": Error, }, with: { @@ -92,6 +94,7 @@ pub mod bindings { tracing: true, async: true, trappable_error_type: { + "wasi:io/streams"::"write-error": Error, "wasi:filesystem/types"::"error-code": Error, }, with: { @@ -125,6 +128,7 @@ pub mod bindings { ", tracing: true, trappable_error_type: { + "wasi:io/streams"::"write-error": Error, "wasi:filesystem/types"::"error-code": Error, "wasi:sockets/network"::"error-code": Error, }, @@ -153,18 +157,56 @@ pub(crate) static RUNTIME: once_cell::sync::Lazy = .unwrap() }); -pub(crate) fn spawn(f: F) -> tokio::task::JoinHandle +pub(crate) struct AbortOnDropJoinHandle(tokio::task::JoinHandle); +impl Drop for AbortOnDropJoinHandle { + fn drop(&mut self) { + self.0.abort() + } +} +impl std::ops::Deref for AbortOnDropJoinHandle { + type Target = tokio::task::JoinHandle; + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl std::ops::DerefMut for AbortOnDropJoinHandle { + fn deref_mut(&mut self) -> &mut tokio::task::JoinHandle { + &mut self.0 + } +} +impl From> for AbortOnDropJoinHandle { + fn from(jh: tokio::task::JoinHandle) -> Self { + AbortOnDropJoinHandle(jh) + } +} +impl std::future::Future for AbortOnDropJoinHandle { + type Output = T; + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + use std::pin::Pin; + use std::task::Poll; + match Pin::new(&mut self.as_mut().0).poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(r) => Poll::Ready(r.expect("child task panicked")), + } + } +} + +pub(crate) fn spawn(f: F) -> AbortOnDropJoinHandle where F: std::future::Future + Send + 'static, G: Send + 'static, { - match tokio::runtime::Handle::try_current() { + let j = match tokio::runtime::Handle::try_current() { Ok(_) => tokio::task::spawn(f), Err(_) => { let _enter = RUNTIME.enter(); tokio::task::spawn(f) } - } + }; + AbortOnDropJoinHandle(j) } pub fn in_tokio(f: F) -> F::Output { diff --git a/crates/wasi/src/preview2/pipe.rs b/crates/wasi/src/preview2/pipe.rs index f8d64546c635..d4878b3b4df7 100644 --- a/crates/wasi/src/preview2/pipe.rs +++ b/crates/wasi/src/preview2/pipe.rs @@ -7,9 +7,11 @@ //! Some convenience constructors are included for common backing types like `Vec` and `String`, //! but the virtual pipes can be instantiated with any `Read` or `Write` type. //! -use crate::preview2::{HostInputStream, HostOutputStream, StreamState}; -use anyhow::Error; +use crate::preview2::{HostInputStream, HostOutputStream, OutputStreamError, StreamState}; +use anyhow::{anyhow, Error}; use bytes::Bytes; +use std::sync::{Arc, Mutex}; +use tokio::sync::mpsc; #[derive(Debug)] pub struct MemoryInputPipe { @@ -53,12 +55,14 @@ impl HostInputStream for MemoryInputPipe { #[derive(Debug, Clone)] pub struct MemoryOutputPipe { + capacity: usize, buffer: std::sync::Arc>, } impl MemoryOutputPipe { - pub fn new() -> Self { + pub fn new(capacity: usize) -> Self { MemoryOutputPipe { + capacity, buffer: std::sync::Arc::new(std::sync::Mutex::new(bytes::BytesMut::new())), } } @@ -74,26 +78,40 @@ impl MemoryOutputPipe { #[async_trait::async_trait] impl HostOutputStream for MemoryOutputPipe { - fn write(&mut self, bytes: Bytes) -> Result<(usize, StreamState), anyhow::Error> { + fn write(&mut self, bytes: Bytes) -> Result<(), OutputStreamError> { let mut buf = self.buffer.lock().unwrap(); + if bytes.len() > self.capacity - buf.len() { + return Err(OutputStreamError::Trap(anyhow!( + "write beyond capacity of MemoryOutputPipe" + ))); + } buf.extend_from_slice(bytes.as_ref()); - Ok((bytes.len(), StreamState::Open)) + // Always ready for writing + Ok(()) } - - async fn ready(&mut self) -> Result<(), Error> { - // This stream is always ready for writing. + fn flush(&mut self) -> Result<(), OutputStreamError> { + // This stream is always flushed Ok(()) } + async fn write_ready(&mut self) -> Result { + let consumed = self.buffer.lock().unwrap().len(); + if consumed < self.capacity { + Ok(self.capacity - consumed) + } else { + // Since the buffer is full, no more bytes will ever be written + Err(OutputStreamError::Closed) + } + } } -/// TODO +/// FIXME: this needs docs pub fn pipe(size: usize) -> (AsyncReadStream, AsyncWriteStream) { let (a, b) = tokio::io::duplex(size); let (_read_half, write_half) = tokio::io::split(a); let (read_half, _write_half) = tokio::io::split(b); ( AsyncReadStream::new(read_half), - AsyncWriteStream::new(write_half), + AsyncWriteStream::new(size, write_half), ) } @@ -101,15 +119,16 @@ pub fn pipe(size: usize) -> (AsyncReadStream, AsyncWriteStream) { pub struct AsyncReadStream { state: StreamState, buffer: Option>, - receiver: tokio::sync::mpsc::Receiver>, - pub(crate) join_handle: tokio::task::JoinHandle<()>, + receiver: mpsc::Receiver>, + #[allow(unused)] // just used to implement unix stdin + pub(crate) join_handle: crate::preview2::AbortOnDropJoinHandle<()>, } impl AsyncReadStream { /// Create a [`AsyncReadStream`]. In order to use the [`HostInputStream`] impl /// provided by this struct, the argument must impl [`tokio::io::AsyncRead`]. pub fn new(mut reader: T) -> Self { - let (sender, receiver) = tokio::sync::mpsc::channel(1); + let (sender, receiver) = mpsc::channel(1); let join_handle = crate::preview2::spawn(async move { loop { use tokio::io::AsyncReadExt; @@ -136,16 +155,10 @@ impl AsyncReadStream { } } -impl Drop for AsyncReadStream { - fn drop(&mut self) { - self.join_handle.abort() - } -} - #[async_trait::async_trait] impl HostInputStream for AsyncReadStream { fn read(&mut self, size: usize) -> Result<(Bytes, StreamState), Error> { - use tokio::sync::mpsc::error::TryRecvError; + use mpsc::error::TryRecvError; match self.buffer.take() { Some(Ok(mut bytes)) => { @@ -181,7 +194,7 @@ impl HostInputStream for AsyncReadStream { } Ok(Err(e)) => Err(e.into()), Err(TryRecvError::Empty) => Ok((Bytes::new(), self.state)), - Err(TryRecvError::Disconnected) => Err(anyhow::anyhow!( + Err(TryRecvError::Disconnected) => Err(anyhow!( "AsyncReadStream sender died - should be impossible" )), } @@ -200,7 +213,7 @@ impl HostInputStream for AsyncReadStream { } Some(Err(e)) => self.buffer = Some(Err(e)), None => { - return Err(anyhow::anyhow!( + return Err(anyhow!( "no more sender for an open AsyncReadStream - should be impossible" )) } @@ -210,163 +223,195 @@ impl HostInputStream for AsyncReadStream { } #[derive(Debug)] -enum WriteState { - Ready, - Pending, - Err(std::io::Error), +struct WorkerState { + alive: bool, + items: std::collections::VecDeque, + write_budget: usize, + flush_pending: bool, + error: Option, } -/// Provides a [`HostOutputStream`] impl from a [`tokio::io::AsyncWrite`] impl -pub struct AsyncWriteStream { - state: Option, - sender: tokio::sync::mpsc::Sender, - result_receiver: tokio::sync::mpsc::Receiver>, - join_handle: tokio::task::JoinHandle<()>, +impl WorkerState { + fn check_error(&mut self) -> Result<(), OutputStreamError> { + if let Some(e) = self.error.take() { + return Err(OutputStreamError::LastOperationFailed(e)); + } + if !self.alive { + return Err(OutputStreamError::Closed); + } + Ok(()) + } } -impl AsyncWriteStream { - /// Create a [`AsyncWriteStream`]. In order to use the [`HostOutputStream`] impl - /// provided by this struct, the argument must impl [`tokio::io::AsyncWrite`]. - pub fn new(mut writer: T) -> Self { - let (sender, mut receiver) = tokio::sync::mpsc::channel::(1); - let (result_sender, result_receiver) = tokio::sync::mpsc::channel(1); +struct Worker { + state: Mutex, + new_work: tokio::sync::Notify, + write_ready_changed: tokio::sync::Notify, +} - let join_handle = crate::preview2::spawn(async move { - 'outer: loop { - use tokio::io::AsyncWriteExt; - match receiver.recv().await { - Some(mut bytes) => { - while !bytes.is_empty() { - match writer.write_buf(&mut bytes).await { - Ok(0) => { - let _ = result_sender.send(Ok(StreamState::Closed)).await; - break 'outer; - } - Ok(_) => { - if bytes.is_empty() { - match result_sender.send(Ok(StreamState::Open)).await { - Ok(_) => break, - Err(_) => break 'outer, - } - } - continue; - } - Err(e) => { - let _ = result_sender.send(Err(e)).await; - break 'outer; - } - } - } - } +enum Job { + Flush, + Write(Bytes), +} - // The other side of the channel hung up, the task can exit now - None => break 'outer, - } - } - }); +enum WriteStatus<'a> { + Done(Result), + Pending(tokio::sync::futures::Notified<'a>), +} - AsyncWriteStream { - state: Some(WriteState::Ready), - sender, - result_receiver, - join_handle, +impl Worker { + fn new(write_budget: usize) -> Self { + Self { + state: Mutex::new(WorkerState { + alive: true, + items: std::collections::VecDeque::new(), + write_budget, + flush_pending: false, + error: None, + }), + new_work: tokio::sync::Notify::new(), + write_ready_changed: tokio::sync::Notify::new(), } } + fn check_write(&self) -> WriteStatus<'_> { + let mut state = self.state(); + if let Err(e) = state.check_error() { + return WriteStatus::Done(Err(e)); + } - fn send(&mut self, bytes: Bytes) -> anyhow::Result<(usize, StreamState)> { - use tokio::sync::mpsc::error::TrySendError; - - debug_assert!(matches!(self.state, Some(WriteState::Ready))); + if state.flush_pending || state.write_budget == 0 { + return WriteStatus::Pending(self.write_ready_changed.notified()); + } - let len = bytes.len(); - match self.sender.try_send(bytes) { - Ok(_) => { - self.state = Some(WriteState::Pending); - Ok((len, StreamState::Open)) - } - Err(TrySendError::Full(_)) => { - unreachable!("task shouldnt be full when writestate is ready") + WriteStatus::Done(Ok(state.write_budget)) + } + fn state(&self) -> std::sync::MutexGuard { + self.state.lock().unwrap() + } + fn pop(&self) -> Option { + let mut state = self.state(); + if state.items.is_empty() { + if state.flush_pending { + return Some(Job::Flush); } - Err(TrySendError::Closed(_)) => unreachable!("task shouldn't die while not closed"), + } else if let Some(bytes) = state.items.pop_front() { + return Some(Job::Write(bytes)); } - } -} -impl Drop for AsyncWriteStream { - fn drop(&mut self) { - self.join_handle.abort() + None } -} - -#[async_trait::async_trait] -impl HostOutputStream for AsyncWriteStream { - fn write(&mut self, bytes: Bytes) -> Result<(usize, StreamState), anyhow::Error> { - use tokio::sync::mpsc::error::TryRecvError; - - match self.state { - Some(WriteState::Ready) => self.send(bytes), - Some(WriteState::Pending) => match self.result_receiver.try_recv() { - Ok(Ok(StreamState::Open)) => { - self.state = Some(WriteState::Ready); - self.send(bytes) - } - - Ok(Ok(StreamState::Closed)) => { - self.state = None; - Ok((0, StreamState::Closed)) - } + fn report_error(&self, e: std::io::Error) { + { + let mut state = self.state(); + state.alive = false; + state.error = Some(e.into()); + state.flush_pending = false; + } + self.write_ready_changed.notify_waiters(); + } + async fn work(&self, mut writer: T) { + use tokio::io::AsyncWriteExt; + loop { + let notified = self.new_work.notified(); + while let Some(job) = self.pop() { + match job { + Job::Flush => { + if let Err(e) = writer.flush().await { + self.report_error(e); + return; + } - Ok(Err(e)) => { - self.state = None; - Err(e.into()) - } + tracing::debug!("worker marking flush complete"); + self.state().flush_pending = false; + } - Err(TryRecvError::Empty) => { - self.state = Some(WriteState::Pending); - Ok((0, StreamState::Open)) + Job::Write(mut bytes) => { + tracing::debug!("worker writing: {bytes:?}"); + let len = bytes.len(); + match writer.write_all_buf(&mut bytes).await { + Err(e) => { + self.report_error(e); + return; + } + Ok(_) => { + self.state().write_budget += len; + } + } + } } - Err(TryRecvError::Disconnected) => { - unreachable!("task shouldn't die while pending") - } - }, - Some(WriteState::Err(_)) => { - // Move the error payload out of self.state, because errors are not Copy, - // and set self.state to None, because the stream is now closed. - if let Some(WriteState::Err(e)) = self.state.take() { - Err(e.into()) - } else { - unreachable!("self.state shown to be Some(Err(e)) in match clause") - } + self.write_ready_changed.notify_waiters(); } - None => Ok((0, StreamState::Closed)), + notified.await; } } +} - async fn ready(&mut self) -> Result<(), Error> { - match &self.state { - Some(WriteState::Pending) => match self.result_receiver.recv().await { - Some(Ok(StreamState::Open)) => { - self.state = Some(WriteState::Ready); - } +/// Provides a [`HostOutputStream`] impl from a [`tokio::io::AsyncWrite`] impl +pub struct AsyncWriteStream { + worker: Arc, + _join_handle: crate::preview2::AbortOnDropJoinHandle<()>, +} - Some(Ok(StreamState::Closed)) => { - self.state = None; - } +impl AsyncWriteStream { + /// Create a [`AsyncWriteStream`]. In order to use the [`HostOutputStream`] impl + /// provided by this struct, the argument must impl [`tokio::io::AsyncWrite`]. + pub fn new( + write_budget: usize, + writer: T, + ) -> Self { + let worker = Arc::new(Worker::new(write_budget)); - Some(Err(e)) => { - self.state = Some(WriteState::Err(e)); - } + let w = Arc::clone(&worker); + let join_handle = crate::preview2::spawn(async move { w.work(writer).await }); - None => unreachable!("task shouldn't die while pending"), - }, + AsyncWriteStream { + worker, + _join_handle: join_handle, + } + } +} - Some(WriteState::Ready | WriteState::Err(_)) | None => {} +#[async_trait::async_trait] +impl HostOutputStream for AsyncWriteStream { + fn write(&mut self, bytes: Bytes) -> Result<(), OutputStreamError> { + let mut state = self.worker.state(); + state.check_error()?; + if state.flush_pending { + return Err(OutputStreamError::Trap(anyhow!( + "write not permitted while flush pending" + ))); + } + match state.write_budget.checked_sub(bytes.len()) { + Some(remaining_budget) => { + state.write_budget = remaining_budget; + state.items.push_back(bytes); + } + None => return Err(OutputStreamError::Trap(anyhow!("write exceeded budget"))), } + drop(state); + self.worker.new_work.notify_waiters(); + Ok(()) + } + fn flush(&mut self) -> Result<(), OutputStreamError> { + let mut state = self.worker.state(); + state.check_error()?; + + state.flush_pending = true; + self.worker.new_work.notify_waiters(); Ok(()) } + + async fn write_ready(&mut self) -> Result { + loop { + match self.worker.check_write() { + WriteStatus::Done(r) => return r, + WriteStatus::Pending(notifier) => notifier.await, + } + } + } } /// An output stream that consumes all input written to it, and is always ready. @@ -374,13 +419,18 @@ pub struct SinkOutputStream; #[async_trait::async_trait] impl HostOutputStream for SinkOutputStream { - fn write(&mut self, buf: Bytes) -> Result<(usize, StreamState), Error> { - Ok((buf.len(), StreamState::Open)) + fn write(&mut self, _buf: Bytes) -> Result<(), OutputStreamError> { + Ok(()) } - - async fn ready(&mut self) -> Result<(), Error> { + fn flush(&mut self) -> Result<(), OutputStreamError> { + // This stream is always flushed Ok(()) } + + async fn write_ready(&mut self) -> Result { + // This stream is always ready for writing. + Ok(usize::MAX) + } } /// A stream that is ready immediately, but will always report that it's closed. @@ -402,12 +452,15 @@ pub struct ClosedOutputStream; #[async_trait::async_trait] impl HostOutputStream for ClosedOutputStream { - fn write(&mut self, _: Bytes) -> Result<(usize, StreamState), Error> { - Ok((0, StreamState::Closed)) + fn write(&mut self, _: Bytes) -> Result<(), OutputStreamError> { + Err(OutputStreamError::Closed) + } + fn flush(&mut self) -> Result<(), OutputStreamError> { + Err(OutputStreamError::Closed) } - async fn ready(&mut self) -> Result<(), Error> { - Ok(()) + async fn write_ready(&mut self) -> Result { + Err(OutputStreamError::Closed) } } @@ -416,8 +469,37 @@ mod test { use super::*; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; - // 10ms was enough for every CI platform except linux riscv64: - const REASONABLE_DURATION: std::time::Duration = std::time::Duration::from_millis(100); + // This is a gross way to handle CI running under qemu for non-x86 architectures. + #[cfg(not(target_arch = "x86_64"))] + const TEST_ITERATIONS: usize = 10; + + // This is a gross way to handle CI running under qemu for non-x86 architectures. + #[cfg(not(target_arch = "x86_64"))] + const REASONABLE_DURATION: std::time::Duration = std::time::Duration::from_millis(200); + + #[cfg(target_arch = "x86_64")] + const TEST_ITERATIONS: usize = 100; + + #[cfg(target_arch = "x86_64")] + const REASONABLE_DURATION: std::time::Duration = std::time::Duration::from_millis(10); + + async fn resolves_immediately(fut: F) -> O + where + F: futures::Future, + { + tokio::time::timeout(REASONABLE_DURATION, fut) + .await + .expect("operation timed out") + } + + // TODO: is there a way to get tokio to warp through timeouts when it knows nothing is + // happening? + async fn never_resolves(fut: F) { + tokio::time::timeout(REASONABLE_DURATION, fut) + .await + .err() + .expect("operation should time out"); + } pub fn simplex(size: usize) -> (impl AsyncRead, impl AsyncWrite) { let (a, b) = tokio::io::duplex(size); @@ -426,7 +508,7 @@ mod test { (read_half, write_half) } - #[tokio::test(flavor = "multi_thread")] + #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn empty_read_stream() { let mut reader = AsyncReadStream::new(tokio::io::empty()); let (bs, state) = reader.read(10).unwrap(); @@ -440,9 +522,8 @@ mod test { // The reader task hasn't run yet. Call `ready` to await and fill the buffer. StreamState::Open => { - tokio::time::timeout(REASONABLE_DURATION, reader.ready()) + resolves_immediately(reader.ready()) .await - .expect("the reader should be ready instantly") .expect("ready is ok"); let (bs, state) = reader.read(0).unwrap(); assert!(bs.is_empty()); @@ -451,7 +532,7 @@ mod test { } } - #[tokio::test(flavor = "multi_thread")] + #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn infinite_read_stream() { let mut reader = AsyncReadStream::new(tokio::io::repeat(0)); @@ -459,9 +540,8 @@ mod test { assert_eq!(state, StreamState::Open); if bs.is_empty() { // Reader task hasn't run yet. Call `ready` to await and fill the buffer. - tokio::time::timeout(REASONABLE_DURATION, reader.ready()) + resolves_immediately(reader.ready()) .await - .expect("the reader should be ready instantly") .expect("ready is ok"); // Now a read should succeed let (bs, state) = reader.read(10).unwrap(); @@ -488,7 +568,7 @@ mod test { r } - #[tokio::test(flavor = "multi_thread")] + #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn finite_read_stream() { let mut reader = AsyncReadStream::new(finite_async_reader(&[1; 123]).await); @@ -496,9 +576,8 @@ mod test { assert_eq!(state, StreamState::Open); if bs.is_empty() { // Reader task hasn't run yet. Call `ready` to await and fill the buffer. - tokio::time::timeout(REASONABLE_DURATION, reader.ready()) + resolves_immediately(reader.ready()) .await - .expect("the reader should be ready instantly") .expect("ready is ok"); // Now a read should succeed let (bs, state) = reader.read(123).unwrap(); @@ -516,9 +595,8 @@ mod test { StreamState::Closed => {} // Correct! StreamState::Open => { // Need to await to give this side time to catch up - tokio::time::timeout(REASONABLE_DURATION, reader.ready()) + resolves_immediately(reader.ready()) .await - .expect("the reader should be ready instantly") .expect("ready is ok"); // Now a read should show closed let (bs, state) = reader.read(0).unwrap(); @@ -528,7 +606,7 @@ mod test { } } - #[tokio::test(flavor = "multi_thread")] + #[test_log::test(tokio::test(flavor = "multi_thread"))] // Test that you can write items into the stream, and they get read out in the order they were // written, with the proper indications of readiness for reading: async fn multiple_chunks_read_stream() { @@ -541,9 +619,8 @@ mod test { assert_eq!(state, StreamState::Open); if bs.is_empty() { // Reader task hasn't run yet. Call `ready` to await and fill the buffer. - tokio::time::timeout(REASONABLE_DURATION, reader.ready()) + resolves_immediately(reader.ready()) .await - .expect("the reader should be ready instantly") .expect("ready is ok"); // Now a read should succeed let (bs, state) = reader.read(1).unwrap(); @@ -559,10 +636,7 @@ mod test { assert_eq!(state, StreamState::Open); // We can wait on readiness and it will time out: - tokio::time::timeout(REASONABLE_DURATION, reader.ready()) - .await - .err() - .expect("the reader should time out"); + never_resolves(reader.ready()).await; // Still open and empty: let (bs, state) = reader.read(1).unwrap(); @@ -574,9 +648,8 @@ mod test { // Wait readiness (yes we could possibly win the race and read it out faster, leaving that // out of the test for simplicity) - tokio::time::timeout(REASONABLE_DURATION, reader.ready()) + resolves_immediately(reader.ready()) .await - .expect("the reader should be ready instantly") .expect("the ready is ok"); // read the something else back out: @@ -590,10 +663,7 @@ mod test { assert_eq!(state, StreamState::Open); // We can wait on readiness and it will time out: - tokio::time::timeout(REASONABLE_DURATION, reader.ready()) - .await - .err() - .expect("the reader should time out"); + never_resolves(reader.ready()).await; // nothing else in there: let (bs, state) = reader.read(1).unwrap(); @@ -605,9 +675,8 @@ mod test { // Wait readiness (yes we could possibly win the race and read it out faster, leaving that // out of the test for simplicity) - tokio::time::timeout(REASONABLE_DURATION, reader.ready()) + resolves_immediately(reader.ready()) .await - .expect("the reader should be ready instantly") .expect("the ready is ok"); // empty and now closed: @@ -616,7 +685,7 @@ mod test { assert_eq!(state, StreamState::Closed); } - #[tokio::test(flavor = "multi_thread")] + #[test_log::test(tokio::test(flavor = "multi_thread"))] // At the moment we are restricting AsyncReadStream from buffering more than 4k. This isn't a // suitable design for all applications, and we will probably make a knob or change the // behavior at some point, but this test shows the behavior as it is implemented: @@ -630,9 +699,8 @@ mod test { w }); - tokio::time::timeout(REASONABLE_DURATION, reader.ready()) + resolves_immediately(reader.ready()) .await - .expect("the reader should be ready instantly") .expect("ready is ok"); // Now we expect the reader task has sent 4k from the stream to the reader. @@ -642,9 +710,8 @@ mod test { assert_eq!(state, StreamState::Open); // Allow the crank to turn more: - tokio::time::timeout(REASONABLE_DURATION, reader.ready()) + resolves_immediately(reader.ready()) .await - .expect("the reader should be ready instantly") .expect("ready is ok"); // Again we expect the reader task has sent 4k from the stream to the reader. @@ -654,16 +721,14 @@ mod test { assert_eq!(state, StreamState::Open); // The writer task is now finished - join with it: - let w = tokio::time::timeout(REASONABLE_DURATION, writer_task) - .await - .expect("the join should be ready instantly"); + let w = resolves_immediately(writer_task).await; + // And close the pipe: drop(w); // Allow the crank to turn more: - tokio::time::timeout(REASONABLE_DURATION, reader.ready()) + resolves_immediately(reader.ready()) .await - .expect("the reader should be ready instantly") .expect("ready is ok"); // Now we expect the reader to be empty, and the stream closed: @@ -672,89 +737,153 @@ mod test { assert_eq!(state, StreamState::Closed); } - #[tokio::test(flavor = "multi_thread")] + #[test_log::test(test_log::test(tokio::test(flavor = "multi_thread")))] async fn sink_write_stream() { - let mut writer = AsyncWriteStream::new(tokio::io::sink()); + let mut writer = AsyncWriteStream::new(2048, tokio::io::sink()); let chunk = Bytes::from_static(&[0; 1024]); + let readiness = resolves_immediately(writer.write_ready()) + .await + .expect("write_ready does not trap"); + assert_eq!(readiness, 2048); // I can write whatever: - let (len, state) = writer.write(chunk.clone()).unwrap(); - assert_eq!(len, chunk.len()); - assert_eq!(state, StreamState::Open); - - // It is possible for subsequent writes to be refused, but it is nondeterminstic because - // the worker task consuming them is in another thread: - let (len, state) = writer.write(chunk.clone()).unwrap(); - assert_eq!(state, StreamState::Open); - if !(len == 0 || len == chunk.len()) { - unreachable!() - } + writer.write(chunk.clone()).expect("write does not error"); - tokio::time::timeout(REASONABLE_DURATION, writer.ready()) + // This may consume 1k of the buffer: + let readiness = resolves_immediately(writer.write_ready()) .await - .expect("the writer should be ready instantly") - .expect("ready is ok"); + .expect("write_ready does not trap"); + assert!( + readiness == 1024 || readiness == 2048, + "readiness should be 1024 or 2048, got {readiness}" + ); - // Now additional writes will work: - let (len, state) = writer.write(chunk.clone()).unwrap(); - assert_eq!(len, chunk.len()); - assert_eq!(state, StreamState::Open); + if readiness == 1024 { + writer.write(chunk.clone()).expect("write does not error"); + + let readiness = resolves_immediately(writer.write_ready()) + .await + .expect("write_ready does not trap"); + assert!( + readiness == 1024 || readiness == 2048, + "readiness should be 1024 or 2048, got {readiness}" + ); + } } - #[tokio::test(flavor = "multi_thread")] + #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn closed_write_stream() { - let (reader, writer) = simplex(1024); - drop(reader); - let mut writer = AsyncWriteStream::new(writer); + // Run many times because the test is nondeterministic: + for n in 0..TEST_ITERATIONS { + closed_write_stream_(n).await + } + } + #[tracing::instrument] + async fn closed_write_stream_(n: usize) { + let (reader, writer) = simplex(1); + let mut writer = AsyncWriteStream::new(1024, writer); - // Without checking write readiness, perform a nonblocking write: this should succeed - // because we will buffer up the write. - let chunk = Bytes::from_static(&[0; 1]); - let (len, state) = writer.write(chunk.clone()).unwrap(); + // Drop the reader to allow the worker to transition to the closed state eventually. + drop(reader); - assert_eq!(len, chunk.len()); - assert_eq!(state, StreamState::Open); + // First the api is going to report the last operation failed, then subsequently + // it will be reported as closed. We set this flag once we see LastOperationFailed. + let mut should_be_closed = false; - // Check write readiness: - tokio::time::timeout(REASONABLE_DURATION, writer.ready()) - .await - .expect("the writer should be ready instantly") - .expect("ready is ok"); + // Write some data to the stream to ensure we have data that cannot be flushed. + let chunk = Bytes::from_static(&[0; 1]); + writer + .write(chunk.clone()) + .expect("first write should succeed"); + + // The rest of this test should be valid whether or not we check write readiness: + let mut write_ready_res = None; + if n % 2 == 0 { + let r = resolves_immediately(writer.write_ready()).await; + // Check write readiness: + match r { + // worker hasn't processed write yet: + Ok(1023) => {} + // worker reports failure: + Err(OutputStreamError::LastOperationFailed(_)) => { + tracing::debug!("discovered stream failure in first write_ready"); + should_be_closed = true; + } + r => panic!("unexpected write_ready: {r:?}"), + } + write_ready_res = Some(r); + } // When we drop the simplex reader, that causes the simplex writer to return BrokenPipe on // its write. Now that the buffering crank has turned, our next write will give BrokenPipe. - let err = writer.write(chunk.clone()).err().unwrap(); - assert_eq!( - err.downcast_ref::().unwrap().kind(), - std::io::ErrorKind::BrokenPipe - ); + let flush_res = writer.flush(); + match flush_res { + // worker reports failure: + Err(OutputStreamError::LastOperationFailed(_)) => { + tracing::debug!("discovered stream failure trying to flush"); + assert!(!should_be_closed); + should_be_closed = true; + } + // Already reported failure, now closed + Err(OutputStreamError::Closed) => { + assert!( + should_be_closed, + "expected a LastOperationFailed before we see Closed. {write_ready_res:?}" + ); + } + // Also possible the worker hasnt processed write yet: + Ok(()) => {} + Err(e) => panic!("unexpected flush error: {e:?} {write_ready_res:?}"), + } - // Now that we got the error out of the writer, it should be closed - subsequent writes - // will not work - let (len, state) = writer.write(chunk.clone()).unwrap(); - assert_eq!(len, 0); - assert_eq!(state, StreamState::Closed); + // Waiting for the flush to complete should always indicate that the channel has been + // closed. + match resolves_immediately(writer.write_ready()).await { + // worker reports failure: + Err(OutputStreamError::LastOperationFailed(_)) => { + tracing::debug!("discovered stream failure trying to flush"); + assert!(!should_be_closed); + } + // Already reported failure, now closed + Err(OutputStreamError::Closed) => { + assert!(should_be_closed); + } + r => { + panic!("stream should be reported closed by the end of write_ready after flush, got {r:?}. {write_ready_res:?} {flush_res:?}") + } + } } - #[tokio::test(flavor = "multi_thread")] + #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn multiple_chunks_write_stream() { + // Run many times because the test is nondeterministic: + for n in 0..TEST_ITERATIONS { + multiple_chunks_write_stream_aux(n).await + } + } + #[tracing::instrument] + async fn multiple_chunks_write_stream_aux(_: usize) { use std::ops::Deref; let (mut reader, writer) = simplex(1024); - let mut writer = AsyncWriteStream::new(writer); + let mut writer = AsyncWriteStream::new(1024, writer); // Write a chunk: let chunk = Bytes::from_static(&[123; 1]); - let (len, state) = writer.write(chunk.clone()).unwrap(); - assert_eq!(len, chunk.len()); - assert_eq!(state, StreamState::Open); + let permit = resolves_immediately(writer.write_ready()) + .await + .expect("write should be ready"); + assert_eq!(permit, 1024); + + writer.write(chunk.clone()).expect("write does not trap"); - // After the write, still ready for more writing: - tokio::time::timeout(REASONABLE_DURATION, writer.ready()) + // At this point the message will either be waiting for the worker to process the write, or + // it will be buffered in the simplex channel. + let permit = resolves_immediately(writer.write_ready()) .await - .expect("the writer should be ready instantly") - .expect("ready is ok"); + .expect("write should be ready"); + assert!(matches!(permit, 1023 | 1024)); let mut read_buf = vec![0; chunk.len()]; let read_len = reader.read_exact(&mut read_buf).await.unwrap(); @@ -763,84 +892,176 @@ mod test { // Write a second, different chunk: let chunk2 = Bytes::from_static(&[45; 1]); - let (len, state) = writer.write(chunk2.clone()).unwrap(); - assert_eq!(len, chunk2.len()); - assert_eq!(state, StreamState::Open); - // After the write, still ready for more writing: - tokio::time::timeout(REASONABLE_DURATION, writer.ready()) + // We're only guaranteed to see a consistent write budget if we flush. + writer.flush().expect("channel is still alive"); + + let permit = resolves_immediately(writer.write_ready()) .await - .expect("the writer should be ready instantly") - .expect("ready is ok"); + .expect("write should be ready"); + assert_eq!(permit, 1024); + + writer.write(chunk2.clone()).expect("write does not trap"); + + // At this point the message will either be waiting for the worker to process the write, or + // it will be buffered in the simplex channel. + let permit = resolves_immediately(writer.write_ready()) + .await + .expect("write should be ready"); + assert!(matches!(permit, 1023 | 1024)); let mut read2_buf = vec![0; chunk2.len()]; let read2_len = reader.read_exact(&mut read2_buf).await.unwrap(); assert_eq!(read2_len, chunk2.len()); assert_eq!(read2_buf.as_slice(), chunk2.deref()); + + // We're only guaranteed to see a consistent write budget if we flush. + writer.flush().expect("channel is still alive"); + + let permit = resolves_immediately(writer.write_ready()) + .await + .expect("write should be ready"); + assert_eq!(permit, 1024); } - #[tokio::test(flavor = "multi_thread")] + #[test_log::test(tokio::test(flavor = "multi_thread"))] async fn backpressure_write_stream() { - // Stream can buffer up to 1k, plus one write chunk, before not + // Run many times because the test is nondeterministic: + for n in 0..TEST_ITERATIONS { + backpressure_write_stream_aux(n).await + } + } + #[tracing::instrument] + async fn backpressure_write_stream_aux(_: usize) { + use futures::future::poll_immediate; + + // The channel can buffer up to 1k, plus another 1k in the stream, before not // accepting more input: let (mut reader, writer) = simplex(1024); - let mut writer = AsyncWriteStream::new(writer); + let mut writer = AsyncWriteStream::new(1024, writer); - // Write enough to fill the simplex buffer: let chunk = Bytes::from_static(&[0; 1024]); - let (len, state) = writer.write(chunk.clone()).unwrap(); - - assert_eq!(len, chunk.len()); - assert_eq!(state, StreamState::Open); - // turn the crank and it should be ready for writing again: - tokio::time::timeout(REASONABLE_DURATION, writer.ready()) + let permit = resolves_immediately(writer.write_ready()) .await - .expect("the writer should be ready instantly") - .expect("ready is ok"); + .expect("write should be ready"); + assert_eq!(permit, 1024); - // Now fill the buffer between here and the writer task: - let (len, state) = writer.write(chunk.clone()).unwrap(); - assert_eq!(len, chunk.len()); - assert_eq!(state, StreamState::Open); + writer.write(chunk.clone()).expect("write succeeds"); - // Try shoving even more down there, and it shouldnt accept more input: - let (len, state) = writer.write(chunk.clone()).unwrap(); - assert_eq!(len, 0); - assert_eq!(state, StreamState::Open); + // We might still be waiting for the worker to process the message, or the worker may have + // processed it and released all the budget back to us. + let permit = poll_immediate(writer.write_ready()).await; + assert!(matches!(permit, None | Some(Ok(1024)))); - // turn the crank and it should Not become ready for writing until we read something out. - tokio::time::timeout(REASONABLE_DURATION, writer.ready()) + // Given a little time, the worker will process the message and release all the budget + // back. + let permit = resolves_immediately(writer.write_ready()) .await + .expect("write should be ready"); + assert_eq!(permit, 1024); + + // Now fill the buffer between here and the writer task. This should always indicate + // back-pressure because now both buffers (simplex and worker) are full. + writer.write(chunk.clone()).expect("write does not trap"); + + // Try shoving even more down there, and it shouldnt accept more input: + writer + .write(chunk.clone()) .err() - .expect("the writer should be not become ready"); + .expect("unpermitted write does trap"); - // Still not ready from the .write interface either: - let (len, state) = writer.write(chunk.clone()).unwrap(); - assert_eq!(len, 0); - assert_eq!(state, StreamState::Open); + // No amount of waiting will resolve the situation, as nothing is emptying the simplex + // buffer. + never_resolves(writer.write_ready()).await; - // There is 2k in the buffer. I should be able to read all of it out: + // There is 2k buffered between the simplex and worker buffers. I should be able to read + // all of it out: let mut buf = [0; 2048]; reader.read_exact(&mut buf).await.unwrap(); // and no more: - tokio::time::timeout(REASONABLE_DURATION, reader.read(&mut buf)) + never_resolves(reader.read(&mut buf)).await; + + // Now the backpressure should be cleared, and an additional write should be accepted. + let permit = resolves_immediately(writer.write_ready()) + .await + .expect("ready is ok"); + assert_eq!(permit, 1024); + + // and the write succeeds: + writer.write(chunk.clone()).expect("write does not trap"); + } + + #[test_log::test(tokio::test(flavor = "multi_thread"))] + async fn backpressure_write_stream_with_flush() { + for n in 0..TEST_ITERATIONS { + backpressure_write_stream_with_flush_aux(n).await; + } + } + + async fn backpressure_write_stream_with_flush_aux(_: usize) { + // The channel can buffer up to 1k, plus another 1k in the stream, before not + // accepting more input: + let (mut reader, writer) = simplex(1024); + let mut writer = AsyncWriteStream::new(1024, writer); + + let chunk = Bytes::from_static(&[0; 1024]); + + let permit = resolves_immediately(writer.write_ready()) .await + .expect("write should be ready"); + assert_eq!(permit, 1024); + + writer.write(chunk.clone()).expect("write succeeds"); + + writer.flush().expect("flush succeeds"); + + // Waiting for write_ready to resolve after a flush should always show that we have the + // full budget available, as the message will have flushed to the simplex channel. + let permit = resolves_immediately(writer.write_ready()) + .await + .expect("write_ready succeeds"); + assert_eq!(permit, 1024); + + // Write enough to fill the simplex buffer: + writer.write(chunk.clone()).expect("write does not trap"); + + // Writes should be refused until this flush succeeds. + writer.flush().expect("flush succeeds"); + + // Try shoving even more down there, and it shouldnt accept more input: + writer + .write(chunk.clone()) .err() - .expect("nothing more buffered in the system"); + .expect("unpermitted write does trap"); - // Now the backpressure should be cleared, and an additional write should be accepted. + // No amount of waiting will resolve the situation, as nothing is emptying the simplex + // buffer. + never_resolves(writer.write_ready()).await; + + // There is 2k buffered between the simplex and worker buffers. I should be able to read + // all of it out: + let mut buf = [0; 2048]; + reader.read_exact(&mut buf).await.unwrap(); + + // and no more: + never_resolves(reader.read(&mut buf)).await; - // immediately ready for writing: - tokio::time::timeout(REASONABLE_DURATION, writer.ready()) + // Now the backpressure should be cleared, and an additional write should be accepted. + let permit = resolves_immediately(writer.write_ready()) .await - .expect("the writer should be ready instantly") .expect("ready is ok"); + assert_eq!(permit, 1024); // and the write succeeds: - let (len, state) = writer.write(chunk.clone()).unwrap(); - assert_eq!(len, chunk.len()); - assert_eq!(state, StreamState::Open); + writer.write(chunk.clone()).expect("write does not trap"); + + writer.flush().expect("flush succeeds"); + + let permit = resolves_immediately(writer.write_ready()) + .await + .expect("ready is ok"); + assert_eq!(permit, 1024); } } diff --git a/crates/wasi/src/preview2/preview1.rs b/crates/wasi/src/preview2/preview1.rs index 534c3950d5a8..b45955063b50 100644 --- a/crates/wasi/src/preview2/preview1.rs +++ b/crates/wasi/src/preview2/preview1.rs @@ -5,6 +5,7 @@ use crate::preview2::bindings::cli::{ use crate::preview2::bindings::clocks::{monotonic_clock, wall_clock}; use crate::preview2::bindings::filesystem::{preopens, types as filesystem}; use crate::preview2::bindings::io::streams; +use crate::preview2::bindings::poll; use crate::preview2::filesystem::TableFsExt; use crate::preview2::host::filesystem::TableReaddirExt; use crate::preview2::{bindings, IsATTY, TableError, WasiView}; @@ -18,7 +19,7 @@ use std::slice; use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use wiggle::tracing::instrument; -use wiggle::{GuestError, GuestPtr, GuestSliceMut, GuestStrCow, GuestType}; +use wiggle::{GuestError, GuestPtr, GuestSlice, GuestSliceMut, GuestStrCow, GuestType}; #[derive(Clone, Debug)] struct File { @@ -31,10 +32,95 @@ struct File { /// In append mode, all writes append to the file. append: bool, - /// In blocking mode, read and write calls dispatch to blocking_read and - /// blocking_write on the underlying streams. When false, read and write - /// dispatch to stream's plain read and write. - blocking: bool, + /// When blocking, read and write calls dispatch to blocking_read and + /// blocking_check_write on the underlying streams. When false, read and write + /// dispatch to stream's plain read and check_write. + blocking_mode: BlockingMode, +} + +#[derive(Clone, Copy, Debug)] +enum BlockingMode { + Blocking, + NonBlocking, +} +impl BlockingMode { + fn from_fdflags(flags: &types::Fdflags) -> Self { + if flags.contains(types::Fdflags::NONBLOCK) { + BlockingMode::NonBlocking + } else { + BlockingMode::Blocking + } + } + async fn read( + &self, + host: &mut impl streams::Host, + input_stream: streams::InputStream, + max_size: usize, + ) -> Result<(Vec, streams::StreamStatus), types::Error> { + let max_size = max_size.try_into().unwrap_or(u64::MAX); + match self { + BlockingMode::Blocking => { + stream_res(streams::Host::blocking_read(host, input_stream, max_size).await) + } + BlockingMode::NonBlocking => { + stream_res(streams::Host::read(host, input_stream, max_size).await) + } + } + } + async fn write( + &self, + host: &mut (impl streams::Host + poll::poll::Host), + output_stream: streams::OutputStream, + mut bytes: &[u8], + ) -> Result { + use streams::Host as Streams; + + match self { + BlockingMode::Blocking => { + let total = bytes.len(); + while !bytes.is_empty() { + // NOTE: blocking_write_and_flush takes at most one 4k buffer. + let len = bytes.len().min(4096); + let (chunk, rest) = bytes.split_at(len); + bytes = rest; + + Streams::blocking_write_and_flush(host, output_stream, Vec::from(chunk)).await? + } + + Ok(total) + } + BlockingMode::NonBlocking => { + let n = match Streams::check_write(host, output_stream).await { + Ok(n) => n, + Err(e) if matches!(e.downcast_ref(), Some(streams::WriteError::Closed)) => 0, + Err(e) => Err(e)?, + }; + + let len = bytes.len().min(n as usize); + if len == 0 { + return Ok(0); + } + + match Streams::write(host, output_stream, bytes[..len].to_vec()).await { + Ok(()) => {} + Err(e) if matches!(e.downcast_ref(), Some(streams::WriteError::Closed)) => { + return Ok(0) + } + Err(e) => Err(e)?, + } + + match Streams::blocking_flush(host, output_stream).await { + Ok(()) => {} + Err(e) if matches!(e.downcast_ref(), Some(streams::WriteError::Closed)) => { + return Ok(0) + } + Err(e) => Err(e)?, + }; + + Ok(len) + } + } + } } #[derive(Clone, Debug)] @@ -452,6 +538,18 @@ impl wiggle::GuestErrorType for types::Errno { } } +impl From for types::Error { + fn from(err: streams::Error) -> Self { + match err.downcast() { + Ok(streams::WriteError::Closed | streams::WriteError::LastOperationFailed) => { + types::Errno::Io.into() + } + + Err(t) => types::Error::trap(t), + } + } +} + fn stream_res(r: anyhow::Result>) -> Result { match r { Ok(Ok(a)) => Ok(a), @@ -711,13 +809,15 @@ fn read_string<'a>(ptr: impl Borrow>) -> Result { } // Find first non-empty buffer. -fn first_non_empty_ciovec(ciovs: &types::CiovecArray<'_>) -> Result>> { +fn first_non_empty_ciovec<'a, 'b>( + ciovs: &'a types::CiovecArray<'b>, +) -> Result>> { for iov in ciovs.iter() { let iov = iov?.read()?; if iov.buf_len == 0 { continue; } - return Ok(Some(iov.buf.as_array(iov.buf_len).to_vec()?)); + return Ok(iov.buf.as_array(iov.buf_len).as_slice()?); } Ok(None) } @@ -1013,14 +1113,11 @@ impl< } Descriptor::File(File { fd, - blocking, + blocking_mode, append, .. - }) => (*fd, *blocking, *append), + }) => (*fd, *blocking_mode, *append), }; - - // TODO: use `try_join!` to poll both futures async, unfortunately that is not currently - // possible, because `bindgen` generates methods with `&mut self` receivers. let flags = self.get_flags(fd).await.map_err(|e| { e.try_into() .context("failed to call `get-flags`") @@ -1056,7 +1153,7 @@ impl< if append { fs_flags |= types::Fdflags::APPEND; } - if !blocking { + if matches!(blocking, BlockingMode::NonBlocking) { fs_flags |= types::Fdflags::NONBLOCK; } Ok(types::Fdstat { @@ -1077,7 +1174,9 @@ impl< ) -> Result<(), types::Error> { let mut st = self.transact()?; let File { - append, blocking, .. + append, + blocking_mode, + .. } = st.get_file_mut(fd)?; // Only support changing the NONBLOCK or APPEND flags. @@ -1088,7 +1187,7 @@ impl< return Err(types::Errno::Inval.into()); } *append = flags.contains(types::Fdflags::APPEND); - *blocking = !flags.contains(types::Fdflags::NONBLOCK); + *blocking_mode = BlockingMode::from_fdflags(&flags); Ok(()) } @@ -1214,7 +1313,7 @@ impl< let (mut buf, read, state) = match desc { Descriptor::File(File { fd, - blocking, + blocking_mode, position, .. }) if self.table().is_file(fd) => { @@ -1228,13 +1327,7 @@ impl< .context("failed to call `read-via-stream`") .unwrap_or_else(types::Error::trap) })?; - let max = buf.len().try_into().unwrap_or(u64::MAX); - let (read, state) = if blocking { - stream_res(streams::Host::blocking_read(self, stream, max).await)? - } else { - stream_res(streams::Host::read(self, stream, max).await)? - }; - + let (read, state) = blocking_mode.read(self, stream, buf.len()).await?; let n = read.len().try_into()?; let pos = pos.checked_add(n).ok_or(types::Errno::Overflow)?; position.store(pos, Ordering::Relaxed); @@ -1280,7 +1373,9 @@ impl< ) -> Result { let desc = self.transact()?.get_descriptor(fd)?.clone(); let (mut buf, read, state) = match desc { - Descriptor::File(File { fd, blocking, .. }) if self.table().is_file(fd) => { + Descriptor::File(File { + fd, blocking_mode, .. + }) if self.table().is_file(fd) => { let Some(buf) = first_non_empty_iovec(iovs)? else { return Ok(0); }; @@ -1290,13 +1385,7 @@ impl< .context("failed to call `read-via-stream`") .unwrap_or_else(types::Error::trap) })?; - let max = buf.len().try_into().unwrap_or(u64::MAX); - let (read, state) = if blocking { - stream_res(streams::Host::blocking_read(self, stream, max).await)? - } else { - stream_res(streams::Host::read(self, stream, max).await)? - }; - + let (read, state) = blocking_mode.read(self, stream, buf.len()).await?; (buf, read, state) } Descriptor::Stdin { .. } => { @@ -1326,10 +1415,10 @@ impl< ciovs: &types::CiovecArray<'a>, ) -> Result { let desc = self.transact()?.get_descriptor(fd)?.clone(); - let n = match desc { + match desc { Descriptor::File(File { fd, - blocking, + blocking_mode, append, position, }) if self.table().is_file(fd) => { @@ -1352,29 +1441,24 @@ impl< })?; (stream, position) }; - let (n, _stat) = if blocking { - stream_res(streams::Host::blocking_write(self, stream, buf).await)? - } else { - stream_res(streams::Host::write(self, stream, buf).await)? - }; + let n = blocking_mode.write(self, stream, &buf).await?; if !append { - let pos = pos.checked_add(n).ok_or(types::Errno::Overflow)?; + let pos = pos.checked_add(n as u64).ok_or(types::Errno::Overflow)?; position.store(pos, Ordering::Relaxed); } - n + Ok(n.try_into()?) } Descriptor::Stdout { output_stream, .. } | Descriptor::Stderr { output_stream, .. } => { let Some(buf) = first_non_empty_ciovec(ciovs)? else { return Ok(0); }; - let (n, _stat) = - stream_res(streams::Host::blocking_write(self, output_stream, buf).await)?; - n + Ok(BlockingMode::Blocking + .write(self, output_stream, &buf) + .await? + .try_into()?) } - _ => return Err(types::Errno::Badf.into()), - }; - let n = n.try_into()?; - Ok(n) + _ => Err(types::Errno::Badf.into()), + } } /// Write to a file descriptor, without using and updating the file descriptor's offset. @@ -1387,8 +1471,10 @@ impl< offset: types::Filesize, ) -> Result { let desc = self.transact()?.get_descriptor(fd)?.clone(); - let (n, _stat) = match desc { - Descriptor::File(File { fd, blocking, .. }) if self.table().is_file(fd) => { + let n = match desc { + Descriptor::File(File { + fd, blocking_mode, .. + }) if self.table().is_file(fd) => { let Some(buf) = first_non_empty_ciovec(ciovs)? else { return Ok(0); }; @@ -1397,11 +1483,7 @@ impl< .context("failed to call `write-via-stream`") .unwrap_or_else(types::Error::trap) })?; - if blocking { - stream_res(streams::Host::blocking_write(self, stream, buf).await)? - } else { - stream_res(streams::Host::write(self, stream, buf).await)? - } + blocking_mode.write(self, stream, &buf).await? } Descriptor::Stdout { .. } | Descriptor::Stderr { .. } => { // NOTE: legacy implementation returns SPIPE here @@ -1409,8 +1491,7 @@ impl< } _ => return Err(types::Errno::Badf.into()), }; - let n = n.try_into()?; - Ok(n) + Ok(n.try_into()?) } /// Return a description of the given preopened file descriptor. @@ -1806,7 +1887,7 @@ impl< fd, position: Default::default(), append: fdflags.contains(types::Fdflags::APPEND), - blocking: !fdflags.contains(types::Fdflags::NONBLOCK), + blocking_mode: BlockingMode::from_fdflags(&fdflags), })?; Ok(fd.into()) } @@ -1939,7 +2020,7 @@ impl< #[instrument(skip(self))] fn sched_yield(&mut self) -> Result<(), types::Error> { - // TODO: This is not yet covered in Preview2. + // No such thing in preview 2. Intentionally left empty. Ok(()) } diff --git a/crates/wasi/src/preview2/stdio.rs b/crates/wasi/src/preview2/stdio.rs index 947a4885f558..ca8453a3b999 100644 --- a/crates/wasi/src/preview2/stdio.rs +++ b/crates/wasi/src/preview2/stdio.rs @@ -4,8 +4,7 @@ use crate::preview2::bindings::cli::{ }; use crate::preview2::bindings::io::streams; use crate::preview2::pipe::AsyncWriteStream; -use crate::preview2::{HostOutputStream, StreamState, WasiView}; -use anyhow::Error; +use crate::preview2::{HostOutputStream, OutputStreamError, WasiView}; use bytes::Bytes; use is_terminal::IsTerminal; @@ -19,10 +18,19 @@ mod worker_thread_stdin; #[cfg(windows)] pub use self::worker_thread_stdin::{stdin, Stdin}; +// blocking-write-and-flush must accept 4k. It doesn't seem likely that we need to +// buffer more than that to implement a wrapper on the host process's stdio. If users +// really need more, they can write their own implementation using AsyncWriteStream +// and tokio's stdout/err. +const STDIO_BUFFER_SIZE: usize = 4096; + pub struct Stdout(AsyncWriteStream); pub fn stdout() -> Stdout { - Stdout(AsyncWriteStream::new(tokio::io::stdout())) + Stdout(AsyncWriteStream::new( + STDIO_BUFFER_SIZE, + tokio::io::stdout(), + )) } impl IsTerminal for Stdout { fn is_terminal(&self) -> bool { @@ -31,18 +39,24 @@ impl IsTerminal for Stdout { } #[async_trait::async_trait] impl HostOutputStream for Stdout { - fn write(&mut self, bytes: Bytes) -> Result<(usize, StreamState), Error> { + fn write(&mut self, bytes: Bytes) -> Result<(), OutputStreamError> { self.0.write(bytes) } - async fn ready(&mut self) -> Result<(), Error> { - self.0.ready().await + fn flush(&mut self) -> Result<(), OutputStreamError> { + self.0.flush() + } + async fn write_ready(&mut self) -> Result { + self.0.write_ready().await } } pub struct Stderr(AsyncWriteStream); pub fn stderr() -> Stderr { - Stderr(AsyncWriteStream::new(tokio::io::stderr())) + Stderr(AsyncWriteStream::new( + STDIO_BUFFER_SIZE, + tokio::io::stderr(), + )) } impl IsTerminal for Stderr { fn is_terminal(&self) -> bool { @@ -51,11 +65,14 @@ impl IsTerminal for Stderr { } #[async_trait::async_trait] impl HostOutputStream for Stderr { - fn write(&mut self, bytes: Bytes) -> Result<(usize, StreamState), Error> { + fn write(&mut self, bytes: Bytes) -> Result<(), OutputStreamError> { self.0.write(bytes) } - async fn ready(&mut self) -> Result<(), Error> { - self.0.ready().await + fn flush(&mut self) -> Result<(), OutputStreamError> { + self.0.flush() + } + async fn write_ready(&mut self) -> Result { + self.0.write_ready().await } } @@ -321,12 +338,16 @@ mod test { ) } + // This test doesn't work under qemu because of the use of fork in the test helper. #[test] + #[cfg_attr(not(target = "x86_64"), ignore)] fn test_async_fd_stdin() { test_stdin_by_forking(super::stdin); } + // This test doesn't work under qemu because of the use of fork in the test helper. #[test] + #[cfg_attr(not(target = "x86_64"), ignore)] fn test_worker_thread_stdin() { test_stdin_by_forking(super::worker_thread_stdin::stdin); } diff --git a/crates/wasi/src/preview2/stream.rs b/crates/wasi/src/preview2/stream.rs index 827c8cbe6070..2c8a46b7f99d 100644 --- a/crates/wasi/src/preview2/stream.rs +++ b/crates/wasi/src/preview2/stream.rs @@ -1,4 +1,4 @@ -use crate::preview2::filesystem::{FileInputStream, FileOutputStream}; +use crate::preview2::filesystem::FileInputStream; use crate::preview2::{Table, TableError}; use anyhow::Error; use bytes::Bytes; @@ -77,56 +77,93 @@ pub trait HostInputStream: Send + Sync { async fn ready(&mut self) -> Result<(), Error>; } +#[derive(Debug)] +pub enum OutputStreamError { + Closed, + LastOperationFailed(anyhow::Error), + Trap(anyhow::Error), +} +impl std::fmt::Display for OutputStreamError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + OutputStreamError::Closed => write!(f, "closed"), + OutputStreamError::LastOperationFailed(e) => write!(f, "last operation failed: {e}"), + OutputStreamError::Trap(e) => write!(f, "trap: {e}"), + } + } +} +impl std::error::Error for OutputStreamError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + OutputStreamError::Closed => None, + OutputStreamError::LastOperationFailed(e) | OutputStreamError::Trap(e) => e.source(), + } + } +} + /// Host trait for implementing the `wasi:io/streams.output-stream` resource: /// A bytestream which can be written to. #[async_trait::async_trait] pub trait HostOutputStream: Send + Sync { - /// Write bytes. On success, returns the number of bytes written. - /// Important: this write must be non-blocking! - /// Returning an Err which downcasts to a [`StreamRuntimeError`] will be - /// reported to Wasm as the empty error result. Otherwise, errors will trap. - fn write(&mut self, bytes: Bytes) -> Result<(usize, StreamState), Error>; + /// Write bytes after obtaining a permit to write those bytes + /// Prior to calling [`write`](Self::write) + /// the caller must call [`write_ready`](Self::write_ready), + /// which resolves to a non-zero permit + /// + /// This method must never block. + /// [`write_ready`](Self::write_ready) permit indicates the maximum amount of bytes that are + /// permitted to be written in a single [`write`](Self::write) following the + /// [`write_ready`](Self::write_ready) resolution + /// + /// # Errors + /// + /// Returns an [OutputStreamError] if: + /// - stream is closed + /// - prior operation ([`write`](Self::write) or [`flush`](Self::flush)) failed + /// - caller performed an illegal operation (e.g. wrote more bytes than were permitted) + fn write(&mut self, bytes: Bytes) -> Result<(), OutputStreamError>; - /// Transfer bytes directly from an input stream to an output stream. - /// Important: this splice must be non-blocking! - /// Returning an Err which downcasts to a [`StreamRuntimeError`] will be - /// reported to Wasm as the empty error result. Otherwise, errors will trap. - fn splice( - &mut self, - src: &mut dyn HostInputStream, - nelem: usize, - ) -> Result<(usize, StreamState), Error> { - let mut nspliced = 0; - let mut state = StreamState::Open; - - // TODO: handle the case where `bs.len()` is less than `nelem` - let (bs, read_state) = src.read(nelem)?; - // TODO: handle the case where write returns less than `bs.len()` - let (nwritten, _write_state) = self.write(bs)?; - nspliced += nwritten; - if read_state.is_closed() { - state = read_state; - } + /// Trigger a flush of any bytes buffered in this stream implementation. + /// + /// This method may be called at any time and must never block. + /// + /// After this method is called, [`write_ready`](Self::write_ready) must pend until flush is + /// complete. + /// When [`write_ready`](Self::write_ready) becomes ready after a flush, that guarantees that + /// all prior writes have been flushed from the implementation successfully, or that any error + /// associated with those writes is reported in the return value of [`flush`](Self::flush) or + /// [`write_ready`](Self::write_ready) + /// + /// # Errors + /// + /// Returns an [OutputStreamError] if: + /// - stream is closed + /// - prior operation ([`write`](Self::write) or [`flush`](Self::flush)) failed + /// - caller performed an illegal operation (e.g. wrote more bytes than were permitted) + fn flush(&mut self) -> Result<(), OutputStreamError>; - Ok((nspliced, state)) - } + /// Returns a future, which: + /// - when pending, indicates 0 bytes are permitted for writing + /// - when ready, returns a non-zero number of bytes permitted to write + /// + /// # Errors + /// + /// Returns an [OutputStreamError] if: + /// - stream is closed + /// - prior operation ([`write`](Self::write) or [`flush`](Self::flush)) failed + async fn write_ready(&mut self) -> Result; - /// Repeatedly write a byte to a stream. Important: this write must be - /// non-blocking! + /// Repeatedly write a byte to a stream. + /// Important: this write must be non-blocking! /// Returning an Err which downcasts to a [`StreamRuntimeError`] will be /// reported to Wasm as the empty error result. Otherwise, errors will trap. - fn write_zeroes(&mut self, nelem: usize) -> Result<(usize, StreamState), Error> { + fn write_zeroes(&mut self, nelem: usize) -> Result<(), OutputStreamError> { // TODO: We could optimize this to not allocate one big zeroed buffer, and instead write // repeatedly from a 'static buffer of zeros. let bs = Bytes::from_iter(core::iter::repeat(0 as u8).take(nelem)); - let r = self.write(bs)?; - Ok(r) + self.write(bs)?; + Ok(()) } - - /// Check for write readiness: this method blocks until the stream is - /// ready for writing. - /// Returning an error will trap execution. - async fn ready(&mut self) -> Result<(), Error>; } pub(crate) enum InternalInputStream { @@ -134,34 +171,21 @@ pub(crate) enum InternalInputStream { File(FileInputStream), } -pub(crate) enum InternalOutputStream { - Host(Box), - File(FileOutputStream), -} - pub(crate) trait InternalTableStreamExt { fn push_internal_input_stream( &mut self, istream: InternalInputStream, ) -> Result; + fn push_internal_input_stream_child( + &mut self, + istream: InternalInputStream, + parent: u32, + ) -> Result; fn get_internal_input_stream_mut( &mut self, fd: u32, ) -> Result<&mut InternalInputStream, TableError>; fn delete_internal_input_stream(&mut self, fd: u32) -> Result; - - fn push_internal_output_stream( - &mut self, - ostream: InternalOutputStream, - ) -> Result; - fn get_internal_output_stream_mut( - &mut self, - fd: u32, - ) -> Result<&mut InternalOutputStream, TableError>; - fn delete_internal_output_stream( - &mut self, - fd: u32, - ) -> Result; } impl InternalTableStreamExt for Table { fn push_internal_input_stream( @@ -170,32 +194,20 @@ impl InternalTableStreamExt for Table { ) -> Result { self.push(Box::new(istream)) } - fn get_internal_input_stream_mut( - &mut self, - fd: u32, - ) -> Result<&mut InternalInputStream, TableError> { - self.get_mut(fd) - } - fn delete_internal_input_stream(&mut self, fd: u32) -> Result { - self.delete(fd) - } - - fn push_internal_output_stream( + fn push_internal_input_stream_child( &mut self, - ostream: InternalOutputStream, + istream: InternalInputStream, + parent: u32, ) -> Result { - self.push(Box::new(ostream)) + self.push_child(Box::new(istream), parent) } - fn get_internal_output_stream_mut( + fn get_internal_input_stream_mut( &mut self, fd: u32, - ) -> Result<&mut InternalOutputStream, TableError> { + ) -> Result<&mut InternalInputStream, TableError> { self.get_mut(fd) } - fn delete_internal_output_stream( - &mut self, - fd: u32, - ) -> Result { + fn delete_internal_input_stream(&mut self, fd: u32) -> Result { self.delete(fd) } } @@ -204,6 +216,13 @@ impl InternalTableStreamExt for Table { pub trait TableStreamExt { /// Push a [`HostInputStream`] into a [`Table`], returning the table index. fn push_input_stream(&mut self, istream: Box) -> Result; + /// Same as [`push_input_stream`](Self::push_output_stream) except assigns a parent resource to + /// the input-stream created. + fn push_input_stream_child( + &mut self, + istream: Box, + parent: u32, + ) -> Result; /// Get a mutable reference to a [`HostInputStream`] in a [`Table`]. fn get_input_stream_mut(&mut self, fd: u32) -> Result<&mut dyn HostInputStream, TableError>; /// Remove [`HostInputStream`] from table: @@ -212,6 +231,13 @@ pub trait TableStreamExt { /// Push a [`HostOutputStream`] into a [`Table`], returning the table index. fn push_output_stream(&mut self, ostream: Box) -> Result; + /// Same as [`push_output_stream`](Self::push_output_stream) except assigns a parent resource + /// to the output-stream created. + fn push_output_stream_child( + &mut self, + ostream: Box, + parent: u32, + ) -> Result; /// Get a mutable reference to a [`HostOutputStream`] in a [`Table`]. fn get_output_stream_mut(&mut self, fd: u32) -> Result<&mut dyn HostOutputStream, TableError>; @@ -222,6 +248,13 @@ impl TableStreamExt for Table { fn push_input_stream(&mut self, istream: Box) -> Result { self.push_internal_input_stream(InternalInputStream::Host(istream)) } + fn push_input_stream_child( + &mut self, + istream: Box, + parent: u32, + ) -> Result { + self.push_internal_input_stream_child(InternalInputStream::Host(istream), parent) + } fn get_input_stream_mut(&mut self, fd: u32) -> Result<&mut dyn HostInputStream, TableError> { match self.get_internal_input_stream_mut(fd)? { InternalInputStream::Host(ref mut h) => Ok(h.as_mut()), @@ -246,26 +279,21 @@ impl TableStreamExt for Table { &mut self, ostream: Box, ) -> Result { - self.push_internal_output_stream(InternalOutputStream::Host(ostream)) + self.push(Box::new(ostream)) + } + fn push_output_stream_child( + &mut self, + ostream: Box, + parent: u32, + ) -> Result { + self.push_child(Box::new(ostream), parent) } fn get_output_stream_mut(&mut self, fd: u32) -> Result<&mut dyn HostOutputStream, TableError> { - match self.get_internal_output_stream_mut(fd)? { - InternalOutputStream::Host(ref mut h) => Ok(h.as_mut()), - _ => Err(TableError::WrongType), - } + let boxed: &mut Box = self.get_mut(fd)?; + Ok(boxed.as_mut()) } fn delete_output_stream(&mut self, fd: u32) -> Result, TableError> { - let occ = self.entry(fd)?; - match occ.get().downcast_ref::() { - Some(InternalOutputStream::Host(_)) => { - let any = occ.remove_entry()?; - match *any.downcast().expect("downcast checked above") { - InternalOutputStream::Host(h) => Ok(h), - _ => unreachable!("variant checked above"), - } - } - _ => Err(TableError::WrongType), - } + self.delete(fd) } } @@ -275,18 +303,7 @@ mod test { #[test] fn input_stream_in_table() { - struct DummyInputStream; - #[async_trait::async_trait] - impl HostInputStream for DummyInputStream { - fn read(&mut self, _size: usize) -> Result<(Bytes, StreamState), Error> { - unimplemented!(); - } - async fn ready(&mut self) -> Result<(), Error> { - unimplemented!(); - } - } - - let dummy = DummyInputStream; + let dummy = crate::preview2::pipe::ClosedInputStream; let mut table = Table::new(); // Put it into the table: let ix = table.push_input_stream(Box::new(dummy)).unwrap(); @@ -308,18 +325,7 @@ mod test { #[test] fn output_stream_in_table() { - struct DummyOutputStream; - #[async_trait::async_trait] - impl HostOutputStream for DummyOutputStream { - fn write(&mut self, _: Bytes) -> Result<(usize, StreamState), Error> { - unimplemented!(); - } - async fn ready(&mut self) -> Result<(), Error> { - unimplemented!(); - } - } - - let dummy = DummyOutputStream; + let dummy = crate::preview2::pipe::SinkOutputStream; let mut table = Table::new(); // Put it in the table: let ix = table.push_output_stream(Box::new(dummy)).unwrap(); diff --git a/crates/wasi/src/preview2/tcp.rs b/crates/wasi/src/preview2/tcp.rs index a563a6e09b3b..9721e2c3c37a 100644 --- a/crates/wasi/src/preview2/tcp.rs +++ b/crates/wasi/src/preview2/tcp.rs @@ -1,13 +1,12 @@ -use crate::preview2::{HostInputStream, HostOutputStream, StreamState, Table, TableError}; -use bytes::{Bytes, BytesMut}; +use super::{HostInputStream, HostOutputStream, OutputStreamError}; +use crate::preview2::{ + with_ambient_tokio_runtime, AbortOnDropJoinHandle, StreamState, Table, TableError, +}; use cap_net_ext::{AddressFamily, Blocking, TcpListenerExt}; -use cap_std::net::{TcpListener, TcpStream}; +use cap_std::net::TcpListener; use io_lifetimes::raw::{FromRawSocketlike, IntoRawSocketlike}; -use io_lifetimes::AsSocketlike; use std::io; use std::sync::Arc; -use system_interface::io::IoExt; -use tokio::io::Interest; /// The state of a TCP socket. /// @@ -47,116 +46,210 @@ pub(crate) enum HostTcpState { pub(crate) struct HostTcpSocket { /// The part of a `HostTcpSocket` which is reference-counted so that we /// can pass it to async tasks. - pub(crate) inner: Arc, + pub(crate) inner: Arc, /// The current state in the bind/listen/accept/connect progression. pub(crate) tcp_state: HostTcpState, } -/// The inner reference-counted state of a `HostTcpSocket`. -pub(crate) struct HostTcpSocketInner { - pub(crate) tcp_socket: tokio::net::TcpStream, +pub(crate) struct TcpReadStream { + stream: Arc, + closed: bool, } -impl HostTcpSocket { - /// Create a new socket in the given family. - pub fn new(family: AddressFamily) -> io::Result { - // Create a new host socket and set it to non-blocking, which is needed - // by our async implementation. - let tcp_socket = TcpListener::new(family, Blocking::No)?; - - let std_socket = - unsafe { std::net::TcpStream::from_raw_socketlike(tcp_socket.into_raw_socketlike()) }; +impl TcpReadStream { + fn new(stream: Arc) -> Self { + Self { + stream, + closed: false, + } + } + fn stream_state(&self) -> StreamState { + if self.closed { + StreamState::Closed + } else { + StreamState::Open + } + } +} - let tokio_tcp_socket = crate::preview2::with_ambient_tokio_runtime(|| { - tokio::net::TcpStream::try_from(std_socket).unwrap() - }); +#[async_trait::async_trait] +impl HostInputStream for TcpReadStream { + fn read(&mut self, size: usize) -> Result<(bytes::Bytes, StreamState), anyhow::Error> { + if size == 0 || self.closed { + return Ok((bytes::Bytes::new(), self.stream_state())); + } - Ok(Self { - inner: Arc::new(HostTcpSocketInner { - tcp_socket: tokio_tcp_socket, - }), - tcp_state: HostTcpState::Default, - }) - } + let mut buf = bytes::BytesMut::with_capacity(size); + let n = match self.stream.try_read_buf(&mut buf) { + // A 0-byte read indicates that the stream has closed. + Ok(0) => { + self.closed = true; + 0 + } + Ok(n) => n, - /// Create a `HostTcpSocket` from an existing socket. - /// - /// The socket must be in non-blocking mode. - pub fn from_tcp_stream(tcp_socket: cap_std::net::TcpStream) -> io::Result { - let fd = rustix::fd::OwnedFd::from(tcp_socket); - let tcp_socket = TcpListener::from(fd); + // Failing with `EWOULDBLOCK` is how we differentiate between a closed channel and no + // data to read right now. + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => 0, - let std_tcp_socket = - unsafe { std::net::TcpStream::from_raw_socketlike(tcp_socket.into_raw_socketlike()) }; - let tokio_tcp_socket = crate::preview2::with_ambient_tokio_runtime(|| { - tokio::net::TcpStream::try_from(std_tcp_socket).unwrap() - }); + Err(e) => { + tracing::debug!("unexpected error on TcpReadStream read: {e:?}"); + self.closed = true; + 0 + } + }; - Ok(Self { - inner: Arc::new(HostTcpSocketInner { - tcp_socket: tokio_tcp_socket, - }), - tcp_state: HostTcpState::Default, - }) + buf.truncate(n); + Ok((buf.freeze(), self.stream_state())) } - pub fn tcp_socket(&self) -> &tokio::net::TcpStream { - self.inner.tcp_socket() + async fn ready(&mut self) -> Result<(), anyhow::Error> { + if self.closed { + return Ok(()); + } + self.stream.readable().await?; + Ok(()) } +} - pub fn clone_inner(&self) -> Arc { - Arc::clone(&self.inner) - } +const SOCKET_READY_SIZE: usize = 1024 * 1024 * 1024; + +pub(crate) struct TcpWriteStream { + stream: Arc, + write_handle: Option>>, } -impl HostTcpSocketInner { - pub fn tcp_socket(&self) -> &tokio::net::TcpStream { - let tcp_socket = &self.tcp_socket; +impl TcpWriteStream { + pub(crate) fn new(stream: Arc) -> Self { + Self { + stream, + write_handle: None, + } + } + + /// Write `bytes` in a background task, remembering the task handle for use in a future call to + /// `write_ready` + fn background_write(&mut self, mut bytes: bytes::Bytes) { + assert!(self.write_handle.is_none()); - tcp_socket + let stream = self.stream.clone(); + self.write_handle + .replace(crate::preview2::spawn(async move { + // Note: we are not using the AsyncWrite impl here, and instead using the TcpStream + // primitive try_write, which goes directly to attempt a write with mio. This has + // two advantages: 1. this operation takes a &TcpStream instead of a &mut TcpStream + // required to AsyncWrite, and 2. it eliminates any buffering in tokio we may need + // to flush. + while !bytes.is_empty() { + stream.writable().await?; + match stream.try_write(&bytes) { + Ok(n) => { + let _ = bytes.split_to(n); + } + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => continue, + Err(e) => return Err(e.into()), + } + } + + Ok(()) + })); } } #[async_trait::async_trait] -impl HostInputStream for Arc { - fn read(&mut self, size: usize) -> anyhow::Result<(Bytes, StreamState)> { - if size == 0 { - return Ok((Bytes::new(), StreamState::Open)); +impl HostOutputStream for TcpWriteStream { + fn write(&mut self, mut bytes: bytes::Bytes) -> Result<(), OutputStreamError> { + if self.write_handle.is_some() { + return Err(OutputStreamError::Trap(anyhow::anyhow!( + "unpermitted: cannot write while background write ongoing" + ))); } - let mut buf = BytesMut::zeroed(size); - let socket = self.tcp_socket(); - let r = socket.try_io(Interest::READABLE, || { - socket.as_socketlike_view::().read(&mut buf) - }); - let (n, state) = read_result(r)?; - buf.truncate(n); - Ok((buf.freeze(), state)) + while !bytes.is_empty() { + match self.stream.try_write(&bytes) { + Ok(n) => { + let _ = bytes.split_to(n); + } + + Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => { + // As `try_write` indicated that it would have blocked, we'll perform the write + // in the background to allow us to return immediately. + self.background_write(bytes); + + return Ok(()); + } + + Err(e) => return Err(OutputStreamError::LastOperationFailed(e.into())), + } + } + + Ok(()) } - async fn ready(&mut self) -> anyhow::Result<()> { - self.tcp_socket.readable().await?; + fn flush(&mut self) -> Result<(), OutputStreamError> { + // `flush` is a no-op here, as we're not managing any internal buffer. Additionally, + // `write_ready` will join the background write task if it's active, so following `flush` + // with `write_ready` will have the desired effect. Ok(()) } -} -#[async_trait::async_trait] -impl HostOutputStream for Arc { - fn write(&mut self, buf: Bytes) -> anyhow::Result<(usize, StreamState)> { - if buf.is_empty() { - return Ok((0, StreamState::Open)); + async fn write_ready(&mut self) -> Result { + if let Some(handle) = &mut self.write_handle { + handle + .await + .map_err(|e| OutputStreamError::LastOperationFailed(e.into()))?; + + // Only clear out the write handle once the task has exited, to ensure that + // `write_ready` remains cancel-safe. + self.write_handle = None; } - let socket = self.tcp_socket(); - let r = socket.try_io(Interest::WRITABLE, || { - socket.as_socketlike_view::().write(buf.as_ref()) - }); - let (n, state) = write_result(r)?; - Ok((n, state)) + + self.stream + .writable() + .await + .map_err(|e| OutputStreamError::LastOperationFailed(e.into()))?; + + Ok(SOCKET_READY_SIZE) } +} - async fn ready(&mut self) -> anyhow::Result<()> { - self.tcp_socket.writable().await?; - Ok(()) +impl HostTcpSocket { + /// Create a new socket in the given family. + pub fn new(family: AddressFamily) -> io::Result { + // Create a new host socket and set it to non-blocking, which is needed + // by our async implementation. + let tcp_listener = TcpListener::new(family, Blocking::No)?; + Self::from_tcp_listener(tcp_listener) + } + + /// Create a `HostTcpSocket` from an existing socket. + /// + /// The socket must be in non-blocking mode. + pub fn from_tcp_stream(tcp_socket: cap_std::net::TcpStream) -> io::Result { + let tcp_listener = TcpListener::from(rustix::fd::OwnedFd::from(tcp_socket)); + Self::from_tcp_listener(tcp_listener) + } + + pub fn from_tcp_listener(tcp_listener: cap_std::net::TcpListener) -> io::Result { + let fd = tcp_listener.into_raw_socketlike(); + let std_stream = unsafe { std::net::TcpStream::from_raw_socketlike(fd) }; + let stream = with_ambient_tokio_runtime(|| tokio::net::TcpStream::try_from(std_stream))?; + + Ok(Self { + inner: Arc::new(stream), + tcp_state: HostTcpState::Default, + }) + } + + pub fn tcp_socket(&self) -> &tokio::net::TcpStream { + &self.inner + } + + /// Create the input/output stream pair for a tcp socket. + pub fn as_split(&self) -> (Box, Box) { + let input = Box::new(TcpReadStream::new(self.inner.clone())); + let output = Box::new(TcpWriteStream::new(self.inner.clone())); + (input, output) } } @@ -185,31 +278,3 @@ impl TableTcpSocketExt for Table { self.get_mut(fd) } } - -pub(crate) fn read_result(r: io::Result) -> io::Result<(usize, StreamState)> { - match r { - Ok(0) => Ok((0, StreamState::Closed)), - Ok(n) => Ok((n, StreamState::Open)), - Err(e) - if e.kind() == io::ErrorKind::Interrupted || e.kind() == io::ErrorKind::WouldBlock => - { - Ok((0, StreamState::Open)) - } - Err(e) => Err(e), - } -} - -pub(crate) fn write_result(r: io::Result) -> io::Result<(usize, StreamState)> { - match r { - // We special-case zero-write stores ourselves, so if we get a zero - // back from a `write`, it means the stream is closed on some - // platforms. - Ok(0) => Ok((0, StreamState::Closed)), - Ok(n) => Ok((n, StreamState::Open)), - #[cfg(not(windows))] - Err(e) if e.raw_os_error() == Some(rustix::io::Errno::PIPE.raw_os_error()) => { - Ok((0, StreamState::Closed)) - } - Err(e) => Err(e), - } -} diff --git a/crates/wasi/wit/deps/io/streams.wit b/crates/wasi/wit/deps/io/streams.wit index 98df181c1ea4..e2631f66a569 100644 --- a/crates/wasi/wit/deps/io/streams.wit +++ b/crates/wasi/wit/deps/io/streams.wit @@ -134,58 +134,115 @@ interface streams { /// This [represents a resource](https://github.com/WebAssembly/WASI/blob/main/docs/WitInWasi.md#Resources). type output-stream = u32 - /// Perform a non-blocking write of bytes to a stream. + /// An error for output-stream operations. /// - /// This function returns a `u64` and a `stream-status`. The `u64` indicates - /// the number of bytes from `buf` that were written, which may be less than - /// the length of `buf`. The `stream-status` indicates if further writes to - /// the stream are expected to be read. + /// Contrary to input-streams, a closed output-stream is reported using + /// an error. + enum write-error { + /// The last operation (a write or flush) failed before completion. + last-operation-failed, + /// The stream is closed: no more input will be accepted by the + /// stream. A closed output-stream will return this error on all + /// future operations. + closed + } + /// Check readiness for writing. This function never blocks. + /// + /// Returns the number of bytes permitted for the next call to `write`, + /// or an error. Calling `write` with more bytes than this function has + /// permitted will trap. + /// + /// When this function returns 0 bytes, the `subscribe-to-output-stream` + /// pollable will become ready when this function will report at least + /// 1 byte, or an error. + check-write: func( + this: output-stream + ) -> result + + /// Perform a write. This function never blocks. /// - /// When the returned `stream-status` is `open`, the `u64` return value may - /// be less than the length of `buf`. This indicates that no more bytes may - /// be written to the stream promptly. In that case the - /// `subscribe-to-output-stream` pollable will indicate when additional bytes - /// may be promptly written. + /// Precondition: check-write gave permit of Ok(n) and contents has a + /// length of less than or equal to n. Otherwise, this function will trap. /// - /// Writing an empty list must return a non-error result with `0` for the - /// `u64` return value, and the current `stream-status`. + /// returns Err(closed) without writing if the stream has closed since + /// the last call to check-write provided a permit. write: func( this: output-stream, - /// Data to write - buf: list - ) -> result> + contents: list + ) -> result<_, write-error> - /// Blocking write of bytes to a stream. + /// Perform a write of up to 4096 bytes, and then flush the stream. Block + /// until all of these operations are complete, or an error occurs. /// - /// This is similar to `write`, except that it blocks until at least one - /// byte can be written. - blocking-write: func( - this: output-stream, - /// Data to write - buf: list - ) -> result> + /// This is a convenience wrapper around the use of `check-write`, + /// `subscribe-to-output-stream`, `write`, and `flush`, and is implemented + /// with the following pseudo-code: + /// + /// ```text + /// let pollable = subscribe-to-output-stream(this); + /// while !contents.is_empty() { + /// // Wait for the stream to become writable + /// poll-oneoff(pollable); + /// let Ok(n) = check-write(this); // eliding error handling + /// let len = min(n, contents.len()); + /// let (chunk, rest) = contents.split_at(len); + /// write(this, chunk); // eliding error handling + /// contents = rest; + /// } + /// flush(this); + /// // Wait for completion of `flush` + /// poll-oneoff(pollable); + /// // Check for any errors that arose during `flush` + /// let _ = check-write(this); // eliding error handling + /// ``` + blocking-write-and-flush: func( + this: output-stream, + contents: list + ) -> result<_, write-error> - /// Write multiple zero-bytes to a stream. + /// Request to flush buffered output. This function never blocks. /// - /// This function returns a `u64` indicating the number of zero-bytes - /// that were written; it may be less than `len`. Equivelant to a call to - /// `write` with a list of zeroes of the given length. - write-zeroes: func( + /// This tells the output-stream that the caller intends any buffered + /// output to be flushed. the output which is expected to be flushed + /// is all that has been passed to `write` prior to this call. + /// + /// Upon calling this function, the `output-stream` will not accept any + /// writes (`check-write` will return `ok(0)`) until the flush has + /// completed. The `subscribe-to-output-stream` pollable will become ready + /// when the flush has completed and the stream can accept more writes. + flush: func( this: output-stream, - /// The number of zero-bytes to write - len: u64 - ) -> result> + ) -> result<_, write-error> + + /// Request to flush buffered output, and block until flush completes + /// and stream is ready for writing again. + blocking-flush: func( + this: output-stream, + ) -> result<_, write-error> + + /// Create a `pollable` which will resolve once the output-stream + /// is ready for more writing, or an error has occured. When this + /// pollable is ready, `check-write` will return `ok(n)` with n>0, or an + /// error. + /// + /// If the stream is closed, this pollable is always ready immediately. + /// + /// The created `pollable` is a child resource of the `output-stream`. + /// Implementations may trap if the `output-stream` is dropped before + /// all derived `pollable`s created with this function are dropped. + subscribe-to-output-stream: func(this: output-stream) -> pollable - /// Write multiple zero bytes to a stream, with blocking. + /// Write zeroes to a stream. /// - /// This is similar to `write-zeroes`, except that it blocks until at least - /// one byte can be written. Equivelant to a call to `blocking-write` with - /// a list of zeroes of the given length. - blocking-write-zeroes: func( + /// this should be used precisely like `write` with the exact same + /// preconditions (must use check-write first), but instead of + /// passing a list of bytes, you simply pass the number of zero-bytes + /// that should be written. + write-zeroes: func( this: output-stream, - /// The number of zero bytes to write + /// The number of zero-bytes to write len: u64 - ) -> result> + ) -> result<_, write-error> /// Read from one stream and write to another. /// @@ -232,16 +289,6 @@ interface streams { src: input-stream ) -> result> - /// Create a `pollable` which will resolve once either the specified stream - /// is ready to accept bytes or the `stream-state` has become closed. - /// - /// Once the stream-state is closed, this pollable is always ready - /// immediately. - /// - /// The created `pollable` is a child resource of the `output-stream`. - /// Implementations may trap if the `output-stream` is dropped before - /// all derived `pollable`s created with this function are dropped. - subscribe-to-output-stream: func(this: output-stream) -> pollable /// Dispose of the specified `output-stream`, after which it may no longer /// be used. diff --git a/crates/wasi/wit/deps/sockets/tcp.wit b/crates/wasi/wit/deps/sockets/tcp.wit index 4edb1db7f0b1..3922769b308e 100644 --- a/crates/wasi/wit/deps/sockets/tcp.wit +++ b/crates/wasi/wit/deps/sockets/tcp.wit @@ -81,6 +81,9 @@ interface tcp { /// - /// - start-connect: func(this: tcp-socket, network: network, remote-address: ip-socket-address) -> result<_, error-code> + /// Note: the returned `input-stream` and `output-stream` are child + /// resources of the `tcp-socket`. Implementations may trap if the + /// `tcp-socket` is dropped before both of these streams are dropped. finish-connect: func(this: tcp-socket) -> result, error-code> /// Start listening for new connections. @@ -116,6 +119,10 @@ interface tcp { /// /// On success, this function returns the newly accepted client socket along with /// a pair of streams that can be used to read & write to the connection. + /// + /// Note: the returned `input-stream` and `output-stream` are child + /// resources of the returned `tcp-socket`. Implementations may trap if the + /// `tcp-socket` is dropped before its child streams are dropped. /// /// # Typical errors /// - `not-listening`: Socket is not in the Listener state. (EINVAL) @@ -223,6 +230,10 @@ interface tcp { /// Create a `pollable` which will resolve once the socket is ready for I/O. /// + /// The created `pollable` is a child resource of the `tcp-socket`. + /// Implementations may trap if the `tcp-socket` is dropped before all + /// derived `pollable`s created with this function are dropped. + /// /// Note: this function is here for WASI Preview2 only. /// It's planned to be removed when `future` is natively supported in Preview3. subscribe: func(this: tcp-socket) -> pollable diff --git a/tests/all/cli_tests.rs b/tests/all/cli_tests.rs index be10bc1d9d9c..dea58a99551b 100644 --- a/tests/all/cli_tests.rs +++ b/tests/all/cli_tests.rs @@ -739,7 +739,7 @@ fn wasi_misaligned_pointer() -> Result<()> { } #[test] -#[ignore] // FIXME(#6811) currently is flaky and may produce no output +#[cfg_attr(not(feature = "component-model"), ignore)] fn hello_with_preview2() -> Result<()> { let wasm = build_wasm("tests/all/cli_tests/hello_wasi_snapshot1.wat")?; let stdout = run_wasmtime(&[ diff --git a/tests/all/cli_tests/wasi-http.wat b/tests/all/cli_tests/wasi-http.wat index 9f5d51979426..8893c3ca4584 100644 --- a/tests/all/cli_tests/wasi-http.wat +++ b/tests/all/cli_tests/wasi-http.wat @@ -3,8 +3,8 @@ (func $__wasi_proc_exit (param i32))) (import "wasi:io/streams" "write" (func $__wasi_io_streams_write (param i32 i32 i32 i32))) - (import "wasi:io/streams" "blocking-write" - (func $__wasi_io_streams_blocking_write (param i32 i32 i32 i32))) + (import "wasi:io/streams" "blocking-write-and-flush" + (func $__wasi_io_streams_blocking_write_and_flush (param i32 i32 i32 i32))) (import "wasi:io/streams" "subscribe-to-output-stream" (func $__wasi_io_streams_subscribe_to_output_stream (param i32) (result i32))) (import "wasi:http/types" "new-fields" @@ -70,7 +70,7 @@ ;; A helper function for printing ptr-len strings. (func $print (param $ptr i32) (param $len i32) - (call $__wasi_io_streams_blocking_write + (call $__wasi_io_streams_blocking_write_and_flush i32.const 4 ;; Value for stdout local.get $ptr local.get $len