from lib import get_data from itertools import pairwise from collections import defaultdict data = get_data(__file__) orig_t, p2 = data.split("\n\n") pairs = {} for line in p2.splitlines(): ls, rs = line.split(" -> ") assert ls not in pairs pairs[ls] = rs for steps in [10, 40]: t = str(orig_t) start_letter = t[0] end_letter = t[-1] td = defaultdict(int) for a, b in pairwise(t): td[a + b] += 1 for _ in range(steps): ntd = defaultdict(int) for pair, count in td.items(): if pair in pairs: a, b = pair c = pairs[pair] ntd[a + c] += count ntd[c + b] += count else: ntd[pair] += count td = ntd counts = defaultdict(int) for (a, b), count in td.items(): counts[a] += count counts[b] += count counts = {k: v // 2 for k, v in counts.items()} counts[start_letter] += 1 counts[end_letter] += 1 r = max(counts.values()) - min(counts.values()) print(r)