changeset 74:10392fdf5048

Use thread local memoization caches by default and move mutex caches behind SharedCache attribute
author Leodore <lconrads@ucsc.edu>
date Sun, 20 Mar 2022 23:12:31 -0700
parents 3adb990ef496
children c14b0b81328f
files inner/src/lib.rs
diffstat 1 files changed, 151 insertions(+), 134 deletions(-) [+]
line wrap: on
line diff
--- a/inner/src/lib.rs	Sat Jan 01 12:32:21 2022 +0100
+++ b/inner/src/lib.rs	Sun Mar 20 23:12:31 2022 -0700
@@ -1,31 +1,97 @@
 #![crate_type = "proc-macro"]
 #![allow(unused_imports)] // Spurious complaints about a required trait import.
 
-use syn::{self, parse_macro_input, spanned::Spanned, ItemFn};
+use syn::{self, parse, parse_macro_input, spanned::Spanned, Expr, ItemFn};
 
 use proc_macro::TokenStream;
 use quote::{self, ToTokens};
 
+mod kw {
+    syn::custom_keyword!(Capacity);
+    syn::custom_keyword!(TimeToLive);
+    syn::custom_keyword!(SharedCache);
+    syn::custom_punctuation!(Colon, :);
+}
+
+#[derive(Default, Clone)]
+struct CacheOptions {
+    lru_max_entries: Option<usize>,
+    time_to_live: Option<Expr>,
+    shared_cache: bool,
+}
+
+#[derive(Clone)]
+enum CacheOption {
+    LRUMaxEntries(usize),
+    TimeToLive(Expr),
+    SharedCache,
+}
+
+// To extend option parsing, add functionality here.
+impl parse::Parse for CacheOption {
+    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
+        let la = input.lookahead1();
+        if la.peek(kw::Capacity) {
+            #[cfg(not(feature = "full"))]
+            return Err(syn::Error::new(input.span(),
+            "memoize error: Capacity specified, but the feature 'full' is not enabled! To fix this, compile with `--features=full`.",
+            ));
+
+            input.parse::<kw::Capacity>().unwrap();
+            input.parse::<kw::Colon>().unwrap();
+            let cap: syn::LitInt = input.parse().unwrap();
+
+            return Ok(CacheOption::LRUMaxEntries(cap.base10_parse()?));
+        }
+        if la.peek(kw::TimeToLive) {
+            #[cfg(not(feature = "full"))]
+            return Err(syn::Error::new(input.span(),
+            "memoize error: TimeToLive specified, but the feature 'full' is not enabled! To fix this, compile with `--features=full`.",
+            ));
+
+            input.parse::<kw::TimeToLive>().unwrap();
+            input.parse::<kw::Colon>().unwrap();
+            let cap: syn::Expr = input.parse().unwrap();
+
+            return Ok(CacheOption::TimeToLive(cap));
+        }
+        if la.peek(kw::SharedCache) {
+            input.parse::<kw::SharedCache>().unwrap();
+            return Ok(CacheOption::SharedCache);
+        }
+        Err(la.error())
+    }
+}
+
+impl parse::Parse for CacheOptions {
+    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
+        let f: syn::punctuated::Punctuated<CacheOption, syn::Token![,]> =
+            input.parse_terminated(CacheOption::parse)?;
+        let mut opts = Self::default();
+
+        for opt in f {
+            match opt {
+                CacheOption::LRUMaxEntries(cap) => opts.lru_max_entries = Some(cap),
+                CacheOption::TimeToLive(sec) => opts.time_to_live = Some(sec),
+                CacheOption::SharedCache => opts.shared_cache = true,
+            }
+        }
+        Ok(opts)
+    }
+}
+
 // This implementation of the storage backend does not depend on any more crates.
 #[cfg(not(feature = "full"))]
 mod store {
+    use crate::CacheOptions;
     use proc_macro::TokenStream;
 
     /// Returns tokenstreams (for quoting) of the store type and an expression to initialize it.
-    pub fn construct_cache(
-        attr: &TokenStream,
+    pub(crate) fn construct_cache(
+        _options: &CacheOptions,
         key_type: proc_macro2::TokenStream,
         value_type: proc_macro2::TokenStream,
     ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
-        if !attr.is_empty() {
-            return (
-                syn::Error::new_spanned(proc_macro2::TokenStream::from(attr.clone()),
-                    "memoize error: Attributes are specified, but the feature 'full' is not enabled! To fix this, compile with `--features=full`.",
-                )
-                .to_compile_error(),
-                proc_macro2::TokenStream::new(),
-            );
-        }
         // This is the unbounded default.
         (
             quote::quote! { std::collections::HashMap<#key_type, #value_type> },
@@ -35,8 +101,8 @@
 
     /// Returns names of methods as TokenStreams to insert and get (respectively) elements from a
     /// store.
-    pub fn cache_access_methods(
-        _attr: &TokenStream,
+    pub(crate) fn cache_access_methods(
+        _options: &CacheOptions,
     ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
         (quote::quote! { insert }, quote::quote! { get })
     }
@@ -45,75 +111,19 @@
 // This implementation of the storage backend also depends on the `lru` crate.
 #[cfg(feature = "full")]
 mod store {
+    use crate::CacheOptions;
     use proc_macro::TokenStream;
-    use syn::{parse as p, Expr};
-
-    #[derive(Default, Clone)]
-    pub(crate) struct CacheOptions {
-        lru_max_entries: Option<usize>,
-        pub(crate) time_to_live: Option<Expr>,
-    }
-
-    #[derive(Clone)]
-    enum CacheOption {
-        LRUMaxEntries(usize),
-        TimeToLive(Expr),
-    }
-
-    syn::custom_keyword!(Capacity);
-    syn::custom_keyword!(TimeToLive);
-    syn::custom_punctuation!(Colon, :);
-
-    // To extend option parsing, add functionality here.
-    impl p::Parse for CacheOption {
-        fn parse(input: p::ParseStream) -> syn::Result<Self> {
-            let la = input.lookahead1();
-            if la.peek(Capacity) {
-                let _: Capacity = input.parse().unwrap();
-                let _: Colon = input.parse().unwrap();
-                let cap: syn::LitInt = input.parse().unwrap();
-
-                return Ok(CacheOption::LRUMaxEntries(cap.base10_parse()?));
-            }
-            if la.peek(TimeToLive) {
-                let _: TimeToLive = input.parse().unwrap();
-                let _: Colon = input.parse().unwrap();
-                let cap: syn::Expr = input.parse().unwrap();
-
-                return Ok(CacheOption::TimeToLive(cap));
-            }
-            Err(la.error())
-        }
-    }
-
-    impl p::Parse for CacheOptions {
-        fn parse(input: p::ParseStream) -> syn::Result<Self> {
-            let f: syn::punctuated::Punctuated<CacheOption, syn::Token![,]> =
-                input.parse_terminated(CacheOption::parse)?;
-            let mut opts = Self::default();
-
-            for opt in f {
-                match opt {
-                    CacheOption::LRUMaxEntries(cap) => opts.lru_max_entries = Some(cap),
-                    CacheOption::TimeToLive(sec) => opts.time_to_live = Some(sec),
-                }
-            }
-            Ok(opts)
-        }
-    }
 
     /// Returns TokenStreams to be used in quote!{} for parametrizing the memoize store variable,
     /// and initializing it.
     ///
     /// First return value: Type of store ("Container<K,V>").
     /// Second return value: Initializer syntax ("Container::<K,V>::new()").
-    pub fn construct_cache(
-        attr: &TokenStream,
+    pub(crate) fn construct_cache(
+        options: &CacheOptions,
         key_type: proc_macro2::TokenStream,
         value_type: proc_macro2::TokenStream,
     ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
-        let options: CacheOptions = syn::parse(attr.clone()).unwrap();
-
         let value_type = match options.time_to_live {
             None => quote::quote! {#value_type},
             Some(_) => quote::quote! {(std::time::Instant, #value_type)},
@@ -133,11 +143,9 @@
 
     /// Returns names of methods as TokenStreams to insert and get (respectively) elements from a
     /// store.
-    pub fn cache_access_methods(
-        attr: &TokenStream,
+    pub(crate) fn cache_access_methods(
+        options: &CacheOptions,
     ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
-        let options: CacheOptions = syn::parse(attr.clone()).unwrap();
-
         // This is the unbounded default.
         match options.lru_max_entries {
             None => (quote::quote! { insert }, quote::quote! { get }),
@@ -219,13 +227,26 @@
         syn::ReturnType::Type(_, ty) => return_type = ty.to_token_stream(),
     }
 
+    // Parse options from macro attributes
+    let options: CacheOptions = syn::parse(attr.clone()).unwrap();
+
     // Construct storage for the memoized keys and return values.
     let store_ident = syn::Ident::new(&map_name.to_uppercase(), sig.span());
-    let (cache_type, cache_init) = store::construct_cache(&attr, input_tuple_type, return_type.clone());
-    let store = quote::quote! {
-        ::memoize::lazy_static::lazy_static! {
-            static ref #store_ident : std::sync::Mutex<#cache_type> =
-                std::sync::Mutex::new(#cache_init);
+    let (cache_type, cache_init) =
+        store::construct_cache(&options, input_tuple_type, return_type.clone());
+    let store = if options.shared_cache {
+        quote::quote! {
+            ::memoize::lazy_static::lazy_static! {
+                static ref #store_ident : std::sync::Mutex<#cache_type> =
+                    std::sync::Mutex::new(#cache_init);
+            }
+        }
+    } else {
+        quote::quote! {
+            std::thread_local! {
+                static #store_ident : std::cell::RefCell<#cache_type> =
+                    std::cell::RefCell::new(#cache_init);
+            }
         }
     };
 
@@ -234,74 +255,70 @@
     renamed_fn.sig.ident = syn::Ident::new(&renamed_name, func.sig.span());
     let memoized_id = &renamed_fn.sig.ident;
 
-    // Extract the function name and identifier.
-    let fn_name = func.sig.ident.clone();
-    let fn_vis = func.vis.clone();
-    
     // Construct memoizer function, which calls the original function.
     let syntax_names_tuple = quote::quote! { (#(#input_names),*) };
     let syntax_names_tuple_cloned = quote::quote! { (#(#input_names.clone()),*) };
-    let (insert_fn, get_fn) = store::cache_access_methods(&attr);
-    #[cfg(feature = "full")]
-    let memoizer = {
-        let options: store::CacheOptions = syn::parse(attr.clone().into()).unwrap();
-        match options.time_to_live {
-            None => quote::quote! {
-                #sig {
-                    {
-                        let mut hm = &mut #store_ident.lock().unwrap();
-                        if let Some(r) = hm.#get_fn(&#syntax_names_tuple_cloned) {
-                            return r.clone();
-                        }
-                    }
-                    let r = #memoized_id(#(#input_names.clone()),*);
-                    let mut hm = &mut #store_ident.lock().unwrap();
-                    hm.#insert_fn(#syntax_names_tuple, r.clone());
-                    r
-                }
+    let (insert_fn, get_fn) = store::cache_access_methods(&options);
+    let (read_memo, memoize) = match options.time_to_live {
+        None => (
+            quote::quote!(hm.#get_fn(&#syntax_names_tuple_cloned).cloned()),
+            quote::quote!(hm.#insert_fn(#syntax_names_tuple, r.clone());),
+        ),
+        Some(ttl) => (
+            quote::quote! {
+                hm.#get_fn(&#syntax_names_tuple_cloned).and_then(|(last_updated, r)|
+                    (last_updated.elapsed() < #ttl).then(|| r.clone())
+                )
             },
-            Some(ttl) => quote::quote! {
-                #sig {
-                    {
-                        let mut hm = &mut #store_ident.lock().unwrap();
-                        if let Some((last_updated, r)) = hm.#get_fn(&#syntax_names_tuple_cloned) {
-                            if last_updated.elapsed() < #ttl {
-                                return r.clone();
-                            }
-                        }
-                    }
-                    let r = #memoized_id(#(#input_names.clone()),*);
-                    let mut hm = &mut #store_ident.lock().unwrap();
-                    hm.#insert_fn(#syntax_names_tuple, (std::time::Instant::now(), r.clone()));
-                    r
-                }
-            },
-        }
+            quote::quote!(hm.#insert_fn(#syntax_names_tuple, (std::time::Instant::now(), r.clone()));),
+        ),
     };
-    #[cfg(not(feature = "full"))]
-    let memoizer = quote::quote! {
-        #fn_vis fn #fn_name (
-            #(#input_names: #input_types),*
-        ) -> #return_type {
+
+    let memoizer = if options.shared_cache {
+        quote::quote! {
             {
-                let mut hm = &mut #store_ident.lock().unwrap();
-                if let Some(r) = hm.#get_fn(&#syntax_names_tuple_cloned) {
-                    return r.clone();
+                let mut hm = #store_ident.lock().unwrap();
+                if let Some(r) = #read_memo {
+                    return r
                 }
             }
             let r = #memoized_id(#(#input_names.clone()),*);
-            let mut hm = &mut #store_ident.lock().unwrap();
-            hm.#insert_fn(#syntax_names_tuple, r.clone());
+
+            let mut hm = #store_ident.lock().unwrap();
+            #memoize
+
+            r
+        }
+    } else {
+        quote::quote! {
+            let r = #store_ident.with(|hm| {
+                let mut hm = hm.borrow_mut();
+                #read_memo
+            });
+            if let Some(r) = r {
+                return r;
+            }
+
+            let r = #memoized_id(#(#input_names.clone()),*);
+
+            #store_ident.with(|hm| {
+                let mut hm = hm.borrow_mut();
+                #memoize
+            });
+
             r
         }
     };
 
     (quote::quote! {
+        #renamed_fn
+
         #store
 
-        #renamed_fn
-
-        #memoizer
+        #[allow(unused_variables)]
+        #sig {
+            #memoizer
+        }
     })
     .into()
 }