changeset 78:e9ccc1702af0

Merge pull request #12 from Codadillo/master Use a thread local cache for memoization.
author Lewin Bormann <lbo@spheniscida.de>
date Sat, 02 Apr 2022 13:51:05 -0700
parents b493c1f99877 (current diff) e010a33862ac (diff)
children ff66cf5c7f8c
files
diffstat 2 files changed, 183 insertions(+), 146 deletions(-) [+]
line wrap: on
line diff
--- a/README.md	Sat Apr 02 13:49:42 2022 -0700
+++ b/README.md	Sat Apr 02 13:51:05 2022 -0700
@@ -36,29 +36,46 @@
 This is expanded into (with a few simplifications):
 
 ```rust
-// This is obviously further expanded before compiling.
-lazy_static! {
-  static ref MEMOIZED_MAPPING_HELLO : Mutex<HashMap<String, bool>>;
+std::thread_local! {
+  static MEMOIZED_MAPPING_HELLO : RefCell<HashMap<(String, usize), bool>> = RefCell::new(HashMap::new());
 }
 
-fn memoized_original_hello(arg: String, arg2: usize) -> bool {
+pub fn memoized_original_hello(arg: String, arg2: usize) -> bool {
   arg.len() % 2 == arg2
 }
 
+#[allow(unused_variables)]
 fn hello(arg: String, arg2: usize) -> bool {
-  {
-    let mut hm = &mut MEMOIZED_MAPPING_HELLO.lock().unwrap();
-    if let Some(r) = hm.get(&(arg.clone(), arg2.clone())) {
-      return r.clone();
-    }
+  let r = MEMOIZED_MAPPING_HELLO.with(|hm| {
+    let mut hm = hm.borrow_mut();
+    hm.get(&(arg.clone(), arg2.clone())).cloned()
+  });
+  if let Some(r) = r {
+    return r;
   }
+
   let r = memoized_original_hello(arg.clone(), arg2.clone());
-  hm.insert((arg, arg2), r.clone());
+
+  MEMOIZED_MAPPING_HELLO.with(|hm| {
+    let mut hm = hm.borrow_mut();
+    hm.insert((arg, arg2), r.clone());
+  });
+
   r
 }
+
 ```
 
 ## Further Functionality
