Mercurial > lbo > hg > aoc22
changeset 64:c9010e9a5257
Day 12 Part 2
Add memoization implementation for count function
author | Lewin Bormann <lbo@spheniscida.de> |
---|---|
date | Sat, 23 Dec 2023 15:13:15 +0100 |
parents | f2355e1a8e8c |
children | f2fb41098579 |
files | 2023/day12.ml 2023/dune 2023/input/12_test.txt |
diffstat | 3 files changed, 147 insertions(+), 14 deletions(-) [+] |
line wrap: on
line diff
--- a/2023/day12.ml Sat Dec 23 14:14:45 2023 +0100 +++ b/2023/day12.ml Sat Dec 23 15:13:15 2023 +0100 @@ -9,7 +9,7 @@ type record = { row : spring array; damaged_groups : int list } [@@deriving show, sexp] - (* The input to the puzzle is a list of records. *) +(* The input to the puzzle is a list of records. *) type input = record list [@@deriving show, sexp] module Parse = struct @@ -31,9 +31,16 @@ in return { row; damaged_groups } - let parse_all = sep_by1 (char '\n') parse_record <* choice [ char '\n' >>| fun _ -> (); end_of_input ] + let parse_all = + sep_by1 (char '\n') parse_record + <* choice + [ + ( char '\n' >>| fun _ -> + (); + end_of_input ); + ] - let parse_input (s:string) : input = + let parse_input (s : string) : input = parse_string ~consume:All parse_all s |> Result.ok_or_failwith end @@ -45,7 +52,8 @@ (* A group (during traversal) is either a not-yet-started group of n broken springs, or a group of n broken springs that has already been entered. In the latter case, it first must be finished before allowing the next working spring. *) - type group = Whole of int | Entered of int [@@deriving show, sexp] + type group = Whole of int | Entered of int + [@@deriving show, sexp, eq, compare] (* count combinations for a single record: row is the array of springs, ix is the current index, and groups is the current list of groups. *) @@ -60,7 +68,8 @@ | Broken, Entered 0 :: _ -> 0 | Broken, _ -> 0 | Working, Entered 0 :: cs -> count row (ix + 1) cs - | Working, Entered _ :: _ -> 0 (* we're in a non-finished group of broken springs *) + | Working, Entered _ :: _ -> + 0 (* we're in a non-finished group of broken springs *) | Working, cs -> count row (ix + 1) cs | Unknown, [] -> (* assume ? = working *) count row (ix + 1) [] | Unknown, Entered 0 :: cs -> @@ -72,8 +81,8 @@ let with_broken = count row (ix + 1) (Entered (c - 1) :: cs) in (* then assume ? = working; i.e. skip group *) let with_working = count row (ix + 1) (Whole c :: cs) in - Out_channel.printf "broken: %d working: %d (ix %d)\n" with_broken - with_working ix; + (*Out_channel.printf "broken: %d working: %d (ix %d)\n" with_broken + with_working ix;*) with_broken + with_working | Unknown, Entered c :: cs -> assert (c > 0); @@ -82,21 +91,145 @@ (* count combinations for a single record *) let count_combinations { row; damaged_groups } = - Out_channel.printf "\n"; let groups = List.map ~f:(fun c -> Whole c) damaged_groups in count row 0 groups - (* a list of counts of combinations per record. *) + (* a list of counts of combinations per record. *) type combinations_counts = int list [@@deriving show] - (* count combinations for each record *) + (* count combinations for each record *) let count_all_combinations records : combinations_counts = List.map ~f:count_combinations records end +(* The algorithm in Part 1 is too slow for large lists, but it is the correct approach. + Use memoization to speed it up. *) +module Part2 = struct + open Part1 + + module Memo_key = struct + type t = { ix : int; head : group; groupsleft : int } + [@@deriving compare, eq, sexp] + + let hash t = + let open Int in + (t.ix * 31) + match t.head with Whole c -> c | Entered c -> c + end + + let create_memo () : (Memo_key.t, int) Hashtbl.t = + Hashtbl.create (module Memo_key) + + let check_memo memo ix = function + | head :: groups -> + Hashtbl.mem memo + Memo_key.{ ix; head; groupsleft = 1 + List.length groups } + | [] -> false + + let get_memo memo ix = function + | head :: groups -> + Hashtbl.find_exn memo + Memo_key.{ ix; head; groupsleft = 1 + List.length groups } + | [] -> assert false + + let set_memo memo ix groups value = + match groups with + | head :: groups -> + Hashtbl.set memo + ~key:Memo_key.{ ix; head; groupsleft = 1 + List.length groups } + ~data:value + | [] -> () + + (* a crude memoization scheme: the key is (index, head group, number of groups left to process) + and identifies the state enough to reliably cache the outcome. *) + let rec count_memo memo row ix groups : int = + if check_memo memo ix groups then get_memo memo ix groups + else + let r = count (count_memo memo) row ix groups in + set_memo memo ix groups r; + r + + (* count combinations for a single record: row is the array of springs, ix is the current index, + and groups is the current list of groups. + + The result is the number of combinations if starting at ix with remaining groups. + + Mutual recursion: for memoization to work, the inner count function is passed as countrec. + Maybe not the most elegant, but is compact enough. + *) + and count countrec row ix (groups : group list) : int = + let len = Array.length row in + match ix with + | l when l = len -> ( match groups with [] | [ Entered 0 ] -> 1 | _ -> 0) + | ix -> ( + match (row.(ix), groups) with + | Broken, (Entered c | Whole c) :: cs when c > 0 -> + countrec row (ix + 1) (Entered (c - 1) :: cs) + | Broken, Entered 0 :: _ -> 0 + | Broken, _ -> 0 + | Working, Entered 0 :: cs -> countrec row (ix + 1) cs + | Working, Entered _ :: _ -> + 0 (* we're in a non-finished group of broken springs *) + | Working, cs -> countrec row (ix + 1) cs + | Unknown, [] -> (* assume ? = working *) countrec row (ix + 1) [] + | Unknown, Entered 0 :: cs -> + (* assume ? = working because previous group is separated by working spring *) + countrec row (ix + 1) cs + | Unknown, Whole c :: cs -> + assert (c > 0); + (* first assume ? = broken *) + let with_broken = countrec row (ix + 1) (Entered (c - 1) :: cs) in + (* then assume ? = working; i.e. skip group *) + let with_working = countrec row (ix + 1) (Whole c :: cs) in + (*Out_channel.printf "broken: %d working: %d (ix %d)\n" with_broken + with_working ix;*) + with_broken + with_working + | Unknown, Entered c :: cs -> + assert (c > 0); + (* forced assumption: ? = broken because we are in a group. *) + countrec row (ix + 1) (Entered (c - 1) :: cs)) + + (* replace each spring list with five copies of itsef, separated by Unknown. + + replace list of damaged spring groups with five copies of itself. + *) + let unfold_springs { row; damaged_groups } = + let multiplier = 5 in + let multiply_list ?sep l = + let range = List.range 0 multiplier in + let multiplied = + List.fold ~init:[] + ~f:(fun acc _ -> + (match sep with Some s -> s :: l | None -> l) @ acc) + range + in + match sep with Some _ -> List.tl_exn multiplied | None -> multiplied + in + let damaged_groups' = multiply_list damaged_groups in + let row_list = Array.to_list row in + let row_list' = multiply_list ~sep:Unknown row_list in + let row' = Array.of_list row_list' in + { row = row'; damaged_groups = damaged_groups' } + + let count_combinations { row; damaged_groups } = + let memo = create_memo () in + let groups = List.map ~f:(fun c -> Whole c) damaged_groups in + count_memo memo row 0 groups + + (* a list of counts of combinations per record. *) + (* count combinations for a single record: row is the array of springs, ix is the current index, + and groups is the current list of groups. *) + let count_all_combinations records : Part1.combinations_counts = + let multiplied = List.map ~f:unfold_springs records in + List.map ~f:count_combinations multiplied +end + let () = let input = In_channel.(input_all stdin) in let parsed = Parse.parse_input input in let combos = Part1.count_all_combinations parsed in - let sum = List.fold combos ~init:0 ~f:Int.(+) in - Out_channel.(printf "%s\nsum: %d\n" (Part1.show_combinations_counts combos) sum) + let sum = List.fold combos ~init:0 ~f:Int.( + ) in + let combos2 = Part2.count_all_combinations parsed in + let sum2 = List.fold combos2 ~init:0 ~f:Int.( + ) in + Out_channel.( + printf "%s\nsum: %d\n" (Part1.show_combinations_counts combos) sum; + printf "%s\nsum: %d\n" (Part1.show_combinations_counts combos2) sum2)
--- a/2023/dune Sat Dec 23 14:14:45 2023 +0100 +++ b/2023/dune Sat Dec 23 15:13:15 2023 +0100 @@ -92,4 +92,4 @@ (modules day12) (libraries base core angstrom) (preprocess - (pps ppx_let ppx_sexp_conv ppx_compare ppx_deriving.show))) + (pps ppx_let ppx_sexp_conv ppx_compare ppx_deriving.show ppx_deriving.eq)))