changeset 9:69a47879ec78

Implement caching for functions with multiple arguments.
author Lewin Bormann <lbo@spheniscida.de>
date Thu, 15 Oct 2020 14:16:02 +0200
parents 8d44466dbb15
children dd0035137133
files README.md examples/test1.rs examples/test2.rs src/lib.rs
diffstat 4 files changed, 76 insertions(+), 43 deletions(-) [+]
line wrap: on
line diff
--- a/README.md	Thu Oct 15 13:34:53 2020 +0200
+++ b/README.md	Thu Oct 15 14:16:02 2020 +0200
@@ -1,6 +1,7 @@
 # memoize
 
-A `#[memoize]` attribute for somewhat simple Rust functions. That's it.
+A `#[memoize]` attribute for somewhat simple Rust functions: That is, functions
+with one or more `Clone`-able arguments, and a `Clone`-able return type. That's it.
 
 Read the documentation (`cargo doc --open`) for the sparse details, or take a
 look at the `examples/`, if you want to know more:
@@ -27,19 +28,20 @@
   static ref MEMOIZED_MAPPING_HELLO : Mutex<HashMap<String, bool>>;
 }
 
-fn memoized_original_hello(arg: String) -> bool {
-    arg.len()%2 == 0
+fn memoized_original_hello(arg: String, arg2: usize) -> bool {
+    arg.len() % 2 == arg2
 }
 
-fn hello(arg: String) -> bool {
+fn hello(arg: String, arg2: usize) -> bool {
     let mut hm = &mut MEMOIZED_MAPPING_HELLO.lock().unwrap();
-    if let Some(r) = hm.get(&arg) {
+    if let Some(r) = hm.get(&(arg.clone(), arg2.clone())) {
         return r.clone();
     }
-    let r = memoized_original_hello(arg.clone());
-    hm.insert(arg, r.clone());
+    let r = memoized_original_hello(arg.clone(), arg2.clone());
+    hm.insert((arg, arg2), r.clone());
     r
 }
+
 ```
 
 Intentionally not yet on crates.rs.
--- a/examples/test1.rs	Thu Oct 15 13:34:53 2020 +0200
+++ b/examples/test1.rs	Thu Oct 15 14:16:02 2020 +0200
@@ -10,7 +10,11 @@
 #[memoize]
 fn hello(key: String) -> ComplexStruct {
     println!("hello: {}", key);
-    ComplexStruct { s: key, b: false, i: 332 }
+    ComplexStruct {
+        s: key,
+        b: false,
+        i: 332,
+    }
 }
 
 fn main() {
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/examples/test2.rs	Thu Oct 15 14:16:02 2020 +0200
@@ -0,0 +1,15 @@
+use memoize::memoize;
+
+#[memoize]
+fn hello(arg: String, arg2: usize) -> bool {
+    println!("{} => {}", arg, arg2);
+    arg.len() % 2 == arg2
+}
+
+fn main() {
+    // `hello` is only called once here.
+    assert!(!hello("World".to_string(), 0));
+    assert!(!hello("World".to_string(), 0));
+    // Sometimes one might need the original function.
+    assert!(!memoized_original_hello("World".to_string(), 0));
+}
--- a/src/lib.rs	Thu Oct 15 13:34:53 2020 +0200
+++ b/src/lib.rs	Thu Oct 15 14:16:02 2020 +0200
@@ -8,14 +8,15 @@
 
 /*
  * TODO:
- * - Functions with multiple arguments.
  */
 
 /**
  * memoize is an attribute to create a memoized version of a (simple enough) function.
  *
- * So far, it works on functions with one argument which is Clone-able, returning a Clone-able
- * value.
+ * So far, it works on functions with one or more arguments which are `Clone`-able, returning a
+ * `Clone`-able value. Several clones happen within the storage and recall layer, with the
+ * assumption being that `memoize` is used to cache such expensive functions that very few
+ * `clone()`s do not matter.
  *
  * Calls are memoized for the lifetime of a program, using a statically allocated, Mutex-protected
  * HashMap.
@@ -26,13 +27,13 @@
  * ```
  * use memoize::memoize;
  * #[memoize]
- * fn hello(arg: String) -> bool {
- *      arg.len()%2 == 0
+ * fn hello(arg: String, arg2: usize) -> bool {
+ *      arg.len()%2 == arg2
  * }
  *
  * // `hello` is only called once.
- * assert!(! hello("World".to_string()));
- * assert!(! hello("World".to_string()));
+ * assert!(! hello("World".to_string(), 0));
+ * assert!(! hello("World".to_string(), 0));
  * ```
  *
  * If you need to use the un-memoized function, it is always available as `memoized_original_{fn}`,
@@ -45,66 +46,77 @@
     let func = parse_macro_input!(item as ItemFn);
     let sig = &func.sig;
 
-    let fn_name = &func.sig.ident.to_string();
+    let fn_name = &sig.ident.to_string();
     let renamed_name = format!("memoized_original_{}", fn_name);
     let map_name = format!("memoized_mapping_{}", fn_name);
 
-    let mut type_in = None;
-    let mut name_in = None;
+    let input_type;
+    let input_names;
     let type_out;
 
     // Only one argument
-    // TODO: cache multiple arguments
-    if sig.inputs.len() == 1 {
-        if let syn::FnArg::Typed(ref arg) = sig.inputs[0] {
-            type_in = Some(arg.ty.clone());
+    if let syn::FnArg::Receiver(_) = sig.inputs[0] {
+        return TokenStream::from(
+            syn::Error::new(
+                sig.span(),
+                "Cannot memoize method (self-receiver) without arguments!",
+            )
+            .to_compile_error(),
+        );
+    }
+    let mut types = vec![];
+    let mut names = vec![];
+    for a in &sig.inputs {
+        if let syn::FnArg::Typed(ref arg) = a {
+            types.push(arg.ty.clone());
+
             if let syn::Pat::Ident(_) = &*arg.pat {
-                name_in = Some(arg.pat.clone());
+                names.push(arg.pat.clone());
             } else {
-                return syn::Error::new(
-                    sig.span(),
-                    "Cannot memoize method (self-receiver) without arguments!",
-                )
-                .to_compile_error()
-                .into();
+                return syn::Error::new(sig.span(), "Cannot memoize arbitrary patterns!")
+                    .to_compile_error()
+                    .into();
             }
-        } else {
-            return TokenStream::from(
-                syn::Error::new(
-                    sig.span(),
-                    "Cannot memoize method (self-receiver) without arguments!",
-                )
-                .to_compile_error(),
-            );
         }
     }
+
+    // We treat functions with one or with multiple arguments the same: The type is made into a
+    // tuple.
+    input_type = Some(quote::quote! { (#(#types),*) });
+    input_names = Some(names);
+
     match &sig.output {
         syn::ReturnType::Default => type_out = quote::quote! { () },
         syn::ReturnType::Type(_, ty) => type_out = ty.to_token_stream(),
     }
 
-    let type_in = type_in.unwrap();
-    let name_in = name_in.unwrap();
+    // Construct storage for the memoized keys and return values.
+    let input_type = input_type.unwrap();
+    let input_names = input_names.unwrap();
     let store_ident = syn::Ident::new(&map_name.to_uppercase(), sig.span());
     let store = quote::quote! {
         lazy_static::lazy_static! {
-            static ref #store_ident : std::sync::Mutex<std::collections::HashMap<#type_in, #type_out>> =
+            static ref #store_ident : std::sync::Mutex<std::collections::HashMap<#input_type, #type_out>> =
                 std::sync::Mutex::new(std::collections::HashMap::new());
         }
     };
 
+    // Rename original function.
     let mut renamed_fn = func.clone();
     renamed_fn.sig.ident = syn::Ident::new(&renamed_name, func.sig.span());
     let memoized_id = &renamed_fn.sig.ident;
 
+    // Construct memoizer function, which calls the original function.
+    let syntax_names_tuple = quote::quote! { (#(#input_names),*) };
+    let syntax_names_tuple_cloned = quote::quote! { (#(#input_names.clone()),*) };
     let memoizer = quote::quote! {
         #sig {
             let mut hm = &mut #store_ident.lock().unwrap();
-            if let Some(r) = hm.get(&#name_in) {
+            if let Some(r) = hm.get(&#syntax_names_tuple_cloned) {
                 return r.clone();
             }
-            let r = #memoized_id(#name_in.clone());
-            hm.insert(#name_in, r.clone());
+            let r = #memoized_id(#(#input_names.clone()),*);
+            hm.insert(#syntax_names_tuple, r.clone());
             r
         }
     };