Skip to content

Commit

Permalink
Allow non-rust idents for renamed field names in SerializeValue
Browse files Browse the repository at this point in the history
this adds support for UDTs that have fields that are not valid rust
idents but are valid scylla field names
  • Loading branch information
nrxus committed Dec 3, 2024
1 parent 1c0d353 commit 8b579e6
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
28 changes: 28 additions & 0 deletions scylla-cql/src/types/serialize/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2824,4 +2824,32 @@ pub(crate) mod tests {

assert_eq!(reference, row);
}

#[test]
fn test_udt_with_non_rust_ident() {
#[derive(SerializeValue, Debug)]
#[scylla(crate = crate)]
struct UdtWithNonRustIdent {
#[scylla(rename = "a$a")]
a: i32,
}

let typ = ColumnType::UserDefinedType {
type_name: "typ".into(),
keyspace: "ks".into(),
field_types: vec![("a$a".into(), ColumnType::Int)],
};
let value = UdtWithNonRustIdent { a: 42 };

let mut reference = Vec::new();
// Total length of the struct
reference.extend_from_slice(&8i32.to_be_bytes());
// Field 'a'
reference.extend_from_slice(&(std::mem::size_of_val(&value.a) as i32).to_be_bytes());
reference.extend_from_slice(&value.a.to_be_bytes());

let udt = do_serialize(value, &typ);

assert_eq!(reference, udt);
}
}
13 changes: 6 additions & 7 deletions scylla-macros/src/serialize/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::collections::HashMap;

use darling::FromAttributes;
use proc_macro::TokenStream;
use proc_macro2::Span;
use syn::parse_quote;

use crate::Flavor;
Expand Down Expand Up @@ -327,14 +326,14 @@ impl Generator for FieldSortingGenerator<'_> {
.generate_udt_type_match(parse_quote!(#crate_path::UdtTypeCheckErrorKind::NotUdt)),
);

fn make_visited_flag_ident(field_name: &str) -> syn::Ident {
syn::Ident::new(&format!("visited_flag_{}", field_name), Span::call_site())
fn make_visited_flag_ident(field_name: &syn::Ident) -> syn::Ident {
syn::Ident::new(&format!("visited_flag_{}", field_name), field_name.span())
}

// Generate a "visited" flag for each field
let visited_flag_names = rust_field_names
let visited_flag_names = rust_field_idents
.iter()
.map(|s| make_visited_flag_ident(s))
.map(make_visited_flag_ident)
.collect::<Vec<_>>();
statements.extend::<Vec<_>>(parse_quote! {
#(let mut #visited_flag_names = false;)*
Expand All @@ -347,11 +346,11 @@ impl Generator for FieldSortingGenerator<'_> {
.fields
.iter()
.filter(|f| !f.attrs.ignore_missing)
.map(|f| f.field_name());
.map(|f| &f.ident);
// An iterator over visited flags of Rust fields that can't be ignored
// (i.e., if UDT misses a corresponding field, an error should be raised).
let nonignorable_visited_flag_names =
nonignorable_rust_field_names.map(|s| make_visited_flag_ident(&s));
nonignorable_rust_field_names.map(make_visited_flag_ident);

// Generate a variable that counts down visited fields.
let field_count = self.ctx.fields.len();
Expand Down

0 comments on commit 8b579e6

Please # to comment.