Skip to content

Commit b1765dd

Browse files
committed
fix(client): drop in-use connections when they finish if Client is dropped
1 parent af8d11b commit b1765dd

File tree

3 files changed

+126
-47
lines changed

3 files changed

+126
-47
lines changed

src/client/pool.rs

+14-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::collections::{HashMap, VecDeque};
33
use std::fmt;
44
use std::io;
55
use std::ops::{Deref, DerefMut, BitAndAssign};
6-
use std::rc::Rc;
6+
use std::rc::{Rc, Weak};
77
use std::time::{Duration, Instant};
88

99
use futures::{Future, Async, Poll};
@@ -103,7 +103,7 @@ impl<T: Clone> Pool<T> {
103103
status: Rc::new(Cell::new(TimedKA::Busy)),
104104
},
105105
key: key,
106-
pool: self.clone(),
106+
pool: Rc::downgrade(&self.inner),
107107
}
108108
}
109109

@@ -118,7 +118,7 @@ impl<T: Clone> Pool<T> {
118118
Pooled {
119119
entry: entry,
120120
key: key,
121-
pool: self.clone(),
121+
pool: Rc::downgrade(&self.inner),
122122
}
123123
}
124124

@@ -161,7 +161,7 @@ impl<T> Clone for Pool<T> {
161161
pub struct Pooled<T> {
162162
entry: Entry<T>,
163163
key: Rc<String>,
164-
pool: Pool<T>,
164+
pool: Weak<RefCell<PoolInner<T>>>,
165165
}
166166

167167
impl<T> Deref for Pooled<T> {
@@ -194,8 +194,16 @@ impl<T: Clone> KeepAlive for Pooled<T> {
194194
return;
195195
}
196196
self.entry.is_reused = true;
197-
if self.pool.is_enabled() {
198-
self.pool.put(self.key.clone(), self.entry.clone());
197+
if let Some(inner) = self.pool.upgrade() {
198+
let mut pool = Pool {
199+
inner: inner,
200+
};
201+
if pool.is_enabled() {
202+
pool.put(self.key.clone(), self.entry.clone());
203+
}
204+
} else {
205+
trace!("pool dropped, dropping pooled ({:?})", self.key);
206+
self.entry.status.set(TimedKA::Disabled);
199207
}
200208
}
201209

src/proto/conn.rs

+19-6
Original file line numberDiff line numberDiff line change
@@ -235,20 +235,29 @@ where I: AsyncRead + AsyncWrite,
235235
//
236236
// When writing finishes, we need to wake the task up in case there
237237
// is more reading that can be done, to start a new message.
238+
239+
240+
238241
let wants_read = match self.state.reading {
239242
Reading::Body(..) |
240243
Reading::KeepAlive => return,
241244
Reading::Init => true,
242245
Reading::Closed => false,
243246
};
244247

245-
match self.state.writing {
248+
let wants_write = match self.state.writing {
246249
Writing::Continue(..) |
247250
Writing::Body(..) |
248251
Writing::Ending(..) => return,
249-
Writing::Init |
250-
Writing::KeepAlive |
251-
Writing::Closed => (),
252+
Writing::Init => true,
253+
Writing::KeepAlive => false,
254+
Writing::Closed => false,
255+
};
256+
257+
// if the client is at Reading::Init and Writing::Init,
258+
// it's not actually looking for a read, but a write.
259+
if wants_write && !T::should_read_first() {
260+
return;
252261
}
253262

254263
if !self.io.is_read_blocked() {
@@ -704,9 +713,13 @@ impl<B, K: KeepAlive> State<B, K> {
704713

705714
fn idle(&mut self) {
706715
self.method = None;
707-
self.reading = Reading::Init;
708-
self.writing = Writing::Init;
709716
self.keep_alive.idle();
717+
if self.is_idle() {
718+
self.reading = Reading::Init;
719+
self.writing = Writing::Init;
720+
} else {
721+
self.close();
722+
}
710723
}
711724

712725
fn is_idle(&self) -> bool {

tests/client.rs

+93-35
Original file line numberDiff line numberDiff line change
@@ -543,57 +543,115 @@ fn client_pooled_socket_disconnected() {
543543
}
544544
*/
545545

546-
#[test]
547-
fn drop_body_before_eof_closes_connection() {
548-
// https://github.com/hyperium/hyper/issues/1353
546+
mod dispatch_impl {
547+
use super::*;
549548
use std::io::{self, Read, Write};
550549
use std::sync::Arc;
551550
use std::sync::atomic::{AtomicUsize, Ordering};
551+
use std::thread;
552552
use std::time::Duration;
553+
554+
use futures::{self, Future};
555+
use futures::sync::oneshot;
553556
use tokio_core::reactor::{Timeout};
554557
use tokio_core::net::TcpStream;
555558
use tokio_io::{AsyncRead, AsyncWrite};
559+
556560
use hyper::client::HttpConnector;
557561
use hyper::server::Service;
558-
use hyper::Uri;
562+
use hyper::{Client, Uri};
563+
use hyper;
559564

560-
let _ = pretty_env_logger::init();
561565

562-
let server = TcpListener::bind("127.0.0.1:0").unwrap();
563-
let addr = server.local_addr().unwrap();
564-
let mut core = Core::new().unwrap();
565-
let handle = core.handle();
566-
let closes = Arc::new(AtomicUsize::new(0));
567-
let client = Client::configure()
568-
.connector(DebugConnector(HttpConnector::new(1, &core.handle()), closes.clone()))
569-
.no_proto()
570-
.build(&handle);
571566

572-
let (tx1, rx1) = oneshot::channel();
567+
#[test]
568+
fn drop_body_before_eof_closes_connection() {
569+
// https://github.com/hyperium/hyper/issues/1353
570+
let _ = pretty_env_logger::init();
573571

574-
thread::spawn(move || {
575-
let mut sock = server.accept().unwrap().0;
576-
sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
577-
sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap();
578-
let mut buf = [0; 4096];
579-
sock.read(&mut buf).expect("read 1");
580-
let body = vec![b'x'; 1024 * 128];
581-
write!(sock, "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", body.len()).expect("write head");
582-
let _ = sock.write_all(&body);
583-
let _ = tx1.send(());
584-
});
572+
let server = TcpListener::bind("127.0.0.1:0").unwrap();
573+
let addr = server.local_addr().unwrap();
574+
let mut core = Core::new().unwrap();
575+
let handle = core.handle();
576+
let closes = Arc::new(AtomicUsize::new(0));
577+
let client = Client::configure()
578+
.connector(DebugConnector(HttpConnector::new(1, &core.handle()), closes.clone()))
579+
.no_proto()
580+
.build(&handle);
581+
582+
let (tx1, rx1) = oneshot::channel();
583+
584+
thread::spawn(move || {
585+
let mut sock = server.accept().unwrap().0;
586+
sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
587+
sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap();
588+
let mut buf = [0; 4096];
589+
sock.read(&mut buf).expect("read 1");
590+
let body = vec![b'x'; 1024 * 128];
591+
write!(sock, "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", body.len()).expect("write head");
592+
let _ = sock.write_all(&body);
593+
let _ = tx1.send(());
594+
});
585595

586-
let uri = format!("http://{}/a", addr).parse().unwrap();
596+
let uri = format!("http://{}/a", addr).parse().unwrap();
587597

588-
let res = client.get(uri).and_then(move |res| {
589-
assert_eq!(res.status(), hyper::StatusCode::Ok);
590-
Timeout::new(Duration::from_secs(1), &handle).unwrap()
591-
.from_err()
592-
});
593-
let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked")));
594-
core.run(res.join(rx).map(|r| r.0)).unwrap();
598+
let res = client.get(uri).and_then(move |res| {
599+
assert_eq!(res.status(), hyper::StatusCode::Ok);
600+
Timeout::new(Duration::from_secs(1), &handle).unwrap()
601+
.from_err()
602+
});
603+
let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked")));
604+
core.run(res.join(rx).map(|r| r.0)).unwrap();
605+
606+
assert_eq!(closes.load(Ordering::Relaxed), 1);
607+
}
608+
609+
#[test]
610+
fn drop_client_closes_connection() {
611+
// https://github.com/hyperium/hyper/issues/1353
612+
let _ = pretty_env_logger::init();
595613

596-
assert_eq!(closes.load(Ordering::Relaxed), 1);
614+
let server = TcpListener::bind("127.0.0.1:0").unwrap();
615+
let addr = server.local_addr().unwrap();
616+
let mut core = Core::new().unwrap();
617+
let handle = core.handle();
618+
let closes = Arc::new(AtomicUsize::new(0));
619+
620+
let (tx1, rx1) = oneshot::channel();
621+
622+
thread::spawn(move || {
623+
let mut sock = server.accept().unwrap().0;
624+
sock.set_read_timeout(Some(Duration::from_secs(5))).unwrap();
625+
sock.set_write_timeout(Some(Duration::from_secs(5))).unwrap();
626+
let mut buf = [0; 4096];
627+
sock.read(&mut buf).expect("read 1");
628+
let body =[b'x'; 64];
629+
write!(sock, "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n", body.len()).expect("write head");
630+
let _ = sock.write_all(&body);
631+
let _ = tx1.send(());
632+
});
633+
634+
let uri = format!("http://{}/a", addr).parse().unwrap();
635+
636+
let res = {
637+
let client = Client::configure()
638+
.connector(DebugConnector(HttpConnector::new(1, &handle), closes.clone()))
639+
.no_proto()
640+
.build(&handle);
641+
client.get(uri).and_then(move |res| {
642+
assert_eq!(res.status(), hyper::StatusCode::Ok);
643+
res.body().concat2()
644+
}).and_then(|_| {
645+
Timeout::new(Duration::from_secs(1), &handle).unwrap()
646+
.from_err()
647+
})
648+
};
649+
// client is dropped
650+
let rx = rx1.map_err(|_| hyper::Error::Io(io::Error::new(io::ErrorKind::Other, "thread panicked")));
651+
core.run(res.join(rx).map(|r| r.0)).unwrap();
652+
653+
assert_eq!(closes.load(Ordering::Relaxed), 1);
654+
}
597655

598656

599657

0 commit comments

Comments
 (0)