webtau_macros/
lib.rs

1//! Proc macros for webtau dual-target command generation.
2//!
3//! # v2 `#[command]` Contract
4//!
5//! ```rust,ignore
6//! #[webtau::command]
7//! fn name(state: &T | &mut T [, arg: Type]*) [-> ReturnType] { body }
8//! ```
9//!
10//! **Supported grammar:**
11//! - First parameter **must** be a reference: `name: &T` (read-only) or `name: &mut T` (mutable).
12//!   The identifier can be any name (e.g., `state`, `world`, `game`).
13//! - Additional parameters are named, typed values forwarded as the command's args.
14//! - Return type may be:
15//!   - `T` where `T: Serialize` — value returned directly.
16//!   - `Result<T, E>` where `T: Serialize, E: Display + Serialize` — errors surface to JS.
17//!   - Omitted (unit `()`) — command returns nothing.
18//! - The function name becomes the command name for `invoke()`.
19//!
20//! **Generated code:**
21//! - Inner function `__webtau_<name>` containing the original body.
22//! - `#[cfg(not(wasm32))]` — `#[tauri::command]` wrapper with `State<Mutex<T>>`.
23//! - `#[cfg(wasm32)]` — `#[wasm_bindgen]` wrapper with args-object deserialize.
24//!
25//! **Unsupported forms** (compile-time error):
26//! - Methods with `self`.
27//! - Missing or non-reference state parameter.
28//! - Tuple or struct patterns in parameters.
29//! - Async functions.
30
31use proc_macro::TokenStream;
32use proc_macro2::TokenStream as TokenStream2;
33use quote::{format_ident, quote};
34use syn::{
35    parse_macro_input, spanned::Spanned, FnArg, GenericArgument, ItemFn, Pat, PatIdent,
36    PathArguments, ReturnType, Type, TypeReference,
37};
38
39// ── Public entry point ────────────────────────────────────────────────
40
41#[proc_macro_attribute]
42pub fn command(_attr: TokenStream, item: TokenStream) -> TokenStream {
43    let input = parse_macro_input!(item as ItemFn);
44    match expand_command(input) {
45        Ok(tokens) => tokens.into(),
46        Err(err) => err.to_compile_error().into(),
47    }
48}
49
50// ── Parsed representation ─────────────────────────────────────────────
51
52struct CommandDef {
53    name: syn::Ident,
54    state_ident: syn::Ident,
55    state_ty: Box<Type>,
56    state_mut: bool,
57    extra_params: Vec<(Option<syn::token::Mut>, syn::Ident, Box<Type>)>,
58    ret: ReturnShape,
59    body: syn::Block,
60}
61
62enum ReturnShape {
63    Unit,
64    Plain(Box<Type>),
65    Result { ok: Box<Type>, err: Box<Type> },
66}
67
68// ── Parsing + diagnostics (Step 1) ────────────────────────────────────
69
70fn expand_command(func: ItemFn) -> syn::Result<TokenStream2> {
71    // Reject async
72    if let Some(tok) = &func.sig.asyncness {
73        return Err(syn::Error::new(
74            tok.span(),
75            "#[command] does not support async functions",
76        ));
77    }
78
79    // Reject methods with self
80    for arg in &func.sig.inputs {
81        if let FnArg::Receiver(recv) = arg {
82            return Err(syn::Error::new(
83                recv.span(),
84                "#[command] does not support methods with `self`; \
85                 use a free function with `state: &T` or `state: &mut T`",
86            ));
87        }
88    }
89
90    // Must have at least one parameter (the state)
91    if func.sig.inputs.is_empty() {
92        return Err(syn::Error::new(
93            func.sig.ident.span(),
94            "#[command] requires at least one parameter: `state: &T` or `state: &mut T`",
95        ));
96    }
97
98    // ── Parse state parameter (first) ──
99    let first = match func.sig.inputs.first().unwrap() {
100        FnArg::Typed(pt) => pt,
101        _ => unreachable!("already rejected Receiver"),
102    };
103
104    // Must be a simple ident pattern
105    let state_ident = match &*first.pat {
106        Pat::Ident(PatIdent { ident, .. }) => ident.clone(),
107        other => {
108            return Err(syn::Error::new(
109                other.span(),
110                "#[command] state parameter must be a simple identifier \
111                 (e.g., `state: &T`)",
112            ));
113        }
114    };
115
116    // Must be &T or &mut T
117    let (state_ty, state_mut) = match &*first.ty {
118        Type::Reference(TypeReference {
119            elem, mutability, ..
120        }) => (elem.clone(), mutability.is_some()),
121        other => {
122            return Err(syn::Error::new(
123                other.span(),
124                "#[command] first parameter must be a reference: \
125                 `&T` or `&mut T`",
126            ));
127        }
128    };
129
130    // ── Parse extra parameters ──
131    let mut extra_params = Vec::new();
132    for arg in func.sig.inputs.iter().skip(1) {
133        let typed = match arg {
134            FnArg::Typed(pt) => pt,
135            _ => unreachable!(),
136        };
137        let (mutability, ident) = match &*typed.pat {
138            Pat::Ident(PatIdent { mutability, ident, .. }) => (*mutability, ident.clone()),
139            other => {
140                return Err(syn::Error::new(
141                    other.span(),
142                    "#[command] parameters must use simple identifiers \
143                     (no tuple or struct patterns)",
144                ));
145            }
146        };
147        if ident.to_string().starts_with("__webtau") {
148            return Err(syn::Error::new(
149                ident.span(),
150                "#[command] parameter names starting with `__webtau` are reserved \
151                 for generated code",
152            ));
153        }
154        extra_params.push((mutability, ident, typed.ty.clone()));
155    }
156
157    // ── Parse return type ──
158    let ret = match &func.sig.output {
159        ReturnType::Default => ReturnShape::Unit,
160        ReturnType::Type(_, ty) => parse_return_type(ty),
161    };
162
163    let def = CommandDef {
164        name: func.sig.ident.clone(),
165        state_ident,
166        state_ty,
167        state_mut,
168        extra_params,
169        ret,
170        body: (*func.block).clone(),
171    };
172
173    Ok(generate_all(&def))
174}
175
176fn parse_return_type(ty: &Type) -> ReturnShape {
177    if let Type::Path(tp) = ty {
178        if let Some(seg) = tp.path.segments.last() {
179            if seg.ident == "Result" {
180                if let PathArguments::AngleBracketed(ab) = &seg.arguments {
181                    let mut types = ab.args.iter().filter_map(|a| {
182                        if let GenericArgument::Type(t) = a {
183                            Some(Box::new(t.clone()))
184                        } else {
185                            None
186                        }
187                    });
188                    if let (Some(ok), Some(err)) = (types.next(), types.next()) {
189                        return ReturnShape::Result { ok, err };
190                    }
191                }
192            }
193        }
194    }
195    ReturnShape::Plain(Box::new(ty.clone()))
196}
197
198// ── Code generation ───────────────────────────────────────────────────
199
200fn generate_all(def: &CommandDef) -> TokenStream2 {
201    let inner = generate_inner(def);
202    let native = generate_native(def);
203    let wasm = generate_wasm(def);
204
205    quote! {
206        #inner
207        #native
208        #wasm
209    }
210}
211
212/// Emit the inner function containing the user's original body.
213fn generate_inner(def: &CommandDef) -> TokenStream2 {
214    let inner_name = format_ident!("__webtau_{}", def.name);
215    let body = &def.body;
216    let state_ty = &def.state_ty;
217    let state_ident = &def.state_ident;
218
219    let state_param = if def.state_mut {
220        quote! { #state_ident: &mut #state_ty }
221    } else {
222        quote! { #state_ident: &#state_ty }
223    };
224
225    let extra: Vec<_> = def
226        .extra_params
227        .iter()
228        .map(|(mutability, id, ty)| quote! { #mutability #id: #ty })
229        .collect();
230
231    let ret = ret_tokens(&def.ret);
232
233    quote! {
234        #[doc(hidden)]
235        #[inline(always)]
236        fn #inner_name(#state_param, #(#extra),*) #ret #body
237    }
238}
239
240/// Emit the `#[tauri::command]` wrapper (Step 2 — native codegen).
241fn generate_native(def: &CommandDef) -> TokenStream2 {
242    let pub_name = &def.name;
243    let inner_name = format_ident!("__webtau_{}", def.name);
244    let state_ty = &def.state_ty;
245
246    let extra_defs: Vec<_> = def
247        .extra_params
248        .iter()
249        .map(|(_, id, ty)| quote! { #id: #ty })
250        .collect();
251    let extra_names: Vec<_> = def
252        .extra_params
253        .iter()
254        .map(|(_, id, _)| quote! { #id })
255        .collect();
256
257    // Use `__webtau_` prefix to avoid collisions with user arg names
258    let (lock, state_ref) = if def.state_mut {
259        (
260            quote! { let mut __webtau_guard = __webtau_tauri_state.lock().unwrap(); },
261            quote! { &mut __webtau_guard },
262        )
263    } else {
264        (
265            quote! { let __webtau_guard = __webtau_tauri_state.lock().unwrap(); },
266            quote! { &__webtau_guard },
267        )
268    };
269
270    let ret = ret_tokens(&def.ret);
271
272    quote! {
273        #[cfg(not(target_arch = "wasm32"))]
274        #[::tauri::command]
275        pub fn #pub_name(
276            #(#extra_defs,)*
277            __webtau_tauri_state: ::tauri::State<'_, ::std::sync::Mutex<#state_ty>>
278        ) #ret {
279            #lock
280            #inner_name(#state_ref, #(#extra_names),*)
281        }
282    }
283}
284
285/// Emit the `#[wasm_bindgen]` wrapper (WASM codegen).
286fn generate_wasm(def: &CommandDef) -> TokenStream2 {
287    let pub_name = &def.name;
288    let inner_name = format_ident!("__webtau_{}", def.name);
289    let has_extra = !def.extra_params.is_empty();
290
291    let state_accessor = if def.state_mut {
292        quote! { with_state_mut }
293    } else {
294        quote! { with_state }
295    };
296
297    // ── Args handling ──
298    let (wasm_param, args_preamble, call_args) = if has_extra {
299        let struct_name = format_ident!("__Webtau{}Args", to_pascal_case(&def.name.to_string()));
300
301        let field_defs: Vec<_> = def
302            .extra_params
303            .iter()
304            .map(|(_, id, ty)| quote! { #id: #ty })
305            .collect();
306        let field_refs: Vec<_> = def
307            .extra_params
308            .iter()
309            .map(|(_, id, _)| quote! { __args.#id })
310            .collect();
311
312        (
313            quote! { args: ::wasm_bindgen::JsValue },
314            quote! {
315                #[derive(::serde::Deserialize)]
316                struct #struct_name { #(#field_defs,)* }
317                let __args: #struct_name =
318                    ::serde_wasm_bindgen::from_value(args).unwrap();
319            },
320            field_refs,
321        )
322    } else {
323        (quote! {}, quote! {}, vec![])
324    };
325
326    // ── Return handling ──
327    let (wasm_ret, body_expr) = match &def.ret {
328        ReturnShape::Unit => (
329            quote! {},
330            quote! {
331                #state_accessor(|state| {
332                    #inner_name(state, #(#call_args),*);
333                })
334            },
335        ),
336        ReturnShape::Plain(_) => (
337            quote! { -> ::wasm_bindgen::JsValue },
338            quote! {
339                #state_accessor(|state| {
340                    let __result = #inner_name(state, #(#call_args),*);
341                    ::serde_wasm_bindgen::to_value(&__result).unwrap()
342                })
343            },
344        ),
345        ReturnShape::Result { .. } => (
346            quote! { -> ::std::result::Result<::wasm_bindgen::JsValue, ::wasm_bindgen::JsError> },
347            quote! {
348                #state_accessor(|state| {
349                    match #inner_name(state, #(#call_args),*) {
350                        Ok(__val) => Ok(::serde_wasm_bindgen::to_value(&__val).unwrap()),
351                        Err(__err) => Err(::wasm_bindgen::JsError::new(&__err.to_string())),
352                    }
353                })
354            },
355        ),
356    };
357
358    quote! {
359        #[cfg(target_arch = "wasm32")]
360        #[::wasm_bindgen::prelude::wasm_bindgen]
361        pub fn #pub_name(#wasm_param) #wasm_ret {
362            #args_preamble
363            #body_expr
364        }
365    }
366}
367
368// ── Helpers ───────────────────────────────────────────────────────────
369
370fn ret_tokens(shape: &ReturnShape) -> TokenStream2 {
371    match shape {
372        ReturnShape::Unit => quote! {},
373        ReturnShape::Plain(ty) => quote! { -> #ty },
374        ReturnShape::Result { ok, err } => quote! { -> ::std::result::Result<#ok, #err> },
375    }
376}
377
378fn to_pascal_case(s: &str) -> String {
379    s.split('_')
380        .map(|word| {
381            let mut c = word.chars();
382            match c.next() {
383                None => String::new(),
384                Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
385            }
386        })
387        .collect()
388}