view 2023/day08.ml @ 57:4a584287ebec

Day 10 Part 1
author Lewin Bormann <lbo@spheniscida.de>
date Wed, 20 Dec 2023 20:55:26 +0100
parents 52ad18a532ca
children
line wrap: on
line source

open Base
open Core
open Angstrom

(* good old simple hashtable *)
module Hashtbl = Base.Hashtbl

type direction = L | R [@@deriving sexp]
type pointer = string [@@deriving sexp]
type node = { id : pointer; left : pointer; right : pointer } [@@deriving sexp]
type network = (pointer, node) Base.Hashtbl.t

let create_network () = Hashtbl.create (module String)

let string_of_network net =
  let f ~key ~data acc =
    Printf.sprintf "%s => (%s, %s)\n" key data.left data.right :: acc
  in
  let ss = Hashtbl.fold ~init:[] ~f net in
  String.concat ss

let network_of_nodes nodes : network =
  let f n = (n.id, n) in
  let l = List.map ~f nodes in
  Hashtbl.of_alist_exn (module String) l

module Parse = struct
  let directionP = char 'L' *> return L <|> char 'R' *> return R
  let instructionsP = many1 directionP <* char '\n'
  let pointerP = take_while (fun c -> Char.is_uppercase c || Char.is_digit c)

  let nodeP =
    let open Angstrom.Let_syntax in
    let%bind id = pointerP in
    let%bind _ = string " = (" in
    let%bind left = pointerP in
    let%bind _ = string ", " in
    let%bind right = pointerP in
    let%bind _ = char ')' in
    let%bind _ = option '\n' (char '\n') in
    return { id; left; right }

  let contentsP =
    let open Angstrom.Let_syntax in
    let%bind instructions = instructionsP in
    let%bind _ = char '\n' in
    let%bind nodes = many1 nodeP in
    let network = network_of_nodes nodes in
    return (network, instructions)

  exception Parse_exn of string

  let parse_all ch =
    let input = In_channel.input_all ch in
    match parse_string ~consume:All contentsP input with
    | Ok ok -> ok
    | Error e -> raise (Parse_exn e)
end

module Part1 = struct
  let rec traverse_once network count current instrs =
    match current with
    | "ZZZ" -> (current, count)
    | _ -> (
        let node = Hashtbl.find_exn network current in
        match instrs with
        | [] -> (current, count)
        | x :: xs -> (
            match x with
            | L -> traverse_once network (count + 1) node.left xs
            | R -> traverse_once network (count + 1) node.right xs))

  let initial_node : pointer = "AAA"

  let traverse network instructions =
    let rec do_it count0 from =
      match traverse_once network from count0 instructions with
      | "ZZZ", count -> count
      | from', count -> do_it from' count
    in
    do_it initial_node 0
end

module Part2 = struct
  type state = pointer Array.t * int Array.t

  let find_start_ptrs network =
    let keys = Hashtbl.keys network in
    let f = String.is_suffix ~suffix:"A" in
    Array.of_list (List.filter ~f keys)

  let is_end_ptr = String.is_suffix ~suffix:"Z"

  let rec all ~f ?(ix = 0) a =
    if Int.(ix = Array.length a) then true
    else if not (f a.(ix)) then false
    else all ~ix:(ix + 1) ~f a

  let print_initial ((st, _) : state) =
    let f ix e = Out_channel.printf "Path %d => %s\n" ix e in
    Array.iteri ~f st

  let print_period count ((st, _) : state) =
    let f ix e =
      if is_end_ptr e then Out_channel.printf "%d => %d\n" ix count else ()
    in
    Array.iteri ~f st

  let update_periods ((st, periods) : state) count =
    let f ix e =
      if is_end_ptr e && Int.(periods.(ix) = 0) then periods.(ix) <- count
      else ()
    in
    Array.iteri ~f st

  let periods_complete = all ~f:(fun x -> Int.(x <> 0))
  let is_finished ((st, _) : state) = all ~f:is_end_ptr st

  let next network ptr =
    let node = Hashtbl.find_exn network ptr in
    function L -> node.left | R -> node.right

  let rec gcd a b = if b = 0 then a else gcd b (a mod b)
  let lcm a b = Int.abs (a * b) / gcd a b
  let rec multi_lcm = function [] -> 1 | x :: xs -> lcm x (multi_lcm xs)

  let rec traverse_once network count ((st, periods) as state : state) instrs =
    if periods_complete periods then (count, state, true)
    else
      let f instr p = next network p instr in
      match instrs with
      | [] -> (count, state, false)
      | x :: xs ->
          update_periods state count;
          Array.map_inplace ~f:(f x) st;
          traverse_once network (count + 1) state xs

  (* Same algorithm as above, but now we also track the period of each
     path. Each path is determined by its initial node. The period is set once the
     a destination node (ending in Z) has been reached. Once all periods are discovered,
     the number of turns it takes to finish all paths is the lowest common multiple of
     all periods. *)
  let traverse network instrs =
    let initial_ptrs = find_start_ptrs network in
    let periods = Array.create ~len:(Array.length initial_ptrs) 0 in
    let rec do_it count st =
      match traverse_once network count st instrs with
      | count', st', false -> do_it count' st'
      | count', _, true -> count'
    in
    let state = (initial_ptrs, periods) in
    print_initial state;
    ignore @@ do_it 0 state;
    periods

  let solve network instrs =
    let periods = traverse network instrs in
    multi_lcm (Array.to_list periods)
end

let () =
  let network, instructions = Parse.parse_all In_channel.stdin in
  Out_channel.output_string Out_channel.stdout (string_of_network network);
  Out_channel.output_string Out_channel.stdout "\n";
  let count = Part1.traverse network instructions in
  Out_channel.printf "1: Took %d turns\n" count;
  let count2 = Part2.solve network instructions in
  Out_channel.printf "2: Would take %d turns\n" count2