diff --git a/benches/deku.rs b/benches/deku.rs index b047d7b0..b8b6fc7f 100644 --- a/benches/deku.rs +++ b/benches/deku.rs @@ -108,5 +108,40 @@ fn criterion_benchmark(c: &mut Criterion) { }); } -criterion_group!(benches, criterion_benchmark); +pub fn read_all_vs_count(c: &mut Criterion) { + #[derive(DekuRead, DekuWrite)] + pub struct AllWrapper { + #[deku(read_all)] + pub data: Vec, + } + + #[derive(DekuRead, DekuWrite)] + #[deku(ctx = "len: usize")] + pub struct CountWrapper { + #[deku(count = "len")] + pub data: Vec, + } + + c.bench_function("read_all_bytes", |b| { + b.iter(|| AllWrapper::from_bytes(black_box((&[1; 1500], 0)))) + }); + + c.bench_function("read_all", |b| { + b.iter(|| { + let mut cursor = [1u8; 1500].as_ref(); + let mut reader = Reader::new(&mut cursor); + AllWrapper::from_reader_with_ctx(black_box(&mut reader), ()) + }) + }); + + c.bench_function("count", |b| { + b.iter(|| { + let mut cursor = [1u8; 1500].as_ref(); + let mut reader = Reader::new(&mut cursor); + CountWrapper::from_reader_with_ctx(black_box(&mut reader), 1500) + }) + }); +} + +criterion_group!(benches, criterion_benchmark, read_all_vs_count); criterion_main!(benches); diff --git a/src/impls/primitive.rs b/src/impls/primitive.rs index 257f02e1..1fa05021 100644 --- a/src/impls/primitive.rs +++ b/src/impls/primitive.rs @@ -179,20 +179,9 @@ macro_rules! ImplDekuReadBytes { input: &BitSlice, (endian, size): (Endian, ByteSize), ) -> Result<(usize, Self), DekuError> { - let bit_size: usize = size.0 * 8; + let bit_size = BitSize(size.0 * 8); - let input_is_le = endian.is_le(); - - let bit_slice = &input[..bit_size]; - - let bytes = bit_slice.domain().region().unwrap().1; - let value = if input_is_le { - <$typ>::from_le_bytes(bytes.try_into()?) - } else { - <$typ>::from_be_bytes(bytes.try_into()?) - }; - - Ok((bit_size, value)) + <$typ>::read(input, (endian, bit_size)) } } diff --git a/src/reader.rs b/src/reader.rs index 19283e8a..50382c1d 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -22,11 +22,16 @@ pub enum ReaderRet { /// Max bits requested from [`Reader::read_bits`] during one call pub const MAX_BITS_AMT: usize = 128; +enum Leftover { + Byte(u8), + Bits(BitVec), +} + /// Reader to use with `from_reader_with_ctx` pub struct Reader<'a, R: Read> { inner: &'a mut R, /// bits stored from previous reads that didn't read to the end of a byte size - leftover: BitVec, + leftover: Option, /// Amount of bits read during the use of [read_bits](Reader::read_bits) and [read_bytes](Reader::read_bytes). pub bits_read: usize, } @@ -37,7 +42,7 @@ impl<'a, R: Read> Reader<'a, R> { pub fn new(inner: &'a mut R) -> Self { Self { inner, - leftover: BitVec::new(), // with_capacity 8? + leftover: None, bits_read: 0, } } @@ -81,7 +86,15 @@ impl<'a, R: Read> Reader<'a, R> { /// ``` #[inline] pub fn rest(&mut self) -> Vec { - self.leftover.iter().by_vals().collect() + match &self.leftover { + Some(Leftover::Bits(bits)) => bits.iter().by_vals().collect(), + Some(Leftover::Byte(byte)) => { + let bytes: &[u8] = &[*byte]; + let bits: BitVec = BitVec::try_from_slice(bytes).unwrap(); + bits.iter().by_vals().collect() + } + None => alloc::vec![], + } } /// Return true if we are at the end of a reader and there are no cached bits in the reader. @@ -90,7 +103,7 @@ impl<'a, R: Read> Reader<'a, R> { /// The byte that was read will be internally buffered #[inline] pub fn end(&mut self) -> bool { - if !self.leftover.is_empty() { + if self.leftover.is_some() { #[cfg(feature = "logging")] log::trace!("not end"); false @@ -104,10 +117,11 @@ impl<'a, R: Read> Reader<'a, R> { } } - // logic is best if we just turn this into bits right now - self.leftover = BitVec::try_from_slice(&buf).unwrap(); #[cfg(feature = "logging")] - log::trace!("not end"); + log::trace!("not end: read {:02x?}", &buf); + + self.bits_read += 8; + self.leftover = Some(Leftover::Byte(buf[0])); false } } @@ -145,19 +159,44 @@ impl<'a, R: Read> Reader<'a, R> { } let mut ret = BitVec::new(); - match amt.cmp(&self.leftover.len()) { + // if Leftover::Bytes exists, convert into Bits + if let Some(Leftover::Byte(byte)) = self.leftover { + let bytes: &[u8] = &[byte]; + let bits: BitVec = BitVec::try_from_slice(bytes).unwrap(); + self.leftover = Some(Leftover::Bits(bits)); + } + + let previous_len = match &self.leftover { + Some(Leftover::Bits(bits)) => bits.len(), + None => 0, + Some(Leftover::Byte(_)) => unreachable!(), + }; + + match amt.cmp(&previous_len) { // exact match, just use leftover Ordering::Equal => { - core::mem::swap(&mut ret, &mut self.leftover); - self.leftover.clear(); + if let Some(Leftover::Bits(bits)) = &mut self.leftover { + core::mem::swap(&mut ret, bits); + self.leftover = None; + } else { + unreachable!(); + } } // previous read was not enough to satisfy the amt requirement, return all previously Ordering::Greater => { // read bits - ret.extend_from_bitslice(&self.leftover); + match self.leftover { + Some(Leftover::Bits(ref bits)) => { + ret.extend_from_bitslice(bits); + } + Some(Leftover::Byte(_)) => { + unreachable!("what"); + } + None => {} + } // calculate the amount of bytes we need to read to read enough bits - let bits_left = amt - self.leftover.len(); + let bits_left = amt - ret.len(); let mut bytes_len = bits_left / 8; if (bits_left % 8) != 0 { bytes_len += 1; @@ -179,7 +218,7 @@ impl<'a, R: Read> Reader<'a, R> { // create bitslice and remove unused bits let rest = BitSlice::try_from_slice(read_buf).unwrap(); let (rest, not_needed) = rest.split_at(bits_left); - core::mem::swap(&mut not_needed.to_bitvec(), &mut self.leftover); + self.leftover = Some(Leftover::Bits(not_needed.to_bitvec())); // create return ret.extend_from_bitslice(rest); @@ -187,9 +226,14 @@ impl<'a, R: Read> Reader<'a, R> { // The entire bits we need to return have been already read previously from bytes but // not all were read, return required leftover bits Ordering::Less => { - let used = self.leftover.split_off(amt); - ret.extend_from_bitslice(&self.leftover); - self.leftover = used; + // read bits + if let Some(Leftover::Bits(bits)) = &mut self.leftover { + let used = bits.split_off(amt); + ret.extend_from_bitslice(bits); + self.leftover = Some(Leftover::Bits(used)); + } else { + unreachable!(); + } } } @@ -210,10 +254,8 @@ impl<'a, R: Read> Reader<'a, R> { pub fn read_bytes(&mut self, amt: usize, buf: &mut [u8]) -> Result { #[cfg(feature = "logging")] log::trace!("read_bytes: requesting {amt} bytes"); - if self.leftover.is_empty() { - if buf.len() < amt { - return Err(DekuError::Incomplete(NeedSize::new(amt * 8))); - } + + if self.leftover.is_none() { if let Err(e) = self.inner.read_exact(&mut buf[..amt]) { if e.kind() == ErrorKind::UnexpectedEof { return Err(DekuError::Incomplete(NeedSize::new(amt * 8))); @@ -226,12 +268,57 @@ impl<'a, R: Read> Reader<'a, R> { #[cfg(feature = "logging")] log::trace!("read_bytes: returning {:02x?}", &buf[..amt]); - Ok(ReaderRet::Bytes) - } else { - Ok(ReaderRet::Bits(self.read_bits(amt * 8)?)) + return Ok(ReaderRet::Bytes); + } + + // Trying to keep this not in the hot path + self.read_bytes_other(amt, buf) + } + + #[inline(never)] + fn read_bytes_other(&mut self, amt: usize, buf: &mut [u8]) -> Result { + match self.leftover { + Some(Leftover::Byte(byte)) => self.read_bytes_leftover(buf, byte, amt), + Some(Leftover::Bits(_)) => Ok(ReaderRet::Bits(self.read_bits(amt * 8)?)), + _ => unreachable!(), } } + #[inline(never)] + fn read_bytes_leftover( + &mut self, + buf: &mut [u8], + byte: u8, + amt: usize, + ) -> Result { + buf[0] = byte; + + #[cfg(feature = "logging")] + log::trace!("read_bytes_leftover: using previous read {:02x?}", &buf[0]); + + self.leftover = None; + let remaining = amt - 1; + let buf_len = buf.len(); + if buf_len < remaining { + return Err(DekuError::Incomplete(NeedSize::new(remaining * 8))); + } + if let Err(e) = self + .inner + .read_exact(&mut buf[amt - remaining..][..remaining]) + { + if e.kind() == ErrorKind::UnexpectedEof { + return Err(DekuError::Incomplete(NeedSize::new(remaining * 8))); + } + return Err(DekuError::Io(e.kind())); + } + self.bits_read += remaining * 8; + + #[cfg(feature = "logging")] + log::trace!("read_bytes_leftover: returning {:02x?}", &buf); + + Ok(ReaderRet::Bytes) + } + /// Attempt to read bytes from `Reader`. This will return `ReaderRet::Bytes` with a valid /// `buf` of bytes if we have no "leftover" bytes and thus are byte aligned. If we are not byte /// aligned, this will call `read_bits` and return `ReaderRet::Bits(_)` of size `N` * 8. @@ -244,8 +331,9 @@ impl<'a, R: Read> Reader<'a, R> { buf: &mut [u8; N], ) -> Result { #[cfg(feature = "logging")] - log::trace!("read_bytes: requesting {N} bytes"); - if self.leftover.is_empty() { + log::trace!("read_bytes_const: requesting {N} bytes"); + + if self.leftover.is_none() { if let Err(e) = self.inner.read_exact(buf) { if e.kind() == ErrorKind::UnexpectedEof { return Err(DekuError::Incomplete(NeedSize::new(N * 8))); @@ -256,13 +344,63 @@ impl<'a, R: Read> Reader<'a, R> { self.bits_read += N * 8; #[cfg(feature = "logging")] - log::trace!("read_bytes: returning {:02x?}", &buf); + log::trace!("read_bytes_const: returning {:02x?}", &buf); - Ok(ReaderRet::Bytes) - } else { - Ok(ReaderRet::Bits(self.read_bits(N * 8)?)) + return Ok(ReaderRet::Bytes); + } + + // Trying to keep this not in the hot path + self.read_bytes_const_other::(buf) + } + + #[inline(never)] + fn read_bytes_const_other( + &mut self, + buf: &mut [u8; N], + ) -> Result { + match self.leftover { + Some(Leftover::Byte(byte)) => self.read_bytes_const_leftover(buf, byte), + Some(Leftover::Bits(_)) => Ok(ReaderRet::Bits(self.read_bits(N * 8)?)), + _ => unreachable!(), } } + + #[inline(never)] + fn read_bytes_const_leftover( + &mut self, + buf: &mut [u8; N], + byte: u8, + ) -> Result { + buf[0] = byte; + + #[cfg(feature = "logging")] + log::trace!( + "read_bytes_const_leftover: using previous read {:02x?}", + &buf[0] + ); + + self.leftover = None; + let remaining = N - 1; + let buf_len = buf.len(); + if buf_len < remaining { + return Err(DekuError::Incomplete(NeedSize::new(remaining * 8))); + } + if let Err(e) = self + .inner + .read_exact(&mut buf[N - remaining..][..remaining]) + { + if e.kind() == ErrorKind::UnexpectedEof { + return Err(DekuError::Incomplete(NeedSize::new(remaining * 8))); + } + return Err(DekuError::Io(e.kind())); + } + self.bits_read += remaining * 8; + + #[cfg(feature = "logging")] + log::trace!("read_bytes_const_leftover: returning {:02x?}", &buf); + + Ok(ReaderRet::Bytes) + } } #[cfg(test)]