from lib import get_data, str_to_ints from collections import defaultdict data = """mask = XXXXXXXXXXXXXXXXXXXXXXXXXXXXX1XXXX0X mem[8] = 11 mem[7] = 101 mem[8] = 0 """ data = get_data(__file__) mem = defaultdict(int) mfix, mopt = None, None for line in data.splitlines(): if line.startswith("mask"): fix = "0b" opt = "0b" for c in line[7:]: if c == "1": fix += "1" opt += "0" elif c == "0": fix += "0" opt += "0" elif c == "X": fix += "0" opt += "1" else: assert False mfix = int(fix, 2) mopt = int(opt, 2) elif line.startswith("mem"): assert mfix is not None and mopt is not None addr, value = str_to_ints(line) value = mfix | (value & mopt) mem[addr] = value print(sum(v for v in mem.values())) mem = defaultdict(int) masks = [] for line in data.splitlines(): if line.startswith("mask"): fix = ["0b"] opt = ["0b"] for c in line[7:]: if c == "0": fix = [f + "0" for f in fix] opt = [o + "1" for o in opt] elif c == "1": fix = [f + "1" for f in fix] opt = [o + "0" for o in opt] elif c == "X": nfix = [] nopt = [] for f in fix: nfix.append(f + "0") nfix.append(f + "1") for o in opt: nopt.append(o + "0") nopt.append(o + "0") fix = nfix opt = nopt else: assert False masks = tuple( zip( tuple(map(lambda f: int(f, 2), fix)), tuple(map(lambda o: int(o, 2), opt)), ) ) elif line.startswith("mem"): addr, value = str_to_ints(line) for fix, opt in masks: addr_ = fix | (addr & opt) mem[addr_] = value print(sum(v for v in mem.values()))