view 2023/day12.ml @ 78:ade1919a5409 default tip

Day 17: streamline PQ
author Lewin Bormann <lbo@spheniscida.de>
date Tue, 02 Jan 2024 18:42:41 +0100
parents f2fb41098579
children
line wrap: on
line source

open Angstrom
open Base
open Core

(* A spring is either working, broken, or unknown. *)
type spring = Working | Broken | Unknown [@@deriving show, sexp]

(* A record is a row of springs, and a list of groups of broken springs. *)
type record = { row : spring array; damaged_groups : int list }
[@@deriving show, sexp]

(* The input to the puzzle is a list of records. *)
type input = record list [@@deriving show, sexp]

module Parse = struct
  let parse_spring =
    choice
      [
        char '?' *> return Unknown;
        char '#' *> return Broken;
        char '.' *> return Working;
      ]

  let parse_row = many1 parse_spring

  let parse_record =
    let open Angstrom.Let_syntax in
    let%bind row = parse_row <* char ' ' >>| Array.of_list in
    let%bind damaged_groups =
      sep_by1 (char ',') (take_while1 Char.is_digit >>| Int.of_string)
    in
    return { row; damaged_groups }

  let parse_all =
    sep_by1 (char '\n') parse_record
    <* choice
         [
           ( char '\n' >>| fun _ ->
             ();
             end_of_input );
         ]

  let parse_input (s : string) : input =
    parse_string ~consume:All parse_all s |> Result.ok_or_failwith
end

