from lib import get_data, str_to_ints class Amp: def __init__(self, xs, buffer_extra=10000): self.xs = list(xs) self.xs += [0 for _ in range(buffer_extra)] self.i = 0 self.inputs = [] self.outputs = [] self.done = False self.input_required = False self.rel_base = 0 def feed(self, input): self.input_required = False self.inputs.append(input) def pop(self): v = self.outputs[0] self.outputs = self.outputs[1:] return v def get_param(self, offset, mode): if mode == 0: p = self.xs[self.xs[offset]] elif mode == 1: p = self.xs[offset] elif mode == 2: assert self.rel_base + offset >= 0 p = self.xs[self.rel_base + self.xs[offset]] else: assert False return p def get_addr(self, offset, mode): if mode == 0: return self.xs[offset] elif mode == 2: return self.rel_base + self.xs[offset] else: assert False def go(self): xs = self.xs i = self.i while i < len(xs): assert xs[i] >= 0 inst = str(xs[i]) inst = "0" * (5 - len(inst)) + inst assert len(inst) == 5 op = int(inst[3:5]) mode_p1 = int(inst[2]) mode_p2 = int(inst[1]) mode_p3 = int(inst[0]) match op: case 1: p1 = self.get_param(i + 1, mode_p1) p2 = self.get_param(i + 2, mode_p2) addr = self.get_addr(i + 3, mode_p3) xs[addr] = p1 + p2 i += 4 case 2: p1 = self.get_param(i + 1, mode_p1) p2 = self.get_param(i + 2, mode_p2) addr = self.get_addr(i + 3, mode_p3) xs[addr] = p1 * p2 i += 4 case 3: # read input if len(self.inputs) == 0: self.i = i self.input_required = True return addr = self.get_addr(i + 1, mode_p1) xs[addr] = self.inputs[0] self.inputs = self.inputs[1:] i += 2 case 4: # output v = self.get_param(i + 1, mode_p1) self.outputs.append(v) i += 2 self.i = i return case 99: self.done = True return case 5: p1 = self.get_param(i + 1, mode_p1) p2 = self.get_param(i + 2, mode_p2) if p1 != 0: i = p2 else: i += 3 case 6: p1 = self.get_param(i + 1, mode_p1) p2 = self.get_param(i + 2, mode_p2) if p1 == 0: i = p2 else: i += 3 case 7: p1 = self.get_param(i + 1, mode_p1) p2 = self.get_param(i + 2, mode_p2) addr = self.get_addr(i + 3, mode_p3) if p1 < p2: xs[addr] = 1 else: xs[addr] = 0 i += 4 case 8: p1 = self.get_param(i + 1, mode_p1) p2 = self.get_param(i + 2, mode_p2) addr = self.get_addr(i + 3, mode_p3) if p1 == p2: xs[addr] = 1 else: xs[addr] = 0 i += 4 case 9: p1 = self.get_param(i + 1, mode_p1) self.rel_base += p1 i += 2 case _: assert False self.i = i def part_1(data): xs_orig = str_to_ints(data) a = Amp(xs_orig) a.feed(1) while not a.done: a.go() a.go() print(a.pop()) a = Amp(xs_orig) a.feed(2) while not a.done: a.go() a.go() print(a.pop()) def main(): data = get_data(__file__) part_1(data) if __name__ == "__main__": main()