changeset 64:c9010e9a5257

Day 12 Part 2 Add memoization implementation for count function
author Lewin Bormann <lbo@spheniscida.de>
date Sat, 23 Dec 2023 15:13:15 +0100
parents f2355e1a8e8c
children f2fb41098579
files 2023/day12.ml 2023/dune 2023/input/12_test.txt
diffstat 3 files changed, 147 insertions(+), 14 deletions(-) [+]
line wrap: on
line diff
--- a/2023/day12.ml	Sat Dec 23 14:14:45 2023 +0100
+++ b/2023/day12.ml	Sat Dec 23 15:13:15 2023 +0100
@@ -9,7 +9,7 @@
 type record = { row : spring array; damaged_groups : int list }
 [@@deriving show, sexp]
 
-  (* The input to the puzzle is a list of records. *)
+(* The input to the puzzle is a list of records. *)
 type input = record list [@@deriving show, sexp]
 
 module Parse = struct
@@ -31,9 +31,16 @@
     in
     return { row; damaged_groups }
 
-  let parse_all = sep_by1 (char '\n') parse_record <* choice [ char '\n' >>| fun _ -> (); end_of_input ]
+  let parse_all =
+    sep_by1 (char '\n') parse_record
+    <* choice
+         [
+           ( char '\n' >>| fun _ ->
+             ();
+             end_of_input );
+         ]
 
-  let parse_input (s:string) : input =
+  let parse_input (s : string) : input =
     parse_string ~consume:All parse_all s |> Result.ok_or_failwith
 end
 
@@ -45,7 +52,8 @@
   (* A group (during traversal) is either a not-yet-started group of n broken springs,
      or a group of n broken springs that has already been entered.
      In the latter case, it first must be finished before allowing the next working spring. *)
-  type group = Whole of int | Entered of int [@@deriving show, sexp]
+  type group = Whole of int | Entered of int
+  [@@deriving show, sexp, eq, compare]
 
   (* count combinations for a single record: row is the array of springs, ix is the current index,
      and groups is the current list of groups. *)
@@ -60,7 +68,8 @@
         | Broken, Entered 0 :: _ -> 0
         | Broken, _ -> 0
         | Working, Entered 0 :: cs -> count row (ix + 1) cs
