Mercurial > lbo > hg > memoize
changeset 16:2c254214df2e
Implement using an LRU cache.
author | Lewin Bormann <lbo@spheniscida.de> |
---|---|
date | Thu, 15 Oct 2020 20:34:20 +0200 |
parents | a794b6862ef4 |
children | cf5a71de7e50 |
files | Cargo.toml examples/test1.rs src/lib.rs |
diffstat | 3 files changed, 153 insertions(+), 21 deletions(-) [+] |
line wrap: on
line diff
--- a/Cargo.toml Thu Oct 15 19:24:47 2020 +0200 +++ b/Cargo.toml Thu Oct 15 20:34:20 2020 +0200 @@ -19,3 +19,9 @@ proc-macro2 = "1.0" quote = "1.0" syn = { version = "1.0", features = ["full"] } + +lru = { version = "0.6", optional = true } + +[features] +default = [] +full = ["lru"]
--- a/examples/test1.rs Thu Oct 15 19:24:47 2020 +0200 +++ b/examples/test1.rs Thu Oct 15 20:34:20 2020 +0200 @@ -7,7 +7,7 @@ i: i32, } -#[memoize] +#[memoize(Capacity: 123)] fn hello(key: String) -> ComplexStruct { println!("hello: {}", key); ComplexStruct {
--- a/src/lib.rs Thu Oct 15 19:24:47 2020 +0200 +++ b/src/lib.rs Thu Oct 15 20:34:20 2020 +0200 @@ -6,16 +6,130 @@ use proc_macro::TokenStream; use quote::{self, ToTokens}; -/* - * TODO: - */ +trait MemoizeStore<K, V> +where + K: std::hash::Hash, +{ + fn get(&mut self, k: &K) -> Option<&V>; + fn put(&mut self, k: K, v: V); +} + +impl<K: std::hash::Hash + Eq + Clone, V> MemoizeStore<K, V> for std::collections::HashMap<K, V> { + fn get(&mut self, k: &K) -> Option<&V> { + std::collections::HashMap::<K, V>::get(self, k) + } + fn put(&mut self, k: K, v: V) { + self.insert(k, v); + } +} + +#[cfg(not(feature = "full"))] +mod store { + use proc_macro::TokenStream; + + /// Returns tokenstreams (for quoting) of the store type and an expression to initialize it. + pub fn construct_cache( + _attr: TokenStream, + key_type: proc_macro2::TokenStream, + value_type: proc_macro2::TokenStream, + ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { + // This is the unbounded default. + ( + quote::quote! { std::collections::HashMap<#key_type, #value_type> }, + quote::quote! { std::collections::HashMap::new() }, + ) + } + + /// Returns tokenstreams (for quoting) of method names for inserting/getting (first/second + /// return tuple value). + pub fn cache_access_methods( + _attr: &TokenStream, + ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { + (quote::quote! { insert }, quote::quote! { get }) + } +} + +#[cfg(feature = "full")] +mod store { + use super::MemoizeStore; + use proc_macro::TokenStream; + use syn::parse as p; + + impl<K: std::hash::Hash + Eq + Clone, V> MemoizeStore<K, V> for lru::LruCache<K, V> { + fn get(&mut self, k: &K) -> Option<&V> { + lru::LruCache::<K, V>::get(self, k) + } + fn put(&mut self, k: K, v: V) { + lru::LruCache::<K, V>::put(self, k, v); + } + } + + #[derive(Default, Debug, Clone)] + struct CacheOptions { + lru_max_entries: Option<usize>, + } + + syn::custom_keyword!(Capacity); + syn::custom_punctuation!(Colon, :); + + impl p::Parse for CacheOptions { + fn parse(input: p::ParseStream) -> syn::Result<Self> { + let la = input.lookahead1(); + if la.peek(Capacity) { + let _: Capacity = input.parse().unwrap(); + let _: Colon = input.parse().unwrap(); + let cap: syn::LitInt = input.parse().unwrap(); + + return Ok(CacheOptions { + lru_max_entries: Some(cap.base10_parse()?), + }); + } + Ok(Default::default()) + } + } + + /// Returns tokenstreams (for quoting) of the store type and an expression to initialize it. + pub fn construct_cache( + attr: &TokenStream, + key_type: proc_macro2::TokenStream, + value_type: proc_macro2::TokenStream, + ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { + let options: CacheOptions = syn::parse(attr.clone()).unwrap(); + + // This is the unbounded default. + match options.lru_max_entries { + None => ( + quote::quote! { std::collections::HashMap<#key_type, #value_type> }, + quote::quote! { std::collections::HashMap::new() }, + ), + Some(cap) => ( + quote::quote! { lru::LruCache<#key_type, #value_type> }, + quote::quote! { lru::LruCache::new(#cap) }, + ), + } + } + + /// Returns tokenstreams (for quoting) of method names for inserting/getting (first/second + /// return tuple value). + pub fn cache_access_methods( + attr: &TokenStream, + ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { + let options: CacheOptions = syn::parse(attr.clone()).unwrap(); + + // This is the unbounded default. + match options.lru_max_entries { + None => (quote::quote! { insert }, quote::quote! { get }), + Some(cap) => (quote::quote! { put }, quote::quote! { get }), + } + } +} /** * memoize is an attribute to create a memoized version of a (simple enough) function. * - * So far, it works on functions with one or more arguments which are `Clone`-able, returning a - * `Clone`-able value. Several clones happen within the storage and recall layer, with the - * assumption being that `memoize` is used to cache such expensive functions that very few + * So far, it works on functions with one or more arguments which are `Clone`- and `Hash`-able, + * returning a `Clone`-able value. Several clones happen within the storage and recall layer, with + * the assumption being that `memoize` is used to cache such expensive functions that very few * `clone()`s do not matter. * * Calls are memoized for the lifetime of a program, using a statically allocated, Mutex-protected @@ -39,10 +153,13 @@ * If you need to use the un-memoized function, it is always available as `memoized_original_{fn}`, * in this case: `memoized_original_hello()`. * + * The `memoize` attribute can take further parameters in order to use an LRU cache: + * `#[memoize(Capacity: 1234)]`. + * * See the `examples` for concrete applications. */ #[proc_macro_attribute] -pub fn memoize(_attr: TokenStream, item: TokenStream) -> TokenStream { +pub fn memoize(attr: TokenStream, item: TokenStream) -> TokenStream { let func = parse_macro_input!(item as ItemFn); let sig = &func.sig; @@ -56,8 +173,11 @@ let return_type; match check_signature(sig) { - Ok((t, n)) => { input_types = t; input_names = n; }, - Err(e) => return e.to_compile_error().into() + Ok((t, n)) => { + input_types = t; + input_names = n; + } + Err(e) => return e.to_compile_error().into(), } let input_tuple_type = quote::quote! { (#(#input_types),*) }; @@ -69,10 +189,11 @@ // Construct storage for the memoized keys and return values. let store_ident = syn::Ident::new(&map_name.to_uppercase(), sig.span()); + let (cache_type, cache_init) = store::construct_cache(&attr, input_tuple_type, return_type); let store = quote::quote! { lazy_static::lazy_static! { - static ref #store_ident : std::sync::Mutex<std::collections::HashMap<#input_tuple_type, #return_type>> = - std::sync::Mutex::new(std::collections::HashMap::new()); + static ref #store_ident : std::sync::Mutex<#cache_type> = + std::sync::Mutex::new(#cache_init); } }; @@ -84,14 +205,15 @@ // Construct memoizer function, which calls the original function. let syntax_names_tuple = quote::quote! { (#(#input_names),*) }; let syntax_names_tuple_cloned = quote::quote! { (#(#input_names.clone()),*) }; + let (insert_fn, get_fn) = store::cache_access_methods(&attr); let memoizer = quote::quote! { #sig { let mut hm = &mut #store_ident.lock().unwrap(); - if let Some(r) = hm.get(&#syntax_names_tuple_cloned) { + if let Some(r) = hm.#get_fn(&#syntax_names_tuple_cloned) { return r.clone(); } let r = #memoized_id(#(#input_names.clone()),*); - hm.insert(#syntax_names_tuple, r.clone()); + hm.#insert_fn(#syntax_names_tuple, r.clone()); r } }; @@ -106,13 +228,14 @@ .into() } -fn check_signature(sig: &syn::Signature) -> Result<(Vec<Box<syn::Type>>, Vec<Box<syn::Pat>>), syn::Error> { +fn check_signature( + sig: &syn::Signature, +) -> Result<(Vec<Box<syn::Type>>, Vec<Box<syn::Pat>>), syn::Error> { if let syn::FnArg::Receiver(_) = sig.inputs[0] { - return Err( - syn::Error::new( - sig.span(), - "Cannot memoize method (self-receiver) without arguments!", - )); + return Err(syn::Error::new( + sig.span(), + "Cannot memoize method (self-receiver) without arguments!", + )); } let mut types = vec![]; @@ -124,7 +247,10 @@ if let syn::Pat::Ident(_) = &*arg.pat { names.push(arg.pat.clone()); } else { - return Err(syn::Error::new(sig.span(), "Cannot memoize arbitrary patterns!")); + return Err(syn::Error::new( + sig.span(), + "Cannot memoize arbitrary patterns!", + )); } } }