Skip to content

Commit c4bb4db

Browse files
committed
fix(http1): only send 100 Continue if request body is polled
Before, if a client request included an `Expect: 100-continue` header, the `100 Continue` response was sent immediately. However, this is problematic if the service is going to reply with some 4xx status code and reject the body. This change delays the automatic sending of the `100 Continue` status until the service has call `poll_data` on the request body once.
1 parent a354580 commit c4bb4db

File tree

7 files changed

+332
-39
lines changed

7 files changed

+332
-39
lines changed

src/body/body.rs

+151-22
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use futures_util::TryStreamExt;
1111
use http::HeaderMap;
1212
use http_body::{Body as HttpBody, SizeHint};
1313

14-
use crate::common::{task, Future, Never, Pin, Poll};
14+
use crate::common::{task, watch, Future, Never, Pin, Poll};
1515
use crate::proto::DecodedLength;
1616
use crate::upgrade::OnUpgrade;
1717

@@ -33,7 +33,7 @@ enum Kind {
3333
Once(Option<Bytes>),
3434
Chan {
3535
content_length: DecodedLength,
36-
abort_rx: oneshot::Receiver<()>,
36+
want_tx: watch::Sender,
3737
rx: mpsc::Receiver<Result<Bytes, crate::Error>>,
3838
},
3939
H2 {
@@ -79,12 +79,14 @@ enum DelayEof {
7979
/// Useful when wanting to stream chunks from another thread. See
8080
/// [`Body::channel`](Body::channel) for more.
8181
#[must_use = "Sender does nothing unless sent on"]
82-
#[derive(Debug)]
8382
pub struct Sender {
84-
abort_tx: oneshot::Sender<()>,
83+
want_rx: watch::Receiver,
8584
tx: BodySender,
8685
}
8786

87+
const WANT_PENDING: usize = 1;
88+
const WANT_READY: usize = 2;
89+
8890
impl Body {
8991
/// Create an empty `Body` stream.
9092
///
@@ -106,17 +108,22 @@ impl Body {
106108
/// Useful when wanting to stream chunks from another thread.
107109
#[inline]
108110
pub fn channel() -> (Sender, Body) {
109-
Self::new_channel(DecodedLength::CHUNKED)
111+
Self::new_channel(DecodedLength::CHUNKED, /*wanter =*/ false)
110112
}
111113

112-
pub(crate) fn new_channel(content_length: DecodedLength) -> (Sender, Body) {
114+
pub(crate) fn new_channel(content_length: DecodedLength, wanter: bool) -> (Sender, Body) {
113115
let (tx, rx) = mpsc::channel(0);
114-
let (abort_tx, abort_rx) = oneshot::channel();
115116

116-
let tx = Sender { abort_tx, tx };
117+
// If wanter is true, `Sender::poll_ready()` won't becoming ready
118+
// until the `Body` has been polled for data once.
119+
let want = if wanter { WANT_PENDING } else { WANT_READY };
120+
121+
let (want_tx, want_rx) = watch::channel(want);
122+
123+
let tx = Sender { want_rx, tx };
117124
let rx = Body::new(Kind::Chan {
118125
content_length,
119-
abort_rx,
126+
want_tx,
120127
rx,
121128
});
122129

@@ -236,11 +243,9 @@ impl Body {
236243
Kind::Chan {
237244
content_length: ref mut len,
238245
ref mut rx,
239-
ref mut abort_rx,
246+
ref mut want_tx,
240247
} => {
241-
if let Poll::Ready(Ok(())) = Pin::new(abort_rx).poll(cx) {
242-
return Poll::Ready(Some(Err(crate::Error::new_body_write_aborted())));
243-
}
248+
want_tx.send(WANT_READY);
244249

245250
match ready!(Pin::new(rx).poll_next(cx)?) {
246251
Some(chunk) => {
@@ -460,19 +465,29 @@ impl From<Cow<'static, str>> for Body {
460465
impl Sender {
461466
/// Check to see if this `Sender` can send more data.
462467
pub fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
463-
match self.abort_tx.poll_canceled(cx) {
464-
Poll::Ready(()) => return Poll::Ready(Err(crate::Error::new_closed())),
465-
Poll::Pending => (), // fallthrough
466-
}
467-
468+
// Check if the receiver end has tried polling for the body yet
469+
ready!(self.poll_want(cx)?);
468470
self.tx
469471
.poll_ready(cx)
470472
.map_err(|_| crate::Error::new_closed())
471473
}
472474

475+
fn poll_want(&mut self, cx: &mut task::Context<'_>) -> Poll<crate::Result<()>> {
476+
match self.want_rx.load(cx) {
477+
WANT_READY => Poll::Ready(Ok(())),
478+
WANT_PENDING => Poll::Pending,
479+
watch::CLOSED => Poll::Ready(Err(crate::Error::new_closed())),
480+
unexpected => unreachable!("want_rx value: {}", unexpected),
481+
}
482+
}
483+
484+
async fn ready(&mut self) -> crate::Result<()> {
485+
futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await
486+
}
487+
473488
/// Send data on this channel when it is ready.
474489
pub async fn send_data(&mut self, chunk: Bytes) -> crate::Result<()> {
475-
futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await?;
490+
self.ready().await?;
476491
self.tx
477492
.try_send(Ok(chunk))
478493
.map_err(|_| crate::Error::new_closed())
@@ -498,20 +513,41 @@ impl Sender {
498513

499514
/// Aborts the body in an abnormal fashion.
500515
pub fn abort(self) {
501-
// TODO(sean): this can just be `self.tx.clone().try_send()`
502-
let _ = self.abort_tx.send(());
516+
let _ = self
517+
.tx
518+
// clone so the send works even if buffer is full
519+
.clone()
520+
.try_send(Err(crate::Error::new_body_write_aborted()));
503521
}
504522

505523
pub(crate) fn send_error(&mut self, err: crate::Error) {
506524
let _ = self.tx.try_send(Err(err));
507525
}
508526
}
509527

528+
impl fmt::Debug for Sender {
529+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
530+
#[derive(Debug)]
531+
struct Open;
532+
#[derive(Debug)]
533+
struct Closed;
534+
535+
let mut builder = f.debug_tuple("Sender");
536+
match self.want_rx.peek() {
537+
watch::CLOSED => builder.field(&Closed),
538+
_ => builder.field(&Open),
539+
};
540+
541+
builder.finish()
542+
}
543+
}
544+
510545
#[cfg(test)]
511546
mod tests {
512547
use std::mem;
548+
use std::task::Poll;
513549

514-
use super::{Body, Sender};
550+
use super::{Body, DecodedLength, HttpBody, Sender};
515551

516552
#[test]
517553
fn test_size_of() {
@@ -541,4 +577,97 @@ mod tests {
541577
"Option<Sender>"
542578
);
543579
}
580+
581+
#[tokio::test]
582+
async fn channel_abort() {
583+
let (tx, mut rx) = Body::channel();
584+
585+
tx.abort();
586+
587+
let err = rx.data().await.unwrap().unwrap_err();
588+
assert!(err.is_body_write_aborted(), "{:?}", err);
589+
}
590+
591+
#[tokio::test]
592+
async fn channel_abort_when_buffer_is_full() {
593+
let (mut tx, mut rx) = Body::channel();
594+
595+
tx.try_send_data("chunk 1".into()).expect("send 1");
596+
// buffer is full, but can still send abort
597+
tx.abort();
598+
599+
let chunk1 = rx.data().await.expect("item 1").expect("chunk 1");
600+
assert_eq!(chunk1, "chunk 1");
601+
602+
let err = rx.data().await.unwrap().unwrap_err();
603+
assert!(err.is_body_write_aborted(), "{:?}", err);
604+
}
605+
606+
#[test]
607+
fn channel_buffers_one() {
608+
let (mut tx, _rx) = Body::channel();
609+
610+
tx.try_send_data("chunk 1".into()).expect("send 1");
611+
612+
// buffer is now full
613+
let chunk2 = tx.try_send_data("chunk 2".into()).expect_err("send 2");
614+
assert_eq!(chunk2, "chunk 2");
615+
}
616+
617+
#[tokio::test]
618+
async fn channel_empty() {
619+
let (_, mut rx) = Body::channel();
620+
621+
assert!(rx.data().await.is_none());
622+
}
623+
624+
#[test]
625+
fn channel_ready() {
626+
let (mut tx, _rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ false);
627+
628+
let mut tx_ready = tokio_test::task::spawn(tx.ready());
629+
630+
assert!(tx_ready.poll().is_ready(), "tx is ready immediately");
631+
}
632+
633+
#[test]
634+
fn channel_wanter() {
635+
let (mut tx, mut rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ true);
636+
637+
let mut tx_ready = tokio_test::task::spawn(tx.ready());
638+
let mut rx_data = tokio_test::task::spawn(rx.data());
639+
640+
assert!(
641+
tx_ready.poll().is_pending(),
642+
"tx isn't ready before rx has been polled"
643+
);
644+
645+
assert!(rx_data.poll().is_pending(), "poll rx.data");
646+
assert!(tx_ready.is_woken(), "rx poll wakes tx");
647+
648+
assert!(
649+
tx_ready.poll().is_ready(),
650+
"tx is ready after rx has been polled"
651+
);
652+
}
653+
654+
#[test]
655+
fn channel_notices_closure() {
656+
let (mut tx, rx) = Body::new_channel(DecodedLength::CHUNKED, /*wanter = */ true);
657+
658+
let mut tx_ready = tokio_test::task::spawn(tx.ready());
659+
660+
assert!(
661+
tx_ready.poll().is_pending(),
662+
"tx isn't ready before rx has been polled"
663+
);
664+
665+
drop(rx);
666+
assert!(tx_ready.is_woken(), "dropping rx wakes tx");
667+
668+
match tx_ready.poll() {
669+
Poll::Ready(Err(ref e)) if e.is_closed() => (),
670+
unexpected => panic!("tx poll ready unexpected: {:?}", unexpected),
671+
}
672+
}
544673
}

src/common/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ pub(crate) mod io;
1414
mod lazy;
1515
mod never;
1616
pub(crate) mod task;
17+
pub(crate) mod watch;
1718

1819
pub use self::exec::Executor;
1920
pub(crate) use self::exec::{BoxSendFuture, Exec};

src/common/watch.rs

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
//! An SPSC broadcast channel.
2+
//!
3+
//! - The value can only be a `usize`.
4+
//! - The consumer is only notified if the value is different.
5+
//! - The value `0` is reserved for closed.
6+
7+
use futures_util::task::AtomicWaker;
8+
use std::sync::{
9+
atomic::{AtomicUsize, Ordering},
10+
Arc,
11+
};
12+
use std::task;
13+
14+
type Value = usize;
15+
16+
pub(crate) const CLOSED: usize = 0;
17+
18+
pub(crate) fn channel(initial: Value) -> (Sender, Receiver) {
19+
debug_assert!(
20+
initial != CLOSED,
21+
"watch::channel initial state of 0 is reserved"
22+
);
23+
24+
let shared = Arc::new(Shared {
25+
value: AtomicUsize::new(initial),
26+
waker: AtomicWaker::new(),
27+
});
28+
29+
(
30+
Sender {
31+
shared: shared.clone(),
32+
},
33+
Receiver { shared },
34+
)
35+
}
36+
37+
pub(crate) struct Sender {
38+
shared: Arc<Shared>,
39+
}
40+
41+
pub(crate) struct Receiver {
42+
shared: Arc<Shared>,
43+
}
44+
45+
struct Shared {
46+
value: AtomicUsize,
47+
waker: AtomicWaker,
48+
}
49+
50+
impl Sender {
51+
pub(crate) fn send(&mut self, value: Value) {
52+
if self.shared.value.swap(value, Ordering::SeqCst) != value {
53+
self.shared.waker.wake();
54+
}
55+
}
56+
}
57+
58+
impl Drop for Sender {
59+
fn drop(&mut self) {
60+
self.send(CLOSED);
61+
}
62+
}
63+
64+
impl Receiver {
65+
pub(crate) fn load(&mut self, cx: &mut task::Context<'_>) -> Value {
66+
self.shared.waker.register(cx.waker());
67+
self.shared.value.load(Ordering::SeqCst)
68+
}
69+
70+
pub(crate) fn peek(&self) -> Value {
71+
self.shared.value.load(Ordering::Relaxed)
72+
}
73+
}

0 commit comments

Comments
 (0)