import heapq from lib import get_data from lib import Grid2D from collections import deque from collections import defaultdict CHAR_COSTS = {"A": 1, "B": 10, "C": 100, "D": 1000} HALL_SPOTS = set([(1, x) for x in [1, 2, 4, 6, 8, 10, 11]]) for n_amphs in [2, 4]: AMPH_SPOTS = {} for i, c in enumerate("ABCD"): spots = [(r, 3 + i * 2) for r in range(2, 2 + n_amphs)] AMPH_SPOTS[c] = tuple(spots) data = get_data(__file__) lines = data.splitlines() if n_amphs == 4: lines.insert(-2, " #D#C#B#A# ") lines.insert(-2, " #D#B#A#C# ") lines[-2] = lines[-2] + " " lines[-1] = lines[-1] + " " data = "\n".join(lines) g = Grid2D(data) GRID_EMPTY = g.clone() def mdist(p1, p2): return abs(p1[0] - p2[0]) + abs(p1[1] - p2[1]) poss = [] for c in "ABCD": for coord in g.find(c): poss.append((coord[0], coord[1], c, False)) GRID_EMPTY[coord] = "." poss = tuple(poss) def dist(n1, n2): """cost from node to node""" if n1 == 0: return 0 return n2[1] - n1[1] def h(node): """heuristic function (never overestimate)""" cost = 0 for r, c, char, _ in node[0]: if (r, c) not in AMPH_SPOTS[char]: cost += mdist((r, c), AMPH_SPOTS[char][0]) * CHAR_COSTS[char] return cost def is_goal(node): for r, c, char, _ in node[0]: if (r, c) in AMPH_SPOTS[char]: continue return False return True def find_accessible_spots(pos: tuple[int, int], char, blocked): start = (pos, 0) visited = set() queue = deque([start]) cost = CHAR_COSTS[char] options = [] while queue: pos, dist = queue.popleft() if pos in visited: continue visited.add(pos) for neighbor in GRID_EMPTY.neighbors_ort(pos): if neighbor in blocked: continue if GRID_EMPTY[neighbor] != ".": continue if neighbor not in visited: new_state = (neighbor, dist + cost) queue.append(new_state) options.append(new_state) return options def neighbors(node): nbs = [] blocked = [] blocked_to_char = defaultdict(str) for r, c, char, _ in node[0]: blocked.append((r, c)) blocked_to_char[(r, c)] = char for i, (r, c, char, moved) in enumerate(node[0]): if moved and (r, c) in AMPH_SPOTS[char]: continue for pos, dist in find_accessible_spots((r, c), char, blocked): # cannot move to hall spots twice if pos in HALL_SPOTS and moved: continue # can only move to hall spot or amph spot (of the right type) if not (pos in HALL_SPOTS or pos in AMPH_SPOTS[char]): continue # only move into amph spot if all lower spots are already occupied by the correct amph if pos in AMPH_SPOTS[char] and not all( blocked_to_char[AMPH_SPOTS[char][i]] == char for i in range( AMPH_SPOTS[char].index(pos) + 1, len(AMPH_SPOTS[char]) ) ): continue amphs = list(node[0]) amphs[i] = (pos[0], pos[1], char, True) neighbor = (tuple(amphs), node[1] + dist) nbs.append(neighbor) return nbs starts = [(poss, 0)] open_set = [] g_score = {} cost = None for start in starts: heapq.heappush(open_set, (h(start), start)) g_score[start] = dist(0, start) while open_set: current_f_score, current = heapq.heappop(open_set) if is_goal(current): assert current_f_score == g_score[current] cost = g_score[current] break for neighbor in neighbors(current): tentative_g_score = g_score[current] + dist(current, neighbor) if neighbor not in g_score or tentative_g_score < g_score[neighbor]: g_score[neighbor] = tentative_g_score f_score = g_score[neighbor] + h(neighbor) heapq.heappush(open_set, (f_score, neighbor)) print(cost)