diff --git a/impl/src/ast.rs b/impl/src/ast.rs index 134ce799..2f8bd6c1 100644 --- a/impl/src/ast.rs +++ b/impl/src/ast.rs @@ -79,6 +79,8 @@ impl<'a> Enum<'a> { } if let Some(display) = &mut variant.attrs.display { display.expand_shorthand(&variant.fields); + } else if variant.attrs.transparent.is_none() { + variant.attrs.transparent = attrs.transparent; } Ok(variant) }) diff --git a/impl/src/attr.rs b/impl/src/attr.rs index 0c55b316..f48d506f 100644 --- a/impl/src/attr.rs +++ b/impl/src/attr.rs @@ -12,6 +12,7 @@ pub struct Attrs<'a> { pub source: Option<&'a Attribute>, pub backtrace: Option<&'a Attribute>, pub from: Option<&'a Attribute>, + pub transparent: Option<&'a Attribute>, } #[derive(Clone)] @@ -29,18 +30,12 @@ pub fn get(input: &[Attribute]) -> Result { source: None, backtrace: None, from: None, + transparent: None, }; for attr in input { if attr.path.is_ident("error") { - let display = parse_display(attr)?; - if attrs.display.is_some() { - return Err(Error::new_spanned( - attr, - "only one #[error(...)] attribute is allowed", - )); - } - attrs.display = Some(display); + parse_error_attribute(&mut attrs, attr)?; } else if attr.path.is_ident("source") { require_empty_attribute(attr)?; if attrs.source.is_some() { @@ -68,15 +63,36 @@ pub fn get(input: &[Attribute]) -> Result { Ok(attrs) } -fn parse_display(attr: &Attribute) -> Result { +fn parse_error_attribute<'a>(attrs: &mut Attrs<'a>, attr: &'a Attribute) -> Result<()> { + syn::custom_keyword!(transparent); + attr.parse_args_with(|input: ParseStream| { - Ok(Display { + if input.parse::>()?.is_some() { + if attrs.transparent.is_some() { + return Err(Error::new_spanned( + attr, + "duplicate #[error(transparent)] attribute", + )); + } + attrs.transparent = Some(attr); + return Ok(()); + } + + let display = Display { original: attr, fmt: input.parse()?, args: parse_token_expr(input, false)?, was_shorthand: false, has_bonus_display: false, - }) + }; + if attrs.display.is_some() { + return Err(Error::new_spanned( + attr, + "only one #[error(...)] attribute is allowed", + )); + } + attrs.display = Some(display); + Ok(()) }) } diff --git a/impl/src/expand.rs b/impl/src/expand.rs index 3b821e74..51b9c9ef 100644 --- a/impl/src/expand.rs +++ b/impl/src/expand.rs @@ -1,7 +1,6 @@ use crate::ast::{Enum, Field, Input, Struct}; -use crate::valid; use proc_macro2::TokenStream; -use quote::{format_ident, quote, quote_spanned}; +use quote::{format_ident, quote, quote_spanned, ToTokens}; use syn::spanned::Spanned; use syn::{DeriveInput, Member, PathArguments, Result, Type}; @@ -18,7 +17,12 @@ fn impl_struct(input: Struct) -> TokenStream { let ty = &input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); - let source_method = input.source_field().map(|source_field| { + let source_body = if input.attrs.transparent.is_some() { + let only_field = &input.fields[0].member; + Some(quote! { + std::error::Error::source(self.#only_field.as_dyn_error()) + }) + } else if let Some(source_field) = input.source_field() { let source = &source_field.member; let asref = if type_is_option(source_field.ty) { Some(quote_spanned!(source.span()=> .as_ref()?)) @@ -26,10 +30,17 @@ fn impl_struct(input: Struct) -> TokenStream { None }; let dyn_error = quote_spanned!(source.span()=> self.#source #asref.as_dyn_error()); + Some(quote! { + std::option::Option::Some(#dyn_error) + }) + } else { + None + }; + let source_method = source_body.map(|body| { quote! { fn source(&self) -> std::option::Option<&(dyn std::error::Error + 'static)> { use thiserror::private::AsDynError; - std::option::Option::Some(#dyn_error) + #body } } }); @@ -76,7 +87,12 @@ fn impl_struct(input: Struct) -> TokenStream { } }); - let display_impl = input.attrs.display.as_ref().map(|display| { + let display_body = if input.attrs.transparent.is_some() { + let only_field = &input.fields[0].member; + Some(quote! { + std::fmt::Display::fmt(&self.#only_field, __formatter) + }) + } else if let Some(display) = &input.attrs.display { let use_as_display = if display.has_bonus_display { Some(quote! { #[allow(unused_imports)] @@ -86,13 +102,20 @@ fn impl_struct(input: Struct) -> TokenStream { None }; let pat = fields_pat(&input.fields); + Some(quote! { + #use_as_display + #[allow(unused_variables)] + let Self #pat = self; + #display + }) + } else { + None + }; + let display_impl = display_body.map(|body| { quote! { impl #impl_generics std::fmt::Display for #ty #ty_generics #where_clause { fn fmt(&self, __formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - #use_as_display - #[allow(unused_variables)] - let Self #pat = self; - #display + #body } } } @@ -128,22 +151,27 @@ fn impl_enum(input: Enum) -> TokenStream { let source_method = if input.has_source() { let arms = input.variants.iter().map(|variant| { let ident = &variant.ident; - match variant.source_field() { - Some(source_field) => { - let source = &source_field.member; - let asref = if type_is_option(source_field.ty) { - Some(quote_spanned!(source.span()=> .as_ref()?)) - } else { - None - }; - let dyn_error = quote_spanned!(source.span()=> source #asref.as_dyn_error()); - quote! { - #ty::#ident {#source: source, ..} => std::option::Option::Some(#dyn_error), - } + if variant.attrs.transparent.is_some() { + let only_field = &variant.fields[0].member; + let source = quote!(std::error::Error::source(transparent.as_dyn_error())); + quote! { + #ty::#ident {#only_field: transparent} => #source, + } + } else if let Some(source_field) = variant.source_field() { + let source = &source_field.member; + let asref = if type_is_option(source_field.ty) { + Some(quote_spanned!(source.span()=> .as_ref()?)) + } else { + None + }; + let dyn_error = quote_spanned!(source.span()=> source #asref.as_dyn_error()); + quote! { + #ty::#ident {#source: source, ..} => std::option::Option::Some(#dyn_error), } - None => quote! { + } else { + quote! { #ty::#ident {..} => std::option::Option::None, - }, + } } }); Some(quote! { @@ -228,8 +256,7 @@ fn impl_enum(input: Enum) -> TokenStream { v.attrs .display .as_ref() - .expect(valid::CHECKED) - .has_bonus_display + .map_or(false, |display| display.has_bonus_display) }) { Some(quote! { #[allow(unused_imports)] @@ -244,7 +271,16 @@ fn impl_enum(input: Enum) -> TokenStream { None }; let arms = input.variants.iter().map(|variant| { - let display = variant.attrs.display.as_ref().expect(valid::CHECKED); + let display = match &variant.attrs.display { + Some(display) => display.to_token_stream(), + None => { + let only_field = match &variant.fields[0].member { + Member::Named(ident) => ident.clone(), + Member::Unnamed(index) => format_ident!("_{}", index), + }; + quote!(std::fmt::Display::fmt(#only_field, __formatter)) + } + }; let ident = &variant.ident; let pat = fields_pat(&variant.fields); quote! { @@ -297,7 +333,7 @@ fn fields_pat(fields: &[Field]) -> TokenStream { Some(Member::Named(_)) => quote!({ #(#members),* }), Some(Member::Unnamed(_)) => { let vars = members.map(|member| match member { - Member::Unnamed(member) => format_ident!("_{}", member.index), + Member::Unnamed(member) => format_ident!("_{}", member), Member::Named(_) => unreachable!(), }); quote!((#(#vars),*)) diff --git a/impl/src/prop.rs b/impl/src/prop.rs index 940b4f89..e011848f 100644 --- a/impl/src/prop.rs +++ b/impl/src/prop.rs @@ -19,7 +19,7 @@ impl Enum<'_> { pub(crate) fn has_source(&self) -> bool { self.variants .iter() - .any(|variant| variant.source_field().is_some()) + .any(|variant| variant.source_field().is_some() || variant.attrs.transparent.is_some()) } pub(crate) fn has_backtrace(&self) -> bool { @@ -30,10 +30,15 @@ impl Enum<'_> { pub(crate) fn has_display(&self) -> bool { self.attrs.display.is_some() + || self.attrs.transparent.is_some() || self .variants .iter() .any(|variant| variant.attrs.display.is_some()) + || self + .variants + .iter() + .all(|variant| variant.attrs.transparent.is_some()) } } diff --git a/impl/src/valid.rs b/impl/src/valid.rs index 7b6750cb..ffe7488e 100644 --- a/impl/src/valid.rs +++ b/impl/src/valid.rs @@ -4,8 +4,6 @@ use quote::ToTokens; use std::collections::BTreeSet as Set; use syn::{Error, Member, Result}; -pub(crate) const CHECKED: &str = "checked in validation"; - impl Input<'_> { pub(crate) fn validate(&self) -> Result<()> { match self { @@ -18,6 +16,20 @@ impl Input<'_> { impl Struct<'_> { fn validate(&self) -> Result<()> { check_non_field_attrs(&self.attrs)?; + if let Some(transparent) = self.attrs.transparent { + if self.fields.len() != 1 { + return Err(Error::new_spanned( + transparent, + "#[error(transparent)] requires exactly one field", + )); + } + if let Some(source) = self.fields.iter().filter_map(|f| f.attrs.source).next() { + return Err(Error::new_spanned( + source, + "transparent error struct can't contain #[source]", + )); + } + } check_field_attrs(&self.fields)?; for field in &self.fields { field.validate()?; @@ -32,7 +44,8 @@ impl Enum<'_> { let has_display = self.has_display(); for variant in &self.variants { variant.validate()?; - if has_display && variant.attrs.display.is_none() { + if has_display && variant.attrs.display.is_none() && variant.attrs.transparent.is_none() + { return Err(Error::new_spanned( variant.original, "missing #[error(\"...\")] display attribute", @@ -58,6 +71,20 @@ impl Enum<'_> { impl Variant<'_> { fn validate(&self) -> Result<()> { check_non_field_attrs(&self.attrs)?; + if self.attrs.transparent.is_some() { + if self.fields.len() != 1 { + return Err(Error::new_spanned( + self.original, + "#[error(transparent)] requires exactly one field", + )); + } + if let Some(source) = self.fields.iter().filter_map(|f| f.attrs.source).next() { + return Err(Error::new_spanned( + source, + "transparent variant can't contain #[source]", + )); + } + } check_field_attrs(&self.fields)?; for field in &self.fields { field.validate()?; @@ -97,6 +124,14 @@ fn check_non_field_attrs(attrs: &Attrs) -> Result<()> { "not expected here; the #[backtrace] attribute belongs on a specific field", )); } + if let Some(display) = &attrs.display { + if attrs.transparent.is_some() { + return Err(Error::new_spanned( + display.original, + "cannot have both #[error(transparent)] and a display attribute", + )); + } + } Ok(()) } diff --git a/tests/test_transparent.rs b/tests/test_transparent.rs new file mode 100644 index 00000000..8fba5fef --- /dev/null +++ b/tests/test_transparent.rs @@ -0,0 +1,57 @@ +use anyhow::anyhow; +use std::error::Error as _; +use std::io; +use thiserror::Error; + +#[test] +fn test_transparent_struct() { + #[derive(Error, Debug)] + #[error(transparent)] + struct Error(ErrorKind); + + #[derive(Error, Debug)] + enum ErrorKind { + #[error("E0")] + E0, + #[error("E1")] + E1(#[from] io::Error), + } + + let error = Error(ErrorKind::E0); + assert_eq!("E0", error.to_string()); + assert!(error.source().is_none()); + + let io = io::Error::new(io::ErrorKind::Other, "oh no!"); + let error = Error(ErrorKind::from(io)); + assert_eq!("E1", error.to_string()); + error.source().unwrap().downcast_ref::().unwrap(); +} + +#[test] +fn test_transparent_enum() { + #[derive(Error, Debug)] + enum Error { + #[error("this failed")] + This, + #[error(transparent)] + Other(anyhow::Error), + } + + let error = Error::This; + assert_eq!("this failed", error.to_string()); + + let error = Error::Other(anyhow!("inner").context("outer")); + assert_eq!("outer", error.to_string()); + assert_eq!("inner", error.source().unwrap().to_string()); +} + +#[test] +fn test_anyhow() { + #[derive(Error, Debug)] + #[error(transparent)] + struct Any(#[from] anyhow::Error); + + let error = Any::from(anyhow!("inner").context("outer")); + assert_eq!("outer", error.to_string()); + assert_eq!("inner", error.source().unwrap().to_string()); +}