from lib import get_data, add2 from collections import defaultdict data = "^WNE$" data = "^ENWWW(NEEE|SSE(EE|N))$" data = "^ESSWWN(E|NNENN(EESS(WNSE|)SSS|WWWSSSSE(SW|NNNE)))$" data = "^WSSEESWWWNW(S|NENNEEEENN(ESSSSW(NWSW|SSEN)|WSWWN(E|WWS(E|SS))))$" DIRS = { "N": (-1, 0), "E": (0, 1), "S": (1, 0), "W": (0, -1), } def part_1(data): g = defaultdict(set) stack: list[tuple[tuple, int, list]] = [((0, 0), 0, [])] while len(stack) > 0: pos, i, i_outs = stack.pop() assert i is not None while i < len(data): c = data[i] if c in DIRS.keys(): npos = add2(pos, DIRS[c]) g[pos].add(npos) g[npos].add(pos) pos = npos i += 1 elif c == "(": to_continue = [i + 1] open_count = 0 j_out = None for j in range(i + 1, len(data)): c = data[j] if c == "|" and open_count == 0: to_continue.append(j + 1) elif c == "(": open_count += 1 elif c == ")" and open_count != 0: open_count -= 1 elif c == ")" and open_count == 0: j_out = j break assert j_out is not None for new_i in to_continue: new_i_outs = list(i_outs) new_i_outs.append(j_out) stack.append((pos, new_i, new_i_outs)) break elif c == "$": break elif c == ")" and len(i_outs) == 0: assert False, "Encountered | without i_out" elif c == ")": i_new = i_outs.pop() assert i == i_new i += 1 elif c == "^": i += 1 elif c == "|" and len(i_outs) == 0: assert False, "Encountered | without i_out" elif c == "|": i = i_outs.pop() i += 1 else: assert False # def parse(i): # max_len = 0 # cur_len = 0 # while i < len(data): # c = data[i] # if c in DIRS.keys(): # i += 1 # cur_len += 1 # elif c == "(": # sub_len, i = parse(i + 1) # cur_len += sub_len # elif c == "|": # max_len = max(max_len, cur_len) # cur_len = 0 # i += 1 # elif c == ")": # all_max = max(max_len, cur_len) # return all_max, i + 1 # elif c == "$": # max_len = max(max_len, cur_len) # i += 1 # else: # print(c) # assert False # return max_len, i # # print(parse(0)[0]) # g = defaultdict(set) # def parse(xs, i): # xs_orig = xs.copy() # xs_done = [] # # while i < len(data): # c = data[i] # if c in DIRS.keys(): # for xi in range(len(xs)): # pos = xs[xi] # npos = add2(pos, DIRS[c]) # g[pos].add(npos) # g[npos].add(pos) # xs[xi] = npos # i += 1 # elif c == "(": # xs, i = parse(xs, i + 1) # elif c == "|": # xs_done += xs # xs = xs_orig # i += 1 # elif c == ")": # xs_done += xs # return xs_done, i + 1 # elif c == "$": # xs_done += xs # i += 1 # else: # assert False # return xs_done, i # # parse([(0, 0)], 0) seen = set() dists = {} xs = [(0, 0)] steps = 0 while len(xs) > 0: nxs = [] for x in xs: if x in seen: continue seen.add(x) dists[x] = steps for nb in g[x]: if not nb in seen: nxs.append(nb) xs = nxs steps += 1 print(max(dists.values())) def main(): # data = get_data(__file__).strip() part_1(data) if __name__ == "__main__": main()