Mercurial > lbo > hg > memoize
changeset 78:e9ccc1702af0
Merge pull request #12 from Codadillo/master
Use a thread local cache for memoization.
author | Lewin Bormann <lbo@spheniscida.de> |
---|---|
date | Sat, 02 Apr 2022 13:51:05 -0700 |
parents | b493c1f99877 (current diff) e010a33862ac (diff) |
children | ff66cf5c7f8c |
files | |
diffstat | 2 files changed, 183 insertions(+), 146 deletions(-) [+] |
line wrap: on
line diff
--- a/README.md Sat Apr 02 13:49:42 2022 -0700 +++ b/README.md Sat Apr 02 13:51:05 2022 -0700 @@ -36,29 +36,46 @@ This is expanded into (with a few simplifications): ```rust -// This is obviously further expanded before compiling. -lazy_static! { - static ref MEMOIZED_MAPPING_HELLO : Mutex<HashMap<String, bool>>; +std::thread_local! { + static MEMOIZED_MAPPING_HELLO : RefCell<HashMap<(String, usize), bool>> = RefCell::new(HashMap::new()); } -fn memoized_original_hello(arg: String, arg2: usize) -> bool { +pub fn memoized_original_hello(arg: String, arg2: usize) -> bool { arg.len() % 2 == arg2 } +#[allow(unused_variables)] fn hello(arg: String, arg2: usize) -> bool { - { - let mut hm = &mut MEMOIZED_MAPPING_HELLO.lock().unwrap(); - if let Some(r) = hm.get(&(arg.clone(), arg2.clone())) { - return r.clone(); - } + let r = MEMOIZED_MAPPING_HELLO.with(|hm| { + let mut hm = hm.borrow_mut(); + hm.get(&(arg.clone(), arg2.clone())).cloned() + }); + if let Some(r) = r { + return r; } + let r = memoized_original_hello(arg.clone(), arg2.clone()); - hm.insert((arg, arg2), r.clone()); + + MEMOIZED_MAPPING_HELLO.with(|hm| { + let mut hm = hm.borrow_mut(); + hm.insert((arg, arg2), r.clone()); + }); + r } + ``` ## Further Functionality +As can be seen in the above example, each thread has its own cache by default. If you would prefer +that every thread share the same cache, you can specify the `SharedCache` option like below to wrap +the cache in a `std::sync::Mutex`. For example: +```rust +#[memoize(SharedCache)] +fn hello(key: String) -> ComplexStruct { + // ... +} +``` You can choose to use an [LRU cache](https://crates.io/crates/lru). In fact, if you know that a memoized function has an unbounded number of different inputs,
--- a/inner/src/lib.rs Sat Apr 02 13:49:42 2022 -0700 +++ b/inner/src/lib.rs Sat Apr 02 13:51:05 2022 -0700 @@ -1,31 +1,98 @@ #![crate_type = "proc-macro"] #![allow(unused_imports)] // Spurious complaints about a required trait import. -use syn::{self, parse_macro_input, spanned::Spanned, ItemFn}; +use syn::{self, parse, parse_macro_input, spanned::Spanned, Expr, ItemFn}; use proc_macro::TokenStream; use quote::{self, ToTokens}; +mod kw { + syn::custom_keyword!(Capacity); + syn::custom_keyword!(TimeToLive); + syn::custom_keyword!(SharedCache); + syn::custom_punctuation!(Colon, :); +} + +#[derive(Default, Clone)] +struct CacheOptions { + lru_max_entries: Option<usize>, + time_to_live: Option<Expr>, + shared_cache: bool, +} + +#[derive(Clone)] +enum CacheOption { + LRUMaxEntries(usize), + TimeToLive(Expr), + SharedCache, +} + +// To extend option parsing, add functionality here. +#[allow(unreachable_code)] +impl parse::Parse for CacheOption { + fn parse(input: parse::ParseStream) -> syn::Result<Self> { + let la = input.lookahead1(); + if la.peek(kw::Capacity) { + #[cfg(not(feature = "full"))] + return Err(syn::Error::new(input.span(), + "memoize error: Capacity specified, but the feature 'full' is not enabled! To fix this, compile with `--features=full`.", + )); + + input.parse::<kw::Capacity>().unwrap(); + input.parse::<kw::Colon>().unwrap(); + let cap: syn::LitInt = input.parse().unwrap(); + + return Ok(CacheOption::LRUMaxEntries(cap.base10_parse()?)); + } + if la.peek(kw::TimeToLive) { + #[cfg(not(feature = "full"))] + return Err(syn::Error::new(input.span(), + "memoize error: TimeToLive specified, but the feature 'full' is not enabled! To fix this, compile with `--features=full`.", + )); + + input.parse::<kw::TimeToLive>().unwrap(); + input.parse::<kw::Colon>().unwrap(); + let cap: syn::Expr = input.parse().unwrap(); + + return Ok(CacheOption::TimeToLive(cap)); + } + if la.peek(kw::SharedCache) { + input.parse::<kw::SharedCache>().unwrap(); + return Ok(CacheOption::SharedCache); + } + Err(la.error()) + } +} + +impl parse::Parse for CacheOptions { + fn parse(input: parse::ParseStream) -> syn::Result<Self> { + let f: syn::punctuated::Punctuated<CacheOption, syn::Token![,]> = + input.parse_terminated(CacheOption::parse)?; + let mut opts = Self::default(); + + for opt in f { + match opt { + CacheOption::LRUMaxEntries(cap) => opts.lru_max_entries = Some(cap), + CacheOption::TimeToLive(sec) => opts.time_to_live = Some(sec), + CacheOption::SharedCache => opts.shared_cache = true, + } + } + Ok(opts) + } +} + // This implementation of the storage backend does not depend on any more crates. #[cfg(not(feature = "full"))] mod store { + use crate::CacheOptions; use proc_macro::TokenStream; /// Returns tokenstreams (for quoting) of the store type and an expression to initialize it. - pub fn construct_cache( - attr: &TokenStream, + pub(crate) fn construct_cache( + _options: &CacheOptions, key_type: proc_macro2::TokenStream, value_type: proc_macro2::TokenStream, ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { - if !attr.is_empty() { - return ( - syn::Error::new_spanned(proc_macro2::TokenStream::from(attr.clone()), - "memoize error: Attributes are specified, but the feature 'full' is not enabled! To fix this, compile with `--features=full`.", - ) - .to_compile_error(), - proc_macro2::TokenStream::new(), - ); - } // This is the unbounded default. ( quote::quote! { std::collections::HashMap<#key_type, #value_type> }, @@ -35,8 +102,8 @@ /// Returns names of methods as TokenStreams to insert and get (respectively) elements from a /// store. - pub fn cache_access_methods( - _attr: &TokenStream, + pub(crate) fn cache_access_methods( + _options: &CacheOptions, ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { (quote::quote! { insert }, quote::quote! { get }) } @@ -45,75 +112,19 @@ // This implementation of the storage backend also depends on the `lru` crate. #[cfg(feature = "full")] mod store { + use crate::CacheOptions; use proc_macro::TokenStream; - use syn::{parse as p, Expr}; - - #[derive(Default, Clone)] - pub(crate) struct CacheOptions { - lru_max_entries: Option<usize>, - pub(crate) time_to_live: Option<Expr>, - } - - #[derive(Clone)] - enum CacheOption { - LRUMaxEntries(usize), - TimeToLive(Expr), - } - - syn::custom_keyword!(Capacity); - syn::custom_keyword!(TimeToLive); - syn::custom_punctuation!(Colon, :); - - // To extend option parsing, add functionality here. - impl p::Parse for CacheOption { - fn parse(input: p::ParseStream) -> syn::Result<Self> { - let la = input.lookahead1(); - if la.peek(Capacity) { - let _: Capacity = input.parse().unwrap(); - let _: Colon = input.parse().unwrap(); - let cap: syn::LitInt = input.parse().unwrap(); - - return Ok(CacheOption::LRUMaxEntries(cap.base10_parse()?)); - } - if la.peek(TimeToLive) { - let _: TimeToLive = input.parse().unwrap(); - let _: Colon = input.parse().unwrap(); - let cap: syn::Expr = input.parse().unwrap(); - - return Ok(CacheOption::TimeToLive(cap)); - } - Err(la.error()) - } - } - - impl p::Parse for CacheOptions { - fn parse(input: p::ParseStream) -> syn::Result<Self> { - let f: syn::punctuated::Punctuated<CacheOption, syn::Token![,]> = - input.parse_terminated(CacheOption::parse)?; - let mut opts = Self::default(); - - for opt in f { - match opt { - CacheOption::LRUMaxEntries(cap) => opts.lru_max_entries = Some(cap), - CacheOption::TimeToLive(sec) => opts.time_to_live = Some(sec), - } - } - Ok(opts) - } - } /// Returns TokenStreams to be used in quote!{} for parametrizing the memoize store variable, /// and initializing it. /// /// First return value: Type of store ("Container<K,V>"). /// Second return value: Initializer syntax ("Container::<K,V>::new()"). - pub fn construct_cache( - attr: &TokenStream, + pub(crate) fn construct_cache( + options: &CacheOptions, key_type: proc_macro2::TokenStream, value_type: proc_macro2::TokenStream, ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { - let options: CacheOptions = syn::parse(attr.clone()).unwrap(); - let value_type = match options.time_to_live { None => quote::quote! {#value_type}, Some(_) => quote::quote! {(std::time::Instant, #value_type)}, @@ -133,11 +144,9 @@ /// Returns names of methods as TokenStreams to insert and get (respectively) elements from a /// store. - pub fn cache_access_methods( - attr: &TokenStream, + pub(crate) fn cache_access_methods( + options: &CacheOptions, ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { - let options: CacheOptions = syn::parse(attr.clone()).unwrap(); - // This is the unbounded default. match options.lru_max_entries { None => (quote::quote! { insert }, quote::quote! { get }), @@ -219,13 +228,26 @@ syn::ReturnType::Type(_, ty) => return_type = ty.to_token_stream(), } + // Parse options from macro attributes + let options: CacheOptions = syn::parse(attr.clone()).unwrap(); + // Construct storage for the memoized keys and return values. let store_ident = syn::Ident::new(&map_name.to_uppercase(), sig.span()); - let (cache_type, cache_init) = store::construct_cache(&attr, input_tuple_type, return_type.clone()); - let store = quote::quote! { - ::memoize::lazy_static::lazy_static! { - static ref #store_ident : std::sync::Mutex<#cache_type> = - std::sync::Mutex::new(#cache_init); + let (cache_type, cache_init) = + store::construct_cache(&options, input_tuple_type, return_type.clone()); + let store = if options.shared_cache { + quote::quote! { + ::memoize::lazy_static::lazy_static! { + static ref #store_ident : std::sync::Mutex<#cache_type> = + std::sync::Mutex::new(#cache_init); + } + } + } else { + quote::quote! { + std::thread_local! { + static #store_ident : std::cell::RefCell<#cache_type> = + std::cell::RefCell::new(#cache_init); + } } }; @@ -234,75 +256,73 @@ renamed_fn.sig.ident = syn::Ident::new(&renamed_name, func.sig.span()); let memoized_id = &renamed_fn.sig.ident; - // Extract the function name and identifier. - let fn_name = func.sig.ident.clone(); - let fn_vis = func.vis.clone(); - // Construct memoizer function, which calls the original function. let syntax_names_tuple = quote::quote! { (#(#input_names),*) }; let syntax_names_tuple_cloned = quote::quote! { (#(#input_names.clone()),*) }; - let (insert_fn, get_fn) = store::cache_access_methods(&attr); - #[cfg(feature = "full")] - let memoizer = { - let options: store::CacheOptions = syn::parse(attr.clone().into()).unwrap(); - match options.time_to_live { - None => quote::quote! { - #sig { - { - let mut hm = &mut #store_ident.lock().unwrap(); - if let Some(r) = hm.#get_fn(&#syntax_names_tuple_cloned) { - return r.clone(); - } - } - let r = #memoized_id(#(#input_names.clone()),*); - let mut hm = &mut #store_ident.lock().unwrap(); - hm.#insert_fn(#syntax_names_tuple, r.clone()); - r - } + let (insert_fn, get_fn) = store::cache_access_methods(&options); + let (read_memo, memoize) = match options.time_to_live { + None => ( + quote::quote!(hm.#get_fn(&#syntax_names_tuple_cloned).cloned()), + quote::quote!(hm.#insert_fn(#syntax_names_tuple, r.clone());), + ), + Some(ttl) => ( + quote::quote! { + hm.#get_fn(&#syntax_names_tuple_cloned).and_then(|(last_updated, r)| + (last_updated.elapsed() < #ttl).then(|| r.clone()) + ) }, - Some(ttl) => quote::quote! { - #sig { - { - let mut hm = &mut #store_ident.lock().unwrap(); - if let Some((last_updated, r)) = hm.#get_fn(&#syntax_names_tuple_cloned) { - if last_updated.elapsed() < #ttl { - return r.clone(); - } - } - } - let r = #memoized_id(#(#input_names.clone()),*); - let mut hm = &mut #store_ident.lock().unwrap(); - hm.#insert_fn(#syntax_names_tuple, (std::time::Instant::now(), r.clone())); - r - } - }, - } + quote::quote!(hm.#insert_fn(#syntax_names_tuple, (std::time::Instant::now(), r.clone()));), + ), }; - #[cfg(not(feature = "full"))] - let memoizer = quote::quote! { - #fn_vis fn #fn_name ( - #(#input_names: #input_types),* - ) -> #return_type { + + let memoizer = if options.shared_cache { + quote::quote! { { - let mut hm = &mut #store_ident.lock().unwrap(); - if let Some(r) = hm.#get_fn(&#syntax_names_tuple_cloned) { - return r.clone(); + let mut hm = #store_ident.lock().unwrap(); + if let Some(r) = #read_memo { + return r } } let r = #memoized_id(#(#input_names.clone()),*); - let mut hm = &mut #store_ident.lock().unwrap(); - hm.#insert_fn(#syntax_names_tuple, r.clone()); + + let mut hm = #store_ident.lock().unwrap(); + #memoize + + r + } + } else { + quote::quote! { + let r = #store_ident.with(|hm| { + let mut hm = hm.borrow_mut(); + #read_memo + }); + if let Some(r) = r { + return r; + } + + let r = #memoized_id(#(#input_names.clone()),*); + + #store_ident.with(|hm| { + let mut hm = hm.borrow_mut(); + #memoize + }); + r } }; - (quote::quote! { + let vis = &func.vis; + + quote::quote! { + #renamed_fn + #store - #renamed_fn - - #memoizer - }) + #[allow(unused_variables)] + #vis #sig { + #memoizer + } + } .into() }