view 2023/day17.ml @ 77:85797fc052cc

Day 17 Part 1: Add comments
author Lewin Bormann <lbo@spheniscida.de>
date Sun, 31 Dec 2023 09:37:40 +0100
parents 2d05d3e059ce
children ade1919a5409
line wrap: on
line source

open Angstrom
open Base
open Core

module PrioQueue (Inner : Comparator.S) = struct
  (* Element added to set *)
  module Elt = struct
    type t = int * Inner.t

    let compare (a, x) (b, y) =
      match Int.compare a b with 0 -> Inner.comparator.compare x y | x -> x

    let sexp_of_t (a, x) =
      Sexp.List [ Int.sexp_of_t a; Inner.comparator.sexp_of_t x ]
  end

  (* Comparator for Set *)
  module EltComp = struct
    include Comparator.Make (Elt)

    type t = Elt.t
  end

  type t = (Elt.t, EltComp.comparator_witness) Set.t

  let empty : t = Set.empty (module EltComp)
  let is_empty = Set.is_empty
  let add q ~prio elt = Set.add q (prio, elt)
  let min_elt_exn q = snd (Set.min_elt_exn q)

  let remove_min_elt_exn q =
    let ((_, elt) as min) = Set.min_elt_exn q in
    (Set.remove q min, elt)

  let sexp_of_t q =
    let elts = Set.to_list q in
    Sexp.List (List.map elts ~f:Elt.sexp_of_t)
end

module Int_pq = PrioQueue (Int)

let () =
  let q = Int_pq.empty in
  let q = Int_pq.add q ~prio:1 1 in
  let q = Int_pq.add q ~prio:2 2 in
  let q = Int_pq.add q ~prio:0 3 in
  assert (Int_pq.min_elt_exn q = 3)

module type Comparable = sig
  type t [@@deriving sexp, compare]
end

type direction = North | East | South | West [@@deriving sexp, compare, eq]

type heatloss = { value : int; min_so_far : int; prev : int * int }
[@@deriving sexp, compare]

type field = { r : int; c : int; field : heatloss array }
[@@deriving sexp, compare]

let rc_to_ix field r c = (r * field.c) + c
let ix_to_rc field ix = Int.(ix / field.c, ix % field.c)
let field_at field r c = field.field.(rc_to_ix field r c)

let field_update field r c f =
  let v = field_at field r c in
  field.field.(rc_to_ix field r c) <- f v

module Parse = struct
  let parse_tile c =
    String.of_char c |> Int.of_string |> fun value ->
    { value; min_so_far = Int.max_value; prev = (0, 0) }

  let parse_line s = String.to_list s |> List.map ~f:parse_tile

  let parse_field inp =
    let lines = String.split inp ~on:'\n' in
    let r = List.length lines in
    let c = String.length (List.hd_exn lines) in
    let last = List.last_exn lines in
    let r = if String.is_empty last then r - 1 else r in
    let parsed_lines = List.map lines ~f:parse_line in
    let field = Array.of_list (List.concat parsed_lines) in
    { r; c; field }
end

module Position = struct
  type t = {
    r : int;
    c : int;
    prev : int * int;
    dir : direction;
    straight : int;
    heatloss : int;
  }
  [@@deriving sexp]

  let compare { heatloss = hl1; r = r1; c = c1; _ }
      { heatloss = hl2; r = r2; c = c2; _ } =
    match Int.compare hl1 hl2 with
    | 0 -> ( match Int.compare r1 r2 with 0 -> Int.compare c1 c2 | c -> c)
    | c -> c

  let initial =
    { r = 0; c = 0; dir = East; prev = (0, 0); straight = 0; heatloss = 0 }
end

(* Directly create a priority queue from a module containing compare/sexp_of_t *)
module PrimPrioQueue (Inner : Comparable) = struct
  module CompPos = struct
    include Inner
    include Comparable.Make (Inner)
  end

  include PrioQueue (CompPos)
end

module Pospq = PrimPrioQueue (Position)

