@@ -11,7 +11,7 @@ use futures_util::TryStreamExt;
11
11
use http:: HeaderMap ;
12
12
use http_body:: { Body as HttpBody , SizeHint } ;
13
13
14
- use crate :: common:: { task, Future , Never , Pin , Poll } ;
14
+ use crate :: common:: { task, watch , Future , Never , Pin , Poll } ;
15
15
use crate :: proto:: DecodedLength ;
16
16
use crate :: upgrade:: OnUpgrade ;
17
17
@@ -33,7 +33,7 @@ enum Kind {
33
33
Once ( Option < Bytes > ) ,
34
34
Chan {
35
35
content_length : DecodedLength ,
36
- abort_rx : oneshot :: Receiver < ( ) > ,
36
+ want_tx : watch :: Sender ,
37
37
rx : mpsc:: Receiver < Result < Bytes , crate :: Error > > ,
38
38
} ,
39
39
H2 {
@@ -79,12 +79,14 @@ enum DelayEof {
79
79
/// Useful when wanting to stream chunks from another thread. See
80
80
/// [`Body::channel`](Body::channel) for more.
81
81
#[ must_use = "Sender does nothing unless sent on" ]
82
- #[ derive( Debug ) ]
83
82
pub struct Sender {
84
- abort_tx : oneshot :: Sender < ( ) > ,
83
+ want_rx : watch :: Receiver ,
85
84
tx : BodySender ,
86
85
}
87
86
87
+ const WANT_PENDING : usize = 1 ;
88
+ const WANT_READY : usize = 2 ;
89
+
88
90
impl Body {
89
91
/// Create an empty `Body` stream.
90
92
///
@@ -106,17 +108,22 @@ impl Body {
106
108
/// Useful when wanting to stream chunks from another thread.
107
109
#[ inline]
108
110
pub fn channel ( ) -> ( Sender , Body ) {
109
- Self :: new_channel ( DecodedLength :: CHUNKED )
111
+ Self :: new_channel ( DecodedLength :: CHUNKED , /*wanter =*/ false )
110
112
}
111
113
112
- pub ( crate ) fn new_channel ( content_length : DecodedLength ) -> ( Sender , Body ) {
114
+ pub ( crate ) fn new_channel ( content_length : DecodedLength , wanter : bool ) -> ( Sender , Body ) {
113
115
let ( tx, rx) = mpsc:: channel ( 0 ) ;
114
- let ( abort_tx, abort_rx) = oneshot:: channel ( ) ;
115
116
116
- let tx = Sender { abort_tx, tx } ;
117
+ // If wanter is true, `Sender::poll_ready()` won't becoming ready
118
+ // until the `Body` has been polled for data once.
119
+ let want = if wanter { WANT_PENDING } else { WANT_READY } ;
120
+
121
+ let ( want_tx, want_rx) = watch:: channel ( want) ;
122
+
123
+ let tx = Sender { want_rx, tx } ;
117
124
let rx = Body :: new ( Kind :: Chan {
118
125
content_length,
119
- abort_rx ,
126
+ want_tx ,
120
127
rx,
121
128
} ) ;
122
129
@@ -236,11 +243,9 @@ impl Body {
236
243
Kind :: Chan {
237
244
content_length : ref mut len,
238
245
ref mut rx,
239
- ref mut abort_rx ,
246
+ ref mut want_tx ,
240
247
} => {
241
- if let Poll :: Ready ( Ok ( ( ) ) ) = Pin :: new ( abort_rx) . poll ( cx) {
242
- return Poll :: Ready ( Some ( Err ( crate :: Error :: new_body_write_aborted ( ) ) ) ) ;
243
- }
248
+ want_tx. send ( WANT_READY ) ;
244
249
245
250
match ready ! ( Pin :: new( rx) . poll_next( cx) ?) {
246
251
Some ( chunk) => {
@@ -460,19 +465,29 @@ impl From<Cow<'static, str>> for Body {
460
465
impl Sender {
461
466
/// Check to see if this `Sender` can send more data.
462
467
pub fn poll_ready ( & mut self , cx : & mut task:: Context < ' _ > ) -> Poll < crate :: Result < ( ) > > {
463
- match self . abort_tx . poll_canceled ( cx) {
464
- Poll :: Ready ( ( ) ) => return Poll :: Ready ( Err ( crate :: Error :: new_closed ( ) ) ) ,
465
- Poll :: Pending => ( ) , // fallthrough
466
- }
467
-
468
+ // Check if the receiver end has tried polling for the body yet
469
+ ready ! ( self . poll_want( cx) ?) ;
468
470
self . tx
469
471
. poll_ready ( cx)
470
472
. map_err ( |_| crate :: Error :: new_closed ( ) )
471
473
}
472
474
475
+ fn poll_want ( & mut self , cx : & mut task:: Context < ' _ > ) -> Poll < crate :: Result < ( ) > > {
476
+ match self . want_rx . load ( cx) {
477
+ WANT_READY => Poll :: Ready ( Ok ( ( ) ) ) ,
478
+ WANT_PENDING => Poll :: Pending ,
479
+ watch:: CLOSED => Poll :: Ready ( Err ( crate :: Error :: new_closed ( ) ) ) ,
480
+ unexpected => unreachable ! ( "want_rx value: {}" , unexpected) ,
481
+ }
482
+ }
483
+
484
+ async fn ready ( & mut self ) -> crate :: Result < ( ) > {
485
+ futures_util:: future:: poll_fn ( |cx| self . poll_ready ( cx) ) . await
486
+ }
487
+
473
488
/// Send data on this channel when it is ready.
474
489
pub async fn send_data ( & mut self , chunk : Bytes ) -> crate :: Result < ( ) > {
475
- futures_util :: future :: poll_fn ( |cx| self . poll_ready ( cx ) ) . await ?;
490
+ self . ready ( ) . await ?;
476
491
self . tx
477
492
. try_send ( Ok ( chunk) )
478
493
. map_err ( |_| crate :: Error :: new_closed ( ) )
@@ -498,20 +513,41 @@ impl Sender {
498
513
499
514
/// Aborts the body in an abnormal fashion.
500
515
pub fn abort ( self ) {
501
- // TODO(sean): this can just be `self.tx.clone().try_send()`
502
- let _ = self . abort_tx . send ( ( ) ) ;
516
+ let _ = self
517
+ . tx
518
+ // clone so the send works even if buffer is full
519
+ . clone ( )
520
+ . try_send ( Err ( crate :: Error :: new_body_write_aborted ( ) ) ) ;
503
521
}
504
522
505
523
pub ( crate ) fn send_error ( & mut self , err : crate :: Error ) {
506
524
let _ = self . tx . try_send ( Err ( err) ) ;
507
525
}
508
526
}
509
527
528
+ impl fmt:: Debug for Sender {
529
+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
530
+ #[ derive( Debug ) ]
531
+ struct Open ;
532
+ #[ derive( Debug ) ]
533
+ struct Closed ;
534
+
535
+ let mut builder = f. debug_tuple ( "Sender" ) ;
536
+ match self . want_rx . peek ( ) {
537
+ watch:: CLOSED => builder. field ( & Closed ) ,
538
+ _ => builder. field ( & Open ) ,
539
+ } ;
540
+
541
+ builder. finish ( )
542
+ }
543
+ }
544
+
510
545
#[ cfg( test) ]
511
546
mod tests {
512
547
use std:: mem;
548
+ use std:: task:: Poll ;
513
549
514
- use super :: { Body , Sender } ;
550
+ use super :: { Body , DecodedLength , HttpBody , Sender } ;
515
551
516
552
#[ test]
517
553
fn test_size_of ( ) {
@@ -541,4 +577,97 @@ mod tests {
541
577
"Option<Sender>"
542
578
) ;
543
579
}
580
+
581
+ #[ tokio:: test]
582
+ async fn channel_abort ( ) {
583
+ let ( tx, mut rx) = Body :: channel ( ) ;
584
+
585
+ tx. abort ( ) ;
586
+
587
+ let err = rx. data ( ) . await . unwrap ( ) . unwrap_err ( ) ;
588
+ assert ! ( err. is_body_write_aborted( ) , "{:?}" , err) ;
589
+ }
590
+
591
+ #[ tokio:: test]
592
+ async fn channel_abort_when_buffer_is_full ( ) {
593
+ let ( mut tx, mut rx) = Body :: channel ( ) ;
594
+
595
+ tx. try_send_data ( "chunk 1" . into ( ) ) . expect ( "send 1" ) ;
596
+ // buffer is full, but can still send abort
597
+ tx. abort ( ) ;
598
+
599
+ let chunk1 = rx. data ( ) . await . expect ( "item 1" ) . expect ( "chunk 1" ) ;
600
+ assert_eq ! ( chunk1, "chunk 1" ) ;
601
+
602
+ let err = rx. data ( ) . await . unwrap ( ) . unwrap_err ( ) ;
603
+ assert ! ( err. is_body_write_aborted( ) , "{:?}" , err) ;
604
+ }
605
+
606
+ #[ test]
607
+ fn channel_buffers_one ( ) {
608
+ let ( mut tx, _rx) = Body :: channel ( ) ;
609
+
610
+ tx. try_send_data ( "chunk 1" . into ( ) ) . expect ( "send 1" ) ;
611
+
612
+ // buffer is now full
613
+ let chunk2 = tx. try_send_data ( "chunk 2" . into ( ) ) . expect_err ( "send 2" ) ;
614
+ assert_eq ! ( chunk2, "chunk 2" ) ;
615
+ }
616
+
617
+ #[ tokio:: test]
618
+ async fn channel_empty ( ) {
619
+ let ( _, mut rx) = Body :: channel ( ) ;
620
+
621
+ assert ! ( rx. data( ) . await . is_none( ) ) ;
622
+ }
623
+
624
+ #[ test]
625
+ fn channel_ready ( ) {
626
+ let ( mut tx, _rx) = Body :: new_channel ( DecodedLength :: CHUNKED , /*wanter = */ false ) ;
627
+
628
+ let mut tx_ready = tokio_test:: task:: spawn ( tx. ready ( ) ) ;
629
+
630
+ assert ! ( tx_ready. poll( ) . is_ready( ) , "tx is ready immediately" ) ;
631
+ }
632
+
633
+ #[ test]
634
+ fn channel_wanter ( ) {
635
+ let ( mut tx, mut rx) = Body :: new_channel ( DecodedLength :: CHUNKED , /*wanter = */ true ) ;
636
+
637
+ let mut tx_ready = tokio_test:: task:: spawn ( tx. ready ( ) ) ;
638
+ let mut rx_data = tokio_test:: task:: spawn ( rx. data ( ) ) ;
639
+
640
+ assert ! (
641
+ tx_ready. poll( ) . is_pending( ) ,
642
+ "tx isn't ready before rx has been polled"
643
+ ) ;
644
+
645
+ assert ! ( rx_data. poll( ) . is_pending( ) , "poll rx.data" ) ;
646
+ assert ! ( tx_ready. is_woken( ) , "rx poll wakes tx" ) ;
647
+
648
+ assert ! (
649
+ tx_ready. poll( ) . is_ready( ) ,
650
+ "tx is ready after rx has been polled"
651
+ ) ;
652
+ }
653
+
654
+ #[ test]
655
+ fn channel_notices_closure ( ) {
656
+ let ( mut tx, rx) = Body :: new_channel ( DecodedLength :: CHUNKED , /*wanter = */ true ) ;
657
+
658
+ let mut tx_ready = tokio_test:: task:: spawn ( tx. ready ( ) ) ;
659
+
660
+ assert ! (
661
+ tx_ready. poll( ) . is_pending( ) ,
662
+ "tx isn't ready before rx has been polled"
663
+ ) ;
664
+
665
+ drop ( rx) ;
666
+ assert ! ( tx_ready. is_woken( ) , "dropping rx wakes tx" ) ;
667
+
668
+ match tx_ready. poll ( ) {
669
+ Poll :: Ready ( Err ( ref e) ) if e. is_closed ( ) => ( ) ,
670
+ unexpected => panic ! ( "tx poll ready unexpected: {:?}" , unexpected) ,
671
+ }
672
+ }
544
673
}
0 commit comments