extendr_macros/
wrappers.rs

1//! This is responsible for generating the C functions that act as wrappers of
2//! the exported Rust functions.
3//!
4//! extendr relies on the [`.Call`-interface](https://cran.r-project.org/doc/manuals/R-exts.html#Calling-_002eCall)
5//! In short, it is necessary the the signature of the C-function have [`SEXP`]
6//! as the type for return type, and argument types.
7//!
8//! For instance, if your function returns nothing, the return type is not
9//! allowed to be `void`, instead `SEXP` must be used, and one should return
10//! [`R_NilValue`].
11//!
12//! ## R wrappers
13//!
14//! Within R, you may call `rextendr::document()` to generate R functions,
15//! that use the `.Call`-interface, to call the wrapped Rust functions.
16//!
17//! You may also manually implement these wrappers, in order to do special
18//! type-checking, or other annotation, that could be more convenient to do
19//! on the R-side. The C-functions are named according to `"{WRAP_PREFIX}{prefix}{mod_name}"`.
20//! See [`WRAP_PREFIX`], and note that `prefix` is set specifically for methods in
21//! `extendr`-impl blocks, while for functions have no prefix.
22//!
23//! [`R_NilValue`]: https://extendr.github.io/libR-sys/libR_sys/static.R_NilValue.html
24//! [`SEXP`]: https://extendr.github.io/libR-sys/libR_sys/type.SEXP.html
25
26use proc_macro2::Ident;
27use quote::{format_ident, quote};
28use std::{collections::HashMap, sync::Mutex};
29use syn::{parse_quote, punctuated::Punctuated, Expr, ExprLit, FnArg, ItemFn, Token, Type};
30
31use crate::extendr_options::ExtendrOptions;
32
33pub const META_PREFIX: &str = "meta__";
34pub const WRAP_PREFIX: &str = "wrap__";
35
36lazy_static::lazy_static! {
37    static ref STRUCT_DOCS: Mutex<HashMap<String, String>> = Mutex::new(HashMap::new());
38}
39
40/// Called by the struct‐level #[extendr] macro to register docstrings.
41pub fn register_struct_doc(name: &str, doc: &str) {
42    STRUCT_DOCS
43        .lock()
44        .unwrap()
45        .insert(name.to_string(), doc.to_string());
46}
47
48/// Retrieve the struct‐level docs (or empty if none).
49pub fn get_struct_doc(name: &str) -> String {
50    STRUCT_DOCS
51        .lock()
52        .unwrap()
53        .get(name)
54        .cloned()
55        .unwrap_or_default()
56}
57
58// Generate wrappers for a specific function.
59pub(crate) fn make_function_wrappers(
60    opts: &ExtendrOptions,
61    wrappers: &mut Vec<ItemFn>,
62    prefix: &str,
63    attrs: &[syn::Attribute],
64    sig: &mut syn::Signature,
65    self_ty: Option<&syn::Type>,
66) -> syn::Result<()> {
67    let rust_name = sig.ident.clone();
68
69    let r_name_str = if let Some(r_name) = opts.r_name.as_ref() {
70        r_name.clone()
71    } else {
72        sig.ident.to_string()
73    };
74
75    let mod_name = if let Some(mod_name) = opts.mod_name.as_ref() {
76        format_ident!("{}", mod_name)
77    } else {
78        sig.ident.clone()
79    };
80
81    let mod_name = sanitize_identifier(mod_name);
82    let wrap_name = format_ident!("{}{}{}", WRAP_PREFIX, prefix, mod_name);
83    let meta_name = format_ident!("{}{}{}", META_PREFIX, prefix, mod_name);
84
85    let rust_name_str = format!("{}", rust_name);
86    let c_name_str = format!("{}", mod_name);
87    let wrap_name_str = format!("{}", wrap_name);
88    let doc_string = get_doc_string(attrs);
89    let return_type_string = get_return_type(sig);
90    let opts_invisible = match opts.invisible {
91        Some(true) => quote!(Some(true)),
92        Some(false) => quote!(Some(false)),
93        None => quote!(None),
94    };
95
96    let inputs = &mut sig.inputs;
97    let has_self = matches!(inputs.iter().next(), Some(FnArg::Receiver(_)));
98
99    let call_name = if has_self {
100        let is_mut = match inputs.iter().next() {
101            Some(FnArg::Receiver(ref receiver)) => receiver.mutability.is_some(),
102            _ => false,
103        };
104        if is_mut {
105            // eg. Person::name(&mut self)
106            quote! { extendr_api::unwrap_or_throw_error(
107                <&mut #self_ty>::try_from(&mut _self_robj)
108            ).#rust_name }
109        } else {
110            // eg. Person::name(&self)
111            quote! { extendr_api::unwrap_or_throw_error(
112                <&#self_ty>::try_from(&_self_robj)
113            ).#rust_name }
114        }
115    } else if let Some(ref self_ty) = &self_ty {
116        // eg. Person::new()
117        quote! { <#self_ty>::#rust_name }
118    } else {
119        // eg. aux_func()
120        quote! { #rust_name }
121    };
122
123    // arguments for the wrapper with type being `SEXP`
124    let formal_args = inputs
125        .iter()
126        .map(|input| translate_formal(input, self_ty))
127        .collect::<syn::Result<Punctuated<FnArg, Token![,]>>>()?;
128
129    // extract the names of the arguments only (`mut` are ignored in `formal_args` already)
130    let sexp_args = formal_args
131        .clone()
132        .into_iter()
133        .map(|x| match x {
134            // the wrapper doesn't use `self` arguments
135            FnArg::Receiver(_) => unreachable!(),
136            FnArg::Typed(ref typed) => match typed.pat.as_ref() {
137                syn::Pat::Ident(ref pat_ident) => pat_ident.ident.clone(),
138                _ => unreachable!(),
139            },
140        })
141        .collect::<Vec<Ident>>();
142
143    // arguments from R (`SEXP`s) are converted to `Robj`
144    let convert_args: Vec<syn::Stmt> = inputs
145        .iter()
146        .map(translate_to_robj)
147        .collect::<syn::Result<Vec<syn::Stmt>>>()?;
148
149    let actual_args: Punctuated<Expr, Token![,]> =
150        inputs.iter().filter_map(translate_actual).collect();
151
152    let meta_args: Vec<Expr> = inputs
153        .iter_mut()
154        .map(|input| translate_meta_arg(input, self_ty))
155        .collect::<syn::Result<Vec<Expr>>>()?;
156
157    // Generate wrappers for rust functions to be called from R.
158    // Example:
159    // ```
160    // #[no_mangle]
161    // #[allow(non_snake_case)]
162    // pub extern "C" fn wrap__hello() -> extendr_api::SEXP {
163    //     unsafe {
164    //         use extendr_api::FromRobj;
165    //         extendr_api::Robj::from(hello()).get()
166    //     }
167    // }
168    // ```
169    let rng_start = if opts.use_rng {
170        {
171            quote!(single_threaded(|| unsafe {
172                extendr_api::GetRNGstate();
173            });)
174        }
175    } else {
176        Default::default()
177    };
178    let rng_end = if opts.use_rng {
179        {
180            quote!(single_threaded(|| unsafe {
181                extendr_api::PutRNGstate();
182            });)
183        }
184    } else {
185        Default::default()
186    };
187
188    // figure out if
189    // -> &Self
190    // -> &mut Self
191    // Or if instead of `Self` the type name is used directly
192    // -> &ImplType / &mut ImplType
193    let return_is_ref_self = {
194        match sig.output {
195            // matches -> () or no-return type
196            syn::ReturnType::Default => false,
197            // ignoring the `-> Self` or `-> ImplType`, as that is not a Reference-type
198            // matches -> &T or &mut T
199            syn::ReturnType::Type(_, ref return_type) => match return_type.as_ref() {
200                Type::Reference(ref reference_type) => {
201                    // checks if T is Self or explicit impl type name
202                    if let Type::Path(path) = reference_type.elem.as_ref() {
203                        let is_typename_impl_type = self_ty
204                            .map(|x| x == reference_type.elem.as_ref())
205                            .unwrap_or(false);
206                        path.path.is_ident("Self") || is_typename_impl_type
207                    } else {
208                        false
209                    }
210                }
211                _ => false,
212            },
213        }
214    };
215
216    let return_type_conversion = if return_is_ref_self {
217        // instead of converting &Self / &mut Self, pass on the passed
218        // ExternalPtr<Self>
219        quote!(
220            let return_ref_to_self = #call_name(#actual_args);
221
222            #(
223            let arg_ref = extendr_api::R_ExternalPtrAddr(#sexp_args)
224                .cast::<Box<dyn std::any::Any>>()
225                .as_ref()
226                .unwrap()
227                .downcast_ref::<#self_ty>()
228                .unwrap();
229            if std::ptr::addr_eq(
230                arg_ref,
231                std::ptr::from_ref(return_ref_to_self)) {
232                    return Ok(extendr_api::Robj::from_sexp(#sexp_args))
233                }
234            )*
235            Err(Error::ExpectedExternalPtrReference.into())
236        )
237    } else {
238        quote!(Ok(extendr_api::Robj::from(#call_name(#actual_args))))
239    };
240
241    // TODO: the unsafe in here is unnecessary
242    wrappers.push(parse_quote!(
243        #[no_mangle]
244        #[allow(non_snake_case, clippy::not_unsafe_ptr_arg_deref)]
245        pub extern "C" fn #wrap_name(#formal_args) -> extendr_api::SEXP {
246            use extendr_api::robj::*;
247
248            // pull RNG state before evaluation
249            #rng_start
250
251            let wrap_result_state: std::result::Result<
252                std::result::Result<extendr_api::Robj, Box<dyn std::error::Error>>,
253                Box<dyn std::any::Any + Send>
254            > = unsafe {
255                    std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || -> std::result::Result<extendr_api::Robj, Box<dyn std::error::Error>> {
256                        #(#convert_args)*
257                        #return_type_conversion
258                    }))
259                };
260
261            // return RNG state back to r after evaluation
262            #rng_end
263
264            // any obj created in above unsafe scope, which are not moved into wrap_result_state are now dropped
265            match wrap_result_state {
266                Ok(Ok(zz)) => {
267                    return unsafe { zz.get() };
268                }
269                // any conversion error bubbled from #actual_args conversions of incoming args from R.
270                Ok(Err(conversion_err)) => {
271                    let err_string = conversion_err.to_string();
272                    drop(conversion_err); // try_from=true errors contain Robj, this must be dropped to not leak
273                    extendr_api::throw_r_error(&err_string);
274                }
275                // any panic (induced by user func code or if user func yields a Result-Err as return value)
276                Err(unwind_err) => {
277                    let panic_msg = if let Some(s) = unwind_err.downcast_ref::<&str>() {
278                        (*s).to_string()
279                    } else if let Some(s) = unwind_err.downcast_ref::<String>() {
280                        s.clone()
281                    } else {
282                        format!("User function panicked: {}", #r_name_str)
283                    };
284
285                    extendr_api::throw_r_error(&panic_msg);
286                }
287            }
288        }
289    ));
290
291    // Generate a function to push the metadata for a function.
292    wrappers.push(parse_quote!(
293        #[allow(non_snake_case)]
294        fn #meta_name(metadata: &mut Vec<extendr_api::metadata::Func>) {
295            let mut args = vec![#(#meta_args,)*];
296
297            metadata.push(extendr_api::metadata::Func {
298                doc: #doc_string,
299                rust_name: #rust_name_str,
300                r_name: #r_name_str,
301                c_name: #wrap_name_str,
302                mod_name: #c_name_str,
303                args: args,
304                return_type: #return_type_string,
305                func_ptr: #wrap_name as * const u8,
306                hidden: false,
307                invisible: #opts_invisible,
308            })
309        }
310    ));
311
312    Ok(())
313}
314
315// Extract doc strings from attributes.
316pub fn get_doc_string(attrs: &[syn::Attribute]) -> String {
317    let mut res = String::new();
318    for attr in attrs {
319        if !attr.path().is_ident("doc") {
320            continue;
321        }
322
323        if let syn::Meta::NameValue(ref nv) = attr.meta {
324            if let Expr::Lit(ExprLit {
325                lit: syn::Lit::Str(ref litstr),
326                ..
327            }) = nv.value
328            {
329                if !res.is_empty() {
330                    res.push('\n');
331                }
332                res.push_str(&litstr.value());
333            }
334        }
335    }
336    res
337}
338
339pub fn get_return_type(sig: &syn::Signature) -> String {
340    match &sig.output {
341        syn::ReturnType::Default => "()".into(),
342        syn::ReturnType::Type(_, ref rettype) => type_name(rettype),
343    }
344}
345
346pub fn mangled_type_name(type_: &Type) -> String {
347    let src = quote!( #type_ ).to_string();
348    let mut res = String::new();
349    for c in src.chars() {
350        if c != ' ' {
351            if c.is_alphanumeric() {
352                res.push(c)
353            } else {
354                let f = format!("_{:02x}", c as u32);
355                res.push_str(&f);
356            }
357        }
358    }
359    res
360}
361
362/// Return a simplified type name that will be meaningful to R. Defaults to a digest.
363// For example:
364// & Fred -> Fred
365// * Fred -> Fred
366// && Fred -> Fred
367// Fred<'a> -> Fred
368// &[i32] -> _hex_hex_hex_hex
369//
370pub fn type_name(type_: &Type) -> String {
371    match type_ {
372        Type::Path(syn::TypePath { path, .. }) => {
373            if let Some(ident) = path.get_ident() {
374                ident.to_string()
375            } else if path.segments.len() == 1 {
376                let seg = path.segments.clone().into_iter().next().unwrap();
377                seg.ident.to_string()
378            } else {
379                mangled_type_name(type_)
380            }
381        }
382        Type::Group(syn::TypeGroup { elem, .. }) => type_name(elem),
383        Type::Reference(syn::TypeReference { elem, .. }) => type_name(elem),
384        Type::Paren(syn::TypeParen { elem, .. }) => type_name(elem),
385        Type::Ptr(syn::TypePtr { elem, .. }) => type_name(elem),
386        _ => mangled_type_name(type_),
387    }
388}
389
390// Generate a list of arguments for the wrapper. All arguments are SEXP for .Call in R.
391pub fn translate_formal(input: &FnArg, self_ty: Option<&syn::Type>) -> syn::Result<FnArg> {
392    match input {
393        // function argument.
394        FnArg::Typed(ref pattype) => {
395            let pat = pattype.pat.as_ref();
396            // ensure that `mut` in args are ignored in the wrapper
397            let pat_ident = translate_only_alias(pat)?;
398            Ok(parse_quote! { #pat_ident: extendr_api::SEXP })
399        }
400        // &self / &mut self
401        FnArg::Receiver(ref receiver) => {
402            if !receiver.attrs.is_empty() || receiver.reference.is_none() {
403                return Err(syn::Error::new_spanned(
404                    input,
405                    "expected &self or &mut self",
406                ));
407            }
408            if self_ty.is_none() {
409                return Err(syn::Error::new_spanned(
410                    input,"found &self in non-impl function - have you missed the #[extendr] before the impl?"
411                ));
412            }
413            Ok(parse_quote! { _self : extendr_api::SEXP })
414        }
415    }
416}
417
418/// Returns only the alias from a function argument.
419///
420/// For example `mut x: Vec<i32>`, the alias is `x`, but the `mut` would still
421/// be present if only the `Ident` of `PatType` was used.
422fn translate_only_alias(pat: &syn::Pat) -> Result<&Ident, syn::Error> {
423    Ok(match pat {
424        syn::Pat::Ident(ref pat_ident) => &pat_ident.ident,
425        _ => {
426            return Err(syn::Error::new_spanned(
427                pat,
428                "failed to translate name of argument",
429            ));
430        }
431    })
432}
433
434// Generate code to make a metadata::Arg.
435fn translate_meta_arg(input: &mut FnArg, self_ty: Option<&syn::Type>) -> syn::Result<Expr> {
436    match input {
437        // function argument.
438        FnArg::Typed(ref mut pattype) => {
439            let pat = pattype.pat.as_ref();
440            let ty = pattype.ty.as_ref();
441            // here the argument name is extracted, without the `mut` keyword,
442            // ensuring the generated r-wrappers, can use these argument names
443            let pat_ident = translate_only_alias(pat)?;
444            let name_string = quote! { #pat_ident }.to_string();
445            let type_string = type_name(ty);
446            let default = if let Some(default) = get_defaults(&mut pattype.attrs) {
447                quote!(Some(#default))
448            } else if let Some(default) = get_named_lit(&mut pattype.attrs, "default") {
449                quote!(Some(#default))
450            } else {
451                quote!(None)
452            };
453            Ok(parse_quote! {
454                extendr_api::metadata::Arg {
455                    name: #name_string,
456                    arg_type: #type_string,
457                    default: #default
458                }
459            })
460        }
461        // &self
462        FnArg::Receiver(ref receiver) => {
463            if !receiver.attrs.is_empty() || receiver.reference.is_none() {
464                return Err(syn::Error::new_spanned(
465                    input,
466                    "expected &self or &mut self",
467                ));
468            }
469            if self_ty.is_none() {
470                return Err(syn::Error::new_spanned(
471                    input,
472            "found &self in non-impl function - have you missed the #[extendr] before the impl?"
473        )
474    );
475            }
476            let type_string = type_name(self_ty.unwrap());
477            Ok(parse_quote! {
478                extendr_api::metadata::Arg {
479                    name: "self",
480                    arg_type: #type_string,
481                    default: None
482                }
483            })
484        }
485    }
486}
487
488// Get defaults from #[extendr(default = "value")] attribute.
489fn get_defaults(attrs: &mut Vec<syn::Attribute>) -> Option<String> {
490    use syn::Lit;
491
492    let mut new_attrs = Vec::new();
493    let mut res = None;
494
495    for i in attrs.drain(0..) {
496        if let syn::Meta::List(ref meta_list) = i.meta {
497            if meta_list.path.is_ident("extendr") {
498                let mut default_value = None;
499                let mut theres_default = false;
500
501                let parse_result = meta_list.parse_nested_meta(|meta| {
502                    if meta.path.is_ident("default") {
503                        theres_default = true;
504                        let value = meta.value()?;
505                        if let Ok(Lit::Str(litstr)) = value.parse() {
506                            default_value = Some(litstr.value());
507                        }
508                    }
509                    Ok(())
510                });
511
512                if parse_result.is_ok() && theres_default {
513                    res = default_value;
514                    continue;
515                }
516            }
517        }
518
519        new_attrs.push(i);
520    }
521    *attrs = new_attrs;
522    res
523}
524
525/// Convert `SEXP` arguments into `Robj`.
526/// This maintains the lifetime of references.
527///
528/// These conversions are from R into Rust
529fn translate_to_robj(input: &FnArg) -> syn::Result<syn::Stmt> {
530    match input {
531        FnArg::Typed(ref pattype) => {
532            let pat = &pattype.pat.as_ref();
533            if let syn::Pat::Ident(ref ident) = pat {
534                let varname = format_ident!("_{}_robj", ident.ident);
535                let ident = &ident.ident;
536                // TODO: these do not need protection, as they come from R
537                Ok(parse_quote! { let #varname = extendr_api::robj::Robj::from_sexp(#ident); })
538            } else {
539                Err(syn::Error::new_spanned(
540                    input,
541                    "expect identifier as arg name",
542                ))
543            }
544        }
545        FnArg::Receiver(_) => {
546            // this is `mut`, in case of a mutable reference
547            Ok(parse_quote! { let mut _self_robj = extendr_api::robj::Robj::from_sexp(_self); })
548        }
549    }
550}
551
552// Generate actual argument list for the call (ie. a list of conversions).
553fn translate_actual(input: &FnArg) -> Option<Expr> {
554    match input {
555        FnArg::Typed(ref pattype) => {
556            let pat = &pattype.pat.as_ref();
557            if let syn::Pat::Ident(ref ident) = pat {
558                let varname = format_ident!("_{}_robj", ident.ident);
559                Some(parse_quote! {
560                    #varname.try_into()?
561                })
562            } else {
563                None
564            }
565        }
566        FnArg::Receiver(_) => {
567            // Do not use self explicitly as an actual arg.
568            None
569        }
570    }
571}
572
573// Get a single named literal from a list of attributes.
574// eg. #[default="xyz"]
575// Remove the attribute from the list.
576fn get_named_lit(attrs: &mut Vec<syn::Attribute>, name: &str) -> Option<String> {
577    let mut new_attrs = Vec::new();
578    let mut res = None;
579    for a in attrs.drain(0..) {
580        if let syn::Meta::NameValue(ref nv) = a.meta {
581            if nv.path.is_ident(name) {
582                if let Expr::Lit(ExprLit {
583                    lit: syn::Lit::Str(ref litstr),
584                    ..
585                }) = nv.value
586                {
587                    eprintln!("#[default = \"arg\"] is deprecated. Use #[extendr(default = \"arg\")] instead.");
588                    res = Some(litstr.value());
589                    continue;
590                }
591            }
592        }
593
594        new_attrs.push(a);
595    }
596    *attrs = new_attrs;
597    res
598}
599
600// Remove the raw identifier prefix (`r#`) from an [`Ident`]
601// If the `Ident` does not start with the prefix, it is returned as is.
602fn sanitize_identifier(ident: Ident) -> Ident {
603    static PREFIX: &str = "r#";
604    let (ident, span) = (ident.to_string(), ident.span());
605    let ident = match ident.strip_prefix(PREFIX) {
606        Some(ident) => ident.into(),
607        None => ident,
608    };
609
610    Ident::new(&ident, span)
611}