From 04091d3e745c27590a5f1b7f581793e4159486b5 Mon Sep 17 00:00:00 2001 From: Dan Burkert Date: Thu, 16 Jan 2020 08:48:43 -0800 Subject: [PATCH] apply recursion limit when skipping fields Fixes #267 --- prost-derive/src/lib.rs | 2 +- src/encoding.rs | 7 ++++--- src/types.rs | 22 +++++++++++----------- tests/src/lib.rs | 10 ++++++++++ 4 files changed, 26 insertions(+), 15 deletions(-) diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index 7ac535802..13393e95e 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -194,7 +194,7 @@ fn try_message(input: TokenStream) -> Result { #struct_name match tag { #(#merge)* - _ => ::prost::encoding::skip_field(wire_type, tag, buf), + _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx), } } diff --git a/src/encoding.rs b/src/encoding.rs index c4ac2cdf4..4ed88882e 100644 --- a/src/encoding.rs +++ b/src/encoding.rs @@ -381,10 +381,11 @@ where Ok(()) } -pub fn skip_field(wire_type: WireType, tag: u32, buf: &mut B) -> Result<(), DecodeError> +pub fn skip_field(wire_type: WireType, tag: u32, buf: &mut B, ctx: DecodeContext) -> Result<(), DecodeError> where B: Buf, { + ctx.limit_reached()?; let len = match wire_type { WireType::Varint => decode_varint(buf).map(|_| 0)?, WireType::ThirtyTwoBit => 4, @@ -399,7 +400,7 @@ where } break 0; } - _ => skip_field(inner_wire_type, inner_tag, buf)?, + _ => skip_field(inner_wire_type, inner_tag, buf, ctx.enter_recursion())?, } }, WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")), @@ -1219,7 +1220,7 @@ macro_rules! map { match tag { 1 => key_merge(wire_type, key, buf, ctx), 2 => val_merge(wire_type, val, buf, ctx), - _ => skip_field(wire_type, tag, buf), + _ => skip_field(wire_type, tag, buf, ctx), } }, )?; diff --git a/src/types.rs b/src/types.rs index 00ca5afdb..43e7355ef 100644 --- a/src/types.rs +++ b/src/types.rs @@ -38,7 +38,7 @@ impl Message for bool { if tag == 1 { bool::merge(wire_type, self, buf, ctx) } else { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } } fn encoded_len(&self) -> usize { @@ -76,7 +76,7 @@ impl Message for u32 { if tag == 1 { uint32::merge(wire_type, self, buf, ctx) } else { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } } fn encoded_len(&self) -> usize { @@ -114,7 +114,7 @@ impl Message for u64 { if tag == 1 { uint64::merge(wire_type, self, buf, ctx) } else { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } } fn encoded_len(&self) -> usize { @@ -152,7 +152,7 @@ impl Message for i32 { if tag == 1 { int32::merge(wire_type, self, buf, ctx) } else { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } } fn encoded_len(&self) -> usize { @@ -190,7 +190,7 @@ impl Message for i64 { if tag == 1 { int64::merge(wire_type, self, buf, ctx) } else { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } } fn encoded_len(&self) -> usize { @@ -228,7 +228,7 @@ impl Message for f32 { if tag == 1 { float::merge(wire_type, self, buf, ctx) } else { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } } fn encoded_len(&self) -> usize { @@ -266,7 +266,7 @@ impl Message for f64 { if tag == 1 { double::merge(wire_type, self, buf, ctx) } else { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } } fn encoded_len(&self) -> usize { @@ -304,7 +304,7 @@ impl Message for String { if tag == 1 { string::merge(wire_type, self, buf, ctx) } else { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } } fn encoded_len(&self) -> usize { @@ -342,7 +342,7 @@ impl Message for Vec { if tag == 1 { bytes::merge(wire_type, self, buf, ctx) } else { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } } fn encoded_len(&self) -> usize { @@ -369,12 +369,12 @@ impl Message for () { tag: u32, wire_type: WireType, buf: &mut B, - _ctx: DecodeContext, + ctx: DecodeContext, ) -> Result<(), DecodeError> where B: Buf, { - skip_field(wire_type, tag, buf) + skip_field(wire_type, tag, buf, ctx) } fn encoded_len(&self) -> usize { 0 diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 47b26e46c..3c808037c 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -462,6 +462,16 @@ mod tests { }; } + #[test] + fn test_267_regression() { + // Checks that skip_field will error appropriately when given a big stack of StartGroup + // tags. + // + // https://github.com/danburkert/prost/issues/267 + let buf = vec![b'C'; 1 << 20]; + <() as Message>::decode(&buf[..]).err().unwrap(); + } + #[test] fn test_default_enum() { let msg = default_enum_value::Test::default();