IMAGE = """.#. ..# ###""" def flipv(pattern): return tuple(map(lambda row: tuple(reversed(row)), pattern)) def fliph(pattern): return tuple(reversed(pattern)) def rot90(pattern): return tuple(map(tuple, map(reversed, tuple(zip(*pattern))))) def parse_rule(line: str) -> dict: lhs, rhs = line.split(" => ") rules = {} lhs = tuple(map(tuple, lhs.split("/"))) rhs = list(map(list, rhs.split("/"))) rules[lhs] = rhs rules[rot90(lhs)] = rhs rules[rot90(rot90((lhs)))] = rhs rules[rot90(rot90(rot90((lhs))))] = rhs rules[flipv(lhs)] = rhs rules[fliph(lhs)] = rhs rules[rot90(flipv(lhs))] = rhs rules[rot90(rot90(flipv(lhs)))] = rhs rules[rot90(rot90(rot90(flipv(lhs))))] = rhs rules[rot90(fliph(lhs))] = rhs rules[rot90(rot90(fliph(lhs)))] = rhs rules[rot90(rot90(rot90(fliph(lhs))))] = rhs return rules def print_image(image): for row in image: print("".join(row)) def slice_get(matrix, row, col, size): r = [] for ri in range(row, row + size): r.append(matrix[ri][col:col+size]) return r def slice_set(matrix, row, col, new): for ri, r in enumerate(new): matrix[ri + row][col:col + len(r)] = r def slice_append(matrix, slice): for ri, row in enumerate(slice): for c in row: matrix[-len(slice) + ri].append(c) def part_1(data, iterations=5): rules = {} image = list(map(list, IMAGE.splitlines())) for line in data.splitlines(): line = line.strip() for k, v in parse_rule(line).items(): assert (k not in rules) or rules[k] == v rules[k] = v for _ in range(iterations): len_image = len(image) new_image = [] if len_image % 2 == 0: for row in range(0, len(image), 2): for _ in range(3): new_image.append([]) for col in range(0, len(image[0]), 2): slice = tuple(map(tuple, slice_get(image, row, col, 2))) new_slice = rules[slice] slice_append(new_image, new_slice) else: for row in range(0, len(image), 3): for _ in range(4): new_image.append([]) for col in range(0, len(image[0]), 3): slice = tuple(map(tuple, slice_get(image, row, col, 3))) new_slice = rules[slice] slice_append(new_image, new_slice) image = new_image count = 0 for row in image: count += row.count("#") print(count) def main(): data = open(0).read() part_1(data) part_1(data, 18) if __name__ == "__main__": main()