diff --git a/.vscode/settings.json b/.vscode/settings.json index fe3047d..e84fbe9 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -5,6 +5,7 @@ "hgetall", "hmap", "hset", - "raddr" + "raddr", + "respv" ] } diff --git a/Cargo.lock b/Cargo.lock index ae9ff26..5f8fbf0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -410,6 +410,7 @@ dependencies = [ "tokio-util", "tracing", "tracing-subscriber", + "winnow", ] [[package]] @@ -692,3 +693,12 @@ name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "winnow" +version = "0.6.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +dependencies = [ + "memchr", +] diff --git a/Cargo.toml b/Cargo.toml index 67436fe..2eaedcc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,3 +19,4 @@ tokio-stream = "0.1.16" tokio-util = { version = "0.7.12", features = ["codec"] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } +winnow = { version = "0.6.20", features = ["simd"] } diff --git a/deny.toml b/deny.toml index b6beb84..0ae1019 100644 --- a/deny.toml +++ b/deny.toml @@ -64,9 +64,9 @@ feature-depth = 1 # https://embarkstudios.github.io/cargo-deny/checks/advisories/cfg.html [advisories] # The path where the advisory databases are cloned/fetched into -#db-path = "$CARGO_HOME/advisory-dbs" +db-path = "$CARGO_HOME/advisory-dbs" # The url(s) of the advisory databases to use -#db-urls = ["https://github.com/rustsec/advisory-db"] +db-urls = ["https://github.com/rustsec/advisory-db"] # A list of advisory IDs to ignore. Note that ignored advisories will still # output a note when they are encountered. ignore = [ diff --git a/src/lib.rs b/src/lib.rs index 1b8dbce..5683a74 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,10 @@ mod backend; mod resp; +mod respv2; pub mod cmd; pub mod network; pub use backend::*; pub use resp::*; +pub use respv2::*; diff --git a/src/main.rs b/src/main.rs index 77cd84a..296c91a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,7 @@ use tracing::{info, warn}; #[tokio::main] async fn main() -> Result<()> { + std::env::set_var("RUST_LOG", "info"); tracing_subscriber::fmt::init(); let addr = "0.0.0.0:6379"; diff --git a/src/network.rs b/src/network.rs index 9fcf034..05dfa06 100644 --- a/src/network.rs +++ b/src/network.rs @@ -8,7 +8,7 @@ use tracing::info; use crate::{ cmd::{Command, CommandExecutor}, - Backend, RespDecode, RespEncode, RespError, RespFrame, + Backend, RespDecodeV2, RespEncode, RespError, RespFrame, }; #[derive(Debug)] diff --git a/src/resp/array.rs b/src/resp/array.rs index 399fbc3..2a65038 100644 --- a/src/resp/array.rs +++ b/src/resp/array.rs @@ -89,6 +89,12 @@ impl Deref for RespArray { } } +impl From> for RespArray { + fn from(v: Vec) -> Self { + RespArray(v) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/resp/map.rs b/src/resp/map.rs index f7bdafa..5784977 100644 --- a/src/resp/map.rs +++ b/src/resp/map.rs @@ -82,6 +82,12 @@ impl DerefMut for RespMap { } } +impl From> for RespMap { + fn from(map: BTreeMap) -> Self { + RespMap(map) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/respv2/mod.rs b/src/respv2/mod.rs new file mode 100644 index 0000000..332e553 --- /dev/null +++ b/src/respv2/mod.rs @@ -0,0 +1,197 @@ +use bytes::BytesMut; + +use crate::{RespError, RespFrame}; + +use parser::{parse_frame, parse_frame_length}; + +mod parser; + +pub trait RespDecodeV2: Sized { + fn decode(buf: &mut BytesMut) -> Result; + fn expect_length(buf: &[u8]) -> Result; +} + +impl RespDecodeV2 for RespFrame { + fn decode(buf: &mut BytesMut) -> Result { + let len = Self::expect_length(buf)?; + let data = buf.split_to(len); + + parse_frame(&mut data.as_ref()).map_err(|e| RespError::InvalidFrame(e.to_string())) + } + + fn expect_length(buf: &[u8]) -> Result { + parse_frame_length(buf) + } +} + +#[cfg(test)] +mod test { + use std::collections::BTreeMap; + + use super::*; + use crate::{RespFrame, RespNullArray, RespNullBulkString}; + use anyhow::Result; + + #[test] + fn respv2_simple_string_length_should_work() -> Result<()> { + let buf = b"+OK\r\n"; + let len = RespFrame::expect_length(buf)?; + assert_eq!(len, buf.len()); + Ok(()) + } + + #[test] + fn respv2_simple_string_length_bad_should_fail() -> Result<()> { + let buf = b"+OK"; + let ret = RespFrame::expect_length(buf); + assert_eq!(ret.unwrap_err(), RespError::NotComplete); + Ok(()) + } + + #[test] + fn respv2_simple_string_should_work() -> Result<()> { + let mut buf = BytesMut::from("+OK\r\n"); + let frame = RespFrame::decode(&mut buf)?; + assert_eq!(frame, RespFrame::SimpleString("OK".into())); + Ok(()) + } + + #[test] + fn respv2_simple_error_length_should_work() -> Result<()> { + let buf = b"-ERR\r\n"; + let len = RespFrame::expect_length(buf)?; + assert_eq!(len, buf.len()); + Ok(()) + } + + #[test] + fn respv2_simple_error_should_work() -> Result<()> { + let mut buf = BytesMut::from("-ERR\r\n"); + let frame = RespFrame::decode(&mut buf)?; + assert_eq!(frame, RespFrame::Error("ERR".into())); + Ok(()) + } + + #[test] + fn respv2_integer_length_should_work() -> Result<()> { + let buf = b":1000\r\n"; + let len = RespFrame::expect_length(buf)?; + assert_eq!(len, buf.len()); + Ok(()) + } + + #[test] + fn respv2_integer_should_work() -> Result<()> { + let mut buf = BytesMut::from(":1000\r\n"); + let frame = RespFrame::decode(&mut buf)?; + assert_eq!(frame, RespFrame::Integer(1000)); + Ok(()) + } + + #[test] + fn respv2_bulk_string_length_should_work() -> Result<()> { + let buf = b"$5\r\nhello\r\n"; + let len = RespFrame::expect_length(buf)?; + assert_eq!(len, buf.len()); + Ok(()) + } + + #[test] + fn respv2_bulk_string_should_work() -> Result<()> { + let mut buf = BytesMut::from("$5\r\nhello\r\n"); + let frame = RespFrame::decode(&mut buf)?; + assert_eq!(frame, RespFrame::BulkString("hello".into())); + Ok(()) + } + + #[test] + fn respv2_null_bulk_string_length_should_work() -> Result<()> { + let buf = b"$-1\r\n"; + let len = RespFrame::expect_length(buf)?; + assert_eq!(len, buf.len()); + Ok(()) + } + + #[test] + fn respv2_null_bulk_string_should_work() -> Result<()> { + let mut buf = BytesMut::from("$-1\r\n"); + let frame = RespFrame::decode(&mut buf)?; + assert_eq!(frame, RespFrame::NullBulkString(RespNullBulkString)); + Ok(()) + } + + #[test] + fn respv2_array_length_should_work() -> Result<()> { + let buf = b"*2\r\n$3\r\nset\r\n$5\r\nhello\r\n"; + let len = RespFrame::expect_length(buf)?; + assert_eq!(len, buf.len()); + Ok(()) + } + + #[test] + fn respv2_array_should_work() -> Result<()> { + let mut buf = BytesMut::from("*2\r\n$3\r\nset\r\n$5\r\nhello\r\n"); + let frame = RespFrame::decode(&mut buf)?; + assert_eq!( + frame, + RespFrame::Array( + vec![ + RespFrame::BulkString("set".into()), + RespFrame::BulkString("hello".into()) + ] + .into() + ) + ); + Ok(()) + } + + #[test] + fn respv2_null_array_length_should_work() -> Result<()> { + let buf = b"*-1\r\n"; + let len = RespFrame::expect_length(buf)?; + assert_eq!(len, buf.len()); + Ok(()) + } + + #[test] + fn respv2_null_array_should_work() -> Result<()> { + let mut buf = BytesMut::from("*-1\r\n"); + let frame = RespFrame::decode(&mut buf)?; + assert_eq!(frame, RespFrame::NullArray(RespNullArray)); + Ok(()) + } + + #[test] + fn respv2_map_length_should_work() -> Result<()> { + let buf = b"%1\r\n+OK\r\n-ERR\r\n"; + let len = RespFrame::expect_length(buf).unwrap(); + assert_eq!(len, buf.len()); + Ok(()) + } + + #[test] + fn respv2_map_should_work() -> Result<()> { + let mut buf = BytesMut::from("%1\r\n+OK\r\n-ERR\r\n"); + let frame = RespFrame::decode(&mut buf).unwrap(); + let items: BTreeMap = + [("OK".to_string(), RespFrame::Error("ERR".into()))] + .into_iter() + .collect(); + assert_eq!(frame, RespFrame::Map(items.into())); + Ok(()) + } + + #[test] + fn respv2_map_with_real_data_should_work() -> Result<()> { + let mut buf = BytesMut::from("%2\r\n+hello\r\n$5\r\nworld\r\n+foo\r\n$3\r\nbar\r\n"); + let frame = RespFrame::decode(&mut buf).unwrap(); + let items: BTreeMap = [ + ("hello".to_string(), RespFrame::BulkString("world".into())), + ("foo".to_string(), RespFrame::BulkString("bar".into())), + ] + .into_iter() + .collect(); + assert_eq!(frame, RespFrame::Map(items.into())); + Ok(()) + } +} diff --git a/src/respv2/parser.rs b/src/respv2/parser.rs new file mode 100644 index 0000000..60b63ab --- /dev/null +++ b/src/respv2/parser.rs @@ -0,0 +1,220 @@ +use std::collections::BTreeMap; + +use winnow::{ + ascii::{digit1, float}, + combinator::{alt, dispatch, fail, opt, preceded, terminated}, + error::{ContextError, ErrMode}, + token::{any, take, take_until}, + PResult, Parser, +}; + +use crate::{ + BulkString, RespArray, RespError, RespFrame, RespMap, RespNull, RespNullArray, + RespNullBulkString, SimpleError, SimpleString, +}; + +const CRLF: &[u8] = b"\r\n"; + +pub fn parse_frame_length(input: &[u8]) -> Result { + let target = &mut (&*input); + let ret = parse_frame_len(target); + match ret { + Ok(_) => { + // calculate the distance between target and input + let start = input.as_ptr(); + let end = (*target).as_ptr(); + let len = end as usize - start as usize; + Ok(len) + } + Err(_) => Err(RespError::NotComplete), + } +} + +fn parse_frame_len(input: &mut &[u8]) -> PResult<()> { + let mut simple_parser = terminated(take_until(0.., CRLF), CRLF).value(()); + dispatch! {any; + b'+' => simple_parser, + b'-' => simple_parser, + b':' => simple_parser, + b'$' => bulk_string_len, + b'*' => array_len, + b'_' => simple_parser, + b'#' => simple_parser, + b',' => simple_parser, + b'%' => map_len, + // b'~' => set, + _ => fail::<_, _, _>, + } + .parse_next(input) +} + +pub fn parse_frame(input: &mut &[u8]) -> PResult { + // frame type has bean processed + dispatch! {any; + b'+' => simple_string.map(RespFrame::SimpleString), + b'-' => error.map(RespFrame::Error), + b':' => integer.map(RespFrame::Integer), + b'$' => alt((null_bulk_string.map(RespFrame::NullBulkString), bulk_string.map(RespFrame::BulkString))), + b'*' => alt((null_array.map(RespFrame::NullArray), array.map(RespFrame::Array))), + b'_' => null.map(RespFrame::Null), + b'#' => boolean.map(RespFrame::Boolean), + b',' => double.map(RespFrame::Double), + b'%' => map.map(RespFrame::Map), + // b'~' => set, + _ => fail::<_, _, _>, + } + .parse_next(input) +} + +// - simple string: "+OK\r\n" +fn simple_string(input: &mut &[u8]) -> PResult { + parse_string.map(SimpleString).parse_next(input) +} + +// - error: "-ERR unknown command 'foobar'\r\n" +fn error(input: &mut &[u8]) -> PResult { + parse_string.map(SimpleError).parse_next(input) +} + +// - integer: ":1000\r\n" +fn integer(input: &mut &[u8]) -> PResult { + let sign = opt(alt(('+', '-'))).parse_next(input)?.unwrap_or('+'); + let sign: i64 = if sign == '+' { 1 } else { -1 }; + let v: i64 = terminated(digit1.parse_to(), CRLF).parse_next(input)?; + Ok(sign * v) +} + +// - null bulk string: "$-1\r\n" +fn null_bulk_string(input: &mut &[u8]) -> PResult { + "-1\r\n".value(RespNullBulkString).parse_next(input) +} + +// - bulk string: "$6\r\nfoobar\r\n" +#[allow(clippy::comparison_chain)] +fn bulk_string(input: &mut &[u8]) -> PResult { + let len = integer.parse_next(input)?; + if len < 0 { + return Err(err_cut("bulk string length must be non-negative")); + } else if len == 0 { + return Ok(BulkString::new(vec![])); + } + + let len = len as usize; + let data = terminated(take(len), CRLF) + .map(|s: &[u8]| s.to_vec()) + .parse_next(input)?; + + Ok(BulkString(data)) +} + +fn bulk_string_len(input: &mut &[u8]) -> PResult<()> { + let len = integer.parse_next(input)?; + if len < -1 { + return Err(err_cut("bulk string length must be non-negative")); + } else if len == 0 || len == -1 { + return Ok(()); + } + + terminated(take(len as usize), CRLF) + .value(()) + .parse_next(input) +} + +fn null_array(input: &mut &[u8]) -> PResult { + "-1\r\n".value(RespNullArray).parse_next(input) +} + +// - array: "*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n" +#[allow(clippy::comparison_chain)] +fn array(input: &mut &[u8]) -> PResult { + let len = integer.parse_next(input)?; + if len < 0 { + return Err(err_cut("array length must be non-negative")); + } else if len == 0 { + return Ok(RespArray(vec![])); + } + + let mut frames = Vec::with_capacity(len as usize); + for _ in 0..len { + let frame = parse_frame(input)?; + frames.push(frame); + } + + Ok(RespArray(frames)) +} + +fn array_len(input: &mut &[u8]) -> PResult<()> { + let len = integer.parse_next(input)?; + if len < -1 { + return Err(err_cut("array length must be non-negative")); + } else if len == 0 || len == -1 { + return Ok(()); + } + + for _ in 0..len { + parse_frame_len(input)?; + } + + Ok(()) +} + +// - null: "_\r\n" +fn null(input: &mut &[u8]) -> PResult { + CRLF.value(RespNull).parse_next(input) +} + +// - boolean: "#t\r\n" +fn boolean(input: &mut &[u8]) -> PResult { + let b = alt(('t', 'f')).parse_next(input)?; + Ok(b == 't') +} + +// - float: ",3.14\r\n" +fn double(input: &mut &[u8]) -> PResult { + terminated(float, CRLF).parse_next(input) +} + +// - map: "%2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n$3\r\nbaz\r\n$3\r\nqux\r\n" +fn map(input: &mut &[u8]) -> PResult { + let len = integer.parse_next(input)?; + if len <= 0 { + return Err(err_cut("map length must be non-negative")); + } + + // let len = len as usize / 2; + let mut map = BTreeMap::new(); + for _ in 0..len { + let key = preceded('+', parse_string).parse_next(input)?; + let value = parse_frame(input)?; + map.insert(key, value); + } + + Ok(RespMap(map)) +} + +fn map_len(input: &mut &[u8]) -> PResult<()> { + let len = integer.parse_next(input)?; + if len <= 0 { + return Err(err_cut("map length must be non-negative")); + } + + for _ in 0..len { + terminated(take_until(0.., CRLF), CRLF) + .value(()) + .parse_next(input)?; + parse_frame_len(input)?; + } + + Ok(()) +} + +fn parse_string(input: &mut &[u8]) -> PResult { + terminated(take_until(0.., CRLF), CRLF) + .map(|s| String::from_utf8_lossy(s).into_owned()) + .parse_next(input) +} + +fn err_cut(_s: impl Into) -> ErrMode { + let context = ContextError::default(); + ErrMode::Cut(context) +}