module Part1 = struct
  type neighbor = int * int * direction

  let dst field = (field.r - 1, field.c - 1)
  let initial = [ Position.initial ]
  let max_straight = 3

  (* Return potential neighbors at position (r, c) in direction dir *)
  let neighbors r c : direction -> neighbor list = function
    | North -> [ (r - 1, c, North); (r, c - 1, West); (r, c + 1, East) ]
    | East -> [ (r - 1, c, North); (r + 1, c, South); (r, c + 1, East) ]
    | South -> [ (r + 1, c, South); (r, c - 1, West); (r, c + 1, East) ]
    | West -> [ (r - 1, c, North); (r + 1, c, South); (r, c - 1, West) ]

  (* Filter out neighbors that are not valid *)
  let valid_neighbors field Position.{ dir; straight; heatloss; _ } neighbors :
      neighbor list =
    let straight_ok = function
      | _, _, dir' when equal_direction dir dir' -> straight <= max_straight
      | _ -> true
    and within_field (r, c, _) = r >= 0 && r < field.r && c >= 0 && c < field.c
    and better_path (r, c, _) =
      let best_heatloss = (field_at field r c).min_so_far in
      let this_heatloss = heatloss + (field_at field r c).value in
      this_heatloss <= best_heatloss
    in
    let valid n =
      match straight_ok n && within_field n with
      | true -> better_path n
      | false -> false
    in
    List.filter neighbors ~f:valid

  (* From position pos, return a list of next tiles to go *)
  let next_options field (Position.{ r; c; dir; straight; _ } as pos) =
    let neighbors = neighbors r c dir in
    let neighbors' = valid_neighbors field pos neighbors in
    let pos_of_neighbor (r', c', dir') =
      Position.
        {
          r = r';
          c = c';
          (* Direction this position is facing when entering its tile *)
          dir = dir';
          (* Allow tracking of path by remembering previous tile *)
          prev = (r, c);
          (* Number of straight tiles in a row *)
          straight =
            (* careful, tricky! (not sure if I got this right, either *)
            (if Int.(straight = 1) || equal_direction dir dir' then straight + 1
             else 1);
          heatloss = pos.heatloss + (field_at field r' c').value;
        }
    in
    List.map neighbors' ~f:pos_of_neighbor

  (* Apply Dijkstra's algorithm (or something like that...)
     to find the shortest path according to restrictions. *)
  let solve field =
    let dstr, dstc = dst field in
    let rec loop (q : Pospq.t) =
      (* No path could be found, options exhausted *)
      if Pospq.is_empty q then None
      else
        let q, pos = Pospq.remove_min_elt_exn q in
        (* Arrived at destination *)
        if Int.equal pos.r dstr && Int.equal pos.c dstc then (
          field_update field pos.r pos.c (fun v ->
              { v with min_so_far = pos.heatloss; prev = pos.prev });
          Some pos.heatloss)
        else
          (* Check if tile has been visited before *)
          let min_so_far = (field_at field pos.r pos.c).min_so_far in
          if min_so_far = Int.max_value then (
            (* We are first; found minimal path to tile; update cost and previous tile *)
            field_update field pos.r pos.c (fun v ->
                { v with min_so_far = pos.heatloss; prev = pos.prev });
            let next = next_options field pos in
            (* Add all next options to the queue *)
            let q =
              List.fold_left next ~init:q ~f:(fun q' opt ->
                  Pospq.add q' ~prio:opt.heatloss opt)
            in
            loop q
            (* Already visited this tile, skip *))
          else loop q
    in
    loop (List.fold_left initial ~init:Pospq.empty ~f:(Pospq.add ~prio:0))

  (* Recursively follow `prev` links in the field array. *)
  let rec trace_path ?(acc = []) field r c =
    match (r, c) with
    | 0, 0 -> (0, 0) :: acc
    | _ ->
        let entry = field_at field r c in
        let r', c' = entry.prev in
        trace_path ~acc:((r, c) :: acc) field r' c'

  (* Trace back the path from the destination to the start *)
  let start_trace_path field =
    let r, c = dst field in
    trace_path field r c

  (* Create a 2D array mapping visited tiles *)
  let visualizer_field field path =
    let a = Array.map field.field ~f:(fun _ -> 0) in
    let f (r, c) = a.(rc_to_ix field r c) <- 1 in
    List.iter path ~f;
    a

  (* Print the 2D array *)
  let print_visualizer_field field r c =
    let rows = Sequence.range 0 r and cols = Sequence.range 0 c in
    let f r' =
      Sequence.iter cols ~f:(fun c' -> printf "%d" field.((r' * r) + c'));
      printf "\n"
    in
    Sequence.iter rows ~f

  (* Create and print the path on a 2D map *)
  let visualize_field field path =
    let f = visualizer_field field path in
    print_visualizer_field f field.r field.c
end

let () =
  let inp = In_channel.(input_all stdin) in
  let field = Parse.parse_field inp in
  let part1 = Option.value_exn (Part1.solve field) in
  let trace = Part1.start_trace_path field in
  Out_channel.printf "Part1: %d\n" part1;
  Out_channel.printf "Path: %s\n"
    (Sexp.to_string_hum
       (List.sexp_of_t
          (fun (r, c) -> Sexp.List [ Int.sexp_of_t r; Int.sexp_of_t c ])
          trace));
  Part1.visualize_field field trace