diff --git a/bitcode_derive/src/attribute.rs b/bitcode_derive/src/attribute.rs index a6a63bf..4454f93 100644 --- a/bitcode_derive/src/attribute.rs +++ b/bitcode_derive/src/attribute.rs @@ -9,29 +9,40 @@ enum BitcodeAttr { BoundType(Type), CrateAlias(Path), Skip, + /// `#[bitcode(with = "LocalType")]`, shorthand for both `encode_with` and `decode_with`. + With(Type), + /// `#[bitcode(encode_with = "LocalType")]`. + EncodeWith(Type), + /// `#[bitcode(decode_with = "LocalType")]`. + DecodeWith(Type), +} + +/// Parses a `#[bitcode(name = "Type")]` string literal value into a [`Type`]. +fn parse_type_attr(nested: &Meta) -> Result { + match nested { + Meta::NameValue(name_value) => { + let expr = &name_value.value; + let str_lit = match expr { + Expr::Lit(ExprLit { + lit: Lit::Str(v), .. + }) => v, + _ => return err(&expr, "expected string e.g. \"LocalType\""), + }; + let value = TokenStream::from_str(&str_lit.value()).unwrap(); + parse2(value).map_err(|e| error(str_lit, &format!("{e}"))) + } + _ => err(&nested, "expected name value"), + } } impl BitcodeAttr { fn new(nested: &Meta) -> Result { let path = path_ident_string(nested.path(), &nested)?; match path.as_str() { - "bound_type" => match nested { - Meta::NameValue(name_value) => { - let expr = &name_value.value; - let str_lit = match expr { - Expr::Lit(ExprLit { - lit: Lit::Str(v), .. - }) => v, - _ => return err(&expr, "expected string e.g. \"T\""), - }; - - let value = TokenStream::from_str(&str_lit.value()).unwrap(); - Ok(Self::BoundType( - parse2(value).map_err(|e| error(str_lit, &format!("{e}")))?, - )) - } - _ => err(&nested, "expected name value"), - }, + "bound_type" => Ok(Self::BoundType(parse_type_attr(nested)?)), + "with" => Ok(Self::With(parse_type_attr(nested)?)), + "encode_with" => Ok(Self::EncodeWith(parse_type_attr(nested)?)), + "decode_with" => Ok(Self::DecodeWith(parse_type_attr(nested)?)), "crate" => match nested { Meta::NameValue(name_value) => { let expr = &name_value.value; @@ -87,6 +98,34 @@ impl BitcodeAttr { err(nested, "can only apply skip to fields") } } + Self::With(ty) => { + Self::set_with(attrs, nested, ty.clone(), true)?; + Self::set_with(attrs, nested, ty, false) + } + Self::EncodeWith(ty) => Self::set_with(attrs, nested, ty, true), + Self::DecodeWith(ty) => Self::set_with(attrs, nested, ty, false), + } + } + + /// Sets `encode_with` (`encode == true`) or `decode_with` (`encode == false`) on a field. + fn set_with(attrs: &mut BitcodeAttrs, nested: &Meta, ty: Type, encode: bool) -> Result<()> { + if let AttrType::Field { + encode_with, + decode_with, + .. + } = &mut attrs.attr_type + { + let slot = if encode { encode_with } else { decode_with }; + if slot.is_some() { + return err(nested, "duplicate with/encode_with/decode_with"); + } + *slot = Some(ty); + Ok(()) + } else { + err( + nested, + "can only apply with/encode_with/decode_with to fields", + ) } } } @@ -104,7 +143,13 @@ pub struct BitcodeAttrs { enum AttrType { Derive, Variant, - Field { bound_type: Option }, + Field { + bound_type: Option, + /// Encode this field as a different (local) type via `From`/`Into`. + encode_with: Option, + /// Decode this field as a different (local) type via `From`/`Into`. + decode_with: Option, + }, } impl BitcodeAttrs { @@ -123,6 +168,22 @@ impl BitcodeAttrs { } } + /// The local type to encode this field as, if `#[bitcode(with)]`/`#[bitcode(encode_with)]` was set. + pub fn encode_with(&self) -> Option<&Type> { + match &self.attr_type { + AttrType::Field { encode_with, .. } => encode_with.as_ref(), + _ => None, + } + } + + /// The local type to decode this field as, if `#[bitcode(with)]`/`#[bitcode(decode_with)]` was set. + pub fn decode_with(&self) -> Option<&Type> { + match &self.attr_type { + AttrType::Field { decode_with, .. } => decode_with.as_ref(), + _ => None, + } + } + pub fn parse_derive(attrs: &[Attribute]) -> Result { let mut ret = Self::new(AttrType::Derive); ret.parse_inner(attrs)?; @@ -136,7 +197,11 @@ impl BitcodeAttrs { } pub fn parse_field(attrs: &[Attribute], _parent_attrs: &Self) -> Result { - let mut ret = Self::new(AttrType::Field { bound_type: None }); + let mut ret = Self::new(AttrType::Field { + bound_type: None, + encode_with: None, + decode_with: None, + }); ret.parse_inner(attrs)?; Ok(ret) } diff --git a/bitcode_derive/src/decode.rs b/bitcode_derive/src/decode.rs index 0421537..6cbf784 100644 --- a/bitcode_derive/src/decode.rs +++ b/bitcode_derive/src/decode.rs @@ -45,7 +45,9 @@ impl crate::shared::Item for Item { ) -> TokenStream { match self { Self::Type => { - let mut de_type = replace_lifetimes(field_type, DE_LIFETIME).to_token_stream(); + // `#[bitcode(decode_with = "Local")]` stores the decoder for `Local`, not the field. + let base_type = field_attrs.decode_with().unwrap_or(field_type); + let mut de_type = replace_lifetimes(base_type, DE_LIFETIME).to_token_stream(); if field_attrs.skip { de_type = quote! { ::core::marker::PhantomData<#de_type> }; } @@ -63,10 +65,15 @@ impl crate::shared::Item for Item { }, // Only used by enum variants. Self::Decode => { + let de_type = replace_lifetimes(field_type, DE_LIFETIME); let value = if field_attrs.skip { quote! { Default::default() } + } else if field_attrs.decode_with().is_some() { + quote! { + ::core::convert::Into::<#de_type>::into(self.#global_field_name.decode()) + } } else { quote! { self.#global_field_name.decode() @@ -86,6 +93,11 @@ impl crate::shared::Item for Item { quote! {{ (#target).write(Default::default()); }} + } else if field_attrs.decode_with().is_some() { + quote! {{ + let __local = self.#global_field_name.decode(); + (#target).write(::core::convert::Into::into(__local)); + }} } else { quote! { self.#global_field_name.decode_in_place(#target); @@ -263,6 +275,10 @@ impl crate::shared::Derive<{ Item::COUNT }> for Decode { Some(parse_quote!(Default)) } + fn with_type(&self, field_attrs: &BitcodeAttrs) -> Option { + field_attrs.decode_with().cloned() + } + fn derive_impl( &self, crate_name: &Path, diff --git a/bitcode_derive/src/encode.rs b/bitcode_derive/src/encode.rs index d134a37..b9fa895 100644 --- a/bitcode_derive/src/encode.rs +++ b/bitcode_derive/src/encode.rs @@ -37,7 +37,9 @@ impl crate::shared::Item for Item { ) -> TokenStream { match self { Self::Type => { - let mut static_type = replace_lifetimes(field_type, "static").to_token_stream(); + // `#[bitcode(encode_with = "Local")]` stores the encoder for `Local`, not the field. + let base_type = field_attrs.encode_with().unwrap_or(field_type); + let mut static_type = replace_lifetimes(base_type, "static").to_token_stream(); if field_attrs.skip { static_type = quote! { ::core::marker::PhantomData<#static_type> }; } @@ -50,6 +52,43 @@ impl crate::shared::Item for Item { #global_field_name: Default::default(), }, Self::Encode | Self::EncodeVectored => { + // `#[bitcode(encode_with = "Local")]`: convert `&Field` to `Local`, then encode `Local`. + // The conversion produces an owned `Local`, so we can't yield `&Local` into the + // vectored fast-path; instead the vectored case stamps out a loop calling `encode`. + if let Some(local) = field_attrs.encode_with().filter(|_| !field_attrs.skip) { + let local_static = replace_lifetimes(local, "static"); + let needs_transmute = &local_static != local; + let encode_one = |field_ref: TokenStream| { + if needs_transmute { + // HACK: Since encoders don't have lifetimes we can't reference as Encode>::Encoder since 'a + // does not exist. Instead we replace this with as Encode>::Encoder and transmute it to + // T<'a>. No encoder actually encodes T<'static> any differently from T<'a> so this is sound. + let local_underscore = replace_lifetimes(local, "_"); + quote! {{ + let __local: #local_underscore = ::core::convert::From::from(#field_ref); + self.#global_field_name.encode(unsafe { + ::core::mem::transmute::<&#local_underscore, &#local_static>(&__local) + }); + }} + } else { + quote! {{ + let __local: #local_static = ::core::convert::From::from(#field_ref); + self.#global_field_name.encode(&__local); + }} + } + }; + return if matches!(self, Self::EncodeVectored) { + let encode_one = encode_one(quote! { &__me.#real_field_name }); + quote! { + for __me in i.clone() { + #encode_one + } + } + } else { + encode_one(quote! { #field_name }) + }; + } + let static_type = replace_lifetimes(field_type, "static"); let value = if field_attrs.skip { quote! { @@ -248,6 +287,10 @@ impl crate::shared::Derive<{ Item::COUNT }> for Encode { None } + fn with_type(&self, field_attrs: &BitcodeAttrs) -> Option { + field_attrs.encode_with().cloned() + } + fn derive_impl( &self, crate_name: &Path, diff --git a/bitcode_derive/src/shared.rs b/bitcode_derive/src/shared.rs index 879e114..46824b7 100644 --- a/bitcode_derive/src/shared.rs +++ b/bitcode_derive/src/shared.rs @@ -130,6 +130,10 @@ pub trait Derive { /// Bound for skipped fields, e.g. `Default` fn skip_bound(&self) -> Option; + /// For `#[bitcode(with)]` fields, the local type whose generic parameters should be bounded + /// instead of the field's own type (which need not implement `Encode`/`Decode`). + fn with_type(&self, field_attrs: &BitcodeAttrs) -> Option; + /// Generates the derive implementation. fn derive_impl( &self, @@ -157,7 +161,12 @@ pub trait Derive { Some(self.bound(crate_name)) }; if let Some(bound) = bound { - bounds.add_bound_type(field.clone(), &field_attrs, bound); + let mut field = field.clone(); + if let Some(with_type) = self.with_type(&field_attrs) { + // Bound the local type instead of the field's own (remote) type. + field.ty = with_type; + } + bounds.add_bound_type(field, &field_attrs, bound); } Ok(field_attrs) }) diff --git a/src/derive/mod.rs b/src/derive/mod.rs index 235d711..da13dab 100644 --- a/src/derive/mod.rs +++ b/src/derive/mod.rs @@ -388,3 +388,113 @@ mod tests { } } } + +#[cfg(test)] +mod with_tests { + use crate::{decode, encode, Decode, Encode}; + use alloc::string::String; + + // "Remote" types that intentionally don't implement Encode/Decode themselves. + #[derive(Debug, PartialEq, Clone)] + struct RemoteStr(String); + + #[derive(Debug, PartialEq, Clone, Copy)] + struct Meters(f32); + + // Local proxy types used via `#[bitcode(with = ...)]`. + #[derive(Encode, Decode)] + struct LocalStr<'a>(&'a str); + + #[derive(Encode, Decode)] + struct LocalMeters(f32); + + impl<'a> From<&'a RemoteStr> for LocalStr<'a> { + fn from(v: &'a RemoteStr) -> Self { + LocalStr(v.0.as_str()) + } + } + impl From> for RemoteStr { + fn from(v: LocalStr<'_>) -> Self { + RemoteStr(String::from(v.0)) + } + } + impl From<&Meters> for LocalMeters { + fn from(v: &Meters) -> Self { + LocalMeters(v.0) + } + } + impl From for Meters { + fn from(v: LocalMeters) -> Self { + Meters(v.0) + } + } + + #[derive(Encode, Decode, Debug, PartialEq)] + struct User { + // Borrowing proxy via the `with` shorthand. + #[bitcode(with = "LocalStr<'a>")] + first_name: RemoteStr, + // Same, but with the encode/decode sides spelled separately. + #[bitcode(encode_with = "LocalStr<'a>", decode_with = "LocalStr<'a>")] + last_name: RemoteStr, + // Non-borrowing proxy (no lifetime, so no transmute on encode). + #[bitcode(with = "LocalMeters")] + height: Meters, + } + + #[derive(Encode, Decode, Debug, PartialEq)] + enum Shape { + Point, + Line(#[bitcode(with = "LocalMeters")] Meters), + Named { + #[bitcode(with = "LocalStr<'a>")] + name: RemoteStr, + }, + } + + #[test] + fn test_with_struct() { + let user = User { + first_name: RemoteStr("Ada".into()), + last_name: RemoteStr("Lovelace".into()), + height: Meters(1.7), + }; + assert_eq!(decode::(&encode(&user)).unwrap(), user); + } + + #[test] + fn test_with_enum() { + for shape in [ + Shape::Point, + Shape::Line(Meters(42.0)), + Shape::Named { + name: RemoteStr("triangle".into()), + }, + ] { + assert_eq!(decode::(&encode(&shape)).unwrap(), shape); + } + } + + /// Encoding via a `with` proxy must be wire-identical to encoding the proxy type directly. + #[test] + fn test_with_wire_compatible() { + #[derive(Encode)] + struct Direct<'a> { + first_name: LocalStr<'a>, + last_name: LocalStr<'a>, + height: LocalMeters, + } + + let user = User { + first_name: RemoteStr("Ada".into()), + last_name: RemoteStr("Lovelace".into()), + height: Meters(1.7), + }; + let direct = Direct { + first_name: LocalStr("Ada"), + last_name: LocalStr("Lovelace"), + height: LocalMeters(1.7), + }; + assert_eq!(encode(&user), encode(&direct)); + } +} diff --git a/src/derive/variant.rs b/src/derive/variant.rs index 2c4b37d..246215f 100644 --- a/src/derive/variant.rs +++ b/src/derive/variant.rs @@ -1,5 +1,4 @@ use crate::coder::{Buffer, Decoder, Encoder, Result, View}; -use crate::error::err; use crate::fast::{CowSlice, NextUnchecked, PushUnchecked, VecImpl}; use crate::pack::{check_less_than, pack_bytes_less_than, unpack_bytes_less_than}; use crate::pack_ints::{pack_ints, unpack_ints, Int};