changeset 5:a1ac40ee3a1e

Memoization for DP
author Lewin Bormann <lbo@spheniscida.de>
date Sun, 12 Mar 2023 16:11:39 +0100
parents bb5c00d28022
children 90bf264f80a5
files dynprog.cc lib/exec.cc lib/exec.h
diffstat 3 files changed, 71 insertions(+), 36 deletions(-) [+]
line wrap: on
line diff
--- a/dynprog.cc	Sun Mar 12 14:12:10 2023 +0100
+++ b/dynprog.cc	Sun Mar 12 16:11:39 2023 +0100
@@ -1,8 +1,13 @@
+
+#include "lib/exec.h"
 
 #include <iostream>
+#include <functional>
 #include <ranges>
 #include <string>
 #include <sstream>
+#include <tuple>
+#include <unordered_map>
 #include <vector>
 
 #include <fmt/core.h>
@@ -10,59 +15,72 @@
 
 using namespace std;
 
-template<typename T>
-size_t lcss_inner(const vector<T>& a, const vector<T>& b, size_t i, size_t j, size_t length, size_t& longest, vector<T>& longest_seq) {
-    if (i >= a.size() || j >= b.size()) {
-        return length;
-    }
+using MemoKey = tuple<size_t, size_t, size_t>;
+
+class TupleHash : function<size_t(const MemoKey&)> {
+public:
+    size_t operator()(const MemoKey& k) const { return get<0>(k) ^ get<1>(k) ^ get<2>(k); }
+};
+
+template <typename T>
+size_t lcss_inner(const vector<T> &a, const vector<T> &b, size_t i, size_t j,
+                  size_t length, size_t &longest, vector<T> &longest_seq,
+                  unordered_map<tuple<size_t, size_t, size_t>, size_t, TupleHash> &cache) {
+  if (i >= a.size() || j >= b.size()) {
+    return length;
+  }
+
+  const auto key = make_tuple(i, j, length);
+  if (cache.contains(key)) {
+    return get<1>(*cache.find(key));
+  }
+
 
-    if (a.at(i) == b.at(j)) {
-        // Current (i, j) is member of a common subsequence.
-        longest = max(longest, length+1);
-        size_t remaining_longest = lcss_inner(a, b, i+1, j+1, length+1, longest, longest_seq);
-        if (remaining_longest == longest) {
-            // This element is part of the longest common subsequence.
-            longest_seq.at(length) = a.at(i);
-        }
-        return remaining_longest;
-    } else {
-        // Skip element from a.
-        size_t remaining_longest_skip_a = lcss_inner(a, b, i+1, j, length, longest, longest_seq);
-        // Skip element from b.
-        size_t remaining_longest_skip_b = lcss_inner(a, b, i, j+1, length, longest, longest_seq);
-        return max(remaining_longest_skip_a, remaining_longest_skip_b);
+  if (a.at(i) == b.at(j)) {
+    // Current (i, j) is member of a common subsequence.
+    longest = max(longest, length + 1);
+    size_t remaining_longest =
+        lcss_inner(a, b, i + 1, j + 1, length + 1, longest, longest_seq, cache);
+    if (remaining_longest == longest) {
+      // This element is part of the longest common subsequence.
+      longest_seq.at(length) = a.at(i);
     }
+    cache.insert(make_pair(key, remaining_longest));
+    return remaining_longest;
+  } else {
+    // Skip element from a.
+    size_t remaining_longest_skip_a =
+        lcss_inner(a, b, i + 1, j, length, longest, longest_seq, cache);
+    // Skip element from b.
+    size_t remaining_longest_skip_b =
+        lcss_inner(a, b, i, j + 1, length, longest, longest_seq, cache);
+    size_t result = max(remaining_longest_skip_a, remaining_longest_skip_b);
+    cache.insert(make_pair(key, result));
+    return result;
+  }
 }
 
 template<typename T>
 vector<T> lcss(const vector<T>& a, const vector<T>& b) {
     size_t longest = 0;
     vector<T> longest_seq(min(a.size(), b.size()));
-    size_t longest_r = lcss_inner(a, b, 0, 0, 0, longest, longest_seq);
+    unordered_map<tuple<size_t, size_t, size_t>, size_t, TupleHash> cache(100);
+    size_t longest_r = lcss_inner(a, b, 0, 0, 0, longest, longest_seq, cache);
 
     fmt::print("Length: {} / {}\n", longest, longest_r);
 
     return longest_seq;
 }
-
-template<typename T, typename T2, template<typename,typename> typename C>
-void evaluate_lcss(const C<T,T2>& a, const C<T,T2>& b) requires ranges::range<C<T,T2>> {
-    vector<T> va(a.begin(), a.end());
-    vector<T> vb(b.begin(), b.end());
-
-    fmt::print("{}\n", va);
-    fmt::print("{}\n", vb);
-    fmt::print("{}\n", lcss(va, vb));
-}
 // g++ only
 template<typename T, template<typename> typename C>
 void evaluate_lcss(const C<T>& a, const C<T>& b) requires ranges::range<C<T>> {
     vector<T> va(a.begin(), a.end());
     vector<T> vb(b.begin(), b.end());
+    auto result = timeit_f(function([&va, &vb]() { return lcss(va, vb); }));
 
     fmt::print("{}\n", va);
     fmt::print("{}\n", vb);
-    fmt::print("{}\n", lcss(va, vb));
+    fmt::print("{}\n", result);
 }
 
 int main(void) {
@@ -71,7 +89,6 @@
     vector<int> b{-3, 1, 5, 7, 2, 4, -1, 3, 6, 7, 9, 11, 10};
 
     evaluate_lcss(a, b);
-    evaluate_lcss(views::iota(1, 10), views::iota(2, 11));
     evaluate_lcss(views::single(11), views::single(11));
     return 0;
 }
--- a/lib/exec.cc	Sun Mar 12 14:12:10 2023 +0100
+++ b/lib/exec.cc	Sun Mar 12 16:11:39 2023 +0100
@@ -1,8 +1,6 @@
 
 #include "exec.h"
 
-#include <chrono>
-#include <iostream>
 
 using namespace std;
 using namespace std::chrono;
--- a/lib/exec.h	Sun Mar 12 14:12:10 2023 +0100
+++ b/lib/exec.h	Sun Mar 12 16:11:39 2023 +0100
@@ -1,7 +1,27 @@
 
+
+#include <chrono>
 #include <functional>
+#include <iostream>
 #include <string>
-
 void timeit(std::function<void()> f, std::string name = std::string());
 
 void benchmarkit(std::function<void()> f, std::string name = std::string(), std::function<void()> setup = [](){});
+
+using std::function, std::string, std::chrono::high_resolution_clock;
+
+template<typename T>
+T timeit_f(function<T()> f, string name = string()) {
+    auto begin = high_resolution_clock::now();
+
+    T result = f();
+
+    auto after = high_resolution_clock::now();
+
+    auto d = after - begin;
+    decltype(d)::rep count = d.count();
+    decltype(d)::period p;
+
+    std::cout << name << " :: " << static_cast<double>(p.num)/p.den * count << std::endl;
+    return result;
+}