diff --git a/benches/deku.rs b/benches/deku.rs index b047d7b0..b8b6fc7f 100644 --- a/benches/deku.rs +++ b/benches/deku.rs @@ -108,5 +108,40 @@ fn criterion_benchmark(c: &mut Criterion) { }); } -criterion_group!(benches, criterion_benchmark); +pub fn read_all_vs_count(c: &mut Criterion) { + #[derive(DekuRead, DekuWrite)] + pub struct AllWrapper { + #[deku(read_all)] + pub data: Vec, + } + + #[derive(DekuRead, DekuWrite)] + #[deku(ctx = "len: usize")] + pub struct CountWrapper { + #[deku(count = "len")] + pub data: Vec, + } + + c.bench_function("read_all_bytes", |b| { + b.iter(|| AllWrapper::from_bytes(black_box((&[1; 1500], 0)))) + }); + + c.bench_function("read_all", |b| { + b.iter(|| { + let mut cursor = [1u8; 1500].as_ref(); + let mut reader = Reader::new(&mut cursor); + AllWrapper::from_reader_with_ctx(black_box(&mut reader), ()) + }) + }); + + c.bench_function("count", |b| { + b.iter(|| { + let mut cursor = [1u8; 1500].as_ref(); + let mut reader = Reader::new(&mut cursor); + CountWrapper::from_reader_with_ctx(black_box(&mut reader), 1500) + }) + }); +} + +criterion_group!(benches, criterion_benchmark, read_all_vs_count); criterion_main!(benches); diff --git a/src/reader.rs b/src/reader.rs index 19283e8a..8fc7e9a6 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -22,11 +22,16 @@ pub enum ReaderRet { /// Max bits requested from [`Reader::read_bits`] during one call pub const MAX_BITS_AMT: usize = 128; +enum Leftover { + Bits(BitVec), + Bytes([u8; 1]), +} + /// Reader to use with `from_reader_with_ctx` pub struct Reader<'a, R: Read> { inner: &'a mut R, /// bits stored from previous reads that didn't read to the end of a byte size - leftover: BitVec, + leftover: Option, /// Amount of bits read during the use of [read_bits](Reader::read_bits) and [read_bytes](Reader::read_bytes). pub bits_read: usize, } @@ -37,7 +42,7 @@ impl<'a, R: Read> Reader<'a, R> { pub fn new(inner: &'a mut R) -> Self { Self { inner, - leftover: BitVec::new(), // with_capacity 8? + leftover: None, bits_read: 0, } } @@ -81,7 +86,14 @@ impl<'a, R: Read> Reader<'a, R> { /// ``` #[inline] pub fn rest(&mut self) -> Vec { - self.leftover.iter().by_vals().collect() + match &self.leftover { + Some(Leftover::Bits(bits)) => bits.iter().by_vals().collect(), + Some(Leftover::Bytes(bytes)) => { + let bits: BitVec = BitVec::try_from_slice(bytes).unwrap(); + bits.iter().by_vals().collect() + } + None => alloc::vec![], + } } /// Return true if we are at the end of a reader and there are no cached bits in the reader. @@ -90,7 +102,7 @@ impl<'a, R: Read> Reader<'a, R> { /// The byte that was read will be internally buffered #[inline] pub fn end(&mut self) -> bool { - if !self.leftover.is_empty() { + if self.leftover.is_some() { #[cfg(feature = "logging")] log::trace!("not end"); false @@ -104,10 +116,10 @@ impl<'a, R: Read> Reader<'a, R> { } } - // logic is best if we just turn this into bits right now - self.leftover = BitVec::try_from_slice(&buf).unwrap(); #[cfg(feature = "logging")] - log::trace!("not end"); + log::trace!("not end: read {:02x?}", &buf); + + self.leftover = Some(Leftover::Bytes(buf)); false } } @@ -145,19 +157,43 @@ impl<'a, R: Read> Reader<'a, R> { } let mut ret = BitVec::new(); - match amt.cmp(&self.leftover.len()) { + // if Leftover::Bytes exists, convert into Bits + if let Some(Leftover::Bytes(bytes)) = self.leftover { + let bits: BitVec = BitVec::try_from_slice(&bytes).unwrap(); + self.leftover = Some(Leftover::Bits(bits)); + } + + let previous_len = match &self.leftover { + Some(Leftover::Bits(bits)) => bits.len(), + None => 0, + Some(Leftover::Bytes(_)) => unreachable!(), + }; + + match amt.cmp(&previous_len) { // exact match, just use leftover Ordering::Equal => { - core::mem::swap(&mut ret, &mut self.leftover); - self.leftover.clear(); + if let Some(Leftover::Bits(bits)) = &mut self.leftover { + core::mem::swap(&mut ret, bits); + self.leftover = None; + } else { + unreachable!(); + } } // previous read was not enough to satisfy the amt requirement, return all previously Ordering::Greater => { // read bits - ret.extend_from_bitslice(&self.leftover); + match self.leftover { + Some(Leftover::Bits(ref bits)) => { + ret.extend_from_bitslice(&bits); + } + Some(Leftover::Bytes(_)) => { + unreachable!("what"); + } + None => {} + } // calculate the amount of bytes we need to read to read enough bits - let bits_left = amt - self.leftover.len(); + let bits_left = amt - ret.len(); let mut bytes_len = bits_left / 8; if (bits_left % 8) != 0 { bytes_len += 1; @@ -179,7 +215,7 @@ impl<'a, R: Read> Reader<'a, R> { // create bitslice and remove unused bits let rest = BitSlice::try_from_slice(read_buf).unwrap(); let (rest, not_needed) = rest.split_at(bits_left); - core::mem::swap(&mut not_needed.to_bitvec(), &mut self.leftover); + self.leftover = Some(Leftover::Bits(not_needed.to_bitvec())); // create return ret.extend_from_bitslice(rest); @@ -187,9 +223,14 @@ impl<'a, R: Read> Reader<'a, R> { // The entire bits we need to return have been already read previously from bytes but // not all were read, return required leftover bits Ordering::Less => { - let used = self.leftover.split_off(amt); - ret.extend_from_bitslice(&self.leftover); - self.leftover = used; + // read bits + if let Some(Leftover::Bits(bits)) = &mut self.leftover { + let used = bits.split_off(amt); + ret.extend_from_bitslice(&bits); + self.leftover = Some(Leftover::Bits(used)); + } else { + unreachable!(); + } } } @@ -210,26 +251,43 @@ impl<'a, R: Read> Reader<'a, R> { pub fn read_bytes(&mut self, amt: usize, buf: &mut [u8]) -> Result { #[cfg(feature = "logging")] log::trace!("read_bytes: requesting {amt} bytes"); - if self.leftover.is_empty() { - if buf.len() < amt { - return Err(DekuError::Incomplete(NeedSize::new(amt * 8))); - } - if let Err(e) = self.inner.read_exact(&mut buf[..amt]) { - if e.kind() == ErrorKind::UnexpectedEof { - return Err(DekuError::Incomplete(NeedSize::new(amt * 8))); - } - return Err(DekuError::Io(e.kind())); - } - self.bits_read += amt * 8; + // previous read of bits? goto read_bits + if let Some(Leftover::Bits(_)) = &self.leftover { + return Ok(ReaderRet::Bits(self.read_bits(amt * 8)?)); + } + // previous read of bytes? use that one byte + let remaining = if let Some(Leftover::Bytes(bytes)) = self.leftover { + buf[0] = bytes[0]; #[cfg(feature = "logging")] - log::trace!("read_bytes: returning {:02x?}", &buf[..amt]); - - Ok(ReaderRet::Bytes) + log::trace!("read_bytes: using previous read {:02x?}", &buf[0]); + self.leftover = None; + amt - 1 } else { - Ok(ReaderRet::Bits(self.read_bits(amt * 8)?)) + amt + }; + + let buf_len = buf.len(); + if buf_len < remaining { + return Err(DekuError::Incomplete(NeedSize::new(remaining * 8))); } + if let Err(e) = self + .inner + .read_exact(&mut buf[amt - remaining..][..remaining]) + { + if e.kind() == ErrorKind::UnexpectedEof { + return Err(DekuError::Incomplete(NeedSize::new(remaining * 8))); + } + return Err(DekuError::Io(e.kind())); + } + + self.bits_read += remaining * 8; + + #[cfg(feature = "logging")] + log::trace!("read_bytes: returning {:02x?}", &buf[..amt]); + + Ok(ReaderRet::Bytes) } /// Attempt to read bytes from `Reader`. This will return `ReaderRet::Bytes` with a valid @@ -244,24 +302,44 @@ impl<'a, R: Read> Reader<'a, R> { buf: &mut [u8; N], ) -> Result { #[cfg(feature = "logging")] - log::trace!("read_bytes: requesting {N} bytes"); - if self.leftover.is_empty() { - if let Err(e) = self.inner.read_exact(buf) { - if e.kind() == ErrorKind::UnexpectedEof { - return Err(DekuError::Incomplete(NeedSize::new(N * 8))); - } - return Err(DekuError::Io(e.kind())); - } + log::trace!("read_bytes_const: requesting {N} bytes"); - self.bits_read += N * 8; + // previous read of bits? goto read_bits + if let Some(Leftover::Bits(_)) = &self.leftover { + return Ok(ReaderRet::Bits(self.read_bits(N * 8)?)); + } + // previous read of bytes? use that one byte + let remaining = if let Some(Leftover::Bytes(bytes)) = self.leftover { + buf[0] = bytes[0]; #[cfg(feature = "logging")] - log::trace!("read_bytes: returning {:02x?}", &buf); - - Ok(ReaderRet::Bytes) + log::trace!("read_bytes_const: using previous read {:02x?}", &buf[0]); + self.leftover = None; + N - 1 } else { - Ok(ReaderRet::Bits(self.read_bits(N * 8)?)) + N + }; + + let buf_len = buf.len(); + if buf_len < remaining { + return Err(DekuError::Incomplete(NeedSize::new(remaining * 8))); + } + if let Err(e) = self + .inner + .read_exact(&mut buf[N - remaining..][..remaining]) + { + if e.kind() == ErrorKind::UnexpectedEof { + return Err(DekuError::Incomplete(NeedSize::new(remaining * 8))); + } + return Err(DekuError::Io(e.kind())); } + + self.bits_read += remaining * 8; + + #[cfg(feature = "logging")] + log::trace!("read_bytes_const: returning {:02x?}", &buf[..N]); + + Ok(ReaderRet::Bytes) } }