336 lines
8.9 KiB
Python
336 lines
8.9 KiB
Python
import re
|
|
import os
|
|
import string
|
|
import heapq
|
|
|
|
NUMBERS = string.digits
|
|
LETTERS_LOWER = string.ascii_lowercase
|
|
LETTERS_UPPER = string.ascii_uppercase
|
|
|
|
INF = float("inf")
|
|
fst = lambda l: l[0]
|
|
snd = lambda l: l[1]
|
|
|
|
|
|
def nth(n):
|
|
return lambda l: l[n]
|
|
|
|
|
|
def maps(f, xs):
|
|
if isinstance(xs, list):
|
|
return [maps(f, x) for x in xs]
|
|
return f(xs)
|
|
|
|
|
|
def mape(f, xs):
|
|
return list(map(f, xs))
|
|
|
|
|
|
def add2(a: tuple[int, int], b: tuple[int, int]) -> tuple[int, int]:
|
|
return (a[0] + b[0], a[1] + b[1])
|
|
|
|
|
|
class Grid2D:
|
|
N = (-1, 0)
|
|
E = (0, 1)
|
|
S = (1, 0)
|
|
W = (0, -1)
|
|
NW = (-1, -1)
|
|
NE = (-1, 1)
|
|
SE = (1, 1)
|
|
SW = (1, -1)
|
|
COORDS_ORTH = (N, E, S, W)
|
|
COORDS_DIAG = (NW, NE, SE, SW)
|
|
|
|
def __init__(self, text: str):
|
|
lines = [line for line in text.splitlines() if line.strip() != ""]
|
|
self.grid = list(map(list, lines))
|
|
self.n_rows = len(self.grid)
|
|
self.n_cols = len(self.grid[0])
|
|
|
|
def __getitem__(self, pos: tuple[int, int]):
|
|
row, col = pos
|
|
return self.grid[row][col]
|
|
|
|
def __setitem__(self, pos: tuple[int, int], val):
|
|
row, col = pos
|
|
self.grid[row][col] = val
|
|
|
|
def hash(self):
|
|
return tuple(map(lambda row: tuple(row), self.grid))
|
|
|
|
def clone(self):
|
|
from copy import deepcopy
|
|
return deepcopy(self)
|
|
|
|
def clone_with_val(self, val):
|
|
c = Grid2D("d\nd")
|
|
c.n_rows = self.n_rows
|
|
c.n_cols = self.n_cols
|
|
c.grid = [[val for _ in range(c.n_cols)] for _ in range(self.n_rows)]
|
|
return c
|
|
|
|
def rows(self) -> list[list[str]]:
|
|
return [row for row in self.grid]
|
|
|
|
def cols(self) -> list[list[str]]:
|
|
rows = self.rows()
|
|
return [[row[col_i] for row in rows] for col_i in range(self.n_cols)]
|
|
|
|
def find(self, chars: str) -> list[tuple[int, int]]:
|
|
return [c for c in self.all_coords() if self[c] in chars]
|
|
|
|
def find_not(self, chars: str) -> list[tuple[int, int]]:
|
|
return [c for c in self.all_coords() if self[c] not in chars]
|
|
|
|
def all_coords(self) -> list[tuple[int, int]]:
|
|
return [
|
|
(row_i, col_i)
|
|
for row_i in range(self.n_rows)
|
|
for col_i in range(self.n_cols)
|
|
]
|
|
|
|
def row_coords(self, row_i) -> list[tuple[int, int]]:
|
|
assert row_i < self.n_rows, f"{row_i=} must be smaller than {self.n_rows=}"
|
|
return [(col_i, row_i) for col_i in range(self.n_cols)]
|
|
|
|
def col_coords(self, col_i) -> list[tuple[int, int]]:
|
|
assert col_i < self.n_cols, f"{col_i=} must be smaller than {self.n_cols=}"
|
|
return [(col_i, row_i) for row_i in range(self.n_rows)]
|
|
|
|
def contains(self, pos: tuple[int, int]) -> bool:
|
|
row, col = pos
|
|
return row >= 0 and row < self.n_rows and col >= 0 and col < self.n_cols
|
|
|
|
def __contains__(self, pos: tuple[int, int]) -> bool:
|
|
return self.contains(pos)
|
|
|
|
def neighbors_ort(self, pos: tuple[int, int]) -> list[tuple[int, int]]:
|
|
return [
|
|
add2(pos, off) for off in self.dirs_ort() if self.contains(add2(pos, off))
|
|
]
|
|
|
|
def neighbors_vert(self, pos: tuple[int, int]) -> list[tuple[int, int]]:
|
|
return [
|
|
add2(pos, off) for off in self.dirs_vert() if self.contains(add2(pos, off))
|
|
]
|
|
|
|
def neighbors_adj(self, pos: tuple[int, int]) -> list[tuple[int, int]]:
|
|
return self.neighbors_ort(pos) + self.neighbors_vert(pos)
|
|
|
|
def flip_ort(self, pos: tuple[int, int]) -> tuple[int, int]:
|
|
return (-pos[0], -pos[1])
|
|
|
|
def dirs_ort(self) -> list[tuple[int, int]]:
|
|
return [self.N, self.E, self.S, self.W]
|
|
|
|
def dirs_vert(self) -> list[tuple[int, int]]:
|
|
return [self.NE, self.SE, self.SW, self.NW]
|
|
|
|
def print(self):
|
|
for r in self.rows():
|
|
print("".join(r))
|
|
|
|
def print_with_gaps(self):
|
|
for r in self.rows():
|
|
print(" ".join(map(str, r)))
|
|
|
|
|
|
class Input:
|
|
def __init__(self, text: str):
|
|
if os.path.isfile(text):
|
|
self.text = open(text).read()
|
|
else:
|
|
self.text = text
|
|
|
|
def stats(self):
|
|
print(f" size: {len(self.text)}")
|
|
print(f"lines: {len(self.text.splitlines())}")
|
|
ps = len(self.paras())
|
|
print(f"paras: {ps}")
|
|
|
|
def lines(self) -> list[str]:
|
|
return self.text.splitlines()
|
|
|
|
def paras(self) -> list[str]:
|
|
return [p for p in self.text.split("\n\n")]
|
|
|
|
def grid2(self) -> Grid2D:
|
|
return Grid2D(self.text)
|
|
|
|
|
|
def prime_factors(n):
|
|
"""
|
|
Returns a list of prime factors for n.
|
|
|
|
:param n: number for which prime factors should be returned
|
|
"""
|
|
factors = []
|
|
rest = n
|
|
divisor = 2
|
|
while rest % divisor == 0:
|
|
factors.append(divisor)
|
|
rest //= divisor
|
|
divisor = 3
|
|
while divisor * divisor <= rest:
|
|
while rest % divisor == 0:
|
|
factors.append(divisor)
|
|
rest //= divisor
|
|
divisor += 2
|
|
if rest != 1:
|
|
factors.append(rest)
|
|
return factors
|
|
|
|
|
|
def lcm(numbers: list[int]) -> int:
|
|
fs = []
|
|
for n in numbers:
|
|
fs += prime_factors(n)
|
|
s = 1
|
|
fs = list(set(fs))
|
|
for f in fs:
|
|
s *= f
|
|
return s
|
|
|
|
|
|
def str_to_int(line: str) -> int:
|
|
line = line.replace(" ", "")
|
|
r = re.compile(r"(-?\d+)")
|
|
m = r.findall(line)
|
|
(x,) = m
|
|
return int(x)
|
|
|
|
|
|
def str_to_ints(line: str) -> list[int]:
|
|
r = re.compile(r"-?\d+")
|
|
return list(map(int, r.findall(line)))
|
|
|
|
|
|
ints = str_to_ints
|
|
|
|
|
|
def str_to_lines_no_empty(text: str) -> list[str]:
|
|
return list(filter(lambda l: l.strip() != "", text.splitlines()))
|
|
|
|
|
|
def str_to_lines(text: str) -> list[str]:
|
|
return list(text.splitlines())
|
|
|
|
|
|
def count_trailing_repeats(lst):
|
|
count = 0
|
|
for elem in reversed(lst):
|
|
if elem != lst[-1]:
|
|
break
|
|
else:
|
|
count += 1
|
|
return count
|
|
|
|
|
|
class A_Star(object):
|
|
def __init__(self, starts, is_goal, h, d, neighbors):
|
|
"""
|
|
:param h: heuristic function (never overestimate)
|
|
:param d: cost from node to node function
|
|
:param neighbors: neighbors function
|
|
"""
|
|
open_set = []
|
|
g_score = {}
|
|
self.cost = None
|
|
|
|
for start in starts:
|
|
heapq.heappush(open_set, (h(start), start))
|
|
g_score[start] = d(0, start)
|
|
|
|
while open_set:
|
|
current_f_score, current = heapq.heappop(open_set)
|
|
if is_goal(current):
|
|
assert current_f_score == g_score[current]
|
|
self.cost = g_score[current]
|
|
break
|
|
|
|
for neighbor in neighbors(current):
|
|
tentative_g_score = g_score[current] + d(current, neighbor)
|
|
if neighbor not in g_score or tentative_g_score < g_score[neighbor]:
|
|
g_score[neighbor] = tentative_g_score
|
|
f_score = g_score[neighbor] + h(neighbor)
|
|
heapq.heappush(open_set, (f_score, neighbor))
|
|
|
|
|
|
def shoelace_area(corners):
|
|
n = len(corners)
|
|
area = 0
|
|
for i in range(n):
|
|
x1, y1 = corners[i]
|
|
x2, y2 = corners[(i + 1) % n]
|
|
area += (x1 * y2) - (x2 * y1)
|
|
return abs(area) / 2.0
|
|
|
|
|
|
def extract_year_and_date(scriptname) -> tuple[str, str]:
|
|
r = re.compile(r"aoc(\d\d\d\d)/d(\d+).py")
|
|
[(year, day)] = r.findall(scriptname)
|
|
return (year, day)
|
|
|
|
|
|
def get_data(filename):
|
|
path, file = os.path.split(filename)
|
|
year = path[-4:]
|
|
day = file.replace("d", "").replace(".py", "")
|
|
txt_file = f"d{day}.txt"
|
|
|
|
if os.path.isfile(txt_file):
|
|
with open(txt_file) as f:
|
|
return f.read()
|
|
else:
|
|
import subprocess
|
|
|
|
subprocess.call(["../get.py", year, day])
|
|
assert os.path.isfile(txt_file), "Could not download AoC file"
|
|
with open(txt_file) as f:
|
|
return f.read()
|
|
|
|
|
|
def mod_inverse(a, m):
|
|
def egcd(a, b):
|
|
if a == 0:
|
|
return b, 0, 1
|
|
else:
|
|
g, y, x = egcd(b % a, a)
|
|
return g, x - (b // a) * y, y
|
|
|
|
g, x, _ = egcd(a, m)
|
|
if g != 1:
|
|
raise Exception("Modular inverse does not exist")
|
|
else:
|
|
return x % m
|
|
|
|
|
|
class V:
|
|
def __init__(self, *args):
|
|
self.xs: tuple[int] = tuple(args)
|
|
|
|
def __eq__(self, other):
|
|
if isinstance(other, V):
|
|
return all(v1 == v2 for v1, v2 in zip(self.xs, other.xs))
|
|
elif hasattr(other, "__len__") and len(self.xs) == len(other):
|
|
return all(v1 == v2 for v1, v2 in zip(self.xs, other))
|
|
return False
|
|
|
|
def __getitem__(self, i: int):
|
|
return self.xs[i]
|
|
|
|
def __hash__(self):
|
|
return hash(self.xs)
|
|
|
|
def __add__(self, other):
|
|
if isinstance(other, V):
|
|
return V(*[v1 + v2 for v1, v2 in zip(self.xs, other.xs)])
|
|
assert hasattr(other, "__len__"), f"V.__add__({self}, {other}) missing `len`"
|
|
assert len(self.xs) == len(other), f"V.__add__({self}, {other}) `len` mismatch"
|
|
return V(*[v1 + v2 for v1, v2 in zip(self.xs, other)])
|
|
|
|
def __repr__(self):
|
|
s = ", ".join(map(str, self.xs))
|
|
return f"V({s})"
|