144 lines
4.4 KiB
Python
144 lines
4.4 KiB
Python
import heapq
|
|
from lib import get_data
|
|
from lib import Grid2D
|
|
from collections import deque
|
|
from collections import defaultdict
|
|
|
|
CHAR_COSTS = {"A": 1, "B": 10, "C": 100, "D": 1000}
|
|
HALL_SPOTS = set([(1, x) for x in [1, 2, 4, 6, 8, 10, 11]])
|
|
|
|
for n_amphs in [2, 4]:
|
|
AMPH_SPOTS = {}
|
|
for i, c in enumerate("ABCD"):
|
|
spots = [(r, 3 + i * 2) for r in range(2, 2 + n_amphs)]
|
|
AMPH_SPOTS[c] = tuple(spots)
|
|
|
|
data = get_data(__file__)
|
|
lines = data.splitlines()
|
|
if n_amphs == 4:
|
|
lines.insert(-2, " #D#C#B#A# ")
|
|
lines.insert(-2, " #D#B#A#C# ")
|
|
lines[-2] = lines[-2] + " "
|
|
lines[-1] = lines[-1] + " "
|
|
data = "\n".join(lines)
|
|
|
|
g = Grid2D(data)
|
|
GRID_EMPTY = g.clone()
|
|
|
|
def mdist(p1, p2):
|
|
return abs(p1[0] - p2[0]) + abs(p1[1] - p2[1])
|
|
|
|
poss = []
|
|
for c in "ABCD":
|
|
for coord in g.find(c):
|
|
poss.append((coord[0], coord[1], c, False))
|
|
GRID_EMPTY[coord] = "."
|
|
poss = tuple(poss)
|
|
|
|
def dist(n1, n2):
|
|
"""cost from node to node"""
|
|
if n1 == 0:
|
|
return 0
|
|
return n2[1] - n1[1]
|
|
|
|
def h(node):
|
|
"""heuristic function (never overestimate)"""
|
|
cost = 0
|
|
for r, c, char, _ in node[0]:
|
|
if (r, c) not in AMPH_SPOTS[char]:
|
|
cost += mdist((r, c), AMPH_SPOTS[char][0]) * CHAR_COSTS[char]
|
|
return cost
|
|
|
|
def is_goal(node):
|
|
for r, c, char, _ in node[0]:
|
|
if (r, c) in AMPH_SPOTS[char]:
|
|
continue
|
|
return False
|
|
return True
|
|
|
|
def find_accessible_spots(pos: tuple[int, int], char, blocked):
|
|
start = (pos, 0)
|
|
visited = set()
|
|
queue = deque([start])
|
|
cost = CHAR_COSTS[char]
|
|
|
|
options = []
|
|
while queue:
|
|
pos, dist = queue.popleft()
|
|
if pos in visited:
|
|
continue
|
|
visited.add(pos)
|
|
|
|
for neighbor in GRID_EMPTY.neighbors_ort(pos):
|
|
if neighbor in blocked:
|
|
continue
|
|
if GRID_EMPTY[neighbor] != ".":
|
|
continue
|
|
if neighbor not in visited:
|
|
new_state = (neighbor, dist + cost)
|
|
queue.append(new_state)
|
|
options.append(new_state)
|
|
return options
|
|
|
|
def neighbors(node):
|
|
nbs = []
|
|
|
|
blocked = []
|
|
blocked_to_char = defaultdict(str)
|
|
for r, c, char, _ in node[0]:
|
|
blocked.append((r, c))
|
|
blocked_to_char[(r, c)] = char
|
|
|
|
for i, (r, c, char, moved) in enumerate(node[0]):
|
|
if moved and (r, c) in AMPH_SPOTS[char]:
|
|
continue
|
|
|
|
for pos, dist in find_accessible_spots((r, c), char, blocked):
|
|
# cannot move to hall spots twice
|
|
if pos in HALL_SPOTS and moved:
|
|
continue
|
|
|
|
# can only move to hall spot or amph spot (of the right type)
|
|
if not (pos in HALL_SPOTS or pos in AMPH_SPOTS[char]):
|
|
continue
|
|
|
|
# only move into amph spot if all lower spots are already occupied by the correct amph
|
|
if pos in AMPH_SPOTS[char] and not all(
|
|
blocked_to_char[AMPH_SPOTS[char][i]] == char
|
|
for i in range(
|
|
AMPH_SPOTS[char].index(pos) + 1, len(AMPH_SPOTS[char])
|
|
)
|
|
):
|
|
continue
|
|
|
|
amphs = list(node[0])
|
|
amphs[i] = (pos[0], pos[1], char, True)
|
|
neighbor = (tuple(amphs), node[1] + dist)
|
|
nbs.append(neighbor)
|
|
return nbs
|
|
|
|
starts = [(poss, 0)]
|
|
open_set = []
|
|
g_score = {}
|
|
cost = None
|
|
|
|
for start in starts:
|
|
heapq.heappush(open_set, (h(start), start))
|
|
g_score[start] = dist(0, start)
|
|
|
|
while open_set:
|
|
current_f_score, current = heapq.heappop(open_set)
|
|
if is_goal(current):
|
|
assert current_f_score == g_score[current]
|
|
cost = g_score[current]
|
|
break
|
|
|
|
for neighbor in neighbors(current):
|
|
tentative_g_score = g_score[current] + dist(current, neighbor)
|
|
if neighbor not in g_score or tentative_g_score < g_score[neighbor]:
|
|
g_score[neighbor] = tentative_g_score
|
|
f_score = g_score[neighbor] + h(neighbor)
|
|
heapq.heappush(open_set, (f_score, neighbor))
|
|
|
|
print(cost)
|