from lib import get_data from lib import Grid2D data = get_data(__file__).strip() g = Grid2D(data) flashes_100 = 0 all_flash_round = None for i in range(10**9): for row in range(g.n_rows): for col in range(g.n_cols): g[(row, col)] = str(int(g[(row, col)]) + 1) has_flashed = set() more_flashes = True while more_flashes: more_flashes = False for row in range(g.n_rows): for col in range(g.n_cols): level = int(g[(row, col)]) if level > 9 and (row, col) not in has_flashed: more_flashes = True for nb in g.neighbors_adj((row, col)): g[nb] = str(int(g[nb]) + 1) has_flashed.add((row, col)) for oct in has_flashed: g[oct] = "0" if i < 100: flashes_100 += len(has_flashed) if len(has_flashed) == g.n_rows * g.n_cols: all_flash_round = i + 1 break print(flashes_100) print(all_flash_round)