Skip to content

Commit

Permalink
Fix codegen for codec::Compact as type parameters (#651)
Browse files Browse the repository at this point in the history
* Add failing test for compact generic parameter

* WIP refactor type generation

* Fmt

* Remove deprecated rustfmt optionns

* Remove license template path

* Update parent type parameter visitor

* Introduce different methods for generating a type path for a field

* Add comment

* Fix weights refs

* Add extra compact test cases

* Fmt
  • Loading branch information
ascjones authored Sep 21, 2022
1 parent 033ceb2 commit 3bf7ddc
Show file tree
Hide file tree
Showing 8 changed files with 310 additions and 172 deletions.
3 changes: 0 additions & 3 deletions .rustfmt.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ format_code_in_doc_comments = false
comment_width = 80
normalize_comments = true # changed
normalize_doc_attributes = false
license_template_path = "FILE_TEMPLATE" # changed
format_strings = false
format_macro_matchers = false
format_macro_bodies = true
Expand Down Expand Up @@ -57,8 +56,6 @@ skip_children = false
hide_parse_errors = false
error_on_line_overflow = false
error_on_unformatted = false
report_todo = "Always"
report_fixme = "Always"
ignore = []

# Below are `rustfmt` internal settings
Expand Down
2 changes: 1 addition & 1 deletion codegen/src/api/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pub fn generate_constants(
let constant_hash = subxt_metadata::get_constant_hash(metadata, pallet_name, constant_name)
.unwrap_or_else(|_| abort_call_site!("Metadata information for the constant {}_{} could not be found", pallet_name, constant_name));

let return_ty = type_gen.resolve_type_path(constant.ty.id(), &[]);
let return_ty = type_gen.resolve_type_path(constant.ty.id());
let docs = &constant.docs;

quote! {
Expand Down
2 changes: 1 addition & 1 deletion codegen/src/api/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub fn generate_events(
}
}
});
let event_type = type_gen.resolve_type_path(event.ty.id(), &[]);
let event_type = type_gen.resolve_type_path(event.ty.id());
let event_ty = type_gen.resolve_type(event.ty.id());
let docs = event_ty.docs();

