aocpy/2019/d9.py

155 lines
4.4 KiB
Python

from lib import get_data, str_to_ints
class Amp:
def __init__(self, xs, buffer_extra=10000):
self.xs = list(xs)
self.xs += [0 for _ in range(buffer_extra)]
self.i = 0
self.inputs = []
self.outputs = []
self.done = False
self.input_required = False
self.rel_base = 0
def feed(self, input):
self.input_required = False
self.inputs.append(input)
def pop(self):
v = self.outputs[0]
self.outputs = self.outputs[1:]
return v
def get_param(self, offset, mode):
if mode == 0:
p = self.xs[self.xs[offset]]
elif mode == 1:
p = self.xs[offset]
elif mode == 2:
assert self.rel_base + offset >= 0
p = self.xs[self.rel_base + self.xs[offset]]
else:
assert False
return p
def get_addr(self, offset, mode):
if mode == 0:
return self.xs[offset]
elif mode == 2:
return self.rel_base + self.xs[offset]
else:
assert False
def go(self):
xs = self.xs
i = self.i
while i < len(xs):
assert xs[i] >= 0
inst = str(xs[i])
inst = "0" * (5 - len(inst)) + inst
assert len(inst) == 5
op = int(inst[3:5])
mode_p1 = int(inst[2])
mode_p2 = int(inst[1])
mode_p3 = int(inst[0])
match op:
case 1:
p1 = self.get_param(i + 1, mode_p1)
p2 = self.get_param(i + 2, mode_p2)
addr = self.get_addr(i + 3, mode_p3)
xs[addr] = p1 + p2
i += 4
case 2:
p1 = self.get_param(i + 1, mode_p1)
p2 = self.get_param(i + 2, mode_p2)
addr = self.get_addr(i + 3, mode_p3)
xs[addr] = p1 * p2
i += 4
case 3:
# read input
if len(self.inputs) == 0:
self.i = i
self.input_required = True
return
addr = self.get_addr(i + 1, mode_p1)
xs[addr] = self.inputs[0]
self.inputs = self.inputs[1:]
i += 2
case 4:
# output
v = self.get_param(i + 1, mode_p1)
self.outputs.append(v)
i += 2
self.i = i
return
case 99:
self.done = True
return
case 5:
p1 = self.get_param(i + 1, mode_p1)
p2 = self.get_param(i + 2, mode_p2)
if p1 != 0:
i = p2
else:
i += 3
case 6:
p1 = self.get_param(i + 1, mode_p1)
p2 = self.get_param(i + 2, mode_p2)
if p1 == 0:
i = p2
else:
i += 3
case 7:
p1 = self.get_param(i + 1, mode_p1)
p2 = self.get_param(i + 2, mode_p2)
addr = self.get_addr(i + 3, mode_p3)
if p1 < p2:
xs[addr] = 1
else:
xs[addr] = 0
i += 4
case 8:
p1 = self.get_param(i + 1, mode_p1)
p2 = self.get_param(i + 2, mode_p2)
addr = self.get_addr(i + 3, mode_p3)
if p1 == p2:
xs[addr] = 1
else:
xs[addr] = 0
i += 4
case 9:
p1 = self.get_param(i + 1, mode_p1)
self.rel_base += p1
i += 2
case _:
assert False
self.i = i
def part_1(data):
xs_orig = str_to_ints(data)
a = Amp(xs_orig)
a.feed(1)
while not a.done:
a.go()
a.go()
print(a.pop())
a = Amp(xs_orig)
a.feed(2)
while not a.done:
a.go()
a.go()
print(a.pop())
def main():
data = get_data(__file__)
part_1(data)
if __name__ == "__main__":
main()