200 lines
5.6 KiB
Python
200 lines
5.6 KiB
Python
from lib import get_data, Grid2D, LETTERS_UPPER, LETTERS_LOWER, add2
|
|
from collections import defaultdict
|
|
|
|
LETTERS = LETTERS_UPPER + LETTERS_LOWER
|
|
|
|
data = """#############
|
|
#g#f.D#..h#l#
|
|
#F###e#E###.#
|
|
#dCba...BcIJ#
|
|
#####.@.#####
|
|
#nK.L...G...#
|
|
#M###N#H###.#
|
|
#o#m..#i#jk.#
|
|
#############"""
|
|
|
|
|
|
def part_1(data):
|
|
g = Grid2D(data)
|
|
(start,) = g.find("@")
|
|
starts = [start]
|
|
|
|
graph = defaultdict(set)
|
|
starts_seen = set()
|
|
|
|
while starts:
|
|
start = starts.pop()
|
|
if start in starts_seen:
|
|
continue
|
|
else:
|
|
starts_seen.add(start)
|
|
|
|
start_symbol = g[start]
|
|
xs = [start]
|
|
seen = set()
|
|
|
|
for steps in range(1_000):
|
|
nxs = []
|
|
|
|
for x in xs:
|
|
if x in seen:
|
|
continue
|
|
else:
|
|
seen.add(x)
|
|
|
|
for nb in g.neighbors_ort(x):
|
|
if g[nb] == ".":
|
|
nxs.append(nb)
|
|
elif g[nb] in LETTERS:
|
|
symbol = g[nb]
|
|
if symbol != start_symbol:
|
|
graph[symbol].add((start_symbol, steps + 1))
|
|
graph[start_symbol].add((symbol, steps + 1))
|
|
starts.append(nb)
|
|
xs = nxs
|
|
if len(xs) == 0:
|
|
break
|
|
|
|
all_keys = [g[p] for p in g.find(LETTERS_LOWER)]
|
|
poss = [(0, 0, tuple("@"), "@")]
|
|
best: dict[tuple[tuple, str], int] = {(("@",), "@"): 0}
|
|
|
|
min_dist = 10**9
|
|
while poss:
|
|
current_distance, key_count, keys, symbol = poss.pop()
|
|
# print(current_distance, key_count, keys, symbol)
|
|
|
|
if key_count - 1 == len(all_keys):
|
|
min_dist = min(min_dist, current_distance)
|
|
continue
|
|
|
|
for next_symbol, distance in graph[symbol]:
|
|
if next_symbol in LETTERS_UPPER and not next_symbol.lower() in keys:
|
|
continue
|
|
|
|
if next_symbol in LETTERS_LOWER:
|
|
new_keys = set(keys)
|
|
new_keys.add(next_symbol)
|
|
new_keys = tuple(sorted(new_keys))
|
|
new_key_count = len(new_keys)
|
|
else:
|
|
new_keys = keys
|
|
new_key_count = key_count
|
|
|
|
new_distance = current_distance + distance
|
|
key = (new_keys, next_symbol)
|
|
if (key not in best) or (key in best and best[key] > new_distance):
|
|
best[key] = new_distance
|
|
poss.append((new_distance, new_key_count, new_keys, next_symbol))
|
|
print(min_dist)
|
|
|
|
|
|
def part_2(data):
|
|
g = Grid2D(data)
|
|
(start,) = g.find("@")
|
|
g[start] = "#"
|
|
g[add2(start, (-1, 0))] = "#"
|
|
g[add2(start, (1, 0))] = "#"
|
|
g[add2(start, (0, 1))] = "#"
|
|
g[add2(start, (0, -1))] = "#"
|
|
|
|
g[add2(start, (-1, -1))] = "0"
|
|
g[add2(start, (-1, 1))] = "1"
|
|
g[add2(start, (1, 1))] = "2"
|
|
g[add2(start, (1, -1))] = "3"
|
|
|
|
starts = g.find("0") + g.find("1") + g.find("2") + g.find("3")
|
|
graph = defaultdict(set)
|
|
starts_seen = set()
|
|
|
|
while starts:
|
|
start = starts.pop()
|
|
if start in starts_seen:
|
|
continue
|
|
else:
|
|
starts_seen.add(start)
|
|
|
|
start_symbol = g[start]
|
|
xs = [start]
|
|
seen = set()
|
|
|
|
for steps in range(1_000):
|
|
nxs = []
|
|
|
|
for x in xs:
|
|
if x in seen:
|
|
continue
|
|
else:
|
|
seen.add(x)
|
|
|
|
for nb in g.neighbors_ort(x):
|
|
if g[nb] == ".":
|
|
nxs.append(nb)
|
|
elif g[nb] in LETTERS:
|
|
symbol = g[nb]
|
|
if symbol != start_symbol:
|
|
graph[start].add((nb, steps + 1))
|
|
graph[nb].add((start, steps + 1))
|
|
starts.append(nb)
|
|
xs = nxs
|
|
if len(xs) == 0:
|
|
break
|
|
# g.print()
|
|
|
|
all_keys = [g[p] for p in g.find(LETTERS_LOWER)]
|
|
|
|
robots = tuple(g.find("0") + g.find("1") + g.find("2") + g.find("3"))
|
|
poss = [(0, tuple(), robots)]
|
|
best: dict[tuple[tuple, tuple], int] = {(tuple(), tuple()): 0}
|
|
|
|
min_dist = 10**9
|
|
while poss:
|
|
current_distance, keys, robots = poss.pop()
|
|
|
|
if len(keys) == len(all_keys):
|
|
min_dist = min(min_dist, current_distance)
|
|
# print(min_dist)
|
|
continue
|
|
|
|
# print(current_distance, keys, robots)
|
|
for robot_i in range(len(robots)):
|
|
robot = robots[robot_i]
|
|
# robot_symbol = g[robot]
|
|
for next_pos, distance in graph[robot]:
|
|
next_symbol = g[next_pos]
|
|
|
|
if next_symbol in LETTERS_UPPER and not next_symbol.lower() in keys:
|
|
continue
|
|
|
|
if next_symbol in LETTERS_LOWER:
|
|
new_keys = set(keys)
|
|
new_keys.add(next_symbol)
|
|
new_keys = tuple(sorted(new_keys))
|
|
else:
|
|
new_keys = keys
|
|
|
|
new_distance = current_distance + distance
|
|
|
|
new_robots = list(robots)
|
|
new_robots[robot_i] = next_pos
|
|
new_robots = tuple(new_robots)
|
|
|
|
key = (new_keys, new_robots)
|
|
if (key not in best) or (key in best and best[key] > new_distance):
|
|
best[key] = new_distance
|
|
poss.append((new_distance, new_keys, new_robots))
|
|
|
|
poss = sorted(poss, key=lambda xs: (xs[0], -len(xs[1])), reverse=True)
|
|
poss = poss[-10000:]
|
|
print(min_dist)
|
|
|
|
|
|
def main():
|
|
data = get_data(__file__)
|
|
part_1(data)
|
|
part_2(data)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|