from lib import get_data, Grid2D, LETTERS_UPPER, LETTERS_LOWER, add2 from collections import defaultdict LETTERS = LETTERS_UPPER + LETTERS_LOWER data = """############# #g#f.D#..h#l# #F###e#E###.# #dCba...BcIJ# #####.@.##### #nK.L...G...# #M###N#H###.# #o#m..#i#jk.# #############""" def part_1(data): g = Grid2D(data) (start,) = g.find("@") starts = [start] graph = defaultdict(set) starts_seen = set() while starts: start = starts.pop() if start in starts_seen: continue else: starts_seen.add(start) start_symbol = g[start] xs = [start] seen = set() for steps in range(1_000): nxs = [] for x in xs: if x in seen: continue else: seen.add(x) for nb in g.neighbors_ort(x): if g[nb] == ".": nxs.append(nb) elif g[nb] in LETTERS: symbol = g[nb] if symbol != start_symbol: graph[symbol].add((start_symbol, steps + 1)) graph[start_symbol].add((symbol, steps + 1)) starts.append(nb) xs = nxs if len(xs) == 0: break all_keys = [g[p] for p in g.find(LETTERS_LOWER)] poss = [(0, 0, tuple("@"), "@")] best: dict[tuple[tuple, str], int] = {(("@",), "@"): 0} min_dist = 10**9 while poss: current_distance, key_count, keys, symbol = poss.pop() # print(current_distance, key_count, keys, symbol) if key_count - 1 == len(all_keys): min_dist = min(min_dist, current_distance) continue for next_symbol, distance in graph[symbol]: if next_symbol in LETTERS_UPPER and not next_symbol.lower() in keys: continue if next_symbol in LETTERS_LOWER: new_keys = set(keys) new_keys.add(next_symbol) new_keys = tuple(sorted(new_keys)) new_key_count = len(new_keys) else: new_keys = keys new_key_count = key_count new_distance = current_distance + distance key = (new_keys, next_symbol) if (key not in best) or (key in best and best[key] > new_distance): best[key] = new_distance poss.append((new_distance, new_key_count, new_keys, next_symbol)) print(min_dist) def part_2(data): g = Grid2D(data) (start,) = g.find("@") g[start] = "#" g[add2(start, (-1, 0))] = "#" g[add2(start, (1, 0))] = "#" g[add2(start, (0, 1))] = "#" g[add2(start, (0, -1))] = "#" g[add2(start, (-1, -1))] = "0" g[add2(start, (-1, 1))] = "1" g[add2(start, (1, 1))] = "2" g[add2(start, (1, -1))] = "3" starts = g.find("0") + g.find("1") + g.find("2") + g.find("3") graph = defaultdict(set) starts_seen = set() while starts: start = starts.pop() if start in starts_seen: continue else: starts_seen.add(start) start_symbol = g[start] xs = [start] seen = set() for steps in range(1_000): nxs = [] for x in xs: if x in seen: continue else: seen.add(x) for nb in g.neighbors_ort(x): if g[nb] == ".": nxs.append(nb) elif g[nb] in LETTERS: symbol = g[nb] if symbol != start_symbol: graph[start].add((nb, steps + 1)) graph[nb].add((start, steps + 1)) starts.append(nb) xs = nxs if len(xs) == 0: break # g.print() all_keys = [g[p] for p in g.find(LETTERS_LOWER)] robots = tuple(g.find("0") + g.find("1") + g.find("2") + g.find("3")) poss = [(0, tuple(), robots)] best: dict[tuple[tuple, tuple], int] = {(tuple(), tuple()): 0} min_dist = 10**9 while poss: current_distance, keys, robots = poss.pop() if len(keys) == len(all_keys): min_dist = min(min_dist, current_distance) # print(min_dist) continue # print(current_distance, keys, robots) for robot_i in range(len(robots)): robot = robots[robot_i] # robot_symbol = g[robot] for next_pos, distance in graph[robot]: next_symbol = g[next_pos] if next_symbol in LETTERS_UPPER and not next_symbol.lower() in keys: continue if next_symbol in LETTERS_LOWER: new_keys = set(keys) new_keys.add(next_symbol) new_keys = tuple(sorted(new_keys)) else: new_keys = keys new_distance = current_distance + distance new_robots = list(robots) new_robots[robot_i] = next_pos new_robots = tuple(new_robots) key = (new_keys, new_robots) if (key not in best) or (key in best and best[key] > new_distance): best[key] = new_distance poss.append((new_distance, new_keys, new_robots)) poss = sorted(poss, key=lambda xs: (xs[0], -len(xs[1])), reverse=True) poss = poss[-10000:] print(min_dist) def main(): data = get_data(__file__) part_1(data) part_2(data) if __name__ == "__main__": main()