from lib import get_data, Grid2D, LETTERS_UPPER from collections import defaultdict data = get_data(__file__) g = Grid2D(data) start = None end = None warps = defaultdict(list) inner = set() outer = set() for row in range(g.n_rows - 2): for col in range(g.n_cols - 2): a, b, c = g[(row, col)], g[(row, col + 1)], g[(row, col + 2)] x, y, z = g[(row, col)], g[(row + 1, col)], g[(row + 2, col)] if a in LETTERS_UPPER and b in LETTERS_UPPER and c == ".": warps[a + b].append((row, col + 2)) if col == 0: outer.add((row, col + 2)) else: inner.add((row, col + 2)) elif a == "." and b in LETTERS_UPPER and c in LETTERS_UPPER: warps[b + c].append((row, col)) if col + 3 == g.n_cols: outer.add((row, col)) else: inner.add((row, col)) if x in LETTERS_UPPER and y in LETTERS_UPPER and z == ".": if x + y == "AA": start = (row + 2, col) else: warps[x + y].append((row + 2, col)) if row == 0: outer.add((row + 2, col)) else: inner.add((row + 2, col)) elif x == "." and y in LETTERS_UPPER and z in LETTERS_UPPER: if y + z == "ZZ": end = (row, col) else: warps[y + z].append((row, col)) if row + 3 == g.n_rows: outer.add((row, col)) else: inner.add((row, col)) graph = defaultdict(list) allnodes = set([start, end]) for key, (a, b) in warps.items(): graph[a].append((b, 1)) graph[b].append((a, 1)) allnodes.add(a) allnodes.add(b) for startnode in allnodes: to_visit = [startnode] steps = 0 seen = set() while to_visit: steps += 1 new_to_visit = [] for node in to_visit: if node in seen: continue else: seen.add(node) assert node is not None for nb in g.neighbors_ort(node): if nb in seen: continue if nb in allnodes: if not (nb, steps) in graph[startnode]: graph[startnode].append((nb, steps)) if not (startnode, steps) in graph[nb]: graph[nb].append((startnode, steps)) seen.add(nb) elif g[nb] == ".": new_to_visit.append(nb) to_visit = new_to_visit shortest = {start: 0} to_visit = [start] seen = set() while to_visit: to_visit.sort(key=lambda node: shortest[node], reverse=True) current = to_visit.pop() if current in seen: continue else: seen.add(current) for nb, dist in graph[current]: new_dist = shortest[current] + dist if not nb in shortest: shortest[nb] = new_dist elif new_dist < shortest[nb]: shortest[nb] = new_dist if nb not in seen: to_visit.append(nb) print(shortest[end]) shortest = {(start, 0): 0} to_visit = [(start, 0)] seen = set() while to_visit: to_visit.sort(key=lambda node: shortest[node], reverse=True) current, level = to_visit.pop() if (current, level) in seen: continue else: seen.add((current, level)) for nb, dist in graph[current]: new_dist = shortest[(current, level)] + dist if nb == end and level == 0: print(new_dist) to_visit = None break elif nb == end: continue elif nb == start: continue if dist == 1: if current in inner: new_level = level + 1 elif current in outer: new_level = level - 1 else: assert False else: new_level = level if new_level < 0: continue nn = (nb, new_level) if not nn in shortest: shortest[nn] = new_dist elif new_dist < shortest[nn]: shortest[nn] = new_dist to_visit.append(nn)