Skip to content

Commit

Permalink
Merge 833f696 into bd352e5
Browse files Browse the repository at this point in the history
  • Loading branch information
wcampbell0x2a authored May 22, 2024
2 parents bd352e5 + 833f696 commit d2397b4
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 43 deletions.
37 changes: 36 additions & 1 deletion benches/deku.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,5 +108,40 @@ fn criterion_benchmark(c: &mut Criterion) {
});
}

criterion_group!(benches, criterion_benchmark);
pub fn read_all_vs_count(c: &mut Criterion) {
#[derive(DekuRead, DekuWrite)]
pub struct AllWrapper {
#[deku(read_all)]
pub data: Vec<u8>,
}

#[derive(DekuRead, DekuWrite)]
#[deku(ctx = "len: usize")]
pub struct CountWrapper {
#[deku(count = "len")]
pub data: Vec<u8>,
}

c.bench_function("read_all_bytes", |b| {
b.iter(|| AllWrapper::from_bytes(black_box((&[1; 1500], 0))))
});

c.bench_function("read_all", |b| {
b.iter(|| {
let mut cursor = [1u8; 1500].as_ref();
let mut reader = Reader::new(&mut cursor);
AllWrapper::from_reader_with_ctx(black_box(&mut reader), ())
})
});

c.bench_function("count", |b| {
b.iter(|| {
let mut cursor = [1u8; 1500].as_ref();
let mut reader = Reader::new(&mut cursor);
CountWrapper::from_reader_with_ctx(black_box(&mut reader), 1500)
})
});
}

criterion_group!(benches, criterion_benchmark, read_all_vs_count);
criterion_main!(benches);
15 changes: 2 additions & 13 deletions src/impls/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,20 +179,9 @@ macro_rules! ImplDekuReadBytes {
input: &BitSlice<u8, Msb0>,
(endian, size): (Endian, ByteSize),
) -> Result<(usize, Self), DekuError> {
let bit_size: usize = size.0 * 8;
let bit_size = BitSize(size.0 * 8);

let input_is_le = endian.is_le();

let bit_slice = &input[..bit_size];

let bytes = bit_slice.domain().region().unwrap().1;
let value = if input_is_le {
<$typ>::from_le_bytes(bytes.try_into()?)
} else {
<$typ>::from_be_bytes(bytes.try_into()?)
};

Ok((bit_size, value))
<$typ>::read(input, (endian, bit_size))
}
}

