Files
aocpy/2021/d19.py

122 lines
3.1 KiB
Python

from lib import get_data
from lib import ints
from collections import defaultdict
from itertools import permutations, product
V3 = tuple[int, int, int]
data = get_data(__file__)
scanners = []
for block in data.split("\n\n"):
scanner = []
for line in block.splitlines()[1:]:
a, b, c = ints(line)
scanner.append((a, b, c))
scanners.append(scanner)
def rotate(vi: V3) -> list[V3]:
r = []
for p in permutations(vi):
for f in product([1, -1], repeat=3):
v = tuple([a * b for a, b in zip(p, f)])
r.append(v)
return r
assert len(rotate((8, 0, 7))) == 48
def sub(a: V3, b: V3) -> V3:
return a[0] - b[0], a[1] - b[1], a[2] - b[2]
def add(a: V3, b: V3) -> V3:
return a[0] + b[0], a[1] + b[1], a[2] + b[2]
def relative(scanner: list[V3]):
d = defaultdict(list)
for i in range(len(scanner)):
for j in range(i + 1, len(scanner)):
for ii, jj in [(i, j), (j, i)]:
a, b = scanner[ii], scanner[jj]
delta = sub(a, b)
d[delta].append((a, b))
return d
def overlap(scanner_1, scanner_2):
expected_overlaps = 15
r1 = relative(scanner_1)
scanners_2 = []
for scanner in list(zip(*list(map(rotate, scanner_2)))):
os = set()
r2 = relative(scanner)
# number of bacon pairs that have the same offset
t = sum(1 for k1 in r1.keys() if k1 in r2)
if t >= expected_overlaps:
for k1, v1 in r1.items():
if k1 in r2:
((abs1, abs2),) = v1
((rel1, rel2),) = r2[k1]
os.add(sub(abs1, rel1))
os.add(sub(abs2, rel2))
if len(os) == 1:
# found the right orientation for scanner_2
scanners_2.append((os.pop(), scanner))
else:
r2 = None
else:
r2 = None
if len(scanners_2) == 0:
return None
((orig_2, scanner_2_rel),) = scanners_2
scanner_2_abs = [add(orig_2, b) for b in scanner_2_rel]
return orig_2, scanner_2_abs
origs = [(0, 0, 0)]
todo = set(range(len(scanners)))
done = set([0]) if 0 in todo else set([todo.pop()])
todo.discard(0)
while todo:
for i in range(len(scanners)):
for j in range(len(scanners)):
if i == j:
continue
if i not in done:
continue
if j in done:
continue
r = overlap(scanners[i], scanners[j])
if r is None:
continue
o, s2 = r
origs.append(o)
no = len(set(scanners[i]).intersection(s2))
if no >= 12:
scanners[j] = s2
done.add(j)
todo.discard(j)
# print(f"{i} < {no} > {j} at {o}")
all = []
for s in scanners:
all += s
print(len(set(all)))
def mdist(a, b):
return sum(abs(aa - bb) for aa, bb in zip(a, b))
m = 0
for i in range(len(origs)):
for j in range(i + 1, len(origs)):
m = max(m, mdist(origs[i], origs[j]))
print(m)