aocpy/2019/d7.py
2024-08-03 16:31:06 -04:00

145 lines
4.4 KiB
Python

from lib import get_data, str_to_ints
from itertools import permutations
class Amp:
def __init__(self, xs):
self.xs = list(xs)
self.i = 0
self.inputs = []
self.outputs = []
self.done = False
def feed(self, input):
self.inputs.append(input)
def pop(self):
v = self.outputs[0]
self.outputs = self.outputs[1:]
return v
def go(self):
xs = self.xs
i = self.i
while i < len(xs):
inst = str(xs[i])
inst = "0" * (5 - len(inst)) + inst
assert len(inst) == 5
op = int(inst[3:5])
mode_p1 = int(inst[2])
mode_p2 = int(inst[1])
mode_p3 = int(inst[0])
match op:
case 1:
p1 = xs[xs[i + 1]] if mode_p1 == 0 else xs[i + 1]
p2 = xs[xs[i + 2]] if mode_p2 == 0 else xs[i + 2]
assert mode_p3 == 0
xs[xs[i + 3]] = p1 + p2
i += 4
case 2:
p1 = xs[xs[i + 1]] if mode_p1 == 0 else xs[i + 1]
p2 = xs[xs[i + 2]] if mode_p2 == 0 else xs[i + 2]
assert mode_p3 == 0
xs[xs[i + 3]] = p1 * p2
i += 4
case 3:
assert mode_p1 == 0
assert len(self.inputs) > 0
xs[xs[i + 1]] = self.inputs[0]
self.inputs = self.inputs[1:]
i += 2
case 4:
if mode_p1 == 0:
v = xs[xs[i + 1]]
else:
v = xs[i + 1]
self.outputs.append(v)
i += 2
self.i = i
return
case 99:
self.done = True
return
case 5:
p1 = xs[xs[i + 1]] if mode_p1 == 0 else xs[i + 1]
p2 = xs[xs[i + 2]] if mode_p2 == 0 else xs[i + 2]
if p1 != 0:
i = p2
else:
i += 3
case 6:
p1 = xs[xs[i + 1]] if mode_p1 == 0 else xs[i + 1]
p2 = xs[xs[i + 2]] if mode_p2 == 0 else xs[i + 2]
if p1 == 0:
i = p2
else:
i += 3
case 7:
p1 = xs[xs[i + 1]] if mode_p1 == 0 else xs[i + 1]
p2 = xs[xs[i + 2]] if mode_p2 == 0 else xs[i + 2]
assert mode_p3 == 0
if p1 < p2:
xs[xs[i + 3]] = 1
else:
xs[xs[i + 3]] = 0
i += 4
case 8:
p1 = xs[xs[i + 1]] if mode_p1 == 0 else xs[i + 1]
p2 = xs[xs[i + 2]] if mode_p2 == 0 else xs[i + 2]
assert mode_p3 == 0
if p1 == p2:
xs[xs[i + 3]] = 1
else:
xs[xs[i + 3]] = 0
i += 4
self.i = i
def part_1(data):
xs_orig = str_to_ints(data)
max_output = 0
for ps in permutations(list(range(5))):
current_output = 0
for p in ps:
a = Amp(xs_orig)
a.feed(p)
a.feed(current_output)
a.go()
assert len(a.outputs) == 1
current_output = a.outputs.pop()
max_output = max(max_output, current_output)
print(max_output)
def part_2(data):
xs_orig = str_to_ints(data)
max_output = 0
for ps in permutations(list(range(5, 10))):
amps = [Amp(xs_orig) for _ in range(len(ps))]
for i, p in enumerate(ps):
amps[i].feed(p)
current_output = 0
current_amp_i = 0
while True:
amps[current_amp_i].feed(current_output)
amps[current_amp_i].go()
if amps[current_amp_i].done:
max_output = max(max_output, current_output)
break
assert len(amps[current_amp_i].outputs) == 1
current_output = amps[current_amp_i].outputs.pop()
current_amp_i = (current_amp_i + 1) % len(amps)
print(max_output)
def main():
data = get_data(__file__)
part_1(data)
part_2(data)
if __name__ == "__main__":
main()