Skip to content

Commit

Permalink
Improve Cow<str> deserialisation (serenity-rs#2738)
Browse files Browse the repository at this point in the history
serde's default Cow deserialisation impl wouldn't cut it here, because
it can't specialize on Cow<str>, but we can.
  • Loading branch information
GnomedDev authored Jan 26, 2024
1 parent 9fff1aa commit 27f98e1
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 5 deletions.
8 changes: 4 additions & 4 deletions src/model/channel/attachment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use reqwest::Client as ReqwestClient;
#[cfg(feature = "model")]
use crate::internal::prelude::*;
use crate::model::prelude::*;
use crate::model::utils::is_false;
use crate::model::utils::{is_false, CowStr};

fn base64_bytes<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
where
Expand All @@ -13,10 +13,10 @@ where
use base64::Engine as _;
use serde::de::Error;

let base64 = <Option<String>>::deserialize(deserializer)?;
let base64 = <Option<CowStr<'de>>>::deserialize(deserializer)?;
let bytes = match base64 {
Some(base64) => {
Some(base64::prelude::BASE64_STANDARD.decode(base64).map_err(D::Error::custom)?)
Some(CowStr(base64)) => {
Some(base64::prelude::BASE64_STANDARD.decode(&*base64).map_err(D::Error::custom)?)
},
None => None,
};
Expand Down
43 changes: 42 additions & 1 deletion src/model/utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt;
use std::hash::Hash;
Expand Down Expand Up @@ -83,6 +84,44 @@ where
}
}

pub(super) struct CowStr<'de>(pub Cow<'de, str>);

impl<'de> serde::Deserialize<'de> for CowStr<'de> {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> StdResult<Self, D::Error> {
struct CowStrVisitor;
impl<'de> serde::de::Visitor<'de> for CowStrVisitor {
type Value = CowStr<'de>;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("a string")
}

fn visit_borrowed_str<E>(self, val: &'de str) -> StdResult<Self::Value, E>
where
E: serde::de::Error,
{
Ok(CowStr(Cow::Borrowed(val)))
}

fn visit_str<E>(self, val: &str) -> StdResult<Self::Value, E>
where
E: serde::de::Error,
{
self.visit_string(val.into())
}

fn visit_string<E>(self, val: String) -> StdResult<Self::Value, E>
where
E: serde::de::Error,
{
Ok(CowStr(Cow::Owned(val)))
}
}

deserializer.deserialize_string(CowStrVisitor)
}
}

pub(super) enum StrOrInt<'de> {
String(String),
Str(&'de str),
Expand Down Expand Up @@ -283,10 +322,12 @@ pub mod stickers {
pub mod comma_separated_string {
use serde::{Deserialize, Deserializer, Serializer};

use super::CowStr;

pub fn deserialize<'de, D: Deserializer<'de>>(
deserializer: D,
) -> Result<Vec<String>, D::Error> {
let str_sequence = String::deserialize(deserializer)?;
let str_sequence = CowStr::deserialize(deserializer)?.0;
let vec = str_sequence.split(", ").map(str::to_owned).collect();

Ok(vec)
Expand Down

0 comments on commit 27f98e1

Please # to comment.