From 041af9873b25d807a0c3c125d42c84ab7ce4d824 Mon Sep 17 00:00:00 2001 From: rcrwhyg Date: Thu, 5 Sep 2024 18:04:04 +0800 Subject: [PATCH] feature: frame decode --- Cargo.lock | 21 ++ Cargo.toml | 1 + src/resp/decode.rs | 656 ++++++++++++++++++++++++++++++++++++++++++++- src/resp/mod.rs | 81 +++++- 4 files changed, 753 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5bd5129..f9f4b17 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -57,6 +57,7 @@ dependencies = [ "anyhow", "bytes", "enum_dispatch", + "thiserror", ] [[package]] @@ -70,6 +71,26 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "thiserror" +version = "1.0.63" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.63" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "unicode-ident" version = "1.0.12" diff --git a/Cargo.toml b/Cargo.toml index 3432ea8..91fa9bb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,3 +10,4 @@ license = "MIT" anyhow = "1.0.86" bytes = "1.7.1" enum_dispatch = "0.3.13" +thiserror = "1.0.63" diff --git a/src/resp/decode.rs b/src/resp/decode.rs index 4bd8276..1692705 100644 --- a/src/resp/decode.rs +++ b/src/resp/decode.rs @@ -1 +1,655 @@ -// empty file +/* +- 如何解析 Frame + - simple string: "+OK\r\n" + - error: "-Error message\r\n" + - bulk error: "!\r\n\r\n" + - integer: ":[<+|->]\r\n" + - bulk string: "$\r\n\r\n" + - null bulk string: "$-1\r\n" + - array: "*\r\n..." + - "*2\r\n$3\r\nget\r\n$5\r\nhello\r\n" + - null array: "*-1\r\n" + - null: "_\r\n" + - boolean: "#\r\n" + - double: ",[<+|->][.][[sign]]\r\n" + - map: "%\r\n..." + - set: "~\r\n..." + */ + +use bytes::{Buf, BytesMut}; + +use super::{ + BulkString, RespArray, RespDecode, RespError, RespFrame, RespMap, RespNull, RespNullArray, + RespNullBulkString, RespSet, SimpleError, SimpleString, +}; + +const CRLF: &[u8] = b"\r\n"; +const CRLF_LEN: usize = CRLF.len(); + +impl RespDecode for RespFrame { + const PREFIX: &'static str = ""; + + fn decode(buf: &mut BytesMut) -> Result { + let mut iter = buf.iter().peekable(); + match iter.peek() { + Some(b'+') => { + todo!() + } + Some(b'-') => { + let frame = SimpleError::decode(buf)?; + Ok(frame.into()) + } + Some(b':') => { + let frame = i64::decode(buf)?; + Ok(frame.into()) + } + Some(b'$') => { + // try null bulk string first + match RespNullBulkString::decode(buf) { + Ok(frame) => Ok(frame.into()), + Err(RespError::NotComplete) => Err(RespError::NotComplete), + Err(_) => { + let frame = BulkString::decode(buf)?; + Ok(frame.into()) + } + } + } + Some(b'*') => { + // try null array first + match RespNullArray::decode(buf) { + Ok(frame) => Ok(frame.into()), + Err(RespError::NotComplete) => Err(RespError::NotComplete), + Err(_) => { + let frame = RespArray::decode(buf)?; + Ok(frame.into()) + } + } + } + Some(b'_') => { + let frame = RespNull::decode(buf)?; + Ok(frame.into()) + } + Some(b'#') => { + let frame = bool::decode(buf)?; + Ok(frame.into()) + } + Some(b',') => { + let frame = f64::decode(buf)?; + Ok(frame.into()) + } + Some(b'%') => { + let frame = RespMap::decode(buf)?; + Ok(frame.into()) + } + Some(b'~') => { + let frame = RespSet::decode(buf)?; + Ok(frame.into()) + } + _ => Err(RespError::InvalidFrameType(format!( + "expect_length: unknown frame type: {:?}", + buf + ))), + } + } + + fn expect_length(buf: &[u8]) -> std::result::Result { + let mut iter = buf.iter().peekable(); + match iter.peek() { + Some(b'*') => RespArray::expect_length(buf), + Some(b'~') => RespSet::expect_length(buf), + Some(b'%') => RespMap::expect_length(buf), + Some(b'$') => BulkString::expect_length(buf), + Some(b':') => i64::expect_length(buf), + Some(b'+') => SimpleString::expect_length(buf), + Some(b'-') => SimpleError::expect_length(buf), + Some(b'#') => bool::expect_length(buf), + Some(b',') => f64::expect_length(buf), + Some(b'_') => RespNull::expect_length(buf), + _ => Err(RespError::NotComplete), + } + } +} + +impl RespDecode for SimpleString { + const PREFIX: &'static str = "+"; + + fn decode(buf: &mut BytesMut) -> Result { + // search for "\r\n" + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + + // split the buffer + let data = buf.split_to(end + CRLF_LEN); + let s = String::from_utf8_lossy(&data[Self::PREFIX.len()..end]); + Ok(Self::new(s)) + } + + fn expect_length(buf: &[u8]) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + Ok(end + CRLF_LEN) + } +} + +impl RespDecode for SimpleError { + const PREFIX: &'static str = "-"; + + fn decode(buf: &mut BytesMut) -> Result { + // search for "\r\n" + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + + // split the buffer + let data = buf.split_to(end + CRLF_LEN); + let s = String::from_utf8_lossy(&data[Self::PREFIX.len()..end]); + Ok(Self::new(s)) + } + fn expect_length(buf: &[u8]) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + Ok(end + CRLF_LEN) + } +} + +impl RespDecode for RespNull { + const PREFIX: &'static str = "_"; + + fn decode(buf: &mut BytesMut) -> Result { + extract_fixed_data(buf, "_\r\n", "Null")?; + Ok(Self) + } + + fn expect_length(_buf: &[u8]) -> Result { + Ok(3) + } +} + +impl RespDecode for RespNullArray { + const PREFIX: &'static str = "*"; + + fn decode(buf: &mut BytesMut) -> Result { + extract_fixed_data(buf, "*-1\r\n", "NullArray")?; + Ok(Self) + } + + fn expect_length(_buf: &[u8]) -> Result { + Ok(4) + } +} + +impl RespDecode for RespNullBulkString { + const PREFIX: &'static str = "$"; + + fn decode(buf: &mut BytesMut) -> Result { + extract_fixed_data(buf, "$-1\r\n", "NullBulkString")?; + Ok(Self) + } + + fn expect_length(_buf: &[u8]) -> Result { + Ok(5) + } +} + +impl RespDecode for i64 { + const PREFIX: &'static str = ":"; + + fn decode(buf: &mut BytesMut) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + + // split the buffer + let data = buf.split_to(end + CRLF_LEN); + let s = String::from_utf8_lossy(&data[Self::PREFIX.len()..end]); + + Ok(s.parse()?) + } + + fn expect_length(buf: &[u8]) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + Ok(end + CRLF_LEN) + } +} + +impl RespDecode for bool { + const PREFIX: &'static str = "#"; + + fn decode(buf: &mut BytesMut) -> Result { + match extract_fixed_data(buf, "#t\r\n", "Bool") { + Ok(_) => Ok(true), + Err(RespError::NotComplete) => Err(RespError::NotComplete), + Err(_) => match extract_fixed_data(buf, "#f\r\n", "Bool") { + Ok(_) => Ok(false), + Err(e) => Err(e), + }, + } + } + + fn expect_length(_buf: &[u8]) -> Result { + Ok(4) + } +} + +impl RespDecode for BulkString { + const PREFIX: &'static str = "$"; + + fn decode(buf: &mut BytesMut) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + let remained = &buf[end + CRLF_LEN..]; + if remained.len() < len + CRLF_LEN { + return Err(RespError::NotComplete); + } + + buf.advance(end + CRLF_LEN); + + let data = buf.split_to(len + CRLF_LEN); + Ok(Self(data[..len].to_vec())) + } + + fn expect_length(buf: &[u8]) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + Ok(end + CRLF_LEN + len + CRLF_LEN) + } +} + +// - array: "*\r\n..." +// - "*2\r\n$3\r\nget\r\n$5\r\nhello\r\n" +// FIXME: need to handle incomplete +impl RespDecode for RespArray { + const PREFIX: &'static str = "*"; + + fn decode(buf: &mut BytesMut) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + let total_len = calc_total_length(buf, end, len, Self::PREFIX)?; + + if buf.len() < total_len { + return Err(RespError::NotComplete); + } + + buf.advance(end + CRLF_LEN); + + let mut frames = Vec::with_capacity(len); + for _ in 0..len { + frames.push(RespFrame::decode(buf)?); + } + + Ok(Self::new(frames)) + } + + fn expect_length(buf: &[u8]) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + calc_total_length(buf, end, len, Self::PREFIX) + } +} + +// - double: ",[<+|->][.][[sign]]\r\n" +impl RespDecode for f64 { + const PREFIX: &'static str = ","; + + fn decode(buf: &mut BytesMut) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + let data = buf.split_to(end + CRLF_LEN); + let s = String::from_utf8_lossy(&data[Self::PREFIX.len()..end]); + Ok(s.parse()?) + } + + fn expect_length(buf: &[u8]) -> Result { + let end = extract_simple_frame_data(buf, Self::PREFIX)?; + Ok(end + CRLF_LEN) + } +} + +// - map: "%\r\n..." +impl RespDecode for RespMap { + const PREFIX: &'static str = "%"; + + fn decode(buf: &mut BytesMut) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + let total_len = calc_total_length(buf, end, len, Self::PREFIX)?; + + if buf.len() < total_len { + return Err(RespError::NotComplete); + } + + buf.advance(end + CRLF_LEN); + + let mut frames = RespMap::new(); + for _ in 0..len { + let key = SimpleString::decode(buf)?; + let value = RespFrame::decode(buf)?; + frames.insert(key.0, value); + } + + Ok(frames) + } + + fn expect_length(buf: &[u8]) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + calc_total_length(buf, end, len, Self::PREFIX) + } +} + +// - set: "~\r\n..." +impl RespDecode for RespSet { + const PREFIX: &'static str = "~"; + + fn decode(buf: &mut BytesMut) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + + let total_len = calc_total_length(buf, end, len, Self::PREFIX)?; + + if buf.len() < total_len { + return Err(RespError::NotComplete); + } + + buf.advance(end + CRLF_LEN); + + let mut frames = Vec::new(); + for _ in 0..len { + frames.push(RespFrame::decode(buf)?); + } + + Ok(Self::new(frames)) + } + + fn expect_length(buf: &[u8]) -> Result { + let (end, len) = parse_length(buf, Self::PREFIX)?; + calc_total_length(buf, end, len, Self::PREFIX) + } +} + +fn extract_fixed_data( + buf: &mut BytesMut, + expect: &str, + expect_type: &str, +) -> Result<(), RespError> { + if buf.len() < expect.len() { + return Err(RespError::NotComplete); + } + + if !buf.starts_with(expect.as_bytes()) { + return Err(RespError::InvalidFrameType(format!( + "expect: {}, got: {:?}", + expect_type, buf + ))); + } + + buf.advance(expect.len()); + + Ok(()) +} + +fn extract_simple_frame_data(buf: &[u8], prefix: &str) -> Result { + if buf.len() < 3 { + return Err(RespError::NotComplete); + } + + if !buf.starts_with(prefix.as_bytes()) { + return Err(RespError::InvalidFrameType(format!( + "expect: SimpleString({}), got: {:?}", + prefix, buf + ))); + } + + let end = find_crlf(buf, 1).ok_or(RespError::NotComplete)?; + + Ok(end) +} + +// find nth CRLF in the buffer +fn find_crlf(buf: &[u8], nth: usize) -> Option { + let mut count = 0; + for i in 1..buf.len() - 1 { + if buf[i] == b'\r' && buf[i + 1] == b'\n' { + count += 1; + if count == nth { + return Some(i); + } + } + } + + None +} + +fn parse_length(buf: &[u8], prefix: &str) -> Result<(usize, usize), RespError> { + let end = extract_simple_frame_data(buf, prefix)?; + let s = String::from_utf8_lossy(&buf[prefix.len()..end]); + Ok((end, s.parse()?)) +} + +fn calc_total_length(buf: &[u8], end: usize, len: usize, prefix: &str) -> Result { + let mut total = end + CRLF_LEN; + let mut data = &buf[total..]; + + match prefix { + "*" | "~" => { + // find nth CRLF in the buffer, for array and set, we need to find 1 CRLF for each element + for _ in 0..len { + let len = RespFrame::expect_length(data)?; + data = &data[len..]; + total += len; + } + Ok(total) + } + "%" => { + // find nth CRLF in the buffer. For map, we need to find 2 CRLF for each key-value pair + for _ in 0..len { + let len = SimpleString::expect_length(data)?; + + data = &data[len..]; + total += len; + + let len = RespFrame::expect_length(data)?; + data = &data[len..]; + total += len; + } + Ok(total) + } + _ => Ok(len + CRLF_LEN), + } +} + +#[cfg(test)] +mod tests { + use anyhow::Result; + use bytes::BufMut; + + use super::*; + + #[test] + fn test_simple_string_encode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"+OK\r\n"); + + let frame = SimpleString::decode(&mut buf)?; + assert_eq!(frame, SimpleString::new("OK")); + + buf.extend_from_slice(b"+hello\r"); + + let ret = SimpleString::decode(&mut buf); + assert!(matches!(ret.unwrap_err(), RespError::NotComplete)); + + buf.put_u8(b'\n'); + let frame = SimpleString::decode(&mut buf)?; + assert_eq!(frame, SimpleString::new("hello")); + + Ok(()) + } + + #[test] + fn test_simple_error_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"-Error message\r\n"); + + let frame = SimpleError::decode(&mut buf)?; + assert_eq!(frame, SimpleError::new("Error message")); + + Ok(()) + } + + #[test] + fn test_integer_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b":+123\r\n"); + + let frame = i64::decode(&mut buf)?; + assert_eq!(frame, 123); + + buf.extend_from_slice(b":-123\r\n"); + + let frame = i64::decode(&mut buf)?; + assert_eq!(frame, -123); + + Ok(()) + } + + #[test] + fn test_bulk_string_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"$5\r\nhello\r\n"); + + let frame = BulkString::decode(&mut buf)?; + assert_eq!(frame, BulkString::new(b"hello")); + + buf.extend_from_slice(b"$5\r\nhello"); + let ret = BulkString::decode(&mut buf); + assert_eq!(ret.unwrap_err(), RespError::NotComplete); + + buf.extend_from_slice(b"\r\n"); + let frame = BulkString::decode(&mut buf)?; + assert_eq!(frame, BulkString::new(b"hello")); + + Ok(()) + } + + #[test] + fn test_null_bulk_string_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"$-1\r\n"); + + let frame = RespNullBulkString::decode(&mut buf)?; + assert_eq!(frame, RespNullBulkString); + + Ok(()) + } + + #[test] + fn test_null_array_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"*-1\r\n"); + + let frame = RespNullArray::decode(&mut buf)?; + assert_eq!(frame, RespNullArray); + + Ok(()) + } + + #[test] + fn test_null_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"_\r\n"); + + let frame = RespNull::decode(&mut buf)?; + assert_eq!(frame, RespNull); + + Ok(()) + } + + #[test] + fn test_boolean_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"#t\r\n"); + + let frame = bool::decode(&mut buf)?; + assert!(frame); + + buf.extend_from_slice(b"#f\r\n"); + + let frame = bool::decode(&mut buf)?; + assert!(!frame); + + buf.extend_from_slice(b"#f\r"); + let ret = bool::decode(&mut buf); + assert_eq!(ret.unwrap_err(), RespError::NotComplete); + + buf.put_u8(b'\n'); + let frame = bool::decode(&mut buf)?; + assert!(!frame); + + Ok(()) + } + + #[test] + fn test_array_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"*2\r\n$3\r\nset\r\n$5\r\nhello\r\n"); + + let frame = RespArray::decode(&mut buf)?; + assert_eq!(frame, RespArray::new([b"set".into(), b"hello".into()])); + + buf.extend_from_slice(b"*2\r\n$3\r\nset\r\n"); + let ret = RespArray::decode(&mut buf); + assert_eq!(ret.unwrap_err(), RespError::NotComplete); + + buf.extend_from_slice(b"$5\r\nhello\r\n"); + let frame = RespArray::decode(&mut buf)?; + assert_eq!(frame, RespArray::new([b"set".into(), b"hello".into()])); + + Ok(()) + } + + #[test] + fn test_double_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b",123.45\r\n"); + + let frame = f64::decode(&mut buf)?; + assert_eq!(frame, 123.45); + + buf.extend_from_slice(b",+1.23456e-9\r\n"); + let frame = f64::decode(&mut buf)?; + assert_eq!(frame, 1.23456e-9); + + Ok(()) + } + + #[test] + fn test_map_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"%2\r\n+hello\r\n$5\r\nworld\r\n+foo\r\n$3\r\nbar\r\n"); + + let frame = RespMap::decode(&mut buf)?; + let mut map = RespMap::new(); + map.insert("hello".into(), BulkString::new(b"world").into()); + map.insert("foo".into(), BulkString::new(b"bar").into()); + assert_eq!(frame, map); + + Ok(()) + } + + #[test] + fn test_set_decode() -> Result<()> { + let mut buf = BytesMut::new(); + buf.extend_from_slice(b"~2\r\n$3\r\nset\r\n$5\r\nhello\r\n"); + + let frame = RespSet::decode(&mut buf)?; + assert_eq!( + frame, + RespSet::new(vec![ + BulkString::new(b"set").into(), + BulkString::new(b"hello").into() + ]) + ); + + Ok(()) + } + + #[test] + fn test_calc_array_length() -> Result<()> { + let buf = b"*2\r\n$3\r\nset\r\n$5\r\nhello\r\n"; + let (end, len) = parse_length(buf, "*")?; + let total_len = calc_total_length(buf, end, len, "*")?; + assert_eq!(total_len, buf.len()); + + let buf = b"*2\r\n$3\r\nset\r\n"; + let (end, len) = parse_length(buf, "*")?; + let ret = calc_total_length(buf, end, len, "*"); + assert_eq!(ret.unwrap_err(), RespError::NotComplete); + + Ok(()) + } +} diff --git a/src/resp/mod.rs b/src/resp/mod.rs index 2f03f0f..232561d 100644 --- a/src/resp/mod.rs +++ b/src/resp/mod.rs @@ -1,19 +1,42 @@ +mod decode; +mod encode; + +use bytes::BytesMut; use enum_dispatch::enum_dispatch; use std::{ collections::BTreeMap, ops::{Deref, DerefMut}, }; - -mod decode; -mod encode; +use thiserror::Error; #[enum_dispatch] pub trait RespEncode { fn encode(self) -> Vec; } -pub trait RespDecode { - fn decode(buf: Self) -> Result; +pub trait RespDecode: Sized { + const PREFIX: &'static str; + fn decode(buf: &mut BytesMut) -> Result; + fn expect_length(buf: &[u8]) -> Result; +} + +#[derive(Debug, Error, PartialEq, Eq)] +pub enum RespError { + #[error("Invalid frame")] + InvalidFrame(String), + #[error("Invalid frame type: {0}")] + InvalidFrameType(String), + #[error("Invalid frame length: {0}")] + InvalidFrameLength(isize), + #[error("Frame is not complete")] + NotComplete, + + #[error("Parse int error: {0}")] + ParseIntError(#[from] std::num::ParseIntError), + #[error("Utf8 error: {0}")] + Utf8Error(#[from] std::str::Utf8Error), + #[error("Parse float error: {0}")] + ParseFloatError(#[from] std::num::ParseFloatError), } #[enum_dispatch(RespEncode)] @@ -155,3 +178,51 @@ impl RespSet { RespSet(v.into()) } } + +impl From<&str> for SimpleString { + fn from(s: &str) -> Self { + SimpleString(s.to_string()) + } +} + +impl From<&str> for RespFrame { + fn from(s: &str) -> Self { + SimpleError(s.to_string()).into() + } +} + +impl From<&str> for SimpleError { + fn from(s: &str) -> Self { + SimpleError(s.to_string()) + } +} + +impl From<&str> for BulkString { + fn from(s: &str) -> Self { + BulkString(s.as_bytes().to_vec()) + } +} + +impl From<&[u8]> for BulkString { + fn from(s: &[u8]) -> Self { + BulkString(s.to_vec()) + } +} + +impl From<&[u8]> for RespFrame { + fn from(s: &[u8]) -> Self { + BulkString(s.to_vec()).into() + } +} + +impl From<&[u8; N]> for BulkString { + fn from(s: &[u8; N]) -> Self { + BulkString(s.to_vec()) + } +} + +impl From<&[u8; N]> for RespFrame { + fn from(s: &[u8; N]) -> Self { + BulkString(s.to_vec()).into() + } +}