-        | Working, Entered _ :: _ -> 0 (* we're in a non-finished group of broken springs *)
+        | Working, Entered _ :: _ ->
+            0 (* we're in a non-finished group of broken springs *)
         | Working, cs -> count row (ix + 1) cs
         | Unknown, [] -> (* assume ? = working *) count row (ix + 1) []
         | Unknown, Entered 0 :: cs ->
@@ -72,8 +81,8 @@
             let with_broken = count row (ix + 1) (Entered (c - 1) :: cs) in
             (* then assume ? = working; i.e. skip group *)
             let with_working = count row (ix + 1) (Whole c :: cs) in
-            Out_channel.printf "broken: %d working: %d (ix %d)\n" with_broken
-              with_working ix;
+            (*Out_channel.printf "broken: %d working: %d (ix %d)\n" with_broken
+              with_working ix;*)
             with_broken + with_working
         | Unknown, Entered c :: cs ->
             assert (c > 0);
@@ -82,21 +91,145 @@
 
   (* count combinations for a single record *)
   let count_combinations { row; damaged_groups } =
-    Out_channel.printf "\n";
     let groups = List.map ~f:(fun c -> Whole c) damaged_groups in
     count row 0 groups
 
-      (* a list of counts of combinations per record. *)
+  (* a list of counts of combinations per record. *)
   type combinations_counts = int list [@@deriving show]
 
-    (* count combinations for each record *)
+  (* count combinations for each record *)
   let count_all_combinations records : combinations_counts =
     List.map ~f:count_combinations records
 end
 
+(* The algorithm in Part 1 is too slow for large lists, but it is the correct approach.
+   Use memoization to speed it up. *)
+module Part2 = struct
+  open Part1
+
+  module Memo_key = struct
+    type t = { ix : int; head : group; groupsleft : int }
+    [@@deriving compare, eq, sexp]
+
+    let hash t =
+      let open Int in
+      (t.ix * 31) + match t.head with Whole c -> c | Entered c -> c
+  end
+
+  let create_memo () : (Memo_key.t, int) Hashtbl.t =
+    Hashtbl.create (module Memo_key)
+
+  let check_memo memo ix = function
+    | head :: groups ->
+        Hashtbl.mem memo
+          Memo_key.{ ix; head; groupsleft = 1 + List.length groups }
+    | [] -> false
+
+  let get_memo memo ix = function
+    | head :: groups ->
+        Hashtbl.find_exn memo
+          Memo_key.{ ix; head; groupsleft = 1 + List.length groups }
+    | [] -> assert false
+
+  let set_memo memo ix groups value =
+    match groups with
+    | head :: groups ->
+        Hashtbl.set memo
+          ~key:Memo_key.{ ix; head; groupsleft = 1 + List.length groups }
+          ~data:value
+    | [] -> ()
+
+  (* a crude memoization scheme: the key is (index, head group, number of groups left to process)
+     and identifies the state enough to reliably cache the outcome. *)
+  let rec count_memo memo row ix groups : int =
+    if check_memo memo ix groups then get_memo memo ix groups
+    else
+      let r = count (count_memo memo) row ix groups in
+      set_memo memo ix groups r;
+      r
+
+  (* count combinations for a single record: row is the array of springs, ix is the current index,
+     and groups is the current list of groups.
+
+     The result is the number of combinations if starting at ix with remaining groups.
+
+     Mutual recursion: for memoization to work, the inner count function is passed as countrec.
+     Maybe not the most elegant, but is compact enough.
+  *)
+  and count countrec row ix (groups : group list) : int =
+    let len = Array.length row in
+    match ix with
+    | l when l = len -> ( match groups with [] | [ Entered 0 ] -> 1 | _ -> 0)
+    | ix -> (
+        match (row.(ix), groups) with
+        | Broken, (Entered c | Whole c) :: cs when c > 0 ->
+            countrec row (ix + 1) (Entered (c - 1) :: cs)
+        | Broken, Entered 0 :: _ -> 0
+        | Broken, _ -> 0
+        | Working, Entered 0 :: cs -> countrec row (ix + 1) cs
+        | Working, Entered _ :: _ ->
+            0 (* we're in a non-finished group of broken springs *)
+        | Working, cs -> countrec row (ix + 1) cs
+        | Unknown, [] -> (* assume ? = working *) countrec row (ix + 1) []
+        | Unknown, Entered 0 :: cs ->
+            (* assume ? = working because previous group is separated by working spring *)
+            countrec row (ix + 1) cs
+        | Unknown, Whole c :: cs ->
+            assert (c > 0);
+            (* first assume ? = broken *)
+            let with_broken = countrec row (ix + 1) (Entered (c - 1) :: cs) in
+            (* then assume ? = working; i.e. skip group *)
+            let with_working = countrec row (ix + 1) (Whole c :: cs) in
+            (*Out_channel.printf "broken: %d working: %d (ix %d)\n" with_broken
+              with_working ix;*)
+            with_broken + with_working
+        | Unknown, Entered c :: cs ->
+            assert (c > 0);
+            (* forced assumption: ? = broken because we are in a group. *)
+            countrec row (ix + 1) (Entered (c - 1) :: cs))
+
+  (* replace each spring list with five copies of itsef, separated by Unknown.
+
+     replace list of damaged spring groups with five copies of itself.
+  *)
+  let unfold_springs { row; damaged_groups } =
+    let multiplier = 5 in
+    let multiply_list ?sep l =
+      let range = List.range 0 multiplier in
+      let multiplied =
+        List.fold ~init:[]
+          ~f:(fun acc _ ->
+            (match sep with Some s -> s :: l | None -> l) @ acc)
+          range
+      in
+      match sep with Some _ -> List.tl_exn multiplied | None -> multiplied
+    in
+    let damaged_groups' = multiply_list damaged_groups in
+    let row_list = Array.to_list row in
+    let row_list' = multiply_list ~sep:Unknown row_list in
+    let row' = Array.of_list row_list' in
+    { row = row'; damaged_groups = damaged_groups' }
+
+  let count_combinations { row; damaged_groups } =
+    let memo = create_memo () in
+    let groups = List.map ~f:(fun c -> Whole c) damaged_groups in
+    count_memo memo row 0 groups
+
+  (* a list of counts of combinations per record. *)
+  (* count combinations for a single record: row is the array of springs, ix is the current index,
+     and groups is the current list of groups. *)
+  let count_all_combinations records : Part1.combinations_counts =
+    let multiplied = List.map ~f:unfold_springs records in
+    List.map ~f:count_combinations multiplied
+end
+
 let () =
   let input = In_channel.(input_all stdin) in
   let parsed = Parse.parse_input input in
   let combos = Part1.count_all_combinations parsed in
-  let sum = List.fold combos ~init:0 ~f:Int.(+) in
-  Out_channel.(printf "%s\nsum: %d\n" (Part1.show_combinations_counts combos) sum)
+  let sum = List.fold combos ~init:0 ~f:Int.( + ) in
+  let combos2 = Part2.count_all_combinations parsed in
+  let sum2 = List.fold combos2 ~init:0 ~f:Int.( + ) in
+  Out_channel.(
+    printf "%s\nsum: %d\n" (Part1.show_combinations_counts combos) sum;
+    printf "%s\nsum: %d\n" (Part1.show_combinations_counts combos2) sum2)
--- a/2023/dune	Sat Dec 23 14:14:45 2023 +0100
+++ b/2023/dune	Sat Dec 23 15:13:15 2023 +0100
@@ -92,4 +92,4 @@
  (modules day12)
  (libraries base core angstrom)
  (preprocess
-  (pps ppx_let ppx_sexp_conv ppx_compare ppx_deriving.show)))
+  (pps ppx_let ppx_sexp_conv ppx_compare ppx_deriving.show ppx_deriving.eq)))
--- a/2023/input/12_test.txt	Sat Dec 23 14:14:45 2023 +0100
+++ b/2023/input/12_test.txt	Sat Dec 23 15:13:15 2023 +0100
@@ -3,4 +3,4 @@
 ?#?#?#?#?#?#?#? 1,3,1,6
 ????.#...#... 4,1,1
 ????.######..#####. 1,6,5
-?###???????? 3,2,1
\ No newline at end of file
+?###???????? 3,2,1