Skip to content

Commit 3601bb7

Browse files
authored
Add async LineWriter (#2477)
1 parent ee23679 commit 3601bb7

File tree

4 files changed

+295
-1
lines changed

4 files changed

+295
-1
lines changed

futures-util/src/io/buf_writer.rs

+64-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use pin_project_lite::pin_project;
66
use std::fmt;
77
use std::io::{self, Write};
88
use std::pin::Pin;
9+
use std::ptr;
910

1011
pin_project! {
1112
/// Wraps a writer and buffers its output.
@@ -49,7 +50,7 @@ impl<W: AsyncWrite> BufWriter<W> {
4950
Self { inner, buf: Vec::with_capacity(cap), written: 0 }
5051
}
5152

52-
fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
53+
pub(super) fn flush_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
5354
let mut this = self.project();
5455

5556
let len = this.buf.len();
@@ -83,6 +84,68 @@ impl<W: AsyncWrite> BufWriter<W> {
8384
pub fn buffer(&self) -> &[u8] {
8485
&self.buf
8586
}
87+
88+
/// Capacity of `buf`. how many chars can be held in buffer
89+
pub(super) fn capacity(&self) -> usize {
90+
self.buf.capacity()
91+
}
92+
93+
/// Remaining number of bytes to reach `buf` 's capacity
94+
#[inline]
95+
pub(super) fn spare_capacity(&self) -> usize {
96+
self.buf.capacity() - self.buf.len()
97+
}
98+
99+
/// Write a byte slice directly into buffer
100+
///
101+
/// Will truncate the number of bytes written to `spare_capacity()` so you want to
102+
/// calculate the size of your slice to avoid losing bytes
103+
///
104+
/// Based on `std::io::BufWriter`
105+
pub(super) fn write_to_buf(self: Pin<&mut Self>, buf: &[u8]) -> usize {
106+
let available = self.spare_capacity();
107+
let amt_to_buffer = available.min(buf.len());
108+
109+
// SAFETY: `amt_to_buffer` is <= buffer's spare capacity by construction.
110+
unsafe {
111+
self.write_to_buffer_unchecked(&buf[..amt_to_buffer]);
112+
}
113+
114+
amt_to_buffer
115+
}
116+
117+
/// Write byte slice directly into `self.buf`
118+
///
119+
/// Based on `std::io::BufWriter`
120+
#[inline]
121+
unsafe fn write_to_buffer_unchecked(self: Pin<&mut Self>, buf: &[u8]) {
122+
debug_assert!(buf.len() <= self.spare_capacity());
123+
let this = self.project();
124+
let old_len = this.buf.len();
125+
let buf_len = buf.len();
126+
let src = buf.as_ptr();
127+
let dst = this.buf.as_mut_ptr().add(old_len);
128+
ptr::copy_nonoverlapping(src, dst, buf_len);
129+
this.buf.set_len(old_len + buf_len);
130+
}
131+
132+
/// Write directly using `inner`, bypassing buffering
133+
pub(super) fn inner_poll_write(
134+
self: Pin<&mut Self>,
135+
cx: &mut Context<'_>,
136+
buf: &[u8],
137+
) -> Poll<io::Result<usize>> {
138+
self.project().inner.poll_write(cx, buf)
139+
}
140+
141+
/// Write directly using `inner`, bypassing buffering
142+
pub(super) fn inner_poll_write_vectored(
143+
self: Pin<&mut Self>,
144+
cx: &mut Context<'_>,
145+
bufs: &[IoSlice<'_>],
146+
) -> Poll<io::Result<usize>> {
147+
self.project().inner.poll_write_vectored(cx, bufs)
148+
}
86149
}
87150

88151
impl<W: AsyncWrite> AsyncWrite for BufWriter<W> {

futures-util/src/io/line_writer.rs

+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
use super::buf_writer::BufWriter;
2+
use futures_core::ready;
3+
use futures_core::task::{Context, Poll};
4+
use futures_io::AsyncWrite;
5+
use futures_io::IoSlice;
6+
use pin_project_lite::pin_project;
7+
use std::io;
8+
use std::pin::Pin;
9+
10+
pin_project! {
11+
/// Wrap a writer, like [`BufWriter`] does, but prioritizes buffering lines
12+
///
13+
/// This was written based on `std::io::LineWriter` which goes into further details
14+
/// explaining the code.
15+
///
16+
/// Buffering is actually done using `BufWriter`. This class will leverage `BufWriter`
17+
/// to write on-each-line.
18+
#[derive(Debug)]
19+
pub struct LineWriter<W: AsyncWrite> {
20+
#[pin]
21+
buf_writer: BufWriter<W>,
22+
}
23+
}
24+
25+
impl<W: AsyncWrite> LineWriter<W> {
26+
/// Create a new `LineWriter` with default buffer capacity. The default is currently 1KB
27+
/// which was taken from `std::io::LineWriter`
28+
pub fn new(inner: W) -> LineWriter<W> {
29+
LineWriter::with_capacity(1024, inner)
30+
}
31+
32+
/// Creates a new `LineWriter` with the specified buffer capacity.
33+
pub fn with_capacity(capacity: usize, inner: W) -> LineWriter<W> {
34+
LineWriter { buf_writer: BufWriter::with_capacity(capacity, inner) }
35+
}
36+
37+
/// Flush `buf_writer` if last char is "new line"
38+
fn flush_if_completed_line(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
39+
let this = self.project();
40+
match this.buf_writer.buffer().last().copied() {
41+
Some(b'\n') => this.buf_writer.flush_buf(cx),
42+
_ => Poll::Ready(Ok(())),
43+
}
44+
}
45+
46+
/// Returns a reference to `buf_writer`'s internally buffered data.
47+
pub fn buffer(&self) -> &[u8] {
48+
self.buf_writer.buffer()
49+
}
50+
51+
/// Acquires a reference to the underlying sink or stream that this combinator is
52+
/// pulling from.
53+
pub fn get_ref(&self) -> &W {
54+
self.buf_writer.get_ref()
55+
}
56+
}
57+
58+
impl<W: AsyncWrite> AsyncWrite for LineWriter<W> {
59+
fn poll_write(
60+
mut self: Pin<&mut Self>,
61+
cx: &mut Context<'_>,
62+
buf: &[u8],
63+
) -> Poll<io::Result<usize>> {
64+
let mut this = self.as_mut().project();
65+
let newline_index = match memchr::memrchr(b'\n', buf) {
66+
None => {
67+
ready!(self.as_mut().flush_if_completed_line(cx)?);
68+
return self.project().buf_writer.poll_write(cx, buf);
69+
}
70+
Some(newline_index) => newline_index + 1,
71+
};
72+
73+
ready!(this.buf_writer.as_mut().poll_flush(cx)?);
74+
75+
let lines = &buf[..newline_index];
76+
77+
let flushed = { ready!(this.buf_writer.as_mut().inner_poll_write(cx, lines))? };
78+
79+
if flushed == 0 {
80+
return Poll::Ready(Ok(0));
81+
}
82+
83+
let tail = if flushed >= newline_index {
84+
&buf[flushed..]
85+
} else if newline_index - flushed <= this.buf_writer.capacity() {
86+
&buf[flushed..newline_index]
87+
} else {
88+
let scan_area = &buf[flushed..];
89+
let scan_area = &scan_area[..this.buf_writer.capacity()];
90+
match memchr::memrchr(b'\n', scan_area) {
91+
Some(newline_index) => &scan_area[..newline_index + 1],
92+
None => scan_area,
93+
}
94+
};
95+
96+
let buffered = this.buf_writer.as_mut().write_to_buf(tail);
97+
Poll::Ready(Ok(flushed + buffered))
98+
}
99+
100+
fn poll_write_vectored(
101+
mut self: Pin<&mut Self>,
102+
cx: &mut Context<'_>,
103+
bufs: &[IoSlice<'_>],
104+
) -> Poll<io::Result<usize>> {
105+
let mut this = self.as_mut().project();
106+
// `is_write_vectored()` is handled in original code, but not in this crate
107+
// see https://github.com/rust-lang/rust/issues/70436
108+
109+
let last_newline_buf_idx = bufs
110+
.iter()
111+
.enumerate()
112+
.rev()
113+
.find_map(|(i, buf)| memchr::memchr(b'\n', buf).map(|_| i));
114+
let last_newline_buf_idx = match last_newline_buf_idx {
115+
None => {
116+
ready!(self.as_mut().flush_if_completed_line(cx)?);
117+
return self.project().buf_writer.poll_write_vectored(cx, bufs);
118+
}
119+
Some(i) => i,
120+
};
121+
122+
ready!(this.buf_writer.as_mut().poll_flush(cx)?);
123+
124+
let (lines, tail) = bufs.split_at(last_newline_buf_idx + 1);
125+
126+
let flushed = { ready!(this.buf_writer.as_mut().inner_poll_write_vectored(cx, lines))? };
127+
if flushed == 0 {
128+
return Poll::Ready(Ok(0));
129+
}
130+
131+
let lines_len = lines.iter().map(|buf| buf.len()).sum();
132+
if flushed < lines_len {
133+
return Poll::Ready(Ok(flushed));
134+
}
135+
136+
let buffered: usize = tail
137+
.iter()
138+
.filter(|buf| !buf.is_empty())
139+
.map(|buf| this.buf_writer.as_mut().write_to_buf(buf))
140+
.take_while(|&n| n > 0)
141+
.sum();
142+
143+
Poll::Ready(Ok(flushed + buffered))
144+
}
145+
146+
/// Forward to `buf_writer` 's `BufWriter::poll_flush()`
147+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
148+
self.as_mut().project().buf_writer.poll_flush(cx)
149+
}
150+
151+
/// Forward to `buf_writer` 's `BufWriter::poll_close()`
152+
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
153+
self.as_mut().project().buf_writer.poll_close(cx)
154+
}
155+
}

futures-util/src/io/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ pub use self::buf_reader::{BufReader, SeeKRelative};
6161
mod buf_writer;
6262
pub use self::buf_writer::BufWriter;
6363

64+
mod line_writer;
65+
pub use self::line_writer::LineWriter;
66+
6467
mod chain;
6568
pub use self::chain::Chain;
6669

futures/tests/io_line_writer.rs

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
use futures::executor::block_on;
2+
use futures::io::{AsyncWriteExt, LineWriter};
3+
use std::io;
4+
5+
#[test]
6+
fn line_writer() {
7+
let mut writer = LineWriter::new(Vec::new());
8+
9+
block_on(writer.write(&[0])).unwrap();
10+
assert_eq!(*writer.get_ref(), []);
11+
12+
block_on(writer.write(&[1])).unwrap();
13+
assert_eq!(*writer.get_ref(), []);
14+
15+
block_on(writer.flush()).unwrap();
16+
assert_eq!(*writer.get_ref(), [0, 1]);
17+
18+
block_on(writer.write(&[0, b'\n', 1, b'\n', 2])).unwrap();
19+
assert_eq!(*writer.get_ref(), [0, 1, 0, b'\n', 1, b'\n']);
20+
21+
block_on(writer.flush()).unwrap();
22+
assert_eq!(*writer.get_ref(), [0, 1, 0, b'\n', 1, b'\n', 2]);
23+
24+
block_on(writer.write(&[3, b'\n'])).unwrap();
25+
assert_eq!(*writer.get_ref(), [0, 1, 0, b'\n', 1, b'\n', 2, 3, b'\n']);
26+
}
27+
28+
#[test]
29+
fn line_vectored() {
30+
let mut line_writer = LineWriter::new(Vec::new());
31+
assert_eq!(
32+
block_on(line_writer.write_vectored(&[
33+
io::IoSlice::new(&[]),
34+
io::IoSlice::new(b"\n"),
35+
io::IoSlice::new(&[]),
36+
io::IoSlice::new(b"a"),
37+
]))
38+
.unwrap(),
39+
2
40+
);
41+
assert_eq!(line_writer.get_ref(), b"\n");
42+
43+
assert_eq!(
44+
block_on(line_writer.write_vectored(&[
45+
io::IoSlice::new(&[]),
46+
io::IoSlice::new(b"b"),
47+
io::IoSlice::new(&[]),
48+
io::IoSlice::new(b"a"),
49+
io::IoSlice::new(&[]),
50+
io::IoSlice::new(b"c"),
51+
]))
52+
.unwrap(),
53+
3
54+
);
55+
assert_eq!(line_writer.get_ref(), b"\n");
56+
block_on(line_writer.flush()).unwrap();
57+
assert_eq!(line_writer.get_ref(), b"\nabac");
58+
assert_eq!(block_on(line_writer.write_vectored(&[])).unwrap(), 0);
59+
60+
assert_eq!(
61+
block_on(line_writer.write_vectored(&[
62+
io::IoSlice::new(&[]),
63+
io::IoSlice::new(&[]),
64+
io::IoSlice::new(&[]),
65+
io::IoSlice::new(&[]),
66+
]))
67+
.unwrap(),
68+
0
69+
);
70+
71+
assert_eq!(block_on(line_writer.write_vectored(&[io::IoSlice::new(b"a\nb")])).unwrap(), 3);
72+
assert_eq!(line_writer.get_ref(), b"\nabaca\nb");
73+
}

0 commit comments

Comments
 (0)