+As can be seen in the above example, each thread has its own cache by default. If you would prefer
+that every thread share the same cache, you can specify the `SharedCache` option like below to wrap
+the cache in a `std::sync::Mutex`. For example:
+```rust
+#[memoize(SharedCache)]
+fn hello(key: String) -> ComplexStruct {
+  // ...
+}
+```
 
 You can choose to use an [LRU cache](https://crates.io/crates/lru). In fact, if
 you know that a memoized function has an unbounded number of different inputs,
--- a/inner/src/lib.rs	Sat Apr 02 13:49:42 2022 -0700
+++ b/inner/src/lib.rs	Sat Apr 02 13:51:05 2022 -0700
@@ -1,31 +1,98 @@
 #![crate_type = "proc-macro"]
 #![allow(unused_imports)] // Spurious complaints about a required trait import.
 
-use syn::{self, parse_macro_input, spanned::Spanned, ItemFn};
+use syn::{self, parse, parse_macro_input, spanned::Spanned, Expr, ItemFn};
 
 use proc_macro::TokenStream;
 use quote::{self, ToTokens};
 
+mod kw {
+    syn::custom_keyword!(Capacity);
+    syn::custom_keyword!(TimeToLive);
+    syn::custom_keyword!(SharedCache);
+    syn::custom_punctuation!(Colon, :);
+}
+
+#[derive(Default, Clone)]
+struct CacheOptions {
+    lru_max_entries: Option<usize>,
+    time_to_live: Option<Expr>,
+    shared_cache: bool,
+}
+
+#[derive(Clone)]
+enum CacheOption {
+    LRUMaxEntries(usize),
+    TimeToLive(Expr),
+    SharedCache,
+}
+
+// To extend option parsing, add functionality here.
+#[allow(unreachable_code)]
+impl parse::Parse for CacheOption {
+    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
+        let la = input.lookahead1();
+        if la.peek(kw::Capacity) {
+            #[cfg(not(feature = "full"))]
+            return Err(syn::Error::new(input.span(),
+            "memoize error: Capacity specified, but the feature 'full' is not enabled! To fix this, compile with `--features=full`.",
+            ));
+
+            input.parse::<kw::Capacity>().unwrap();
+            input.parse::<kw::Colon>().unwrap();
+            let cap: syn::LitInt = input.parse().unwrap();
+
+            return Ok(CacheOption::LRUMaxEntries(cap.base10_parse()?));
+        }
+        if la.peek(kw::TimeToLive) {
+            #[cfg(not(feature = "full"))]
+            return Err(syn::Error::new(input.span(),
+            "memoize error: TimeToLive specified, but the feature 'full' is not enabled! To fix this, compile with `--features=full`.",
+            ));
+
+            input.parse::<kw::TimeToLive>().unwrap();
+            input.parse::<kw::Colon>().unwrap();
+            let cap: syn::Expr = input.parse().unwrap();
+
+            return Ok(CacheOption::TimeToLive(cap));
+        }
+        if la.peek(kw::SharedCache) {
+            input.parse::<kw::SharedCache>().unwrap();
+            return Ok(CacheOption::SharedCache);
+        }
+        Err(la.error())
+    }
+}
+
+impl parse::Parse for CacheOptions {
+    fn parse(input: parse::ParseStream) -> syn::Result<Self> {
+        let f: syn::punctuated::Punctuated<CacheOption, syn::Token![,]> =
+            input.parse_terminated(CacheOption::parse)?;
+        let mut opts = Self::default();
+
+        for opt in f {
+            match opt {
+                CacheOption::LRUMaxEntries(cap) => opts.lru_max_entries = Some(cap),
+                CacheOption::TimeToLive(sec) => opts.time_to_live = Some(sec),
+                CacheOption::SharedCache => opts.shared_cache = true,
+            }
+        }
+        Ok(opts)
+    }
+}
+
 // This implementation of the storage backend does not depend on any more crates.
 #[cfg(not(feature = "full"))]
 mod store {
+    use crate::CacheOptions;
     use proc_macro::TokenStream;
 
     /// Returns tokenstreams (for quoting) of the store type and an expression to initialize it.
-    pub fn construct_cache(
-        attr: &TokenStream,
+    pub(crate) fn construct_cache(
+        _options: &CacheOptions,
         key_type: proc_macro2::TokenStream,
         value_type: proc_macro2::TokenStream,
     ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
-        if !attr.is_empty() {
-            return (
-                syn::Error::new_spanned(proc_macro2::TokenStream::from(attr.clone()),
-                    "memoize error: Attributes are specified, but the feature 'full' is not enabled! To fix this, compile with `--features=full`.",
-                )
-                .to_compile_error(),
-                proc_macro2::TokenStream::new(),
-            );
-        }
         // This is the unbounded default.
         (
             quote::quote! { std::collections::HashMap<#key_type, #value_type> },
@@ -35,8 +102,8 @@
 
     /// Returns names of methods as TokenStreams to insert and get (respectively) elements from a
     /// store.
-    pub fn cache_access_methods(
-        _attr: &TokenStream,
+    pub(crate) fn cache_access_methods(
+        _options: &CacheOptions,
     ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
         (quote::quote! { insert }, quote::quote! { get })
     }
@@ -45,75 +112,19 @@
 // This implementation of the storage backend also depends on the `lru` crate.
 #[cfg(feature = "full")]
 mod store {
+    use crate::CacheOptions;
     use proc_macro::TokenStream;
-    use syn::{parse as p, Expr};
-
-    #[derive(Default, Clone)]
-    pub(crate) struct CacheOptions {
-        lru_max_entries: Option<usize>,
-        pub(crate) time_to_live: Option<Expr>,
-    }
-
-    #[derive(Clone)]
-    enum CacheOption {
-        LRUMaxEntries(usize),
-        TimeToLive(Expr),
-    }
-
-    syn::custom_keyword!(Capacity);
-    syn::custom_keyword!(TimeToLive);
-    syn::custom_punctuation!(Colon, :);
-
-    // To extend option parsing, add functionality here.
-    impl p::Parse for CacheOption {
-        fn parse(input: p::ParseStream) -> syn::Result<Self> {
-            let la = input.lookahead1();
-            if la.peek(Capacity) {
-                let _: Capacity = input.parse().unwrap();
-                let _: Colon = input.parse().unwrap();
-                let cap: syn::LitInt = input.parse().unwrap();
-
-                return Ok(CacheOption::LRUMaxEntries(cap.base10_parse()?));
-            }
-            if la.peek(TimeToLive) {
-                let _: TimeToLive = input.parse().unwrap();
-                let _: Colon = input.parse().unwrap();
-                let cap: syn::Expr = input.parse().unwrap();
-
-                return Ok(CacheOption::TimeToLive(cap));
-            }
-            Err(la.error())
-        }
-    }
-
-    impl p::Parse for CacheOptions {
-        fn parse(input: p::ParseStream) -> syn::Result<Self> {
-            let f: syn::punctuated::Punctuated<CacheOption, syn::Token![,]> =
-                input.parse_terminated(CacheOption::parse)?;
-            let mut opts = Self::default();
-
-            for opt in f {
-                match opt {
-                    CacheOption::LRUMaxEntries(cap) => opts.lru_max_entries = Some(cap),
-                    CacheOption::TimeToLive(sec) => opts.time_to_live = Some(sec),
-                }
-            }
-            Ok(opts)
-        }
-    }
 
     /// Returns TokenStreams to be used in quote!{} for parametrizing the memoize store variable,
     /// and initializing it.
     ///
     /// First return value: Type of store ("Container<K,V>").
     /// Second return value: Initializer syntax ("Container::<K,V>::new()").
-    pub fn construct_cache(
-        attr: &TokenStream,
+    pub(crate) fn construct_cache(
+        options: &CacheOptions,
         key_type: proc_macro2::TokenStream,
         value_type: proc_macro2::TokenStream,
     ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
-        let options: CacheOptions = syn::parse(attr.clone()).unwrap();
-
         let value_type = match options.time_to_live {
             None => quote::quote! {#value_type},
             Some(_) => quote::quote! {(std::time::Instant, #value_type)},
@@ -133,11 +144,9 @@
 
     /// Returns names of methods as TokenStreams to insert and get (respectively) elements from a
     /// store.
-    pub fn cache_access_methods(
-        attr: &TokenStream,
+    pub(crate) fn cache_access_methods(
+        options: &CacheOptions,
     ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
-        let options: CacheOptions = syn::parse(attr.clone()).unwrap();
-
         // This is the unbounded default.
         match options.lru_max_entries {
             None => (quote::quote! { insert }, quote::quote! { get }),
@@ -219,13 +228,26 @@
         syn::ReturnType::Type(_, ty) => return_type = ty.to_token_stream(),
     }
 
+    // Parse options from macro attributes
+    let options: CacheOptions = syn::parse(attr.clone()).unwrap();
+
     // Construct storage for the memoized keys and return values.
     let store_ident = syn::Ident::new(&map_name.to_uppercase(), sig.span());
-    let (cache_type, cache_init) = store::construct_cache(&attr, input_tuple_type, return_type.clone());
-    let store = quote::quote! {
-        ::memoize::lazy_static::lazy_static! {
-            static ref #store_ident : std::sync::Mutex<#cache_type> =
-                std::sync::Mutex::new(#cache_init);
+    let (cache_type, cache_init) =
+        store::construct_cache(&options, input_tuple_type, return_type.clone());
+    let store = if options.shared_cache {
+        quote::quote! {
+            ::memoize::lazy_static::lazy_static! {
+                static ref #store_ident : std::sync::Mutex<#cache_type> =
+                    std::sync::Mutex::new(#cache_init);
+            }
+        }
+    } else {
+        quote::quote! {
+            std::thread_local! {
+                static #store_ident : std::cell::RefCell<#cache_type> =
+                    std::cell::RefCell::new(#cache_init);
+            }
         }
     };
 
@@ -234,75 +256,73 @@
     renamed_fn.sig.ident = syn::Ident::new(&renamed_name, func.sig.span());
     let memoized_id = &renamed_fn.sig.ident;
 
-    // Extract the function name and identifier.
-    let fn_name = func.sig.ident.clone();
-    let fn_vis = func.vis.clone();
-    
     // Construct memoizer function, which calls the original function.
     let syntax_names_tuple = quote::quote! { (#(#input_names),*) };
     let syntax_names_tuple_cloned = quote::quote! { (#(#input_names.clone()),*) };
-    let (insert_fn, get_fn) = store::cache_access_methods(&attr);
-    #[cfg(feature = "full")]
-    let memoizer = {
-        let options: store::CacheOptions = syn::parse(attr.clone().into()).unwrap();
-        match options.time_to_live {
-            None => quote::quote! {
-                #sig {
-                    {
-                        let mut hm = &mut #store_ident.lock().unwrap();
-                        if let Some(r) = hm.#get_fn(&#syntax_names_tuple_cloned) {
-                            return r.clone();
-                        }
-                    }
-                    let r = #memoized_id(#(#input_names.clone()),*);
-                    let mut hm = &mut #store_ident.lock().unwrap();
-                    hm.#insert_fn(#syntax_names_tuple, r.clone());
-                    r
-                }
+    let (insert_fn, get_fn) = store::cache_access_methods(&options);
+    let (read_memo, memoize) = match options.time_to_live {
+        None => (
+            quote::quote!(hm.#get_fn(&#syntax_names_tuple_cloned).cloned()),
+            quote::quote!(hm.#insert_fn(#syntax_names_tuple, r.clone());),
+        ),
+        Some(ttl) => (
+            quote::quote! {
+                hm.#get_fn(&#syntax_names_tuple_cloned).and_then(|(last_updated, r)|
+                    (last_updated.elapsed() < #ttl).then(|| r.clone())
+                )
             },
-            Some(ttl) => quote::quote! {
-                #sig {
-                    {
-                        let mut hm = &mut #store_ident.lock().unwrap();
-                        if let Some((last_updated, r)) = hm.#get_fn(&#syntax_names_tuple_cloned) {
-                            if last_updated.elapsed() < #ttl {
-                                return r.clone();
-                            }
-                        }
-                    }
-                    let r = #memoized_id(#(#input_names.clone()),*);
-                    let mut hm = &mut #store_ident.lock().unwrap();
-                    hm.#insert_fn(#syntax_names_tuple, (std::time::Instant::now(), r.clone()));
-                    r
-                }
-            },
-        }
+            quote::quote!(hm.#insert_fn(#syntax_names_tuple, (std::time::Instant::now(), r.clone()));),
+        ),
     };
-    #[cfg(not(feature = "full"))]
-    let memoizer = quote::quote! {
-        #fn_vis fn #fn_name (
-            #(#input_names: #input_types),*
-        ) -> #return_type {
+
+    let memoizer = if options.shared_cache {
+        quote::quote! {
             {
-                let mut hm = &mut #store_ident.lock().unwrap();
-                if let Some(r) = hm.#get_fn(&#syntax_names_tuple_cloned) {
-                    return r.clone();
+                let mut hm = #store_ident.lock().unwrap();
+                if let Some(r) = #read_memo {
+                    return r
                 }
             }
             let r = #memoized_id(#(#input_names.clone()),*);
-            let mut hm = &mut #store_ident.lock().unwrap();
-            hm.#insert_fn(#syntax_names_tuple, r.clone());
+
+            let mut hm = #store_ident.lock().unwrap();
+            #memoize
+
+            r
+        }
+    } else {
+        quote::quote! {
+            let r = #store_ident.with(|hm| {
+                let mut hm = hm.borrow_mut();
+                #read_memo
+            });
+            if let Some(r) = r {
+                return r;
+            }
+
+            let r = #memoized_id(#(#input_names.clone()),*);
+
+            #store_ident.with(|hm| {
+                let mut hm = hm.borrow_mut();
+                #memoize
+            });
+
             r
         }
     };
 
-    (quote::quote! {
+    let vis = &func.vis;
+
+    quote::quote! {
+        #renamed_fn
+
         #store
 
-        #renamed_fn
-
-        #memoizer
-    })
+        #[allow(unused_variables)]
+        #vis #sig {
+            #memoizer
+        }
+    }
     .into()
 }