changeset 16:2c254214df2e

Implement using an LRU cache.
author Lewin Bormann <lbo@spheniscida.de>
date Thu, 15 Oct 2020 20:34:20 +0200
parents a794b6862ef4
children cf5a71de7e50
files Cargo.toml examples/test1.rs src/lib.rs
diffstat 3 files changed, 153 insertions(+), 21 deletions(-) [+]
line wrap: on
line diff
--- a/Cargo.toml	Thu Oct 15 19:24:47 2020 +0200
+++ b/Cargo.toml	Thu Oct 15 20:34:20 2020 +0200
@@ -19,3 +19,9 @@
 proc-macro2 = "1.0"
 quote = "1.0"
 syn = { version = "1.0", features = ["full"] }
+
+lru = { version = "0.6", optional = true }
+
+[features]
+default = []
+full = ["lru"]
--- a/examples/test1.rs	Thu Oct 15 19:24:47 2020 +0200
+++ b/examples/test1.rs	Thu Oct 15 20:34:20 2020 +0200
@@ -7,7 +7,7 @@
     i: i32,
 }
 
-#[memoize]
+#[memoize(Capacity: 123)]
 fn hello(key: String) -> ComplexStruct {
     println!("hello: {}", key);
     ComplexStruct {
--- a/src/lib.rs	Thu Oct 15 19:24:47 2020 +0200
+++ b/src/lib.rs	Thu Oct 15 20:34:20 2020 +0200
@@ -6,16 +6,130 @@
 use proc_macro::TokenStream;
 use quote::{self, ToTokens};
 
-/*
- * TODO:
- */
+trait MemoizeStore<K, V>
+where
+    K: std::hash::Hash,
+{
+    fn get(&mut self, k: &K) -> Option<&V>;
+    fn put(&mut self, k: K, v: V);
+}
+
+impl<K: std::hash::Hash + Eq + Clone, V> MemoizeStore<K, V> for std::collections::HashMap<K, V> {
+    fn get(&mut self, k: &K) -> Option<&V> {
+        std::collections::HashMap::<K, V>::get(self, k)
+    }
+    fn put(&mut self, k: K, v: V) {
+        self.insert(k, v);
+    }
+}
+
+#[cfg(not(feature = "full"))]
+mod store {
+    use proc_macro::TokenStream;
+
+    /// Returns tokenstreams (for quoting) of the store type and an expression to initialize it.
+    pub fn construct_cache(
+        _attr: TokenStream,
+        key_type: proc_macro2::TokenStream,
+        value_type: proc_macro2::TokenStream,
+    ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
+        // This is the unbounded default.
+        (
+            quote::quote! { std::collections::HashMap<#key_type, #value_type> },
+            quote::quote! { std::collections::HashMap::new() },
+        )
+    }
+
+    /// Returns tokenstreams (for quoting) of method names for inserting/getting (first/second
+    /// return tuple value).
+    pub fn cache_access_methods(
+        _attr: &TokenStream,
+    ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
+        (quote::quote! { insert }, quote::quote! { get })
+    }
+}
+
+#[cfg(feature = "full")]
+mod store {
+    use super::MemoizeStore;
+    use proc_macro::TokenStream;
+    use syn::parse as p;
+
+    impl<K: std::hash::Hash + Eq + Clone, V> MemoizeStore<K, V> for lru::LruCache<K, V> {
+        fn get(&mut self, k: &K) -> Option<&V> {
+            lru::LruCache::<K, V>::get(self, k)
+        }
+        fn put(&mut self, k: K, v: V) {
+            lru::LruCache::<K, V>::put(self, k, v);
+        }
+    }
+
+    #[derive(Default, Debug, Clone)]
+    struct CacheOptions {
+        lru_max_entries: Option<usize>,
+    }
+
+    syn::custom_keyword!(Capacity);
+    syn::custom_punctuation!(Colon, :);
+
+    impl p::Parse for CacheOptions {
+        fn parse(input: p::ParseStream) -> syn::Result<Self> {
+            let la = input.lookahead1();
+            if la.peek(Capacity) {
+                let _: Capacity = input.parse().unwrap();
+                let _: Colon = input.parse().unwrap();
+                let cap: syn::LitInt = input.parse().unwrap();
+
+                return Ok(CacheOptions {
+                    lru_max_entries: Some(cap.base10_parse()?),
+                });
+            }
+            Ok(Default::default())
+        }
+    }
+
+    /// Returns tokenstreams (for quoting) of the store type and an expression to initialize it.
+    pub fn construct_cache(
+        attr: &TokenStream,
+        key_type: proc_macro2::TokenStream,
+        value_type: proc_macro2::TokenStream,
+    ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
+        let options: CacheOptions = syn::parse(attr.clone()).unwrap();
+
+        // This is the unbounded default.
+        match options.lru_max_entries {
+            None => (
+                quote::quote! { std::collections::HashMap<#key_type, #value_type> },
+                quote::quote! { std::collections::HashMap::new() },
+            ),
+            Some(cap) => (
+                quote::quote! { lru::LruCache<#key_type, #value_type> },
+                quote::quote! { lru::LruCache::new(#cap) },
+            ),
+        }
+    }
+
+    /// Returns tokenstreams (for quoting) of method names for inserting/getting (first/second
+    /// return tuple value).
+    pub fn cache_access_methods(
+        attr: &TokenStream,
+    ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
+        let options: CacheOptions = syn::parse(attr.clone()).unwrap();
+
+        // This is the unbounded default.
+        match options.lru_max_entries {
+            None => (quote::quote! { insert }, quote::quote! { get }),
+            Some(cap) => (quote::quote! { put }, quote::quote! { get }),
+        }
+    }
+}
 
 /**
  * memoize is an attribute to create a memoized version of a (simple enough) function.
  *
- * So far, it works on functions with one or more arguments which are `Clone`-able, returning a
- * `Clone`-able value. Several clones happen within the storage and recall layer, with the
- * assumption being that `memoize` is used to cache such expensive functions that very few
+ * So far, it works on functions with one or more arguments which are `Clone`- and `Hash`-able,
+ * returning a `Clone`-able value. Several clones happen within the storage and recall layer, with
+ * the assumption being that `memoize` is used to cache such expensive functions that very few
  * `clone()`s do not matter.
  *
  * Calls are memoized for the lifetime of a program, using a statically allocated, Mutex-protected
@@ -39,10 +153,13 @@
  * If you need to use the un-memoized function, it is always available as `memoized_original_{fn}`,
  * in this case: `memoized_original_hello()`.
  *
+ * The `memoize` attribute can take further parameters in order to use an LRU cache:
+ * `#[memoize(Capacity: 1234)]`.
+ *
  * See the `examples` for concrete applications.
  */
 #[proc_macro_attribute]
-pub fn memoize(_attr: TokenStream, item: TokenStream) -> TokenStream {
+pub fn memoize(attr: TokenStream, item: TokenStream) -> TokenStream {
     let func = parse_macro_input!(item as ItemFn);
     let sig = &func.sig;
 
@@ -56,8 +173,11 @@
     let return_type;
 
     match check_signature(sig) {
-        Ok((t, n)) => { input_types = t; input_names = n; },
-        Err(e) => return e.to_compile_error().into()
+        Ok((t, n)) => {
+            input_types = t;
+            input_names = n;
+        }
+        Err(e) => return e.to_compile_error().into(),
     }
 
     let input_tuple_type = quote::quote! { (#(#input_types),*) };
@@ -69,10 +189,11 @@
 
     // Construct storage for the memoized keys and return values.
     let store_ident = syn::Ident::new(&map_name.to_uppercase(), sig.span());
+    let (cache_type, cache_init) = store::construct_cache(&attr, input_tuple_type, return_type);
     let store = quote::quote! {
         lazy_static::lazy_static! {
-            static ref #store_ident : std::sync::Mutex<std::collections::HashMap<#input_tuple_type, #return_type>> =
-                std::sync::Mutex::new(std::collections::HashMap::new());
+            static ref #store_ident : std::sync::Mutex<#cache_type> =
+                std::sync::Mutex::new(#cache_init);
         }
     };
 
@@ -84,14 +205,15 @@
     // Construct memoizer function, which calls the original function.
     let syntax_names_tuple = quote::quote! { (#(#input_names),*) };
     let syntax_names_tuple_cloned = quote::quote! { (#(#input_names.clone()),*) };
+    let (insert_fn, get_fn) = store::cache_access_methods(&attr);
     let memoizer = quote::quote! {
         #sig {
             let mut hm = &mut #store_ident.lock().unwrap();
-            if let Some(r) = hm.get(&#syntax_names_tuple_cloned) {
+            if let Some(r) = hm.#get_fn(&#syntax_names_tuple_cloned) {
                 return r.clone();
             }
             let r = #memoized_id(#(#input_names.clone()),*);
-            hm.insert(#syntax_names_tuple, r.clone());
+            hm.#insert_fn(#syntax_names_tuple, r.clone());
             r
         }
     };
@@ -106,13 +228,14 @@
     .into()
 }
 
-fn check_signature(sig: &syn::Signature) -> Result<(Vec<Box<syn::Type>>, Vec<Box<syn::Pat>>), syn::Error> {
+fn check_signature(
+    sig: &syn::Signature,
+) -> Result<(Vec<Box<syn::Type>>, Vec<Box<syn::Pat>>), syn::Error> {
     if let syn::FnArg::Receiver(_) = sig.inputs[0] {
-        return Err(
-            syn::Error::new(
-                sig.span(),
-                "Cannot memoize method (self-receiver) without arguments!",
-            ));
+        return Err(syn::Error::new(
+            sig.span(),
+            "Cannot memoize method (self-receiver) without arguments!",
+        ));
     }
 
     let mut types = vec![];
@@ -124,7 +247,10 @@
             if let syn::Pat::Ident(_) = &*arg.pat {
                 names.push(arg.pat.clone());
             } else {
-                return Err(syn::Error::new(sig.span(), "Cannot memoize arbitrary patterns!"));
+                return Err(syn::Error::new(
+                    sig.span(),
+                    "Cannot memoize arbitrary patterns!",
+                ));
             }
         }
     }