Mercurial > lbo > hg > ccplay
changeset 5:a1ac40ee3a1e
Memoization for DP
author | Lewin Bormann <lbo@spheniscida.de> |
---|---|
date | Sun, 12 Mar 2023 16:11:39 +0100 |
parents | bb5c00d28022 |
children | 90bf264f80a5 |
files | dynprog.cc lib/exec.cc lib/exec.h |
diffstat | 3 files changed, 71 insertions(+), 36 deletions(-) [+] |
line wrap: on
line diff
--- a/dynprog.cc Sun Mar 12 14:12:10 2023 +0100 +++ b/dynprog.cc Sun Mar 12 16:11:39 2023 +0100 @@ -1,8 +1,13 @@ + +#include "lib/exec.h" #include <iostream> +#include <functional> #include <ranges> #include <string> #include <sstream> +#include <tuple> +#include <unordered_map> #include <vector> #include <fmt/core.h> @@ -10,59 +15,72 @@ using namespace std; -template<typename T> -size_t lcss_inner(const vector<T>& a, const vector<T>& b, size_t i, size_t j, size_t length, size_t& longest, vector<T>& longest_seq) { - if (i >= a.size() || j >= b.size()) { - return length; - } +using MemoKey = tuple<size_t, size_t, size_t>; + +class TupleHash : function<size_t(const MemoKey&)> { +public: + size_t operator()(const MemoKey& k) const { return get<0>(k) ^ get<1>(k) ^ get<2>(k); } +}; + +template <typename T> +size_t lcss_inner(const vector<T> &a, const vector<T> &b, size_t i, size_t j, + size_t length, size_t &longest, vector<T> &longest_seq, + unordered_map<tuple<size_t, size_t, size_t>, size_t, TupleHash> &cache) { + if (i >= a.size() || j >= b.size()) { + return length; + } + + const auto key = make_tuple(i, j, length); + if (cache.contains(key)) { + return get<1>(*cache.find(key)); + } + - if (a.at(i) == b.at(j)) { - // Current (i, j) is member of a common subsequence. - longest = max(longest, length+1); - size_t remaining_longest = lcss_inner(a, b, i+1, j+1, length+1, longest, longest_seq); - if (remaining_longest == longest) { - // This element is part of the longest common subsequence. - longest_seq.at(length) = a.at(i); - } - return remaining_longest; - } else { - // Skip element from a. - size_t remaining_longest_skip_a = lcss_inner(a, b, i+1, j, length, longest, longest_seq); - // Skip element from b. - size_t remaining_longest_skip_b = lcss_inner(a, b, i, j+1, length, longest, longest_seq); - return max(remaining_longest_skip_a, remaining_longest_skip_b); + if (a.at(i) == b.at(j)) { + // Current (i, j) is member of a common subsequence. + longest = max(longest, length + 1); + size_t remaining_longest = + lcss_inner(a, b, i + 1, j + 1, length + 1, longest, longest_seq, cache); + if (remaining_longest == longest) { + // This element is part of the longest common subsequence. + longest_seq.at(length) = a.at(i); } + cache.insert(make_pair(key, remaining_longest)); + return remaining_longest; + } else { + // Skip element from a. + size_t remaining_longest_skip_a = + lcss_inner(a, b, i + 1, j, length, longest, longest_seq, cache); + // Skip element from b. + size_t remaining_longest_skip_b = + lcss_inner(a, b, i, j + 1, length, longest, longest_seq, cache); + size_t result = max(remaining_longest_skip_a, remaining_longest_skip_b); + cache.insert(make_pair(key, result)); + return result; + } } template<typename T> vector<T> lcss(const vector<T>& a, const vector<T>& b) { size_t longest = 0; vector<T> longest_seq(min(a.size(), b.size())); - size_t longest_r = lcss_inner(a, b, 0, 0, 0, longest, longest_seq); + unordered_map<tuple<size_t, size_t, size_t>, size_t, TupleHash> cache(100); + size_t longest_r = lcss_inner(a, b, 0, 0, 0, longest, longest_seq, cache); fmt::print("Length: {} / {}\n", longest, longest_r); return longest_seq; } - -template<typename T, typename T2, template<typename,typename> typename C> -void evaluate_lcss(const C<T,T2>& a, const C<T,T2>& b) requires ranges::range<C<T,T2>> { - vector<T> va(a.begin(), a.end()); - vector<T> vb(b.begin(), b.end()); - - fmt::print("{}\n", va); - fmt::print("{}\n", vb); - fmt::print("{}\n", lcss(va, vb)); -} // g++ only template<typename T, template<typename> typename C> void evaluate_lcss(const C<T>& a, const C<T>& b) requires ranges::range<C<T>> { vector<T> va(a.begin(), a.end()); vector<T> vb(b.begin(), b.end()); + auto result = timeit_f(function([&va, &vb]() { return lcss(va, vb); })); fmt::print("{}\n", va); fmt::print("{}\n", vb); - fmt::print("{}\n", lcss(va, vb)); + fmt::print("{}\n", result); } int main(void) { @@ -71,7 +89,6 @@ vector<int> b{-3, 1, 5, 7, 2, 4, -1, 3, 6, 7, 9, 11, 10}; evaluate_lcss(a, b); - evaluate_lcss(views::iota(1, 10), views::iota(2, 11)); evaluate_lcss(views::single(11), views::single(11)); return 0; }
--- a/lib/exec.cc Sun Mar 12 14:12:10 2023 +0100 +++ b/lib/exec.cc Sun Mar 12 16:11:39 2023 +0100 @@ -1,8 +1,6 @@ #include "exec.h" -#include <chrono> -#include <iostream> using namespace std; using namespace std::chrono;
--- a/lib/exec.h Sun Mar 12 14:12:10 2023 +0100 +++ b/lib/exec.h Sun Mar 12 16:11:39 2023 +0100 @@ -1,7 +1,27 @@ + +#include <chrono> #include <functional> +#include <iostream> #include <string> - void timeit(std::function<void()> f, std::string name = std::string()); void benchmarkit(std::function<void()> f, std::string name = std::string(), std::function<void()> setup = [](){}); + +using std::function, std::string, std::chrono::high_resolution_clock; + +template<typename T> +T timeit_f(function<T()> f, string name = string()) { + auto begin = high_resolution_clock::now(); + + T result = f(); + + auto after = high_resolution_clock::now(); + + auto d = after - begin; + decltype(d)::rep count = d.count(); + decltype(d)::period p; + + std::cout << name << " :: " << static_cast<double>(p.num)/p.den * count << std::endl; + return result; +}