diff --git a/serde_derive/src/de.rs b/serde_derive/src/de.rs index c5861d9bc..ef85385f6 100644 --- a/serde_derive/src/de.rs +++ b/serde_derive/src/de.rs @@ -281,11 +281,21 @@ fn deserialize_body(cont: &Container, params: &Parameters) -> Fragment { } else if let attr::Identifier::No = cont.attrs.identifier() { match &cont.data { Data::Enum(variants) => deserialize_enum(params, variants, &cont.attrs), - Data::Struct(Style::Struct, fields) => { - deserialize_struct(params, fields, &cont.attrs, StructForm::Struct) - } + Data::Struct(Style::Struct, fields) => deserialize_struct( + params, + fields, + &cont.attrs, + cont.attrs.has_flatten(), + StructForm::Struct, + ), Data::Struct(Style::Tuple, fields) | Data::Struct(Style::Newtype, fields) => { - deserialize_tuple(params, fields, &cont.attrs, TupleForm::Tuple) + deserialize_tuple( + params, + fields, + &cont.attrs, + cont.attrs.has_flatten(), + TupleForm::Tuple, + ) } Data::Struct(Style::Unit, _) => deserialize_unit_struct(params, &cont.attrs), } @@ -459,9 +469,13 @@ fn deserialize_tuple( params: &Parameters, fields: &[Field], cattrs: &attr::Container, + has_flatten: bool, form: TupleForm, ) -> Fragment { - assert!(!cattrs.has_flatten()); + assert!( + !has_flatten, + "tuples and tuple variants cannot have flatten fields" + ); let field_count = fields .iter() @@ -579,7 +593,10 @@ fn deserialize_tuple_in_place( fields: &[Field], cattrs: &attr::Container, ) -> Fragment { - assert!(!cattrs.has_flatten()); + assert!( + !cattrs.has_flatten(), + "tuples and tuple variants cannot have flatten fields" + ); let field_count = fields .iter() @@ -910,6 +927,7 @@ fn deserialize_struct( params: &Parameters, fields: &[Field], cattrs: &attr::Container, + has_flatten: bool, form: StructForm, ) -> Fragment { let this_type = ¶ms.this_type; @@ -958,13 +976,13 @@ fn deserialize_struct( ) }) .collect(); - let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs); + let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs, has_flatten); // untagged struct variants do not get a visit_seq method. The same applies to // structs that only have a map representation. let visit_seq = match form { StructForm::Untagged(..) => None, - _ if cattrs.has_flatten() => None, + _ if has_flatten => None, _ => { let mut_seq = if field_names_idents.is_empty() { quote!(_) @@ -987,10 +1005,16 @@ fn deserialize_struct( }) } }; - let visit_map = Stmts(deserialize_map(&type_path, params, fields, cattrs)); + let visit_map = Stmts(deserialize_map( + &type_path, + params, + fields, + cattrs, + has_flatten, + )); let visitor_seed = match form { - StructForm::ExternallyTagged(..) if cattrs.has_flatten() => Some(quote! { + StructForm::ExternallyTagged(..) if has_flatten => Some(quote! { impl #de_impl_generics _serde::de::DeserializeSeed<#delife> for __Visitor #de_ty_generics #where_clause { type Value = #this_type #ty_generics; @@ -1005,7 +1029,7 @@ fn deserialize_struct( _ => None, }; - let fields_stmt = if cattrs.has_flatten() { + let fields_stmt = if has_flatten { None } else { let field_names = field_names_idents @@ -1025,7 +1049,7 @@ fn deserialize_struct( } }; let dispatch = match form { - StructForm::Struct if cattrs.has_flatten() => quote! { + StructForm::Struct if has_flatten => quote! { _serde::Deserializer::deserialize_map(__deserializer, #visitor_expr) }, StructForm::Struct => { @@ -1034,7 +1058,7 @@ fn deserialize_struct( _serde::Deserializer::deserialize_struct(__deserializer, #type_name, FIELDS, #visitor_expr) } } - StructForm::ExternallyTagged(_) if cattrs.has_flatten() => quote! { + StructForm::ExternallyTagged(_) if has_flatten => quote! { _serde::de::VariantAccess::newtype_variant_seed(__variant, #visitor_expr) }, StructForm::ExternallyTagged(_) => quote! { @@ -1116,7 +1140,7 @@ fn deserialize_struct_in_place( }) .collect(); - let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs); + let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs, false); let mut_seq = if field_names_idents.is_empty() { quote!(_) @@ -1210,10 +1234,7 @@ fn deserialize_homogeneous_enum( } } -fn prepare_enum_variant_enum( - variants: &[Variant], - cattrs: &attr::Container, -) -> (TokenStream, Stmts) { +fn prepare_enum_variant_enum(variants: &[Variant]) -> (TokenStream, Stmts) { let mut deserialized_variants = variants .iter() .enumerate() @@ -1247,7 +1268,7 @@ fn prepare_enum_variant_enum( let variant_visitor = Stmts(deserialize_generated_identifier( &variant_names_idents, - cattrs, + false, // variant identifiers does not depend on the presence of flatten fields true, None, fallthrough, @@ -1270,7 +1291,7 @@ fn deserialize_externally_tagged_enum( let expecting = format!("enum {}", params.type_name()); let expecting = cattrs.expecting().unwrap_or(&expecting); - let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants, cattrs); + let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants); // Match arms to extract a variant from a string let variant_arms = variants @@ -1355,7 +1376,7 @@ fn deserialize_internally_tagged_enum( cattrs: &attr::Container, tag: &str, ) -> Fragment { - let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants, cattrs); + let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants); // Match arms to extract a variant from a string let variant_arms = variants @@ -1409,7 +1430,7 @@ fn deserialize_adjacently_tagged_enum( split_with_de_lifetime(params); let delife = params.borrowed.de_lifetime(); - let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants, cattrs); + let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants); let variant_arms: &Vec<_> = &variants .iter() @@ -1810,12 +1831,14 @@ fn deserialize_externally_tagged_variant( params, &variant.fields, cattrs, + variant.attrs.has_flatten(), TupleForm::ExternallyTagged(variant_ident), ), Style::Struct => deserialize_struct( params, &variant.fields, cattrs, + variant.attrs.has_flatten(), StructForm::ExternallyTagged(variant_ident), ), } @@ -1859,6 +1882,7 @@ fn deserialize_internally_tagged_variant( params, &variant.fields, cattrs, + variant.attrs.has_flatten(), StructForm::InternallyTagged(variant_ident, deserializer), ), Style::Tuple => unreachable!("checked in serde_derive_internals"), @@ -1909,12 +1933,14 @@ fn deserialize_untagged_variant( params, &variant.fields, cattrs, + variant.attrs.has_flatten(), TupleForm::Untagged(variant_ident, deserializer), ), Style::Struct => deserialize_struct( params, &variant.fields, cattrs, + variant.attrs.has_flatten(), StructForm::Untagged(variant_ident, deserializer), ), } @@ -1985,7 +2011,7 @@ fn deserialize_untagged_newtype_variant( fn deserialize_generated_identifier( fields: &[(&str, Ident, &BTreeSet)], - cattrs: &attr::Container, + has_flatten: bool, is_variant: bool, ignore_variant: Option, fallthrough: Option, @@ -1999,11 +2025,11 @@ fn deserialize_generated_identifier( is_variant, fallthrough, None, - !is_variant && cattrs.has_flatten(), + !is_variant && has_flatten, None, )); - let lifetime = if !is_variant && cattrs.has_flatten() { + let lifetime = if !is_variant && has_flatten { Some(quote!(<'de>)) } else { None @@ -2043,8 +2069,9 @@ fn deserialize_generated_identifier( fn deserialize_field_identifier( fields: &[(&str, Ident, &BTreeSet)], cattrs: &attr::Container, + has_flatten: bool, ) -> Stmts { - let (ignore_variant, fallthrough) = if cattrs.has_flatten() { + let (ignore_variant, fallthrough) = if has_flatten { let ignore_variant = quote!(__other(_serde::__private::de::Content<'de>),); let fallthrough = quote!(_serde::__private::Ok(__Field::__other(__value))); (Some(ignore_variant), Some(fallthrough)) @@ -2058,7 +2085,7 @@ fn deserialize_field_identifier( Stmts(deserialize_generated_identifier( fields, - cattrs, + has_flatten, false, ignore_variant, fallthrough, @@ -2460,6 +2487,7 @@ fn deserialize_map( params: &Parameters, fields: &[Field], cattrs: &attr::Container, + has_flatten: bool, ) -> Fragment { // Create the field names for the fields. let fields_names: Vec<_> = fields @@ -2480,9 +2508,6 @@ fn deserialize_map( }); // Collect contents for flatten fields into a buffer - let has_flatten = fields - .iter() - .any(|field| field.attrs.flatten() && !field.attrs.skip_deserializing()); let let_collect = if has_flatten { Some(quote! { let mut __collect = _serde::__private::Vec::<_serde::__private::Option<( @@ -2681,7 +2706,10 @@ fn deserialize_map_in_place( fields: &[Field], cattrs: &attr::Container, ) -> Fragment { - assert!(!cattrs.has_flatten()); + assert!( + !cattrs.has_flatten(), + "inplace deserialization of maps doesn't support flatten fields" + ); // Create the field names for the fields. let fields_names: Vec<_> = fields diff --git a/serde_derive/src/internals/ast.rs b/serde_derive/src/internals/ast.rs index a28d3ae7e..4ec709952 100644 --- a/serde_derive/src/internals/ast.rs +++ b/serde_derive/src/internals/ast.rs @@ -85,6 +85,7 @@ impl<'a> Container<'a> { for field in &mut variant.fields { if field.attrs.flatten() { has_flatten = true; + variant.attrs.mark_has_flatten(); } field.attrs.rename_by_rules( variant diff --git a/serde_derive/src/internals/attr.rs b/serde_derive/src/internals/attr.rs index 0cfb23bf1..5064d079a 100644 --- a/serde_derive/src/internals/attr.rs +++ b/serde_derive/src/internals/attr.rs @@ -216,6 +216,22 @@ pub struct Container { type_into: Option, remote: Option, identifier: Identifier, + /// `true` if container is a `struct` and it has a field with `#[serde(flatten)]` + /// attribute or it is an `enum` with a struct variant which has a field with + /// `#[serde(flatten)]` attribute. Examples: + /// + /// ```ignore + /// struct Container { + /// #[serde(flatten)] + /// some_field: (), + /// } + /// enum Container { + /// Variant { + /// #[serde(flatten)] + /// some_field: (), + /// }, + /// } + /// ``` has_flatten: bool, serde_path: Option, is_packed: bool, @@ -794,6 +810,18 @@ pub struct Variant { rename_all_rules: RenameAllRules, ser_bound: Option>, de_bound: Option>, + /// `true` if variant is a struct variant which contains a field with `#[serde(flatten)]` + /// attribute. Examples: + /// + /// ```ignore + /// enum Enum { + /// Variant { + /// #[serde(flatten)] + /// some_field: (), + /// }, + /// } + /// ``` + has_flatten: bool, skip_deserializing: bool, skip_serializing: bool, other: bool, @@ -963,6 +991,7 @@ impl Variant { }, ser_bound: ser_bound.get(), de_bound: de_bound.get(), + has_flatten: false, skip_deserializing: skip_deserializing.get(), skip_serializing: skip_serializing.get(), other: other.get(), @@ -1005,6 +1034,14 @@ impl Variant { self.de_bound.as_ref().map(|vec| &vec[..]) } + pub fn has_flatten(&self) -> bool { + self.has_flatten + } + + pub fn mark_has_flatten(&mut self) { + self.has_flatten = true; + } + pub fn skip_deserializing(&self) -> bool { self.skip_deserializing } diff --git a/test_suite/tests/regression/issue1904.rs b/test_suite/tests/regression/issue1904.rs new file mode 100644 index 000000000..99736c078 --- /dev/null +++ b/test_suite/tests/regression/issue1904.rs @@ -0,0 +1,65 @@ +#![allow(dead_code)] // we do not read enum fields +use serde_derive::Deserialize; + +#[derive(Deserialize)] +pub struct Nested; + +#[derive(Deserialize)] +pub enum ExternallyTagged1 { + Tuple(f64, String), + Flatten { + #[serde(flatten)] + nested: Nested, + }, +} + +#[derive(Deserialize)] +pub enum ExternallyTagged2 { + Flatten { + #[serde(flatten)] + nested: Nested, + }, + Tuple(f64, String), +} + +// Internally tagged enums cannot contain tuple variants so not tested here + +#[derive(Deserialize)] +#[serde(tag = "tag", content = "content")] +pub enum AdjacentlyTagged1 { + Tuple(f64, String), + Flatten { + #[serde(flatten)] + nested: Nested, + }, +} + +#[derive(Deserialize)] +#[serde(tag = "tag", content = "content")] +pub enum AdjacentlyTagged2 { + Flatten { + #[serde(flatten)] + nested: Nested, + }, + Tuple(f64, String), +} + +#[derive(Deserialize)] +#[serde(untagged)] +pub enum Untagged1 { + Tuple(f64, String), + Flatten { + #[serde(flatten)] + nested: Nested, + }, +} + +#[derive(Deserialize)] +#[serde(untagged)] +pub enum Untagged2 { + Flatten { + #[serde(flatten)] + nested: Nested, + }, + Tuple(f64, String), +} diff --git a/test_suite/tests/regression/issue2565.rs b/test_suite/tests/regression/issue2565.rs new file mode 100644 index 000000000..65cbb0a31 --- /dev/null +++ b/test_suite/tests/regression/issue2565.rs @@ -0,0 +1,41 @@ +use serde_derive::{Serialize, Deserialize}; +use serde_test::{assert_tokens, Token}; + +#[derive(Serialize, Deserialize, Debug, PartialEq)] +enum Enum { + Simple { + a: i32, + }, + Flatten { + #[serde(flatten)] + flatten: (), + a: i32, + }, +} + +#[test] +fn simple_variant() { + assert_tokens( + &Enum::Simple { a: 42 }, + &[ + Token::StructVariant { name: "Enum", variant: "Simple", len: 1 }, + Token::Str("a"), + Token::I32(42), + Token::StructVariantEnd, + ] + ); +} + +#[test] +fn flatten_variant() { + assert_tokens( + &Enum::Flatten { flatten: (), a: 42 }, + &[ + Token::NewtypeVariant { name: "Enum", variant: "Flatten" }, + Token::Map { len: None }, + Token::Str("a"), + Token::I32(42), + Token::MapEnd, + ] + ); +} diff --git a/test_suite/tests/regression/issue2792.rs b/test_suite/tests/regression/issue2792.rs new file mode 100644 index 000000000..13c0b7103 --- /dev/null +++ b/test_suite/tests/regression/issue2792.rs @@ -0,0 +1,16 @@ +#![allow(dead_code)] // we do not read enum fields +use serde_derive::Deserialize; + +#[derive(Deserialize)] +#[serde(deny_unknown_fields)] +pub enum A { + B { + c: String, + }, + D { + #[serde(flatten)] + e: E, + }, +} +#[derive(Deserialize)] +pub struct E {} diff --git a/test_suite/tests/test_annotations.rs b/test_suite/tests/test_annotations.rs index 566f7d43f..1488c8364 100644 --- a/test_suite/tests/test_annotations.rs +++ b/test_suite/tests/test_annotations.rs @@ -2380,6 +2380,56 @@ fn test_partially_untagged_enum_desugared() { ); } +/// Regression test for https://github.com/serde-rs/serde/issues/1904 +#[test] +fn test_enum_tuple_and_struct_with_flatten() { + #[derive(Serialize, Deserialize, PartialEq, Debug)] + enum Outer { + Tuple(f64, i32), + Flatten { + #[serde(flatten)] + nested: Nested, + }, + } + + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Nested { + a: i32, + b: i32, + } + + assert_tokens( + &Outer::Tuple(1.2, 3), + &[ + Token::TupleVariant { + name: "Outer", + variant: "Tuple", + len: 2, + }, + Token::F64(1.2), + Token::I32(3), + Token::TupleVariantEnd, + ], + ); + assert_tokens( + &Outer::Flatten { + nested: Nested { a: 1, b: 2 }, + }, + &[ + Token::NewtypeVariant { + name: "Outer", + variant: "Flatten", + }, + Token::Map { len: None }, + Token::Str("a"), + Token::I32(1), + Token::Str("b"), + Token::I32(2), + Token::MapEnd, + ], + ); +} + #[test] fn test_partially_untagged_internally_tagged_enum() { #[derive(Serialize, Deserialize, PartialEq, Debug)]