diff --git a/src/constants.rs b/src/constants.rs index fcff4d93..2600f4fb 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -65,3 +65,8 @@ pub const MAX_I32_SCALE: i32 = 9; pub const MAX_I64_SCALE: u32 = 19; #[cfg(not(feature = "legacy-ops"))] pub const U32_MAX: u64 = u32::MAX as u64; + +// Determines potential overflow for 128 bit operations +pub const OVERFLOW_U96: u128 = 1u128 << 96; +pub const WILL_OVERFLOW_U64: u64 = u64::MAX / 10 - u8::MAX as u64; +pub const BYTES_TO_OVERFLOW_U64: usize = 18; // We can probably get away with less diff --git a/src/decimal.rs b/src/decimal.rs index 55c71d8a..8e340860 100644 --- a/src/decimal.rs +++ b/src/decimal.rs @@ -83,7 +83,7 @@ const NEGATIVE_ONE: Decimal = Decimal { /// `UnpackedDecimal` contains unpacked representation of `Decimal` where each component /// of decimal-format stored in it's own field -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] pub struct UnpackedDecimal { pub negative: bool, pub scale: u32, diff --git a/src/error.rs b/src/error.rs index 0e7613f5..dac90557 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,4 @@ -use crate::constants::MAX_PRECISION_U32; +use crate::{constants::MAX_PRECISION_U32, Decimal}; use alloc::string::String; use core::fmt; @@ -21,6 +21,11 @@ where } } +#[cold] +pub(crate) fn tail_error(from: &'static str) -> Result { + Err(from.into()) +} + #[cfg(feature = "std")] impl std::error::Error for Error {} diff --git a/src/ops/array.rs b/src/ops/array.rs index a35cd5b2..fbea1e21 100644 --- a/src/ops/array.rs +++ b/src/ops/array.rs @@ -54,6 +54,7 @@ pub(crate) fn rescale_internal(value: &mut [u32; 3], value_scale: &mut u32, new_ } } +#[cfg(feature = "legacy-ops")] pub(crate) fn add_by_internal(value: &mut [u32], by: &[u32]) -> u32 { let mut carry: u64 = 0; let vl = value.len(); @@ -92,6 +93,25 @@ pub(crate) fn add_by_internal(value: &mut [u32], by: &[u32]) -> u32 { carry as u32 } +pub(crate) fn add_by_internal_flattened(value: &mut [u32; 3], by: u32) -> u32 { + let mut carry: u64; + let mut sum: u64; + sum = u64::from(value[0]) + u64::from(by); + value[0] = (sum & U32_MASK) as u32; + carry = sum >> 32; + if carry > 0 { + sum = u64::from(value[1]) + carry; + value[1] = (sum & U32_MASK) as u32; + carry = sum >> 32; + if carry > 0 { + sum = u64::from(value[2]) + carry; + value[2] = (sum & U32_MASK) as u32; + carry = sum >> 32; + } + } + carry as u32 +} + #[inline] pub(crate) fn add_one_internal(value: &mut [u32]) -> u32 { let mut carry: u64 = 1; // Start with one, since adding one diff --git a/src/str.rs b/src/str.rs index 531be53a..500cb604 100644 --- a/src/str.rs +++ b/src/str.rs @@ -1,7 +1,7 @@ use crate::{ - constants::{MAX_PRECISION, MAX_STR_BUFFER_SIZE}, - error::Error, - ops::array::{add_by_internal, add_one_internal, div_by_u32, is_all_zero, mul_by_10, mul_by_u32}, + constants::{BYTES_TO_OVERFLOW_U64, MAX_PRECISION, MAX_STR_BUFFER_SIZE, OVERFLOW_U96, WILL_OVERFLOW_U64}, + error::{tail_error, Error}, + ops::array::{add_by_internal_flattened, add_one_internal, div_by_u32, is_all_zero, mul_by_u32}, Decimal, }; @@ -122,160 +122,236 @@ pub(crate) fn fmt_scientific_notation( } // dedicated implementation for the most common case. +#[inline] pub(crate) fn parse_str_radix_10(str: &str) -> Result { - if str.is_empty() { - return Err(Error::from("Invalid decimal: empty")); + let bytes = str.as_bytes(); + // handle the sign + + if bytes.len() < BYTES_TO_OVERFLOW_U64 { + parse_str_radix_10_dispatch::(bytes) + } else { + parse_str_radix_10_dispatch::(bytes) } +} - let mut offset = 0; - let mut len = str.len(); - let bytes = str.as_bytes(); - let mut negative = false; // assume positive +#[inline] +fn parse_str_radix_10_dispatch(bytes: &[u8]) -> Result { + match bytes { + [b, rest @ ..] => byte_dispatch_u64::(rest, 0, 0, *b), + [] => tail_error("Invalid decimal: empty"), + } +} - // handle the sign - if bytes[offset] == b'-' { - negative = true; // leading minus means negative - offset += 1; - len -= 1; - } else if bytes[offset] == b'+' { - // leading + allowed - offset += 1; - len -= 1; +#[inline] +fn overflow_64(val: u64) -> bool { + val >= WILL_OVERFLOW_U64 +} + +#[inline] +pub fn overflow_128(val: u128) -> bool { + val >= OVERFLOW_U96 +} + +#[inline] +fn dispatch_next( + bytes: &[u8], + data64: u64, + scale: u8, +) -> Result { + if let Some((next, bytes)) = bytes.split_first() { + byte_dispatch_u64::(bytes, data64, scale, *next) + } else { + handle_data::(data64 as u128, scale) } +} - // should now be at numeric part of the significand - let mut digits_before_dot: i32 = -1; // digits before '.', -1 if no '.' - let mut coeff = ArrayVec::<_, MAX_STR_BUFFER_SIZE>::new(); // integer significand array +#[inline(never)] +fn non_digit_dispatch_u64( + bytes: &[u8], + data64: u64, + scale: u8, + b: u8, +) -> Result { + match b { + b'-' if FIRST && !HAS => dispatch_next::(bytes, data64, scale), + b'+' if FIRST && !HAS => dispatch_next::(bytes, data64, scale), + b'_' if HAS => handle_separator::(bytes, data64, scale), + b => tail_invalid_digit(b), + } +} - let mut maybe_round = false; - while len > 0 { - let b = bytes[offset]; - match b { - b'0'..=b'9' => { - coeff.push(u32::from(b - b'0')); - offset += 1; - len -= 1; +#[inline] +fn byte_dispatch_u64( + bytes: &[u8], + data64: u64, + scale: u8, + b: u8, +) -> Result { + match b { + b'0'..=b'9' => handle_digit_64::(bytes, data64, scale, b - b'0'), + b'.' if !POINT => handle_point::(bytes, data64, scale), + b => non_digit_dispatch_u64::(bytes, data64, scale, b), + } +} - // If the coefficient is longer than the max, exit early - if coeff.len() as u32 > 28 { - maybe_round = true; - break; - } - } - b'.' => { - if digits_before_dot >= 0 { - return Err(Error::from("Invalid decimal: two decimal points")); - } - digits_before_dot = coeff.len() as i32; - offset += 1; - len -= 1; - } - b'_' => { - // Must start with a number... - if coeff.is_empty() { - return Err(Error::from("Invalid decimal: must start lead with a number")); - } - offset += 1; - len -= 1; - } - _ => return Err(Error::from("Invalid decimal: unknown character")), +#[inline(never)] +fn handle_digit_64( + bytes: &[u8], + data64: u64, + scale: u8, + digit: u8, +) -> Result { + // we have already validated that we cannot overflow + let data64 = data64 * 10 + digit as u64; + let scale = if POINT { scale + 1 } else { 0 }; + + if let Some((next, bytes)) = bytes.split_first() { + let next = *next; + if POINT && BIG && scale >= 28 { + maybe_round(data64 as u128, next, scale, POINT, NEG) + } else if BIG && overflow_64(data64) { + handle_full_128::(data64 as u128, bytes, scale, next) + } else { + byte_dispatch_u64::(bytes, data64, scale, next) } + } else { + let data: u128 = data64 as u128; + + handle_data::(data, scale) } +} - // If we exited before the end of the string then do some rounding if necessary - if maybe_round && offset < bytes.len() { - let next_byte = bytes[offset]; - let digit = match next_byte { - b'0'..=b'9' => u32::from(next_byte - b'0'), - b'_' => 0, - b'.' => { - // Still an error if we have a second dp - if digits_before_dot >= 0 { - return Err(Error::from("Invalid decimal: two decimal points")); +#[inline(never)] +fn handle_point( + bytes: &[u8], + data64: u64, + scale: u8, +) -> Result { + dispatch_next::(bytes, data64, scale) +} + +#[inline(never)] +fn handle_separator( + bytes: &[u8], + data64: u64, + scale: u8, +) -> Result { + dispatch_next::(bytes, data64, scale) +} + +#[inline(never)] +#[cold] +fn tail_invalid_digit(digit: u8) -> Result { + match digit { + b'.' => tail_error("Invalid decimal: two decimal points"), + b'_' => tail_error("Invalid decimal: must start lead with a number"), + _ => tail_error("Invalid decimal: unknown character"), + } +} + +#[inline(never)] +#[cold] +fn handle_full_128( + mut data: u128, + bytes: &[u8], + scale: u8, + next_byte: u8, +) -> Result { + let b = next_byte; + match b { + b'0'..=b'9' => { + let digit = u32::from(b - b'0'); + + // If the data is going to overflow then we should go into recovery mode + let next = (data * 10) + digit as u128; + if overflow_128(next) { + if !POINT { + return tail_error("Invalid decimal: overflow from too many digits"); } - 0 - } - _ => return Err(Error::from("Invalid decimal: unknown character")), - }; - // Round at midpoint - if digit >= 5 { - let mut index = coeff.len() - 1; - loop { - let new_digit = coeff[index] + 1; - if new_digit <= 9 { - coeff[index] = new_digit; - break; - } else { - coeff[index] = 0; - if index == 0 { - coeff.insert(0, 1u32); - digits_before_dot += 1; - coeff.pop(); - break; + if digit >= 5 { + data += 1; + } + handle_data::(data, scale) + } else { + data = next; + let scale = scale + POINT as u8; + if let Some((next, bytes)) = bytes.split_first() { + let next = *next; + if POINT && scale >= 28 { + maybe_round(data, next, scale, POINT, NEG) + } else { + handle_full_128::(data, bytes, scale, next) } + } else { + handle_data::(data, scale) } - index -= 1; } } + b'.' if !POINT => { + // This call won't tail? + if let Some((next, bytes)) = bytes.split_first() { + handle_full_128::(data, bytes, scale, *next) + } else { + handle_data::(data, scale) + } + } + b'_' => { + if let Some((next, bytes)) = bytes.split_first() { + handle_full_128::(data, bytes, scale, *next) + } else { + handle_data::(data, scale) + } + } + b => tail_invalid_digit(b), } +} - // here when no characters left - if coeff.is_empty() { - return Err(Error::from("Invalid decimal: no digits found")); +#[inline(never)] +#[cold] +fn maybe_round(mut data: u128, next_byte: u8, scale: u8, point: bool, negative: bool) -> Result { + let digit = match next_byte { + b'0'..=b'9' => u32::from(next_byte - b'0'), + b'_' => 0, // this should be an invalid string? + b'.' if point => 0, + b => return tail_invalid_digit(b), + }; + + // Round at midpoint + if digit >= 5 { + data += 1; + if overflow_128(data) { + // Highly unlikely scenario which is more indicative of a bug + return tail_error("Invalid decimal: overflow when rounding"); + } } - let mut scale = if digits_before_dot >= 0 { - // we had a decimal place so set the scale - (coeff.len() as u32) - (digits_before_dot as u32) + if negative { + handle_data::(data, scale) } else { - 0 - }; + handle_data::(data, scale) + } +} - let mut data = [0u32, 0u32, 0u32]; - let mut tmp = [0u32, 0u32, 0u32]; - let len = coeff.len(); - for (i, digit) in coeff.iter().enumerate() { - // If the data is going to overflow then we should go into recovery mode - tmp[0] = data[0]; - tmp[1] = data[1]; - tmp[2] = data[2]; - let overflow = mul_by_10(&mut tmp); - if overflow > 0 { - // This means that we have more data to process, that we're not sure what to do with. - // This may or may not be an issue - depending on whether we're past a decimal point - // or not. - if (i as i32) < digits_before_dot && i + 1 < len { - return Err(Error::from("Invalid decimal: overflow from too many digits")); - } +#[inline(never)] +fn tail_no_has() -> Result { + tail_error("Invalid decimal: no digits found") +} - if *digit >= 5 { - let carry = add_one_internal(&mut data); - if carry > 0 { - // Highly unlikely scenario which is more indicative of a bug - return Err(Error::from("Invalid decimal: overflow when rounding")); - } - } - // We're also one less digit so reduce the scale - let diff = (len - i) as u32; - if diff > scale { - return Err(Error::from("Invalid decimal: overflow from scale mismatch")); - } - scale -= diff; - break; - } else { - data[0] = tmp[0]; - data[1] = tmp[1]; - data[2] = tmp[2]; - let carry = add_by_internal(&mut data, &[*digit]); - if carry > 0 { - // Highly unlikely scenario which is more indicative of a bug - return Err(Error::from("Invalid decimal: overflow from carry")); - } - } +#[inline] +fn handle_data(data: u128, scale: u8) -> Result { + debug_assert_eq!(data >> 96, 0); + if !HAS { + tail_no_has() + } else { + Ok(Decimal::from_parts( + data as u32, + (data >> 32) as u32, + (data >> 64) as u32, + NEG, + scale as u32, + )) } - - Ok(Decimal::from_parts(data[0], data[1], data[2], negative, scale)) } pub(crate) fn parse_str_radix_n(str: &str, radix: u32) -> Result { @@ -528,7 +604,7 @@ pub(crate) fn parse_str_radix_n(str: &str, radix: u32) -> Result 0 { // Highly unlikely scenario which is more indicative of a bug return Err(Error::from("Invalid decimal: overflow from carry")); @@ -541,6 +617,7 @@ pub(crate) fn parse_str_radix_n(str: &str, radix: u32) -> Result