Files
aocpy/2021/d23.py
2024-11-27 17:31:52 -05:00

144 lines
4.4 KiB
Python

import heapq
from lib import get_data
from lib import Grid2D
from collections import deque
from collections import defaultdict
CHAR_COSTS = {"A": 1, "B": 10, "C": 100, "D": 1000}
HALL_SPOTS = set([(1, x) for x in [1, 2, 4, 6, 8, 10, 11]])
for n_amphs in [2, 4]:
AMPH_SPOTS = {}
for i, c in enumerate("ABCD"):
spots = [(r, 3 + i * 2) for r in range(2, 2 + n_amphs)]
AMPH_SPOTS[c] = tuple(spots)
data = get_data(__file__)
lines = data.splitlines()
if n_amphs == 4:
lines.insert(-2, " #D#C#B#A# ")
lines.insert(-2, " #D#B#A#C# ")
lines[-2] = lines[-2] + " "
lines[-1] = lines[-1] + " "
data = "\n".join(lines)
g = Grid2D(data)
GRID_EMPTY = g.clone()
def mdist(p1, p2):
return abs(p1[0] - p2[0]) + abs(p1[1] - p2[1])
poss = []
for c in "ABCD":
for coord in g.find(c):
poss.append((coord[0], coord[1], c, False))
GRID_EMPTY[coord] = "."
poss = tuple(poss)
def dist(n1, n2):
"""cost from node to node"""
if n1 == 0:
return 0
return n2[1] - n1[1]
def h(node):
"""heuristic function (never overestimate)"""
cost = 0
for r, c, char, _ in node[0]:
if (r, c) not in AMPH_SPOTS[char]:
cost += mdist((r, c), AMPH_SPOTS[char][0]) * CHAR_COSTS[char]
return cost
def is_goal(node):
for r, c, char, _ in node[0]:
if (r, c) in AMPH_SPOTS[char]:
continue
return False
return True
def find_accessible_spots(pos: tuple[int, int], char, blocked):
start = (pos, 0)
visited = set()
queue = deque([start])
cost = CHAR_COSTS[char]
options = []
while queue:
pos, dist = queue.popleft()
if pos in visited:
continue
visited.add(pos)
for neighbor in GRID_EMPTY.neighbors_ort(pos):
if neighbor in blocked:
continue
if GRID_EMPTY[neighbor] != ".":
continue
if neighbor not in visited:
new_state = (neighbor, dist + cost)
queue.append(new_state)
options.append(new_state)
return options
def neighbors(node):
nbs = []
blocked = []
blocked_to_char = defaultdict(str)
for r, c, char, _ in node[0]:
blocked.append((r, c))
blocked_to_char[(r, c)] = char
for i, (r, c, char, moved) in enumerate(node[0]):
if moved and (r, c) in AMPH_SPOTS[char]:
continue
for pos, dist in find_accessible_spots((r, c), char, blocked):
# cannot move to hall spots twice
if pos in HALL_SPOTS and moved:
continue
# can only move to hall spot or amph spot (of the right type)
if not (pos in HALL_SPOTS or pos in AMPH_SPOTS[char]):
continue
# only move into amph spot if all lower spots are already occupied by the correct amph
if pos in AMPH_SPOTS[char] and not all(
blocked_to_char[AMPH_SPOTS[char][i]] == char
for i in range(
AMPH_SPOTS[char].index(pos) + 1, len(AMPH_SPOTS[char])
)
):
continue
amphs = list(node[0])
amphs[i] = (pos[0], pos[1], char, True)
neighbor = (tuple(amphs), node[1] + dist)
nbs.append(neighbor)
return nbs
starts = [(poss, 0)]
open_set = []
g_score = {}
cost = None
for start in starts:
heapq.heappush(open_set, (h(start), start))
g_score[start] = dist(0, start)
while open_set:
current_f_score, current = heapq.heappop(open_set)
if is_goal(current):
assert current_f_score == g_score[current]
cost = g_score[current]
break
for neighbor in neighbors(current):
tentative_g_score = g_score[current] + dist(current, neighbor)
if neighbor not in g_score or tentative_g_score < g_score[neighbor]:
g_score[neighbor] = tentative_g_score
f_score = g_score[neighbor] + h(neighbor)
heapq.heappush(open_set, (f_score, neighbor))
print(cost)