Expand Down
6 changes: 3 additions & 3 deletions codegen/src/api/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ fn generate_storage_entry_fns(
.enumerate()
.map(|(i, f)| {
let field_name = format_ident!("_{}", syn::Index::from(i));
let field_type = type_gen.resolve_type_path(f.id(), &[]);
let field_type = type_gen.resolve_type_path(f.id());
(field_name, field_type)
})
.collect::<Vec<_>>();
Expand Down Expand Up @@ -142,7 +142,7 @@ fn generate_storage_entry_fns(
(fields, key_impl)
}
_ => {
let ty_path = type_gen.resolve_type_path(key.id(), &[]);
let ty_path = type_gen.resolve_type_path(key.id());
let fields = vec![(format_ident!("_0"), ty_path)];
let hasher = hashers.get(0).unwrap_or_else(|| {
abort_call_site!("No hasher found for single key")
Expand Down Expand Up @@ -173,7 +173,7 @@ fn generate_storage_entry_fns(
StorageEntryType::Plain(ref ty) => ty,
StorageEntryType::Map { ref value, .. } => value,
};
let storage_entry_value_ty = type_gen.resolve_type_path(storage_entry_ty.id(), &[]);
let storage_entry_value_ty = type_gen.resolve_type_path(storage_entry_ty.id());

let docs = &storage_entry.docs;
let docs_token = quote! { #( #[doc = #docs ] )* };
Expand Down
2 changes: 1 addition & 1 deletion codegen/src/types/composite_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ impl CompositeDefFields {

for field in fields {
let type_path =
type_gen.resolve_type_path(field.ty().id(), parent_type_params);
type_gen.resolve_field_type_path(field.ty().id(), parent_type_params);
let field_type = CompositeDefFieldType::new(
field.ty().id(),
type_path,
Expand Down
159 changes: 126 additions & 33 deletions codegen/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ pub use self::{
type_path::{
TypeParameter,
TypePath,
TypePathSubstitute,
TypePathType,
},
};
Expand Down Expand Up @@ -145,12 +144,46 @@ impl<'a> TypeGenerator<'a> {
.clone()
}

/// Get the type path for a field of a struct or an enum variant, providing any generic
/// type parameters from the containing type. This is for identifying where a generic type
/// parameter is used in a field type e.g.
///
/// ```rust
/// struct S<T> {
/// a: T, // `T` is the "parent" type param from the containing type.
/// b: Vec<Option<T>>, // nested use of generic type param `T`.
/// }
/// ```
///
/// This allows generating the correct generic field type paths.
///
/// # Panics
///
/// If no type with the given id found in the type registry.
pub fn resolve_field_type_path(
&self,
id: u32,
parent_type_params: &[TypeParameter],
) -> TypePath {
self.resolve_type_path_recurse(id, true, parent_type_params)
}

/// Get the type path for the given type identifier.
///
/// # Panics
///
/// If no type with the given id found in the type registry.
pub fn resolve_type_path(
pub fn resolve_type_path(&self, id: u32) -> TypePath {
self.resolve_type_path_recurse(id, false, &[])
}

/// Visit each node in a possibly nested type definition to produce a type path.
///
/// e.g `Result<GenericStruct<NestedGenericStruct<T>>, String>`
fn resolve_type_path_recurse(
&self,
id: u32,
is_field: bool,
parent_type_params: &[TypeParameter],
) -> TypePath {
if let Some(parent_type_param) = parent_type_params
Expand All @@ -171,40 +204,100 @@ impl<'a> TypeGenerator<'a> {
)
}

let params_type_ids = match ty.type_def() {
TypeDef::Array(arr) => vec![arr.type_param().id()],
TypeDef::Sequence(seq) => vec![seq.type_param().id()],
TypeDef::Tuple(tuple) => tuple.fields().iter().map(|f| f.id()).collect(),
TypeDef::Compact(compact) => vec![compact.type_param().id()],
TypeDef::BitSequence(seq) => {
vec![seq.bit_order_type().id(), seq.bit_store_type().id()]
let params = ty
.type_params()
.iter()
.filter_map(|f| {
f.ty().map(|f| {
self.resolve_type_path_recurse(f.id(), false, parent_type_params)
})
})
.collect();

let ty = match ty.type_def() {
TypeDef::Composite(_) | TypeDef::Variant(_) => {
let joined_path = ty.path().segments().join("::");
if let Some(substitute_type_path) =
self.type_substitutes.get(&joined_path)
{
TypePathType::Path {
path: substitute_type_path.clone(),
params,
}
} else {
TypePathType::from_type_def_path(
ty.path(),
self.types_mod_ident.clone(),
params,
)
}
}
TypeDef::Primitive(primitive) => {
TypePathType::Primitive {
def: primitive.clone(),
}
}
TypeDef::Array(arr) => {
TypePathType::Array {
len: arr.len() as usize,
of: Box::new(self.resolve_type_path_recurse(
arr.type_param().id(),
false,
parent_type_params,
)),
}
}
TypeDef::Sequence(seq) => {
TypePathType::Vec {
of: Box::new(self.resolve_type_path_recurse(
seq.type_param().id(),
false,
parent_type_params,
)),
}
}
TypeDef::Tuple(tuple) => {
TypePathType::Tuple {
elements: tuple
.fields()
.iter()
.map(|f| {
self.resolve_type_path_recurse(
f.id(),
false,
parent_type_params,
)
})
.collect(),
}
}
TypeDef::Compact(compact) => {
TypePathType::Compact {
inner: Box::new(self.resolve_type_path_recurse(
compact.type_param().id(),
false,
parent_type_params,
)),
is_field,
}
}
_ => {
ty.type_params()
.iter()
.filter_map(|f| f.ty().map(|f| f.id()))
.collect()
TypeDef::BitSequence(bitseq) => {
TypePathType::BitVec {
bit_order_type: Box::new(self.resolve_type_path_recurse(
bitseq.bit_order_type().id(),
false,
parent_type_params,
)),
bit_store_type: Box::new(self.resolve_type_path_recurse(
bitseq.bit_store_type().id(),
false,
parent_type_params,
)),
}
}
};

let params = params_type_ids
.iter()
.map(|tp| self.resolve_type_path(*tp, parent_type_params))
.collect::<Vec<_>>();

let joined_path = ty.path().segments().join("::");
if let Some(substitute_type_path) = self.type_substitutes.get(&joined_path) {
TypePath::Substitute(TypePathSubstitute {
path: substitute_type_path.clone(),
params,
})
} else {
TypePath::Type(TypePathType {
ty,
params,
root_mod_ident: self.types_mod_ident.clone(),
})
}
TypePath::Type(ty)
}

/// Returns the derives to be applied to all generated types.
Expand All @@ -228,7 +321,7 @@ pub struct Module {
name: Ident,
root_mod: Ident,
children: BTreeMap<Ident, Module>,
types: BTreeMap<scale_info::Path<scale_info::form::PortableForm>, TypeDefGen>,
types: BTreeMap<scale_info::Path<PortableForm>, TypeDefGen>,
}

impl ToTokens for Module {
Expand Down
48 changes: 48 additions & 0 deletions codegen/src/types/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use super::*;
use pretty_assertions::assert_eq;
use scale_info::{
meta_type,
scale,
Registry,
TypeInfo,
};
Expand Down Expand Up @@ -371,6 +372,53 @@ fn compact_fields() {
)
}

#[test]
fn compact_generic_parameter() {
use scale::Compact;

#[allow(unused)]
#[derive(TypeInfo)]
struct S {
a: Option<<u128 as codec::HasCompact>::Type>,
nested: Option<Result<Compact<u128>, u8>>,
vector: Vec<Compact<u16>>,
array: [Compact<u8>; 32],
tuple: (Compact<u8>, Compact<u16>),
}

let mut registry = Registry::new();
registry.register_type(&meta_type::<S>());
let portable_types: PortableRegistry = registry.into();

let type_gen = TypeGenerator::new(
&portable_types,
"root",
Default::default(),
Default::default(),
);
let types = type_gen.generate_types_mod();
let tests_mod = get_mod(&types, MOD_PATH).unwrap();

assert_eq!(
tests_mod.into_token_stream().to_string(),
quote! {
pub mod tests {
use super::root;

#[derive(::subxt::ext::codec::Decode, ::subxt::ext::codec::Encode, Debug)]
pub struct S {
pub a: ::core::option::Option<::subxt::ext::codec::Compact<::core::primitive::u128> >,
pub nested: ::core::option::Option<::core::result::Result<::subxt::ext::codec::Compact<::core::primitive::u128>, ::core::primitive::u8 > >,
pub vector: ::std::vec::Vec<::subxt::ext::codec::Compact<::core::primitive::u16> >,
pub array: [::subxt::ext::codec::Compact<::core::primitive::u8>; 32usize],
pub tuple: (::subxt::ext::codec::Compact<::core::primitive::u8>, ::subxt::ext::codec::Compact<::core::primitive::u16>,),
}
}
}
.to_string()
)
}

#[test]
fn generate_array_field() {
#[allow(unused)]
Expand Down
Loading

0 comments on commit 3bf7ddc

Please # to comment.