view dynprog.cc @ 7:202259bcb331

DP: tighten code
author Lewin Bormann <lbo@spheniscida.de>
date Sun, 12 Mar 2023 22:48:09 +0100
parents 90bf264f80a5
children 60c7a574d536
line wrap: on
line source


#include "lib/exec.h"

#include <iostream>
#include <functional>
#include <ranges>
#include <string>
#include <sstream>
#include <tuple>
#include <unordered_map>
#include <vector>

using namespace std;

using MemoKey = tuple<size_t, size_t, size_t>;

class TupleHash : function<size_t(const MemoKey&)> {
public:
    size_t operator()(const MemoKey& k) const { return get<0>(k) ^ get<1>(k) ^ get<2>(k); }
};

template <typename T>
size_t lcss_inner(const vector<T> &a, const vector<T> &b, size_t i, size_t j,
                  size_t length, size_t &longest, vector<T> &longest_seq,
                  unordered_map<tuple<size_t, size_t, size_t>, size_t, TupleHash> &cache) {
  if (i >= a.size() || j >= b.size()) {
    return length;
  }

  const auto key = make_tuple(i, j, length);
  if (cache.contains(key)) {
    return get<1>(*cache.find(key));
  }


  if (a.at(i) == b.at(j)) {
    // Current (i, j) is member of a common subsequence.
    longest = max(longest, length + 1);
    size_t remaining_longest =
        lcss_inner(a, b, i + 1, j + 1, length + 1, longest, longest_seq, cache);
    if (remaining_longest == longest) {
      // This element is part of the longest common subsequence.
      longest_seq.at(length) = a.at(i);
    }
    cache.insert(make_pair(key, remaining_longest));
    return remaining_longest;
  } else {
    // Skip element from a.
    size_t remaining_longest_skip_a =
        lcss_inner(a, b, i + 1, j, length, longest, longest_seq, cache);
    // Skip element from b.
    size_t remaining_longest_skip_b =
        lcss_inner(a, b, i, j + 1, length, longest, longest_seq, cache);
    size_t result = max(remaining_longest_skip_a, remaining_longest_skip_b);
    cache.insert(make_pair(key, result));
    return result;
  }
}

template<typename T>
vector<T> lcss(const vector<T>& a, const vector<T>& b) {
    size_t longest = 0;
    vector<T> longest_seq(min(a.size(), b.size()));
    unordered_map<tuple<size_t, size_t, size_t>, size_t, TupleHash> cache(100);
    size_t longest_r = lcss_inner(a, b, 0, 0, 0, longest, longest_seq, cache);

    fmt::print("Length: {} / {}\n", longest, longest_r);

    return longest_seq;
}
// g++ only
template<typename T, template<typename> typename C>
void evaluate_lcss(const C<T>& a, const C<T>& b) requires ranges::range<C<T>> {
    vector<T> va(a.begin(), a.end());
    vector<T> vb(b.begin(), b.end());
    auto result = timeit_f(function([&va, &vb]() { return lcss(va, vb); }));

    fmt::print("{}\n", va);
    fmt::print("{}\n", vb);
    fmt::print("{}\n", result);
}

template <typename T>
size_t longest_increasing_subsequence(const vector<T> &a, size_t i, size_t prev,
                                      size_t length, size_t &longest,
                                      vector<size_t> &seqixs) {
    longest = max(length, longest);
    if (i >= a.size()) return length;

    size_t length_with_current = a.at(i) < a.at(prev) ? length : longest_increasing_subsequence(a, i+1, i, length+1, longest, seqixs);
    size_t length_without_current = longest_increasing_subsequence(a, i+1, prev, length, longest, seqixs);
    size_t result = max(length_with_current, length_without_current);

    if (result == length_with_current && result == longest) {
        seqixs.at(length) = i;
    }
    return result;
}

// Bug: if first element is 0 or negative, may repeat last index in seqixs
template<typename T>
T max_sum_incr_subseq(const vector<T>& a, size_t i, size_t prev, size_t length, T sum, T& highest_sum, vector<size_t>& seqixs) {
    highest_sum = max(sum, highest_sum);
    if (i >= a.size()) return sum;

    T sum_with_this = a.at(i) < a.at(prev) ? sum : max_sum_incr_subseq(a, i+1, i, length+1, sum + a.at(i), highest_sum, seqixs);
    T sum_without_this = max_sum_incr_subseq(a, i+1, prev, length, sum, highest_sum, seqixs);
    T result = max(sum_with_this, sum_without_this);

    if (result == sum_with_this && sum_with_this > sum_without_this && sum_with_this == highest_sum) {
        if (i == 7 && length == 3) fmt::print("{} {} {} {} {} / {} - {}\n", i, prev, length, sum, highest_sum, sum_with_this, sum_without_this);
        seqixs.at(length) = i;
    }

    return result;
}


template<typename T, template<typename> typename C>
void evaluate_liss(const C<T>& v) requires ranges::range<C<T>> {
    vector<T> a(v.begin(), v.end());
    size_t longest = 0;
    vector<size_t> seqixs(a.size());
    size_t result = timeit_f(function([&a,&longest,&seqixs]() -> size_t { return longest_increasing_subsequence(a, 0, 0, 0, longest, seqixs); }), "longest incr. subsequence");

    fmt::print("Longest incr. subsequence has length {}/{} and has indices: {}\n", longest, result, seqixs);
    
    vector<T> seq(longest);
    for (int i = 0; i < seq.size(); i++)
        seq.at(i) = a.at(seqixs.at(i));
    fmt::print("Longest incr. subsequence is: {}\n", seq);
}

template<typename T, template<typename> typename C>
void evaluate_msis(const C<T>& v) requires ranges::range<C<T>> {
    vector<T> a(v.begin(), v.end());
    T highest = 0;
    vector<size_t> seqixs(a.size());
    size_t result =
        timeit_f(function([&a, &highest, &seqixs]() -> size_t {
                   return max_sum_incr_subseq(a, 0, 0, 0, static_cast<T>(0), highest, seqixs);
                 }),
                 "max sum incr. subsequence");

    fmt::print("Max incr sum subseq subsequence has sum {}/{} and has indices: {}\n", highest, result, seqixs);
    
    vector<T> seq(seqixs.size());
    for (int i = 0; i < seq.size(); i++) {
        if (i > 0 && seqixs.at(i) == 0) break;
        seq.at(i) = a.at(seqixs.at(i));
    }
    fmt::print("Max sum incr. subsequence is: {}\n", seq);
}

int lcss_main(void) {

    vector<int> a{1,6,2,-1,3,55,2,4,23,5,6,7,8,9,10};
    vector<int> b{-3, 1, 5, 7, 2, 4, -1, 3, 6, 7, 9, 11, 10};

    evaluate_lcss(a, b);
    evaluate_lcss(views::single(11), views::single(11));
    return 0;
}

int main(void) {
    vector<int> a{0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11};
    // Expect: 0,2,6,9,13 or ,11
    // or 0,4,6,9,13
    evaluate_liss(a);
    return 0;
}

int msis_main(void) {
    vector<int> a{-1, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11};
    // Expect: 0,2,6,9,13 or ,11
    // or 0,4,6,9,13
    evaluate_msis(a);
    return 0;
}