diff --git a/crates/wasi-http/src/body.rs b/crates/wasi-http/src/body.rs index 95bf4fa3a702..11a9b3ce78ca 100644 --- a/crates/wasi-http/src/body.rs +++ b/crates/wasi-http/src/body.rs @@ -1,186 +1,215 @@ use crate::{bindings::http::types, types::FieldMap}; -use anyhow::anyhow; +use anyhow::{anyhow, Result}; use bytes::Bytes; +use http_body::{Body, Frame}; use http_body_util::combinators::BoxBody; +use http_body_util::BodyExt; use std::future::Future; +use std::mem; +use std::task::{Context, Poll}; use std::{pin::Pin, sync::Arc, time::Duration}; use tokio::sync::{mpsc, oneshot}; use wasmtime_wasi::preview2::{ - self, AbortOnDropJoinHandle, HostInputStream, HostOutputStream, StreamError, Subscribe, + self, poll_noop, AbortOnDropJoinHandle, HostInputStream, HostOutputStream, StreamError, + Subscribe, }; pub type HyperIncomingBody = BoxBody; -/// Holds onto the things needed to construct a [`HostIncomingBody`] until we are ready to build -/// one. The HostIncomingBody spawns a task that starts consuming the incoming body, and we don't -/// want to do that unless the user asks to consume the body. -pub struct HostIncomingBodyBuilder { - pub body: HyperIncomingBody, - pub between_bytes_timeout: Duration, +/// Small wrapper around `BoxBody` which adds a timeout to every frame. +struct BodyWithTimeout { + /// Underlying stream that frames are coming from. + inner: HyperIncomingBody, + /// Currently active timeout that's reset between frames. + timeout: Pin>, + /// Whether or not `timeout` needs to be reset on the next call to + /// `poll_frame`. + reset_sleep: bool, + /// Maximal duration between when a frame is first requested and when it's + /// allowed to arrive. + between_bytes_timeout: Duration, } -impl HostIncomingBodyBuilder { - /// Consume the state held in the [`HostIncomingBodyBuilder`] to spawn a task that will drive the - /// streaming body to completion. Data segments will be communicated out over the - /// [`HostIncomingBodyStream`], and a [`HostFutureTrailers`] gives a way to block on/retrieve - /// the trailers. - pub fn build(mut self) -> HostIncomingBody { - let (body_writer, body_receiver) = mpsc::channel(1); - let (trailer_writer, trailers) = oneshot::channel(); - - let worker = preview2::spawn(async move { - loop { - let frame = match tokio::time::timeout( - self.between_bytes_timeout, - http_body_util::BodyExt::frame(&mut self.body), - ) - .await - { - Ok(None) => break, - - Ok(Some(Ok(frame))) => frame, - - Ok(Some(Err(e))) => { - match body_writer.send(Err(e)).await { - Ok(_) => {} - // If the body read end has dropped, then we report this error with the - // trailers. unwrap and rewrap Err because the Ok side of these two Results - // are different. - Err(e) => { - let _ = trailer_writer.send(Err(e.0.unwrap_err())); - } - } - break; - } +impl BodyWithTimeout { + fn new(inner: HyperIncomingBody, between_bytes_timeout: Duration) -> BodyWithTimeout { + BodyWithTimeout { + inner, + between_bytes_timeout, + reset_sleep: true, + timeout: Box::pin(preview2::with_ambient_tokio_runtime(|| { + tokio::time::sleep(Duration::new(0, 0)) + })), + } + } +} - Err(_) => { - match body_writer - .send(Err(types::Error::TimeoutError( - "data frame timed out".to_string(), - ) - .into())) - .await - { - Ok(_) => {} - Err(e) => { - let _ = trailer_writer.send(Err(e.0.unwrap_err())); - } - } - break; - } - }; +impl Body for BodyWithTimeout { + type Data = Bytes; + type Error = types::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, types::Error>>> { + let me = Pin::into_inner(self); + + // If the timeout timer needs to be reset, do that now relative to the + // current instant. Otherwise test the timeout timer and see if it's + // fired yet and if so we've timed out and return an error. + if me.reset_sleep { + me.timeout + .as_mut() + .reset(tokio::time::Instant::now() + me.between_bytes_timeout); + me.reset_sleep = false; + } - if frame.is_trailers() { - // We know we're not going to write any more data frames at this point, so we - // explicitly drop the body_writer so that anything waiting on the read end returns - // immediately. - drop(body_writer); + // Register interest in this context on the sleep timer, and if the + // sleep elapsed that means that we've timed out. + if let Poll::Ready(()) = me.timeout.as_mut().poll(cx) { + return Poll::Ready(Some(Err(types::Error::TimeoutError( + "frame timed out".to_string(), + )))); + } - let trailers = frame.into_trailers().unwrap(); + // Without timeout business now handled check for the frame. If a frame + // arrives then the sleep timer will be reset on the next frame. + let result = Pin::new(&mut me.inner).poll_frame(cx); + me.reset_sleep = result.is_ready(); + match result { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(Some(Ok(frame))) => Poll::Ready(Some(Ok(frame))), + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(match e.downcast() { + Ok(e) => e, + Err(e) => types::Error::ProtocolError(format!("{e:?}")), + }))), + } + } +} - // TODO: this will fail in two cases: - // 1. we've already used the channel once, which should be imposible, - // 2. the read end is closed. - // I'm not sure how to differentiate between these two cases, or really - // if we need to do anything to handle either. - let _ = trailer_writer.send(Ok(trailers)); +pub struct HostIncomingBody { + body: IncomingBodyState, + /// An optional worker task to keep alive while this body is being read. + /// This ensures that if the parent of this body is dropped before the body + /// then the backing data behind this worker is kept alive. + worker: Option>>>, +} - break; - } +enum IncomingBodyState { + /// The body is stored here meaning that within `HostIncomingBody` the + /// `take_stream` method can be called for example. + Start(BodyWithTimeout), - assert!(frame.is_data(), "frame wasn't data"); + /// The body is within a `HostIncomingBodyStream` meaning that it's not + /// currently owned here. The body will be sent back over this channel when + /// it's done, however. + InBodyStream(oneshot::Receiver), +} - let data = frame.into_data().unwrap(); +/// Message sent when a `HostIncomingBodyStream` is done to the +/// `HostFutureTrailers` state. +enum StreamEnd { + /// The body wasn't completely read and was dropped early. May still have + /// trailers, but requires reading more frames. + Remaining(BodyWithTimeout), - // If the receiver no longer exists, thats ok - in that case we want to keep the - // loop running to relieve backpressure, so we get to the trailers. - let _ = body_writer.send(Ok(data)).await; - } - }); + /// Body was completely read and trailers were read. Here are the trailers. + /// Note that `None` means that the body finished without trailers. + Trailers(Option), +} +impl HostIncomingBody { + pub fn new(body: HyperIncomingBody, between_bytes_timeout: Duration) -> HostIncomingBody { + let body = BodyWithTimeout::new(body, between_bytes_timeout); HostIncomingBody { - worker, - stream: Some(HostIncomingBodyStream::new(body_receiver)), - trailers, + body: IncomingBodyState::Start(body), + worker: None, } } -} -pub struct HostIncomingBody { - pub worker: AbortOnDropJoinHandle<()>, - pub stream: Option, - pub trailers: oneshot::Receiver>, -} + pub fn retain_worker(&mut self, worker: &Arc>>) { + assert!(self.worker.is_none()); + self.worker = Some(worker.clone()); + } -impl HostIncomingBody { - pub fn into_future_trailers(self) -> HostFutureTrailers { - HostFutureTrailers { - _worker: self.worker, - state: HostFutureTrailersState::Waiting(self.trailers), + pub fn take_stream(&mut self) -> Option { + match &mut self.body { + IncomingBodyState::Start(_) => {} + IncomingBodyState::InBodyStream(_) => return None, } + let (tx, rx) = oneshot::channel(); + let body = match mem::replace(&mut self.body, IncomingBodyState::InBodyStream(rx)) { + IncomingBodyState::Start(b) => b, + IncomingBodyState::InBodyStream(_) => unreachable!(), + }; + Some(HostIncomingBodyStream { + state: IncomingBodyStreamState::Open { body, tx }, + buffer: Bytes::new(), + error: None, + }) + } + + pub fn into_future_trailers(self) -> HostFutureTrailers { + HostFutureTrailers::Waiting(self) } } pub struct HostIncomingBodyStream { - pub open: bool, - pub receiver: mpsc::Receiver>, - pub buffer: Bytes, - pub error: Option, + state: IncomingBodyStreamState, + buffer: Bytes, + error: Option, } -impl HostIncomingBodyStream { - fn new(receiver: mpsc::Receiver>) -> Self { - Self { - open: true, - receiver, - buffer: Bytes::new(), - error: None, - } - } +enum IncomingBodyStreamState { + /// The body is currently open for reading and present here. + /// + /// When trailers are read, or when this is dropped, the body is sent along + /// `tx`. + /// + /// This state is transitioned to `Closed` when an error happens, EOF + /// happens, or when trailers are read. + Open { + body: BodyWithTimeout, + tx: oneshot::Sender, + }, + + /// This body is closed and no longer available for reading, no more data + /// will come. + Closed, } #[async_trait::async_trait] impl HostInputStream for HostIncomingBodyStream { fn read(&mut self, size: usize) -> Result { - use mpsc::error::TryRecvError; - - if !self.buffer.is_empty() { - let len = size.min(self.buffer.len()); - let chunk = self.buffer.split_to(len); - return Ok(chunk); - } - - if let Some(e) = self.error.take() { - return Err(StreamError::LastOperationFailed(e)); - } - - if !self.open { - return Err(StreamError::Closed); - } - - match self.receiver.try_recv() { - Ok(Ok(mut bytes)) => { - let len = bytes.len().min(size); - let chunk = bytes.split_to(len); - if !bytes.is_empty() { - self.buffer = bytes; - } - + loop { + // Handle buffered data/errors if any + if !self.buffer.is_empty() { + let len = size.min(self.buffer.len()); + let chunk = self.buffer.split_to(len); return Ok(chunk); } - Ok(Err(e)) => { - self.open = false; + if let Some(e) = self.error.take() { return Err(StreamError::LastOperationFailed(e)); } - Err(TryRecvError::Empty) => { - return Ok(Bytes::new()); - } + // Extract the body that we're reading from. If present perform a + // non-blocking poll to see if a frame is already here. If it is + // then turn the loop again to operate on the results. If it's not + // here then return an empty buffer as no data is available at this + // time. + let body = match &mut self.state { + IncomingBodyStreamState::Open { body, .. } => body, + IncomingBodyStreamState::Closed => return Err(StreamError::Closed), + }; - Err(TryRecvError::Disconnected) => { - self.open = false; - return Err(StreamError::Closed); + let future = body.frame(); + futures::pin_mut!(future); + match poll_noop(future) { + Some(result) => { + self.record_frame(result); + } + None => return Ok(Bytes::new()), } } } @@ -189,50 +218,156 @@ impl HostInputStream for HostIncomingBodyStream { #[async_trait::async_trait] impl Subscribe for HostIncomingBodyStream { async fn ready(&mut self) { - if !self.buffer.is_empty() { + if !self.buffer.is_empty() || self.error.is_some() { return; } - if !self.open { - return; + if let IncomingBodyStreamState::Open { body, .. } = &mut self.state { + let frame = body.frame().await; + self.record_frame(frame); } + } +} - match self.receiver.recv().await { - Some(Ok(bytes)) => self.buffer = bytes, +impl HostIncomingBodyStream { + fn record_frame(&mut self, frame: Option, types::Error>>) { + match frame { + Some(Ok(frame)) => match frame.into_data() { + // A data frame was received, so queue up the buffered data for + // the next `read` call. + Ok(bytes) => { + assert!(self.buffer.is_empty()); + self.buffer = bytes; + } + // Trailers were received meaning that this was the final frame. + // Throw away the body and send the trailers along the + // `tx` channel to make them available. + Err(trailers) => { + let trailers = trailers.into_trailers().unwrap(); + let tx = match mem::replace(&mut self.state, IncomingBodyStreamState::Closed) { + IncomingBodyStreamState::Open { body: _, tx } => tx, + IncomingBodyStreamState::Closed => unreachable!(), + }; + + // NB: ignore send failures here because if this fails then + // no one was interested in the trailers. + let _ = tx.send(StreamEnd::Trailers(Some(trailers))); + } + }, + + // An error was received meaning that the stream is now done. + // Destroy the body to terminate the stream while enqueueing the + // error to get returned from the next call to `read`. Some(Err(e)) => { - self.error = Some(e); - self.open = false; + self.error = Some(e.into()); + self.state = IncomingBodyStreamState::Closed; } - None => self.open = false, + // No more frames are going to be received again, so drop the `body` + // and the `tx` channel we'd send the body back onto because it's + // not needed as frames are done. + None => { + self.state = IncomingBodyStreamState::Closed; + } } } } -pub struct HostFutureTrailers { - _worker: AbortOnDropJoinHandle<()>, - pub state: HostFutureTrailersState, +impl Drop for HostIncomingBodyStream { + fn drop(&mut self) { + // When a body stream is dropped, for whatever reason, attempt to send + // the body back to the `tx` which will provide the trailers if desired. + // This isn't necessary if the state is already closed. Additionally, + // like `record_frame` above, `send` errors are ignored as they indicate + // that the body/trailers aren't actually needed. + let prev = mem::replace(&mut self.state, IncomingBodyStreamState::Closed); + if let IncomingBodyStreamState::Open { body, tx } = prev { + let _ = tx.send(StreamEnd::Remaining(body)); + } + } } -pub enum HostFutureTrailersState { - Waiting(oneshot::Receiver>), +pub enum HostFutureTrailers { + /// Trailers aren't here yet. + /// + /// This state represents two similar states: + /// + /// * The body is here and ready for reading and we're waiting to read + /// trailers. This can happen for example when the actual body wasn't read + /// or if the body was only partially read. + /// + /// * The body is being read by something else and we're waiting for that to + /// send us the trailers (or the body itself). This state will get entered + /// when the body stream is dropped for example. If the body stream reads + /// the trailers itself it will also send a message over here with the + /// trailers. + Waiting(HostIncomingBody), + + /// Trailers are ready and here they are. + /// + /// Note that `Ok(None)` means that there were no trailers for this request + /// while `Ok(Some(_))` means that trailers were found in the request. Done(Result), } #[async_trait::async_trait] impl Subscribe for HostFutureTrailers { async fn ready(&mut self) { - if let HostFutureTrailersState::Waiting(rx) = &mut self.state { - let result = match rx.await { - Ok(Ok(headers)) => Ok(FieldMap::from(headers)), - Ok(Err(e)) => Err(types::Error::ProtocolError(format!("hyper error: {e:?}"))), - Err(_) => Err(types::Error::ProtocolError( - "stream hung up before trailers were received".to_string(), - )), - }; - self.state = HostFutureTrailersState::Done(result); + let body = match self { + HostFutureTrailers::Waiting(body) => body, + HostFutureTrailers::Done(_) => return, + }; + + // If the body is itself being read by a body stream then we need to + // wait for that to be done. + if let IncomingBodyState::InBodyStream(rx) = &mut body.body { + match rx.await { + // Trailers were read for us and here they are, so store the + // result. + Ok(StreamEnd::Trailers(Some(t))) => *self = Self::Done(Ok(t)), + Ok(StreamEnd::Trailers(None)) => *self = Self::Done(Ok(Default::default())), + + // The body wasn't fully read and was dropped before trailers + // were reached. It's up to us now to complete the body. + Ok(StreamEnd::Remaining(b)) => body.body = IncomingBodyState::Start(b), + + // Technically this shouldn't be possible as the sender + // shouldn't get destroyed without receiving a message. Handle + // this just in case though. + Err(_) => { + debug_assert!(false, "should be unreachable"); + *self = HostFutureTrailers::Done(Err(types::Error::ProtocolError( + "stream hung up before trailers were received".to_string(), + ))); + } + } } + + // Here it should be guaranteed that `InBodyStream` is now gone, so if + // we have the body ourselves then read frames until trailers are found. + let body = match self { + HostFutureTrailers::Waiting(body) => body, + HostFutureTrailers::Done(_) => return, + }; + let hyper_body = match &mut body.body { + IncomingBodyState::Start(body) => body, + IncomingBodyState::InBodyStream(_) => unreachable!(), + }; + let result = loop { + match hyper_body.frame().await { + None => break Ok(Default::default()), + Some(Err(e)) => break Err(e), + Some(Ok(frame)) => { + // If this frame is a data frame ignore it as we're only + // interested in trailers. + if let Ok(headers) = frame.into_trailers() { + break Ok(headers); + } + } + } + }; + *self = HostFutureTrailers::Done(result); } } @@ -251,9 +386,6 @@ pub struct HostOutgoingBody { impl HostOutgoingBody { pub fn new() -> (Self, HyperOutgoingBody) { - use http_body_util::BodyExt; - use hyper::body::{Body, Frame}; - use std::task::{Context, Poll}; use tokio::sync::oneshot::error::RecvError; struct BodyImpl { body_receiver: mpsc::Receiver, diff --git a/crates/wasi-http/src/types.rs b/crates/wasi-http/src/types.rs index 5176fa9add04..a032b0020078 100644 --- a/crates/wasi-http/src/types.rs +++ b/crates/wasi-http/src/types.rs @@ -3,11 +3,12 @@ use crate::{ bindings::http::types::{self, Method, Scheme}, - body::{HostIncomingBodyBuilder, HyperIncomingBody, HyperOutgoingBody}, + body::{HostIncomingBody, HyperIncomingBody, HyperOutgoingBody}, }; use anyhow::Context; use http_body_util::BodyExt; use std::any::Any; +use std::sync::Arc; use std::time::Duration; use tokio::net::TcpStream; use tokio::time::timeout; @@ -35,11 +36,11 @@ pub trait WasiHttpView: Send { req: hyper::Request, ) -> wasmtime::Result> { let (parts, body) = req.into_parts(); - let body = HostIncomingBodyBuilder { + let body = HostIncomingBody::new( body, // TODO: this needs to be plumbed through - between_bytes_timeout: std::time::Duration::from_millis(600 * 1000), - }; + std::time::Duration::from_millis(600 * 1000), + ); Ok(self.table().push_resource(HostIncomingRequest { parts, body: Some(body), @@ -159,7 +160,7 @@ pub fn default_send_request( Ok(IncomingResponseInternal { resp, - worker, + worker: Arc::new(worker), between_bytes_timeout, }) }); @@ -199,7 +200,7 @@ fn invalid_url(e: std::io::Error) -> anyhow::Error { pub struct HostIncomingRequest { pub parts: http::request::Parts, - pub body: Option, + pub body: Option, } pub struct HostResponseOutparam { @@ -219,8 +220,8 @@ pub struct HostOutgoingRequest { pub struct HostIncomingResponse { pub status: u16, pub headers: FieldMap, - pub body: Option, - pub worker: AbortOnDropJoinHandle>, + pub body: Option, + pub worker: Arc>>, } pub struct HostOutgoingResponse { @@ -271,7 +272,7 @@ pub enum HostFields { pub struct IncomingResponseInternal { pub resp: hyper::Response, - pub worker: AbortOnDropJoinHandle>, + pub worker: Arc>>, pub between_bytes_timeout: std::time::Duration, } diff --git a/crates/wasi-http/src/types_impl.rs b/crates/wasi-http/src/types_impl.rs index 0cbde9f5cd85..490fb58e6880 100644 --- a/crates/wasi-http/src/types_impl.rs +++ b/crates/wasi-http/src/types_impl.rs @@ -1,9 +1,9 @@ use crate::bindings::http::types::{Error, Headers, Method, Scheme, StatusCode, Trailers}; -use crate::body::{FinishMessage, HostFutureTrailers, HostFutureTrailersState}; +use crate::body::{FinishMessage, HostFutureTrailers}; use crate::types::{HostIncomingRequest, HostOutgoingResponse}; use crate::WasiHttpView; use crate::{ - body::{HostIncomingBody, HostIncomingBodyBuilder, HostOutgoingBody}, + body::{HostIncomingBody, HostOutgoingBody}, types::{ FieldMap, HostFields, HostFutureIncomingResponse, HostIncomingResponse, HostOutgoingRequest, HostResponseOutparam, @@ -234,8 +234,8 @@ impl crate::bindings::http::types::HostIncomingRequest for T { ) -> wasmtime::Result, ()>> { let req = self.table().get_resource_mut(&id)?; match req.body.take() { - Some(builder) => { - let id = self.table().push_resource(builder.build())?; + Some(body) => { + let id = self.table().push_resource(body)?; Ok(Ok(id)) } @@ -375,8 +375,8 @@ impl crate::bindings::http::types::HostIncomingResponse for T { .context("[incoming_response_consume] getting response")?; match r.body.take() { - Some(builder) => { - let id = self.table().push_resource(builder.build())?; + Some(body) => { + let id = self.table().push_resource(body)?; Ok(Ok(id)) } @@ -406,16 +406,16 @@ impl crate::bindings::http::types::HostFutureTrailers for T { id: Resource, ) -> wasmtime::Result, Error>>> { let trailers = self.table().get_resource_mut(&id)?; - match &trailers.state { - HostFutureTrailersState::Waiting(_) => return Ok(None), - HostFutureTrailersState::Done(Err(e)) => return Ok(Some(Err(e.clone()))), - HostFutureTrailersState::Done(Ok(_)) => {} + match trailers { + HostFutureTrailers::Waiting(_) => return Ok(None), + HostFutureTrailers::Done(Err(e)) => return Ok(Some(Err(e.clone()))), + HostFutureTrailers::Done(Ok(_)) => {} } fn get_fields(elem: &mut dyn Any) -> &mut FieldMap { let trailers = elem.downcast_mut::().unwrap(); - match &mut trailers.state { - HostFutureTrailersState::Done(Ok(e)) => e, + match trailers { + HostFutureTrailers::Done(Ok(e)) => e, _ => unreachable!(), } } @@ -439,7 +439,7 @@ impl crate::bindings::http::types::HostIncomingBody for T { ) -> wasmtime::Result, ()>> { let body = self.table().get_resource_mut(&id)?; - if let Some(stream) = body.stream.take() { + if let Some(stream) = body.take_stream() { let stream = InputStream::Host(Box::new(stream)); let stream = self.table().push_child_resource(stream, &id)?; return Ok(Ok(stream)); @@ -539,9 +539,10 @@ impl crate::bindings::http::types::HostFutureIncomingResponse f let resp = self.table().push_resource(HostIncomingResponse { status: parts.status.as_u16(), headers: FieldMap::from(parts.headers), - body: Some(HostIncomingBodyBuilder { - body, - between_bytes_timeout: resp.between_bytes_timeout, + body: Some({ + let mut body = HostIncomingBody::new(body, resp.between_bytes_timeout); + body.retain_worker(&resp.worker); + body }), worker: resp.worker, })?; diff --git a/crates/wasi/src/preview2/mod.rs b/crates/wasi/src/preview2/mod.rs index 1dacb8a91e56..9ebb87d68070 100644 --- a/crates/wasi/src/preview2/mod.rs +++ b/crates/wasi/src/preview2/mod.rs @@ -250,7 +250,13 @@ pub fn in_tokio(f: F) -> F::Output { } } -fn with_ambient_tokio_runtime(f: impl FnOnce() -> R) -> R { +/// Executes the closure `f` with an "ambient Tokio runtime" which basically +/// means that if code in `f` tries to get a runtime `Handle` it'll succeed. +/// +/// If a `Handle` is already available, e.g. in async contexts, then `f` is run +/// immediately. Otherwise for synchronous contexts this crate's fallback +/// runtime is configured and then `f` is executed. +pub fn with_ambient_tokio_runtime(f: impl FnOnce() -> R) -> R { match tokio::runtime::Handle::try_current() { Ok(_) => f(), Err(_) => { @@ -260,7 +266,14 @@ fn with_ambient_tokio_runtime(f: impl FnOnce() -> R) -> R { } } -fn poll_noop(future: Pin<&mut F>) -> Option +/// Attempts to get the result of a `future`. +/// +/// This function does not block and will poll the provided future once. If the +/// result is here then `Some` is returned, otherwise `None` is returned. +/// +/// Note that by polling `future` this means that `future` must be re-polled +/// later if it's to wake up a task. +pub fn poll_noop(future: Pin<&mut F>) -> Option where F: Future, {