Skip to content

Commit

Permalink
apply recursion limit when skipping fields
Browse files Browse the repository at this point in the history
  • Loading branch information
danburkert committed Jan 16, 2020
1 parent 221ebbf commit 04091d3
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 15 deletions.
2 changes: 1 addition & 1 deletion prost-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
#struct_name
match tag {
#(#merge)*
_ => ::prost::encoding::skip_field(wire_type, tag, buf),
_ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
}
}

Expand Down
7 changes: 4 additions & 3 deletions src/encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,11 @@ where
Ok(())
}

pub fn skip_field<B>(wire_type: WireType, tag: u32, buf: &mut B) -> Result<(), DecodeError>
pub fn skip_field<B>(wire_type: WireType, tag: u32, buf: &mut B, ctx: DecodeContext) -> Result<(), DecodeError>
where
B: Buf,
{
ctx.limit_reached()?;
let len = match wire_type {
WireType::Varint => decode_varint(buf).map(|_| 0)?,
WireType::ThirtyTwoBit => 4,
Expand All @@ -399,7 +400,7 @@ where
}
break 0;
}
_ => skip_field(inner_wire_type, inner_tag, buf)?,
_ => skip_field(inner_wire_type, inner_tag, buf, ctx.enter_recursion())?,
}
},
WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")),
Expand Down Expand Up @@ -1219,7 +1220,7 @@ macro_rules! map {
match tag {
1 => key_merge(wire_type, key, buf, ctx),
2 => val_merge(wire_type, val, buf, ctx),
_ => skip_field(wire_type, tag, buf),
_ => skip_field(wire_type, tag, buf, ctx),
}
},
)?;
Expand Down
22 changes: 11 additions & 11 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl Message for bool {
if tag == 1 {
bool::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf)
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self) -> usize {
Expand Down Expand Up @@ -76,7 +76,7 @@ impl Message for u32 {
if tag == 1 {
uint32::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf)
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self) -> usize {
Expand Down Expand Up @@ -114,7 +114,7 @@ impl Message for u64 {
if tag == 1 {
uint64::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf)
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self) -> usize {
Expand Down Expand Up @@ -152,7 +152,7 @@ impl Message for i32 {
if tag == 1 {
int32::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf)
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self) -> usize {
Expand Down Expand Up @@ -190,7 +190,7 @@ impl Message for i64 {
if tag == 1 {
int64::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf)
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self) -> usize {
Expand Down Expand Up @@ -228,7 +228,7 @@ impl Message for f32 {
if tag == 1 {
float::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf)
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self) -> usize {
Expand Down Expand Up @@ -266,7 +266,7 @@ impl Message for f64 {
if tag == 1 {
double::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf)
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self) -> usize {
Expand Down Expand Up @@ -304,7 +304,7 @@ impl Message for String {
if tag == 1 {
string::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf)
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self) -> usize {
Expand Down Expand Up @@ -342,7 +342,7 @@ impl Message for Vec<u8> {
if tag == 1 {
bytes::merge(wire_type, self, buf, ctx)
} else {
skip_field(wire_type, tag, buf)
skip_field(wire_type, tag, buf, ctx)
}
}
fn encoded_len(&self) -> usize {
Expand All @@ -369,12 +369,12 @@ impl Message for () {
tag: u32,
wire_type: WireType,
buf: &mut B,
_ctx: DecodeContext,
ctx: DecodeContext,
) -> Result<(), DecodeError>
where
B: Buf,
{
skip_field(wire_type, tag, buf)
skip_field(wire_type, tag, buf, ctx)
}
fn encoded_len(&self) -> usize {
0
Expand Down
10 changes: 10 additions & 0 deletions tests/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,16 @@ mod tests {
};
}

#[test]
fn test_267_regression() {
// Checks that skip_field will error appropriately when given a big stack of StartGroup
// tags.
//
// https://github.com/danburkert/prost/issues/267
let buf = vec![b'C'; 1 << 20];
<() as Message>::decode(&buf[..]).err().unwrap();
}

#[test]
fn test_default_enum() {
let msg = default_enum_value::Test::default();
Expand Down

0 comments on commit 04091d3

Please # to comment.