Skip to content

Commit

Permalink
Constrain modules generated by macros
Browse files Browse the repository at this point in the history
  • Loading branch information
sug0 committed Dec 5, 2024
1 parent dad49e9 commit 5ea554f
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 8 deletions.
71 changes: 67 additions & 4 deletions crates/module-macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
extern crate proc_macro;

use quote::quote;
use quote::{quote, ToTokens};

#[proc_macro_derive(ModuleFromMiddleware)]
pub fn derive_module_from_middleware(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
Expand All @@ -12,11 +12,42 @@ fn derive_module_from_middleware_inner(
) -> proc_macro2::TokenStream {
let _struct = syn::parse2::<syn::ItemStruct>(input).expect("Expected struct definition");

let struct_name = &_struct.ident;
//let struct_generics = &_struct.generics;
let struct_generics = &_struct.generics;
let struct_generics_params = &struct_generics.params;

let (name, types) = fetch_name_with_generic_params(&_struct);
let where_clauses = {
let mut clauses =
struct_generics
.where_clause
.clone()
.unwrap_or_else(|| syn::WhereClause {
where_token: Default::default(),
predicates: syn::punctuated::Punctuated::new(),
});

for ty in types {
clauses
.predicates
.push(syn::WherePredicate::Type(syn::PredicateType {
lifetimes: None,
bounded_ty: syn::Type::Verbatim(ty),
colon_token: Default::default(),
bounds: {
let mut b = syn::punctuated::Punctuated::new();
b.push(syn::TypeParamBound::Verbatim(quote!(Module)));
b
},
}));
}

quote!(#clauses)
};

quote! {
impl Module for #struct_name {
impl<#struct_generics_params> Module for #name
#where_clauses
{
#[inline(always)]
fn on_chan_open_init_validate(
&self,
Expand Down Expand Up @@ -220,3 +251,35 @@ fn derive_module_from_middleware_inner(
}
}
}

fn fetch_name_with_generic_params(
_struct: &syn::ItemStruct,
) -> (proc_macro2::TokenStream, Vec<proc_macro2::TokenStream>) {
let mut types = vec![];
let mut consts = vec![];
let mut lifetimes = vec![];

for param in _struct.generics.params.iter() {
match param {
syn::GenericParam::Type(type_) => types.push(type_.ident.to_token_stream()),
syn::GenericParam::Lifetime(life_def) => {
lifetimes.push(life_def.lifetime.to_token_stream())
}
syn::GenericParam::Const(constant) => consts.push(constant.ident.to_token_stream()),
}
}

let ident = &_struct.ident;
let (all_params, types) = {
let (mut output, mut consts, types) = (lifetimes, consts, types);
output.append(&mut consts);
output.extend(types.iter().cloned());
(output, types)
};

let name = quote! {
#ident < #(#all_params),* >
};

(name, types)
}
11 changes: 7 additions & 4 deletions crates/module/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,13 @@ mod tests {
use super::*;

#[derive(Debug, ModuleFromMiddleware)]
struct DummyMiddleware(DummyTransferModule);
struct DummyMiddleware<M>(M);

impl MiddlewareModule for DummyMiddleware {
type NextMiddleware = DummyTransferModule;
impl<M> MiddlewareModule for DummyMiddleware<M>
where
M: Module,
{
type NextMiddleware = M;

fn next_middleware(&self) -> &Self::NextMiddleware {
&self.0
Expand All @@ -271,7 +274,7 @@ mod tests {

#[test]
fn dummy_middleware_is_module() {
assert_module_impl::<DummyMiddleware>();
assert_module_impl::<DummyMiddleware<DummyTransferModule>>();
}

#[test]
Expand Down

0 comments on commit 5ea554f

Please sign in to comment.