Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 84 additions & 19 deletions bitcode_derive/src/attribute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,40 @@ enum BitcodeAttr {
BoundType(Type),
CrateAlias(Path),
Skip,
/// `#[bitcode(with = "LocalType")]`, shorthand for both `encode_with` and `decode_with`.
With(Type),
/// `#[bitcode(encode_with = "LocalType")]`.
EncodeWith(Type),
/// `#[bitcode(decode_with = "LocalType")]`.
DecodeWith(Type),
}

/// Parses a `#[bitcode(name = "Type")]` string literal value into a [`Type`].
fn parse_type_attr(nested: &Meta) -> Result<Type> {
match nested {
Meta::NameValue(name_value) => {
let expr = &name_value.value;
let str_lit = match expr {
Expr::Lit(ExprLit {
lit: Lit::Str(v), ..
}) => v,
_ => return err(&expr, "expected string e.g. \"LocalType\""),
};
let value = TokenStream::from_str(&str_lit.value()).unwrap();
parse2(value).map_err(|e| error(str_lit, &format!("{e}")))
}
_ => err(&nested, "expected name value"),
}
}

impl BitcodeAttr {
fn new(nested: &Meta) -> Result<Self> {
let path = path_ident_string(nested.path(), &nested)?;
match path.as_str() {
"bound_type" => match nested {
Meta::NameValue(name_value) => {
let expr = &name_value.value;
let str_lit = match expr {
Expr::Lit(ExprLit {
lit: Lit::Str(v), ..
}) => v,
_ => return err(&expr, "expected string e.g. \"T\""),
};

let value = TokenStream::from_str(&str_lit.value()).unwrap();
Ok(Self::BoundType(
parse2(value).map_err(|e| error(str_lit, &format!("{e}")))?,
))
}
_ => err(&nested, "expected name value"),
},
"bound_type" => Ok(Self::BoundType(parse_type_attr(nested)?)),
"with" => Ok(Self::With(parse_type_attr(nested)?)),
"encode_with" => Ok(Self::EncodeWith(parse_type_attr(nested)?)),
"decode_with" => Ok(Self::DecodeWith(parse_type_attr(nested)?)),
"crate" => match nested {
Meta::NameValue(name_value) => {
let expr = &name_value.value;
Expand Down Expand Up @@ -87,6 +98,34 @@ impl BitcodeAttr {
err(nested, "can only apply skip to fields")
}
}
Self::With(ty) => {
Self::set_with(attrs, nested, ty.clone(), true)?;
Self::set_with(attrs, nested, ty, false)
}
Self::EncodeWith(ty) => Self::set_with(attrs, nested, ty, true),
Self::DecodeWith(ty) => Self::set_with(attrs, nested, ty, false),
}
}

/// Sets `encode_with` (`encode == true`) or `decode_with` (`encode == false`) on a field.
fn set_with(attrs: &mut BitcodeAttrs, nested: &Meta, ty: Type, encode: bool) -> Result<()> {
if let AttrType::Field {
encode_with,
decode_with,
..
} = &mut attrs.attr_type
{
let slot = if encode { encode_with } else { decode_with };
if slot.is_some() {
return err(nested, "duplicate with/encode_with/decode_with");
}
*slot = Some(ty);
Ok(())
} else {
err(
nested,
"can only apply with/encode_with/decode_with to fields",
)
}
}
}
Expand All @@ -104,7 +143,13 @@ pub struct BitcodeAttrs {
enum AttrType {
Derive,
Variant,
Field { bound_type: Option<Type> },
Field {
bound_type: Option<Type>,
/// Encode this field as a different (local) type via `From`/`Into`.
encode_with: Option<Type>,
/// Decode this field as a different (local) type via `From`/`Into`.
decode_with: Option<Type>,
},
}

impl BitcodeAttrs {
Expand All @@ -123,6 +168,22 @@ impl BitcodeAttrs {
}
}

/// The local type to encode this field as, if `#[bitcode(with)]`/`#[bitcode(encode_with)]` was set.
pub fn encode_with(&self) -> Option<&Type> {
match &self.attr_type {
AttrType::Field { encode_with, .. } => encode_with.as_ref(),
_ => None,
}
}

/// The local type to decode this field as, if `#[bitcode(with)]`/`#[bitcode(decode_with)]` was set.
pub fn decode_with(&self) -> Option<&Type> {
match &self.attr_type {
AttrType::Field { decode_with, .. } => decode_with.as_ref(),
_ => None,
}
}

pub fn parse_derive(attrs: &[Attribute]) -> Result<Self> {
let mut ret = Self::new(AttrType::Derive);
ret.parse_inner(attrs)?;
Expand All @@ -136,7 +197,11 @@ impl BitcodeAttrs {
}

pub fn parse_field(attrs: &[Attribute], _parent_attrs: &Self) -> Result<Self> {
let mut ret = Self::new(AttrType::Field { bound_type: None });
let mut ret = Self::new(AttrType::Field {
bound_type: None,
encode_with: None,
decode_with: None,
});
ret.parse_inner(attrs)?;
Ok(ret)
}
Expand Down
18 changes: 17 additions & 1 deletion bitcode_derive/src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ impl crate::shared::Item for Item {
) -> TokenStream {
match self {
Self::Type => {
let mut de_type = replace_lifetimes(field_type, DE_LIFETIME).to_token_stream();
// `#[bitcode(decode_with = "Local")]` stores the decoder for `Local`, not the field.
let base_type = field_attrs.decode_with().unwrap_or(field_type);
let mut de_type = replace_lifetimes(base_type, DE_LIFETIME).to_token_stream();
if field_attrs.skip {
de_type = quote! { ::core::marker::PhantomData<#de_type> };
}
Expand All @@ -63,10 +65,15 @@ impl crate::shared::Item for Item {
},
// Only used by enum variants.
Self::Decode => {
let de_type = replace_lifetimes(field_type, DE_LIFETIME);
let value = if field_attrs.skip {
quote! {
Default::default()
}
} else if field_attrs.decode_with().is_some() {
quote! {
::core::convert::Into::<#de_type>::into(self.#global_field_name.decode())
}
} else {
quote! {
self.#global_field_name.decode()
Expand All @@ -86,6 +93,11 @@ impl crate::shared::Item for Item {
quote! {{
(#target).write(Default::default());
}}
} else if field_attrs.decode_with().is_some() {
quote! {{
let __local = self.#global_field_name.decode();
(#target).write(::core::convert::Into::into(__local));
}}
} else {
quote! {
self.#global_field_name.decode_in_place(#target);
Expand Down Expand Up @@ -263,6 +275,10 @@ impl crate::shared::Derive<{ Item::COUNT }> for Decode {
Some(parse_quote!(Default))
}

fn with_type(&self, field_attrs: &BitcodeAttrs) -> Option<Type> {
field_attrs.decode_with().cloned()
}

fn derive_impl(
&self,
crate_name: &Path,
Expand Down
45 changes: 44 additions & 1 deletion bitcode_derive/src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ impl crate::shared::Item for Item {
) -> TokenStream {
match self {
Self::Type => {
let mut static_type = replace_lifetimes(field_type, "static").to_token_stream();
// `#[bitcode(encode_with = "Local")]` stores the encoder for `Local`, not the field.
let base_type = field_attrs.encode_with().unwrap_or(field_type);
let mut static_type = replace_lifetimes(base_type, "static").to_token_stream();
if field_attrs.skip {
static_type = quote! { ::core::marker::PhantomData<#static_type> };
}
Expand All @@ -50,6 +52,43 @@ impl crate::shared::Item for Item {
#global_field_name: Default::default(),
},
Self::Encode | Self::EncodeVectored => {
// `#[bitcode(encode_with = "Local")]`: convert `&Field` to `Local`, then encode `Local`.
// The conversion produces an owned `Local`, so we can't yield `&Local` into the
// vectored fast-path; instead the vectored case stamps out a loop calling `encode`.
if let Some(local) = field_attrs.encode_with().filter(|_| !field_attrs.skip) {
let local_static = replace_lifetimes(local, "static");
let needs_transmute = &local_static != local;
let encode_one = |field_ref: TokenStream| {
if needs_transmute {
// HACK: Since encoders don't have lifetimes we can't reference <T<'a> as Encode>::Encoder since 'a
// does not exist. Instead we replace this with <T<'static> as Encode>::Encoder and transmute it to
// T<'a>. No encoder actually encodes T<'static> any differently from T<'a> so this is sound.
let local_underscore = replace_lifetimes(local, "_");
quote! {{
let __local: #local_underscore = ::core::convert::From::from(#field_ref);
self.#global_field_name.encode(unsafe {
::core::mem::transmute::<&#local_underscore, &#local_static>(&__local)
});
}}
} else {
quote! {{
let __local: #local_static = ::core::convert::From::from(#field_ref);
self.#global_field_name.encode(&__local);
}}
}
};
return if matches!(self, Self::EncodeVectored) {
let encode_one = encode_one(quote! { &__me.#real_field_name });
quote! {
for __me in i.clone() {
#encode_one
}
}
} else {
encode_one(quote! { #field_name })
};
}

let static_type = replace_lifetimes(field_type, "static");
let value = if field_attrs.skip {
quote! {
Expand Down Expand Up @@ -248,6 +287,10 @@ impl crate::shared::Derive<{ Item::COUNT }> for Encode {
None
}

fn with_type(&self, field_attrs: &BitcodeAttrs) -> Option<Type> {
field_attrs.encode_with().cloned()
}

fn derive_impl(
&self,
crate_name: &Path,
Expand Down
11 changes: 10 additions & 1 deletion bitcode_derive/src/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,10 @@ pub trait Derive<const ITEM_COUNT: usize> {
/// Bound for skipped fields, e.g. `Default`
fn skip_bound(&self) -> Option<Path>;

/// For `#[bitcode(with)]` fields, the local type whose generic parameters should be bounded
/// instead of the field's own type (which need not implement `Encode`/`Decode`).
fn with_type(&self, field_attrs: &BitcodeAttrs) -> Option<Type>;

/// Generates the derive implementation.
fn derive_impl(
&self,
Expand Down Expand Up @@ -157,7 +161,12 @@ pub trait Derive<const ITEM_COUNT: usize> {
Some(self.bound(crate_name))
};
if let Some(bound) = bound {
bounds.add_bound_type(field.clone(), &field_attrs, bound);
let mut field = field.clone();
if let Some(with_type) = self.with_type(&field_attrs) {
// Bound the local type instead of the field's own (remote) type.
field.ty = with_type;
}
bounds.add_bound_type(field, &field_attrs, bound);
}
Ok(field_attrs)
})
Expand Down
110 changes: 110 additions & 0 deletions src/derive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,113 @@ mod tests {
}
}
}

#[cfg(test)]
mod with_tests {
use crate::{decode, encode, Decode, Encode};
use alloc::string::String;

// "Remote" types that intentionally don't implement Encode/Decode themselves.
#[derive(Debug, PartialEq, Clone)]
struct RemoteStr(String);

#[derive(Debug, PartialEq, Clone, Copy)]
struct Meters(f32);

// Local proxy types used via `#[bitcode(with = ...)]`.
#[derive(Encode, Decode)]
struct LocalStr<'a>(&'a str);

#[derive(Encode, Decode)]
struct LocalMeters(f32);

impl<'a> From<&'a RemoteStr> for LocalStr<'a> {
fn from(v: &'a RemoteStr) -> Self {
LocalStr(v.0.as_str())
}
}
impl From<LocalStr<'_>> for RemoteStr {
fn from(v: LocalStr<'_>) -> Self {
RemoteStr(String::from(v.0))
}
}
impl From<&Meters> for LocalMeters {
fn from(v: &Meters) -> Self {
LocalMeters(v.0)
}
}
impl From<LocalMeters> for Meters {
fn from(v: LocalMeters) -> Self {
Meters(v.0)
}
}

#[derive(Encode, Decode, Debug, PartialEq)]
struct User {
// Borrowing proxy via the `with` shorthand.
#[bitcode(with = "LocalStr<'a>")]
first_name: RemoteStr,
// Same, but with the encode/decode sides spelled separately.
#[bitcode(encode_with = "LocalStr<'a>", decode_with = "LocalStr<'a>")]
last_name: RemoteStr,
// Non-borrowing proxy (no lifetime, so no transmute on encode).
#[bitcode(with = "LocalMeters")]
height: Meters,
}

#[derive(Encode, Decode, Debug, PartialEq)]
enum Shape {
Point,
Line(#[bitcode(with = "LocalMeters")] Meters),
Named {
#[bitcode(with = "LocalStr<'a>")]
name: RemoteStr,
},
}

#[test]
fn test_with_struct() {
let user = User {
first_name: RemoteStr("Ada".into()),
last_name: RemoteStr("Lovelace".into()),
height: Meters(1.7),
};
assert_eq!(decode::<User>(&encode(&user)).unwrap(), user);
}

#[test]
fn test_with_enum() {
for shape in [
Shape::Point,
Shape::Line(Meters(42.0)),
Shape::Named {
name: RemoteStr("triangle".into()),
},
] {
assert_eq!(decode::<Shape>(&encode(&shape)).unwrap(), shape);
}
}

/// Encoding via a `with` proxy must be wire-identical to encoding the proxy type directly.
#[test]
fn test_with_wire_compatible() {
#[derive(Encode)]
struct Direct<'a> {
first_name: LocalStr<'a>,
last_name: LocalStr<'a>,
height: LocalMeters,
}

let user = User {
first_name: RemoteStr("Ada".into()),
last_name: RemoteStr("Lovelace".into()),
height: Meters(1.7),
};
let direct = Direct {
first_name: LocalStr("Ada"),
last_name: LocalStr("Lovelace"),
height: LocalMeters(1.7),
};
assert_eq!(encode(&user), encode(&direct));
}
}
Loading
Loading