Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 72 additions & 19 deletions derive/src/expand/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,24 @@ use std::collections::HashMap;

use proc_macro2::{Span, TokenStream};
use quote::{quote, TokenStreamExt};
use syn::{Ident, Index};

use syn::{Ident, Index, Type};
use syn::spanned::Spanned;
use crate::ast::{Enum, Fields, Struct, Variant, Visit};

use super::common::CalcDiscriminant;

/// Visitor which creates structs for fields in a an enum variant.
pub struct EnumStructsVisitor<'a> {
pub revision: usize,
pub types: Vec<Ident>,
pub stream: &'a mut TokenStream,
}

impl<'a> EnumStructsVisitor<'a> {
pub fn new(revision: usize, stream: &'a mut TokenStream) -> Self {
pub fn new(revision: usize, types: Vec<Ident>, stream: &'a mut TokenStream) -> Self {
Self {
revision,
types,
stream,
}
}
Expand All @@ -33,30 +35,81 @@ impl<'ast> Visit<'ast> for EnumStructsVisitor<'_> {
ref fields,
..
} => {
let fields = fields
.iter()
.filter(|x| x.attrs.options.exists_at(self.revision))
.map(|x| {
let name = &x.name;
let ty = &x.ty;
quote! {
#name: #ty
}
});
let fields = fields.iter()
.filter(|x| x.attrs.options.exists_at(self.revision));

let mut generics = vec!();
let mut body = quote::quote!();
for field in fields {
let name = &field.name;
let ty = &field.ty;

if let Type::Path(path) = ty {
let uses_generics = self.types
.iter()
.any(|ident| path.path.is_ident(ident));

if uses_generics {
generics.push(path);
}
}

body.extend(quote::quote! {
#name: #ty,
});
}

let generics = if generics.is_empty() {
quote::quote!()
} else {
quote::quote! {
< #(#generics),* >
}
};

quote! {
struct #name{ #(#fields),* }
struct #name #generics {
#body
}
}
}
Fields::Unnamed {
ref fields,
..
} => {
let fields = fields
.iter()
.filter(|x| x.attrs.options.exists_at(self.revision))
.map(|x| &x.ty);
let fields = fields.iter()
.filter(|x| x.attrs.options.exists_at(self.revision));

let mut generics = vec!();
let mut body = quote::quote!();
for field in fields {
let ty = &field.ty;

if let Type::Path(path) = ty {
let uses_generics = self.types
.iter()
.any(|ident| path.path.is_ident(ident));

if uses_generics {
generics.push(path);
}
}

body.extend(quote::quote! {
#ty,
});
}

let generics = if generics.is_empty() {
quote::quote!()
} else {
quote::quote! {
< #(#generics),* >
}
};

quote! {
struct #name( #(#fields),* );
struct #name #generics ( #body );
}
}
Fields::Unit => {
Expand Down
75 changes: 59 additions & 16 deletions derive/src/expand/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ mod validate_version;
use de::{DeserializeVisitor, EnumStructsVisitor};
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::spanned::Spanned;
use syn::{Token, WhereClause};
use syn::punctuated::Punctuated;
use reexport::Reexport;
use ser::SerializeVisitor;
use validate_version::ValidateRevision;
Expand Down Expand Up @@ -51,12 +54,50 @@ pub fn revision(attr: TokenStream, input: TokenStream) -> syn::Result<TokenStrea
.visit_item(&ast)
.unwrap();

let (name, generics) = match &ast.kind {
ast::ItemKind::Enum(x) => (&x.name, &x.generics),
ast::ItemKind::Struct(x) => (&x.name, &x.generics),
};

let mut serialise_where_clause = if let Some(where_clause) = generics.where_clause.as_ref() {
where_clause.clone()
} else {
WhereClause {
where_token: <Token![where]>::default(),
predicates: Punctuated::new(),
}
};

let mut deserialise_where_clause = if let Some(where_clause) = generics.where_clause.as_ref() {
where_clause.clone()
} else {
WhereClause {
where_token: <Token![where]>::default(),
predicates: Punctuated::new(),
}
};

let mut types = vec![];

for ty in generics.type_params() {
let span = ty.span();

serialise_where_clause.predicates.push(syn::parse_quote_spanned!{span=>
#ty: ::revision::SerializeRevisioned
});
deserialise_where_clause.predicates.push(syn::parse_quote_spanned!{span=>
#ty: ::revision::DeserializeRevisioned
});

types.push(ty.ident.clone());
}

// serialize implementation
let mut serialize = TokenStream::new();
SerializeVisitor::new(revision, &mut serialize).visit_item(&ast).unwrap();

let mut deserialize_structs = TokenStream::new();
EnumStructsVisitor::new(revision, &mut deserialize_structs).visit_item(&ast).unwrap();
EnumStructsVisitor::new(revision, types, &mut deserialize_structs).visit_item(&ast).unwrap();

// deserialize implementation
let deserialize = (1..=revision)
Expand All @@ -81,16 +122,14 @@ pub fn revision(attr: TokenStream, input: TokenStream) -> syn::Result<TokenStrea
})
.collect::<Vec<_>>();

let name = match ast.kind {
ast::ItemKind::Enum(x) => x.name,
ast::ItemKind::Struct(x) => x.name,
};
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let revision = revision as u16;
let revision_error = format!("Invalid revision `{{}}` for type `{}`", name);

let serialize_impl = if attrs.0.serialize {
quote! {
impl ::revision::SerializeRevisioned for #name {
impl #impl_generics ::revision::SerializeRevisioned for #name #ty_generics #serialise_where_clause {
fn serialize_revisioned<W: ::std::io::Write>(&self, writer: &mut W) -> ::std::result::Result<(), ::revision::Error> {
::revision::SerializeRevisioned::serialize_revisioned(&<Self as ::revision::Revisioned>::revision(),writer)?;
#serialize
Expand All @@ -103,7 +142,7 @@ pub fn revision(attr: TokenStream, input: TokenStream) -> syn::Result<TokenStrea

let deserialize_impl = if attrs.0.deserialize {
quote! {
impl ::revision::DeserializeRevisioned for #name {
impl #impl_generics ::revision::DeserializeRevisioned for #name #ty_generics #deserialise_where_clause {
fn deserialize_revisioned<R: ::std::io::Read>(reader: &mut R) -> ::std::result::Result<Self, ::revision::Error> {
let __revision = <u16 as ::revision::DeserializeRevisioned>::deserialize_revisioned(reader)?;
match __revision {
Expand All @@ -123,16 +162,20 @@ pub fn revision(attr: TokenStream, input: TokenStream) -> syn::Result<TokenStrea

Ok(quote! {
#reexport
#deserialize_structs

#serialize_impl
#deserialize_impl
const _: () = {
#deserialize_structs

#serialize_impl
#deserialize_impl

impl #impl_generics ::revision::Revisioned for #name #ty_generics #where_clause {
#[inline]
fn revision() -> u16{
#revision
}
}
};

impl ::revision::Revisioned for #name {
#[inline]
fn revision() -> u16{
#revision
}
}
})
}