from dataclasses import dataclass from collections import defaultdict import sys @dataclass class Disc: name: str value: int holds: list discs = {} data = open(0).read().strip() for line in data.splitlines(): right = None if "->" in line: left, right = line.split(" -> ") else: left = line name, value = left.split() value = int(value[1:-1]) d = Disc(name, value, []) if right is not None: for r in right.split(", "): d.holds.append(r) discs[d.name] = d alldiscs = set(discs.keys()) for d in discs.values(): for h in d.holds: alldiscs.remove(h) assert len(alldiscs) == 1 top_disc = alldiscs.pop() print(top_disc) def weight(disc_name): disc = discs[disc_name] child_weights = {} for child_disc_name in disc.holds: child_weights[child_disc_name] = weight(child_disc_name) if not all([w == list(child_weights.values())[0] for w in child_weights.values()]): values = defaultdict(list) for child_name, child_value in child_weights.items(): values[child_value].append(child_name) same_weight, other_weight = 0, 0 for child_weight, children in values.items(): if len(children) == 1: other_weight = child_weight else: same_weight = child_weight weight_delta = other_weight - same_weight print(discs['dqwocyn'].value - weight_delta) sys.exit(0) return disc.value + sum(child_weights.values()) weight(top_disc)