Expand Down
196 changes: 167 additions & 29 deletions src/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,16 @@ pub enum ReaderRet {
/// Max bits requested from [`Reader::read_bits`] during one call
pub const MAX_BITS_AMT: usize = 128;

enum Leftover {
Byte(u8),
Bits(BitVec<u8, Msb0>),
}

/// Reader to use with `from_reader_with_ctx`
pub struct Reader<'a, R: Read> {
inner: &'a mut R,
/// bits stored from previous reads that didn't read to the end of a byte size
leftover: BitVec<u8, Msb0>,
leftover: Option<Leftover>,
/// Amount of bits read during the use of [read_bits](Reader::read_bits) and [read_bytes](Reader::read_bytes).
pub bits_read: usize,
}
Expand All @@ -37,7 +42,7 @@ impl<'a, R: Read> Reader<'a, R> {
pub fn new(inner: &'a mut R) -> Self {
Self {
inner,
leftover: BitVec::new(), // with_capacity 8?
leftover: None,
bits_read: 0,
}
}
Expand Down Expand Up @@ -81,7 +86,15 @@ impl<'a, R: Read> Reader<'a, R> {
/// ```
#[inline]
pub fn rest(&mut self) -> Vec<bool> {
self.leftover.iter().by_vals().collect()
match &self.leftover {
Some(Leftover::Bits(bits)) => bits.iter().by_vals().collect(),
Some(Leftover::Byte(byte)) => {
let bytes: &[u8] = &[*byte];
let bits: BitVec<u8, Msb0> = BitVec::try_from_slice(bytes).unwrap();
bits.iter().by_vals().collect()
}
None => alloc::vec![],
}
}

/// Return true if we are at the end of a reader and there are no cached bits in the reader.
Expand All @@ -90,7 +103,7 @@ impl<'a, R: Read> Reader<'a, R> {
/// The byte that was read will be internally buffered
#[inline]
pub fn end(&mut self) -> bool {
if !self.leftover.is_empty() {
if self.leftover.is_some() {
#[cfg(feature = "logging")]
log::trace!("not end");
false
Expand All @@ -104,10 +117,11 @@ impl<'a, R: Read> Reader<'a, R> {
}
}

// logic is best if we just turn this into bits right now
self.leftover = BitVec::try_from_slice(&buf).unwrap();
#[cfg(feature = "logging")]
log::trace!("not end");
log::trace!("not end: read {:02x?}", &buf);

self.bits_read += 8;
self.leftover = Some(Leftover::Byte(buf[0]));
false
}
}
Expand Down Expand Up @@ -145,19 +159,44 @@ impl<'a, R: Read> Reader<'a, R> {
}
let mut ret = BitVec::new();

match amt.cmp(&self.leftover.len()) {
// if Leftover::Bytes exists, convert into Bits
if let Some(Leftover::Byte(byte)) = self.leftover {
let bytes: &[u8] = &[byte];
let bits: BitVec<u8, Msb0> = BitVec::try_from_slice(bytes).unwrap();
self.leftover = Some(Leftover::Bits(bits));
}

let previous_len = match &self.leftover {
Some(Leftover::Bits(bits)) => bits.len(),
None => 0,
Some(Leftover::Byte(_)) => unreachable!(),
};

match amt.cmp(&previous_len) {
// exact match, just use leftover
Ordering::Equal => {
core::mem::swap(&mut ret, &mut self.leftover);
self.leftover.clear();
if let Some(Leftover::Bits(bits)) = &mut self.leftover {
core::mem::swap(&mut ret, bits);
self.leftover = None;
} else {
unreachable!();
}
}
// previous read was not enough to satisfy the amt requirement, return all previously
Ordering::Greater => {
// read bits
ret.extend_from_bitslice(&self.leftover);
match self.leftover {
Some(Leftover::Bits(ref bits)) => {
ret.extend_from_bitslice(bits);
}
Some(Leftover::Byte(_)) => {
unreachable!("what");
}
None => {}
}

// calculate the amount of bytes we need to read to read enough bits
let bits_left = amt - self.leftover.len();
let bits_left = amt - ret.len();
let mut bytes_len = bits_left / 8;
if (bits_left % 8) != 0 {
bytes_len += 1;
Expand All @@ -179,17 +218,22 @@ impl<'a, R: Read> Reader<'a, R> {
// create bitslice and remove unused bits
let rest = BitSlice::try_from_slice(read_buf).unwrap();
let (rest, not_needed) = rest.split_at(bits_left);
core::mem::swap(&mut not_needed.to_bitvec(), &mut self.leftover);
self.leftover = Some(Leftover::Bits(not_needed.to_bitvec()));

// create return
ret.extend_from_bitslice(rest);
}
// The entire bits we need to return have been already read previously from bytes but
// not all were read, return required leftover bits
Ordering::Less => {
let used = self.leftover.split_off(amt);
ret.extend_from_bitslice(&self.leftover);
self.leftover = used;
// read bits
if let Some(Leftover::Bits(bits)) = &mut self.leftover {
let used = bits.split_off(amt);
ret.extend_from_bitslice(bits);
self.leftover = Some(Leftover::Bits(used));
} else {
unreachable!();
}
}
}

Expand All @@ -210,10 +254,8 @@ impl<'a, R: Read> Reader<'a, R> {
pub fn read_bytes(&mut self, amt: usize, buf: &mut [u8]) -> Result<ReaderRet, DekuError> {
#[cfg(feature = "logging")]
log::trace!("read_bytes: requesting {amt} bytes");
if self.leftover.is_empty() {
if buf.len() < amt {
return Err(DekuError::Incomplete(NeedSize::new(amt * 8)));
}

if self.leftover.is_none() {
if let Err(e) = self.inner.read_exact(&mut buf[..amt]) {
if e.kind() == ErrorKind::UnexpectedEof {
return Err(DekuError::Incomplete(NeedSize::new(amt * 8)));
Expand All @@ -226,12 +268,57 @@ impl<'a, R: Read> Reader<'a, R> {
#[cfg(feature = "logging")]
log::trace!("read_bytes: returning {:02x?}", &buf[..amt]);

Ok(ReaderRet::Bytes)
} else {
Ok(ReaderRet::Bits(self.read_bits(amt * 8)?))
return Ok(ReaderRet::Bytes);
}

// Trying to keep this not in the hot path
self.read_bytes_other(amt, buf)
}

#[inline(never)]
fn read_bytes_other(&mut self, amt: usize, buf: &mut [u8]) -> Result<ReaderRet, DekuError> {
match self.leftover {
Some(Leftover::Byte(byte)) => self.read_bytes_leftover(buf, byte, amt),
Some(Leftover::Bits(_)) => Ok(ReaderRet::Bits(self.read_bits(amt * 8)?)),
_ => unreachable!(),
}
}

#[inline(never)]
fn read_bytes_leftover(
&mut self,
buf: &mut [u8],
byte: u8,
amt: usize,
) -> Result<ReaderRet, DekuError> {
buf[0] = byte;

#[cfg(feature = "logging")]
log::trace!("read_bytes_leftover: using previous read {:02x?}", &buf[0]);

self.leftover = None;
let remaining = amt - 1;
let buf_len = buf.len();
if buf_len < remaining {
return Err(DekuError::Incomplete(NeedSize::new(remaining * 8)));
}
if let Err(e) = self
.inner
.read_exact(&mut buf[amt - remaining..][..remaining])
{
if e.kind() == ErrorKind::UnexpectedEof {
return Err(DekuError::Incomplete(NeedSize::new(remaining * 8)));
}
return Err(DekuError::Io(e.kind()));
}
self.bits_read += remaining * 8;

#[cfg(feature = "logging")]
log::trace!("read_bytes_leftover: returning {:02x?}", &buf);

Ok(ReaderRet::Bytes)
}

/// Attempt to read bytes from `Reader`. This will return `ReaderRet::Bytes` with a valid
/// `buf` of bytes if we have no "leftover" bytes and thus are byte aligned. If we are not byte
/// aligned, this will call `read_bits` and return `ReaderRet::Bits(_)` of size `N` * 8.
Expand All @@ -244,8 +331,9 @@ impl<'a, R: Read> Reader<'a, R> {
buf: &mut [u8; N],
) -> Result<ReaderRet, DekuError> {
#[cfg(feature = "logging")]
log::trace!("read_bytes: requesting {N} bytes");
if self.leftover.is_empty() {
log::trace!("read_bytes_const: requesting {N} bytes");

if self.leftover.is_none() {
if let Err(e) = self.inner.read_exact(buf) {
if e.kind() == ErrorKind::UnexpectedEof {
return Err(DekuError::Incomplete(NeedSize::new(N * 8)));
Expand All @@ -256,13 +344,63 @@ impl<'a, R: Read> Reader<'a, R> {
self.bits_read += N * 8;

#[cfg(feature = "logging")]
log::trace!("read_bytes: returning {:02x?}", &buf);
log::trace!("read_bytes_const: returning {:02x?}", &buf);

Ok(ReaderRet::Bytes)
} else {
Ok(ReaderRet::Bits(self.read_bits(N * 8)?))
return Ok(ReaderRet::Bytes);
}

// Trying to keep this not in the hot path
self.read_bytes_const_other::<N>(buf)
}

#[inline(never)]
fn read_bytes_const_other<const N: usize>(
&mut self,
buf: &mut [u8; N],
) -> Result<ReaderRet, DekuError> {
match self.leftover {
Some(Leftover::Byte(byte)) => self.read_bytes_const_leftover(buf, byte),
Some(Leftover::Bits(_)) => Ok(ReaderRet::Bits(self.read_bits(N * 8)?)),
_ => unreachable!(),
}
}

#[inline(never)]
fn read_bytes_const_leftover<const N: usize>(
&mut self,
buf: &mut [u8; N],
byte: u8,
) -> Result<ReaderRet, DekuError> {
buf[0] = byte;

#[cfg(feature = "logging")]
log::trace!(
"read_bytes_const_leftover: using previous read {:02x?}",
&buf[0]
);

self.leftover = None;
let remaining = N - 1;
let buf_len = buf.len();
if buf_len < remaining {
return Err(DekuError::Incomplete(NeedSize::new(remaining * 8)));
}
if let Err(e) = self
.inner
.read_exact(&mut buf[N - remaining..][..remaining])
{
if e.kind() == ErrorKind::UnexpectedEof {
return Err(DekuError::Incomplete(NeedSize::new(remaining * 8)));
}
return Err(DekuError::Io(e.kind()));
}
self.bits_read += remaining * 8;

#[cfg(feature = "logging")]
log::trace!("read_bytes_const_leftover: returning {:02x?}", &buf);

Ok(ReaderRet::Bytes)
}
}

#[cfg(test)]
Expand Down

0 comments on commit d2397b4

Please # to comment.