(* for each record, find the number of possible ways to group
   the damaged springs according to the group specification.

   The desired result is the sum of all records' combinations. *)
module Part1 = struct
  (* A group (during traversal) is either a not-yet-started group of n broken springs,
     or a group of n broken springs that has already been entered.
     In the latter case, it first must be finished before allowing the next working spring. *)
  type group = Whole of int | Entered of int
  [@@deriving show, sexp, eq, compare]

  (* count combinations for a single record: row is the array of springs, ix is the current index,
     and groups is the current list of groups. *)
  let rec count row ix (groups : group list) =
    let len = Array.length row in
    match ix with
    | l when l = len -> ( match groups with [] | [ Entered 0 ] -> 1 | _ -> 0)
    | ix -> (
        match (row.(ix), groups) with
        | Broken, (Entered c | Whole c) :: cs when c > 0 ->
            count row (ix + 1) (Entered (c - 1) :: cs)
        | Broken, Entered 0 :: _ -> 0
        | Broken, _ -> 0
        | Working, Entered 0 :: cs -> count row (ix + 1) cs
        | Working, Entered _ :: _ ->
            0 (* we're in a non-finished group of broken springs *)
        | Working, cs -> count row (ix + 1) cs
        | Unknown, [] -> (* assume ? = working *) count row (ix + 1) []
        | Unknown, Entered 0 :: cs ->
            (* assume ? = working because previous group is separated by working spring *)
            count row (ix + 1) cs
        | Unknown, Whole c :: cs ->
            assert (c > 0);
            (* first assume ? = broken *)
            let with_broken = count row (ix + 1) (Entered (c - 1) :: cs) in
            (* then assume ? = working; i.e. skip group *)
            let with_working = count row (ix + 1) (Whole c :: cs) in
            (*Out_channel.printf "broken: %d working: %d (ix %d)\n" with_broken
              with_working ix;*)
            with_broken + with_working
        | Unknown, Entered c :: cs ->
            assert (c > 0);
            (* forced assumption: ? = broken because we are in a group. *)
            count row (ix + 1) (Entered (c - 1) :: cs))

  (* count combinations for a single record *)
  let count_combinations { row; damaged_groups } =
    let groups = List.map ~f:(fun c -> Whole c) damaged_groups in
    count row 0 groups

  (* a list of counts of combinations per record. *)
  type combinations_counts = int list [@@deriving show]

  (* count combinations for each record *)
  let count_all_combinations records : combinations_counts =
    List.map ~f:count_combinations records
end

(* The algorithm in Part 1 is too slow for large lists, but it is the correct approach.
   Use memoization to speed it up. *)
module Part2 = struct
  open Part1

  module Memoize = struct
    module Memo_key = struct
      type t = { ix : int; head : group; groupsleft : int }
      [@@deriving compare, eq, sexp]

      let hash t =
        let open Int in
        (t.ix * 31) + match t.head with Whole c -> c | Entered c -> c
    end

    type t = (Memo_key.t, int) Hashtbl.t

    let create () : t = Hashtbl.create (module Memo_key)

    let get memo ix = function
      | head :: groups ->
          Hashtbl.find memo
            Memo_key.{ ix; head; groupsleft = 1 + List.length groups }
      | [] -> None

    let set memo ix groups value =
      match groups with
      | head :: groups ->
          Hashtbl.set memo
            ~key:Memo_key.{ ix; head; groupsleft = 1 + List.length groups }
            ~data:value
      | [] -> ()
  end

  (* a crude memoization scheme: the key is (index, head group, number of groups left to process)
     and identifies the state enough to reliably cache the outcome. *)
  let rec memoized_count memo row ix groups : int =
    match Memoize.get memo ix groups with
    | Some r -> r
    | None ->
        let r = count (memoized_count memo) row ix groups in
        Memoize.set memo ix groups r;
        r

  (* count combinations for a single record: row is the array of springs, ix is the current index,
     and groups is the current list of groups.

     The result is the number of combinations if starting at ix with remaining groups.

     Mutual recursion: for memoization to work, the inner count function is passed as countrec.
     Maybe not the most elegant, but is compact enough.
  *)
  and count countrec row ix (groups : group list) : int =
    let len = Array.length row in
    match ix with
    | l when l = len -> ( match groups with [] | [ Entered 0 ] -> 1 | _ -> 0)
    | ix -> (
        match (row.(ix), groups) with
        | Broken, (Entered c | Whole c) :: cs when c > 0 ->
            countrec row (ix + 1) (Entered (c - 1) :: cs)
        | Broken, Entered 0 :: _ -> 0
        | Broken, _ -> 0
        | Working, Entered 0 :: cs -> countrec row (ix + 1) cs
        | Working, Entered _ :: _ ->
            0 (* we're in a non-finished group of broken springs *)
        | Working, cs -> countrec row (ix + 1) cs
        | Unknown, [] -> (* assume ? = working *) countrec row (ix + 1) []
        | Unknown, Entered 0 :: cs ->
            (* assume ? = working because previous group is separated by working spring *)
            countrec row (ix + 1) cs
        | Unknown, Whole c :: cs ->
            assert (c > 0);
            (* first assume ? = broken *)
            let with_broken = countrec row (ix + 1) (Entered (c - 1) :: cs) in
            (* then assume ? = working; i.e. skip group *)
            let with_working = countrec row (ix + 1) (Whole c :: cs) in
            (*Out_channel.printf "broken: %d working: %d (ix %d)\n" with_broken
              with_working ix;*)
            with_broken + with_working
        | Unknown, Entered c :: cs ->
            assert (c > 0);
            (* forced assumption: ? = broken because we are in a group. *)
            countrec row (ix + 1) (Entered (c - 1) :: cs))

  (* replace each spring list with five copies of itsef, separated by Unknown.

     replace list of damaged spring groups with five copies of itself.
  *)
  let unfold_springs { row; damaged_groups } =
    let multiplier = 5 in
    let multiply_list ?sep l =
      let range = List.range 0 multiplier in
      let multiplied =
        List.fold ~init:[]
          ~f:(fun acc _ ->
            (match sep with Some s -> s :: l | None -> l) @ acc)
          range
      in
      match sep with Some _ -> List.tl_exn multiplied | None -> multiplied
    in
    let damaged_groups' = multiply_list damaged_groups in
    let row_list = Array.to_list row in
    let row_list' = multiply_list ~sep:Unknown row_list in
    let row' = Array.of_list row_list' in
    { row = row'; damaged_groups = damaged_groups' }

  let count_combinations { row; damaged_groups } =
    let memo = Memoize.create () in
    let groups = List.map ~f:(fun c -> Whole c) damaged_groups in
    memoized_count memo row 0 groups

  (* a list of counts of combinations per record. *)
  (* count combinations for a single record: row is the array of springs, ix is the current index,
     and groups is the current list of groups. *)
  let count_all_combinations records : Part1.combinations_counts =
    let multiplied = List.map ~f:unfold_springs records in
    List.map ~f:count_combinations multiplied
end

let () =
  let input = In_channel.(input_all stdin) in
  let parsed = Parse.parse_input input in
  let combos = Part1.count_all_combinations parsed in
  let sum = List.fold combos ~init:0 ~f:Int.( + ) in
  let combos2 = Part2.count_all_combinations parsed in
  let sum2 = List.fold combos2 ~init:0 ~f:Int.( + ) in
  Out_channel.(
    printf "%s\nsum: %d\n" (Part1.show_combinations_counts combos) sum;
    printf "%s\nsum: %d\n" (Part1.show_combinations_counts combos2) sum2)