145 lines
4.4 KiB
Python
145 lines
4.4 KiB
Python
from lib import get_data, str_to_ints
|
|
from itertools import permutations
|
|
|
|
|
|
class Amp:
|
|
def __init__(self, xs):
|
|
self.xs = list(xs)
|
|
self.i = 0
|
|
self.inputs = []
|
|
self.outputs = []
|
|
self.done = False
|
|
|
|
def feed(self, input):
|
|
self.inputs.append(input)
|
|
|
|
def pop(self):
|
|
v = self.outputs[0]
|
|
self.outputs = self.outputs[1:]
|
|
return v
|
|
|
|
def go(self):
|
|
xs = self.xs
|
|
i = self.i
|
|
while i < len(xs):
|
|
inst = str(xs[i])
|
|
inst = "0" * (5 - len(inst)) + inst
|
|
assert len(inst) == 5
|
|
op = int(inst[3:5])
|
|
mode_p1 = int(inst[2])
|
|
mode_p2 = int(inst[1])
|
|
mode_p3 = int(inst[0])
|
|
match op:
|
|
case 1:
|
|
p1 = xs[xs[i + 1]] if mode_p1 == 0 else xs[i + 1]
|
|
p2 = xs[xs[i + 2]] if mode_p2 == 0 else xs[i + 2]
|
|
assert mode_p3 == 0
|
|
xs[xs[i + 3]] = p1 + p2
|
|
i += 4
|
|
case 2:
|
|
p1 = xs[xs[i + 1]] if mode_p1 == 0 else xs[i + 1]
|
|
p2 = xs[xs[i + 2]] if mode_p2 == 0 else xs[i + 2]
|
|
assert mode_p3 == 0
|
|
xs[xs[i + 3]] = p1 * p2
|
|
i += 4
|
|
case 3:
|
|
assert mode_p1 == 0
|
|
assert len(self.inputs) > 0
|
|
xs[xs[i + 1]] = self.inputs[0]
|
|
self.inputs = self.inputs[1:]
|
|
i += 2
|
|
case 4:
|
|
if mode_p1 == 0:
|
|
v = xs[xs[i + 1]]
|
|
else:
|
|
v = xs[i + 1]
|
|
self.outputs.append(v)
|
|
i += 2
|
|
self.i = i
|
|
return
|
|
case 99:
|
|
self.done = True
|
|
return
|
|
case 5:
|
|
p1 = xs[xs[i + 1]] if mode_p1 == 0 else xs[i + 1]
|
|
p2 = xs[xs[i + 2]] if mode_p2 == 0 else xs[i + 2]
|
|
if p1 != 0:
|
|
i = p2
|
|
else:
|
|
i += 3
|
|
case 6:
|
|
p1 = xs[xs[i + 1]] if mode_p1 == 0 else xs[i + 1]
|
|
p2 = xs[xs[i + 2]] if mode_p2 == 0 else xs[i + 2]
|
|
if p1 == 0:
|
|
i = p2
|
|
else:
|
|
i += 3
|
|
case 7:
|
|
p1 = xs[xs[i + 1]] if mode_p1 == 0 else xs[i + 1]
|
|
p2 = xs[xs[i + 2]] if mode_p2 == 0 else xs[i + 2]
|
|
assert mode_p3 == 0
|
|
if p1 < p2:
|
|
xs[xs[i + 3]] = 1
|
|
else:
|
|
xs[xs[i + 3]] = 0
|
|
i += 4
|
|
case 8:
|
|
p1 = xs[xs[i + 1]] if mode_p1 == 0 else xs[i + 1]
|
|
p2 = xs[xs[i + 2]] if mode_p2 == 0 else xs[i + 2]
|
|
assert mode_p3 == 0
|
|
if p1 == p2:
|
|
xs[xs[i + 3]] = 1
|
|
else:
|
|
xs[xs[i + 3]] = 0
|
|
i += 4
|
|
self.i = i
|
|
|
|
|
|
def part_1(data):
|
|
xs_orig = str_to_ints(data)
|
|
max_output = 0
|
|
for ps in permutations(list(range(5))):
|
|
current_output = 0
|
|
for p in ps:
|
|
a = Amp(xs_orig)
|
|
a.feed(p)
|
|
a.feed(current_output)
|
|
a.go()
|
|
assert len(a.outputs) == 1
|
|
current_output = a.outputs.pop()
|
|
max_output = max(max_output, current_output)
|
|
print(max_output)
|
|
|
|
|
|
def part_2(data):
|
|
xs_orig = str_to_ints(data)
|
|
max_output = 0
|
|
for ps in permutations(list(range(5, 10))):
|
|
amps = [Amp(xs_orig) for _ in range(len(ps))]
|
|
for i, p in enumerate(ps):
|
|
amps[i].feed(p)
|
|
|
|
current_output = 0
|
|
current_amp_i = 0
|
|
while True:
|
|
amps[current_amp_i].feed(current_output)
|
|
amps[current_amp_i].go()
|
|
if amps[current_amp_i].done:
|
|
max_output = max(max_output, current_output)
|
|
break
|
|
assert len(amps[current_amp_i].outputs) == 1
|
|
current_output = amps[current_amp_i].outputs.pop()
|
|
current_amp_i = (current_amp_i + 1) % len(amps)
|
|
|
|
print(max_output)
|
|
|
|
|
|
def main():
|
|
data = get_data(__file__)
|
|
part_1(data)
|
|
part_2(data)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|