add refutable pattern function macro
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
parent
68f42baf73
commit
96f6a75bc8
2 changed files with 65 additions and 0 deletions
|
@ -1,6 +1,7 @@
|
|||
mod admin;
|
||||
mod cargo;
|
||||
mod debug;
|
||||
mod refutable;
|
||||
mod rustc;
|
||||
mod utils;
|
||||
|
||||
|
@ -19,3 +20,6 @@ pub fn recursion_depth(args: TokenStream, input: TokenStream) -> TokenStream { d
|
|||
|
||||
#[proc_macro]
|
||||
pub fn rustc_flags_capture(args: TokenStream) -> TokenStream { rustc::flags_capture(args) }
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn refutable(args: TokenStream, input: TokenStream) -> TokenStream { refutable::refutable(args, input) }
|
||||
|
|
61
src/macros/refutable.rs
Normal file
61
src/macros/refutable.rs
Normal file
|
@ -0,0 +1,61 @@
|
|||
use proc_macro::{Span, TokenStream};
|
||||
use quote::{quote, ToTokens};
|
||||
use syn::{parse_macro_input, AttributeArgs, FnArg::Typed, Ident, ItemFn, Pat, PatIdent, PatType, Stmt};
|
||||
|
||||
pub(super) fn refutable(args: TokenStream, input: TokenStream) -> TokenStream {
|
||||
let _args = parse_macro_input!(args as AttributeArgs);
|
||||
let mut item = parse_macro_input!(input as ItemFn);
|
||||
|
||||
let inputs = item.sig.inputs.clone();
|
||||
let stmt = &mut item.block.stmts;
|
||||
let sig = &mut item.sig;
|
||||
for (i, input) in inputs.iter().enumerate() {
|
||||
let Typed(PatType {
|
||||
pat,
|
||||
..
|
||||
}) = input
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let Pat::Struct(ref pat) = **pat else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let variant = &pat.path;
|
||||
let fields = &pat.fields;
|
||||
|
||||
// new versions of syn can replace this kronecker kludge with get_mut()
|
||||
for (j, input) in sig.inputs.iter_mut().enumerate() {
|
||||
if i != j {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Typed(PatType {
|
||||
ref mut pat,
|
||||
..
|
||||
}) = input
|
||||
else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let name = format!("_args_{i}");
|
||||
*pat = Box::new(Pat::Ident(PatIdent {
|
||||
ident: Ident::new(&name, Span::call_site().into()),
|
||||
attrs: Vec::new(),
|
||||
by_ref: None,
|
||||
mutability: None,
|
||||
subpat: None,
|
||||
}));
|
||||
|
||||
let field = fields.iter();
|
||||
let refute = quote! {
|
||||
let #variant { #( #field ),*, .. } = #name else { panic!("incorrect variant passed to function argument {i}"); };
|
||||
};
|
||||
|
||||
stmt.insert(0, syn::parse2::<Stmt>(refute).expect("syntax error"));
|
||||
}
|
||||
}
|
||||
|
||||
item.into_token_stream().into()
|
||||
}
|
Loading…
Add table
Reference in a new issue