1use proc_macro::TokenStream;
32use proc_macro2::TokenStream as TokenStream2;
33use quote::{format_ident, quote};
34use syn::{
35 parse_macro_input, spanned::Spanned, FnArg, GenericArgument, ItemFn, Pat, PatIdent,
36 PathArguments, ReturnType, Type, TypeReference,
37};
38
39#[proc_macro_attribute]
42pub fn command(_attr: TokenStream, item: TokenStream) -> TokenStream {
43 let input = parse_macro_input!(item as ItemFn);
44 match expand_command(input) {
45 Ok(tokens) => tokens.into(),
46 Err(err) => err.to_compile_error().into(),
47 }
48}
49
50struct CommandDef {
53 name: syn::Ident,
54 state_ident: syn::Ident,
55 state_ty: Box<Type>,
56 state_mut: bool,
57 extra_params: Vec<(Option<syn::token::Mut>, syn::Ident, Box<Type>)>,
58 ret: ReturnShape,
59 body: syn::Block,
60}
61
62enum ReturnShape {
63 Unit,
64 Plain(Box<Type>),
65 Result { ok: Box<Type>, err: Box<Type> },
66}
67
68fn expand_command(func: ItemFn) -> syn::Result<TokenStream2> {
71 if let Some(tok) = &func.sig.asyncness {
73 return Err(syn::Error::new(
74 tok.span(),
75 "#[command] does not support async functions",
76 ));
77 }
78
79 for arg in &func.sig.inputs {
81 if let FnArg::Receiver(recv) = arg {
82 return Err(syn::Error::new(
83 recv.span(),
84 "#[command] does not support methods with `self`; \
85 use a free function with `state: &T` or `state: &mut T`",
86 ));
87 }
88 }
89
90 if func.sig.inputs.is_empty() {
92 return Err(syn::Error::new(
93 func.sig.ident.span(),
94 "#[command] requires at least one parameter: `state: &T` or `state: &mut T`",
95 ));
96 }
97
98 let first = match func.sig.inputs.first().unwrap() {
100 FnArg::Typed(pt) => pt,
101 _ => unreachable!("already rejected Receiver"),
102 };
103
104 let state_ident = match &*first.pat {
106 Pat::Ident(PatIdent { ident, .. }) => ident.clone(),
107 other => {
108 return Err(syn::Error::new(
109 other.span(),
110 "#[command] state parameter must be a simple identifier \
111 (e.g., `state: &T`)",
112 ));
113 }
114 };
115
116 let (state_ty, state_mut) = match &*first.ty {
118 Type::Reference(TypeReference {
119 elem, mutability, ..
120 }) => (elem.clone(), mutability.is_some()),
121 other => {
122 return Err(syn::Error::new(
123 other.span(),
124 "#[command] first parameter must be a reference: \
125 `&T` or `&mut T`",
126 ));
127 }
128 };
129
130 let mut extra_params = Vec::new();
132 for arg in func.sig.inputs.iter().skip(1) {
133 let typed = match arg {
134 FnArg::Typed(pt) => pt,
135 _ => unreachable!(),
136 };
137 let (mutability, ident) = match &*typed.pat {
138 Pat::Ident(PatIdent { mutability, ident, .. }) => (*mutability, ident.clone()),
139 other => {
140 return Err(syn::Error::new(
141 other.span(),
142 "#[command] parameters must use simple identifiers \
143 (no tuple or struct patterns)",
144 ));
145 }
146 };
147 if ident.to_string().starts_with("__webtau") {
148 return Err(syn::Error::new(
149 ident.span(),
150 "#[command] parameter names starting with `__webtau` are reserved \
151 for generated code",
152 ));
153 }
154 extra_params.push((mutability, ident, typed.ty.clone()));
155 }
156
157 let ret = match &func.sig.output {
159 ReturnType::Default => ReturnShape::Unit,
160 ReturnType::Type(_, ty) => parse_return_type(ty),
161 };
162
163 let def = CommandDef {
164 name: func.sig.ident.clone(),
165 state_ident,
166 state_ty,
167 state_mut,
168 extra_params,
169 ret,
170 body: (*func.block).clone(),
171 };
172
173 Ok(generate_all(&def))
174}
175
176fn parse_return_type(ty: &Type) -> ReturnShape {
177 if let Type::Path(tp) = ty {
178 if let Some(seg) = tp.path.segments.last() {
179 if seg.ident == "Result" {
180 if let PathArguments::AngleBracketed(ab) = &seg.arguments {
181 let mut types = ab.args.iter().filter_map(|a| {
182 if let GenericArgument::Type(t) = a {
183 Some(Box::new(t.clone()))
184 } else {
185 None
186 }
187 });
188 if let (Some(ok), Some(err)) = (types.next(), types.next()) {
189 return ReturnShape::Result { ok, err };
190 }
191 }
192 }
193 }
194 }
195 ReturnShape::Plain(Box::new(ty.clone()))
196}
197
198fn generate_all(def: &CommandDef) -> TokenStream2 {
201 let inner = generate_inner(def);
202 let native = generate_native(def);
203 let wasm = generate_wasm(def);
204
205 quote! {
206 #inner
207 #native
208 #wasm
209 }
210}
211
212fn generate_inner(def: &CommandDef) -> TokenStream2 {
214 let inner_name = format_ident!("__webtau_{}", def.name);
215 let body = &def.body;
216 let state_ty = &def.state_ty;
217 let state_ident = &def.state_ident;
218
219 let state_param = if def.state_mut {
220 quote! { #state_ident: &mut #state_ty }
221 } else {
222 quote! { #state_ident: &#state_ty }
223 };
224
225 let extra: Vec<_> = def
226 .extra_params
227 .iter()
228 .map(|(mutability, id, ty)| quote! { #mutability #id: #ty })
229 .collect();
230
231 let ret = ret_tokens(&def.ret);
232
233 quote! {
234 #[doc(hidden)]
235 #[inline(always)]
236 fn #inner_name(#state_param, #(#extra),*) #ret #body
237 }
238}
239
240fn generate_native(def: &CommandDef) -> TokenStream2 {
242 let pub_name = &def.name;
243 let inner_name = format_ident!("__webtau_{}", def.name);
244 let state_ty = &def.state_ty;
245
246 let extra_defs: Vec<_> = def
247 .extra_params
248 .iter()
249 .map(|(_, id, ty)| quote! { #id: #ty })
250 .collect();
251 let extra_names: Vec<_> = def
252 .extra_params
253 .iter()
254 .map(|(_, id, _)| quote! { #id })
255 .collect();
256
257 let (lock, state_ref) = if def.state_mut {
259 (
260 quote! { let mut __webtau_guard = __webtau_tauri_state.lock().unwrap(); },
261 quote! { &mut __webtau_guard },
262 )
263 } else {
264 (
265 quote! { let __webtau_guard = __webtau_tauri_state.lock().unwrap(); },
266 quote! { &__webtau_guard },
267 )
268 };
269
270 let ret = ret_tokens(&def.ret);
271
272 quote! {
273 #[cfg(not(target_arch = "wasm32"))]
274 #[::tauri::command]
275 pub fn #pub_name(
276 #(#extra_defs,)*
277 __webtau_tauri_state: ::tauri::State<'_, ::std::sync::Mutex<#state_ty>>
278 ) #ret {
279 #lock
280 #inner_name(#state_ref, #(#extra_names),*)
281 }
282 }
283}
284
285fn generate_wasm(def: &CommandDef) -> TokenStream2 {
287 let pub_name = &def.name;
288 let inner_name = format_ident!("__webtau_{}", def.name);
289 let has_extra = !def.extra_params.is_empty();
290
291 let state_accessor = if def.state_mut {
292 quote! { with_state_mut }
293 } else {
294 quote! { with_state }
295 };
296
297 let (wasm_param, args_preamble, call_args) = if has_extra {
299 let struct_name = format_ident!("__Webtau{}Args", to_pascal_case(&def.name.to_string()));
300
301 let field_defs: Vec<_> = def
302 .extra_params
303 .iter()
304 .map(|(_, id, ty)| quote! { #id: #ty })
305 .collect();
306 let field_refs: Vec<_> = def
307 .extra_params
308 .iter()
309 .map(|(_, id, _)| quote! { __args.#id })
310 .collect();
311
312 (
313 quote! { args: ::wasm_bindgen::JsValue },
314 quote! {
315 #[derive(::serde::Deserialize)]
316 struct #struct_name { #(#field_defs,)* }
317 let __args: #struct_name =
318 ::serde_wasm_bindgen::from_value(args).unwrap();
319 },
320 field_refs,
321 )
322 } else {
323 (quote! {}, quote! {}, vec![])
324 };
325
326 let (wasm_ret, body_expr) = match &def.ret {
328 ReturnShape::Unit => (
329 quote! {},
330 quote! {
331 #state_accessor(|state| {
332 #inner_name(state, #(#call_args),*);
333 })
334 },
335 ),
336 ReturnShape::Plain(_) => (
337 quote! { -> ::wasm_bindgen::JsValue },
338 quote! {
339 #state_accessor(|state| {
340 let __result = #inner_name(state, #(#call_args),*);
341 ::serde_wasm_bindgen::to_value(&__result).unwrap()
342 })
343 },
344 ),
345 ReturnShape::Result { .. } => (
346 quote! { -> ::std::result::Result<::wasm_bindgen::JsValue, ::wasm_bindgen::JsError> },
347 quote! {
348 #state_accessor(|state| {
349 match #inner_name(state, #(#call_args),*) {
350 Ok(__val) => Ok(::serde_wasm_bindgen::to_value(&__val).unwrap()),
351 Err(__err) => Err(::wasm_bindgen::JsError::new(&__err.to_string())),
352 }
353 })
354 },
355 ),
356 };
357
358 quote! {
359 #[cfg(target_arch = "wasm32")]
360 #[::wasm_bindgen::prelude::wasm_bindgen]
361 pub fn #pub_name(#wasm_param) #wasm_ret {
362 #args_preamble
363 #body_expr
364 }
365 }
366}
367
368fn ret_tokens(shape: &ReturnShape) -> TokenStream2 {
371 match shape {
372 ReturnShape::Unit => quote! {},
373 ReturnShape::Plain(ty) => quote! { -> #ty },
374 ReturnShape::Result { ok, err } => quote! { -> ::std::result::Result<#ok, #err> },
375 }
376}
377
378fn to_pascal_case(s: &str) -> String {
379 s.split('_')
380 .map(|word| {
381 let mut c = word.chars();
382 match c.next() {
383 None => String::new(),
384 Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
385 }
386 })
387 .collect()
388}