from lib import get_data from lib import Grid2D from lib import ints from collections import deque from collections import defaultdict data = get_data(__file__).strip() g = defaultdict(list) for line in data.splitlines(): a, b = line.strip().split("-") g[a].append(b) g[b].append(a) def no_doubles(path): path = [p for p in path if p.islower()] return len(path) == len(set(path)) for part in [1, 2]: start = ("start", ()) visited = set() queue = deque([start]) total = 0 while queue: vertex = queue.popleft() if vertex in visited: continue visited.add(vertex) current, path = vertex neighbors = [] for neighbor in g[current]: if neighbor == "end": total += 1 elif neighbor == "start": continue elif neighbor.islower(): if neighbor not in path or (part == 2 and no_doubles(path)): new_path = tuple(list(path) + [neighbor]) nb = (neighbor, new_path) if nb not in visited: queue.append(nb) elif neighbor.isupper(): new_path = tuple(list(path) + [neighbor]) nb = (neighbor, new_path) if nb not in visited: queue.append(nb) else: assert False print(total)