from lib import get_data, str_to_ints from itertools import permutations class Amp: def __init__(self, xs): self.xs = list(xs) self.i = 0 self.inputs = [] self.outputs = [] self.done = False def feed(self, input): self.inputs.append(input) def pop(self): v = self.outputs[0] self.outputs = self.outputs[1:] return v def go(self): xs = self.xs i = self.i while i < len(xs): inst = str(xs[i]) inst = "0" * (5 - len(inst)) + inst assert len(inst) == 5 op = int(inst[3:5]) mode_p1 = int(inst[2]) mode_p2 = int(inst[1]) mode_p3 = int(inst[0]) match op: case 1: p1 = xs[xs[i + 1]] if mode_p1 == 0 else xs[i + 1] p2 = xs[xs[i + 2]] if mode_p2 == 0 else xs[i + 2] assert mode_p3 == 0 xs[xs[i + 3]] = p1 + p2 i += 4 case 2: p1 = xs[xs[i + 1]] if mode_p1 == 0 else xs[i + 1] p2 = xs[xs[i + 2]] if mode_p2 == 0 else xs[i + 2] assert mode_p3 == 0 xs[xs[i + 3]] = p1 * p2 i += 4 case 3: assert mode_p1 == 0 assert len(self.inputs) > 0 xs[xs[i + 1]] = self.inputs[0] self.inputs = self.inputs[1:] i += 2 case 4: if mode_p1 == 0: v = xs[xs[i + 1]] else: v = xs[i + 1] self.outputs.append(v) i += 2 self.i = i return case 99: self.done = True return case 5: p1 = xs[xs[i + 1]] if mode_p1 == 0 else xs[i + 1] p2 = xs[xs[i + 2]] if mode_p2 == 0 else xs[i + 2] if p1 != 0: i = p2 else: i += 3 case 6: p1 = xs[xs[i + 1]] if mode_p1 == 0 else xs[i + 1] p2 = xs[xs[i + 2]] if mode_p2 == 0 else xs[i + 2] if p1 == 0: i = p2 else: i += 3 case 7: p1 = xs[xs[i + 1]] if mode_p1 == 0 else xs[i + 1] p2 = xs[xs[i + 2]] if mode_p2 == 0 else xs[i + 2] assert mode_p3 == 0 if p1 < p2: xs[xs[i + 3]] = 1 else: xs[xs[i + 3]] = 0 i += 4 case 8: p1 = xs[xs[i + 1]] if mode_p1 == 0 else xs[i + 1] p2 = xs[xs[i + 2]] if mode_p2 == 0 else xs[i + 2] assert mode_p3 == 0 if p1 == p2: xs[xs[i + 3]] = 1 else: xs[xs[i + 3]] = 0 i += 4 self.i = i def part_1(data): xs_orig = str_to_ints(data) max_output = 0 for ps in permutations(list(range(5))): current_output = 0 for p in ps: a = Amp(xs_orig) a.feed(p) a.feed(current_output) a.go() assert len(a.outputs) == 1 current_output = a.outputs.pop() max_output = max(max_output, current_output) print(max_output) def part_2(data): xs_orig = str_to_ints(data) max_output = 0 for ps in permutations(list(range(5, 10))): amps = [Amp(xs_orig) for _ in range(len(ps))] for i, p in enumerate(ps): amps[i].feed(p) current_output = 0 current_amp_i = 0 while True: amps[current_amp_i].feed(current_output) amps[current_amp_i].go() if amps[current_amp_i].done: max_output = max(max_output, current_output) break assert len(amps[current_amp_i].outputs) == 1 current_output = amps[current_amp_i].outputs.pop() current_amp_i = (current_amp_i + 1) % len(amps) print(max_output) def main(): data = get_data(__file__) part_1(data) part_2(data) if __name__ == "__main__": main()