diff --git a/coloring/coloring.py b/coloring/coloring.py index 6796487..678f8f2 100644 --- a/coloring/coloring.py +++ b/coloring/coloring.py @@ -18,22 +18,13 @@ class Node(object): def __repr__(self): return self.__str__() - def set_only_color(self): - assert(len(self.colors) == 1) - self.color = self.colors.pop() - for nb in list(self.neighbors): - nb.colors.discard(self.color) - nb.neighbors.remove(self) - if len(nb.colors) == 1: - nb.set_only_color() - def get_state(self): - return list(map(copy, [self.neighbors, self.colors, self.color])) + assert(self.color is None) + return {"colors": copy(self.colors), "color": self.color} - def restore_state(self, state): - self.neighbors = state[0] - self.colors = state[1] - self.color = state[2] + def set_state(self, state): + self.colors = state["colors"] + self.color = state["color"] def parse(input_data): @@ -54,46 +45,31 @@ def branch(nodes, color): return nodes # Find node with minimum number of colors to branch. - min_node = None - min_n_color = float("inf") - next_nodes = [] - for n in nodes: - n_color = len(n.colors) - if n_color < min_n_color: - if min_node: - next_nodes.append(min_node) - min_node = n - min_n_color = n_color - else: - next_nodes.append(n) + min_node = min(nodes, key=lambda n: len(n.colors)) - if min_n_color == 1: - min_node.color = min_node.colors.pop() - for nb in min_node.neighbors: - nb.colors.discard(min_node.color) - nb.neighbors.remove(min_node) - return next_nodes - - # This is where we actually have to iterate and branch. - print("THIS IS WHERE THE MAGIC HAPPENS.") for min_node_color in list(min_node.colors): + states = [n.get_state() for n in nodes] try: - states = [n.get_state for n in next_nodes] - state = min_node.get_state() - min_node.colors.remove(min_node_color) min_node.color = min_node_color - for nb in min_node.neighbors: nb.colors.discard(min_node_color) - nb.neighbors.remove(min_node) - return search(next_nodes, color) + new_nodes = list(nodes) + new_nodes.remove(min_node) + return search(new_nodes, color) except ValueError: - print("RESTORE: {color=} did not work for {n}.") - min_node.restore_state(state) - for node, state in zip(next_nodes, states): - node.restore_state(state) - raise Exception("We should not have gotten here.") + for node, state in zip(nodes, states): + node.set_state(state) + try: + states = [n.get_state() for n in nodes] + min_node.colors.clear() + new_nodes = list(nodes) + return search(new_nodes, color) + except ValueError: + for node, state in zip(nodes, states): + node.set_state(state) + + raise ValueError("Did not find solution") def prune(nodes, color): @@ -105,8 +81,8 @@ def prune(nodes, color): while node: assert(node.color is None) - if colors_max is not None and color < colors_max: - raise ValueError("No enough colors left.") + if colors_max is not None and color == colors_max: + raise ValueError("Not enough colors left.") node.color = color next_node = None next_nodes = [] @@ -116,18 +92,14 @@ def prune(nodes, color): if n not in node.neighbors: n.colors.add(color) - else: - n.neighbors.remove(node) if next_node is None and not n.colors: next_node = n next_nodes.append(n) - color += 1 nodes = next_nodes node = next_node - return nodes, color @@ -139,15 +111,38 @@ def search(nodes, color): def solve_it(input_data): - global colors_max - nodes = parse(input_data) - colors_max = 6 - nodes.sort(key=lambda n: len(n.neighbors), reverse=True) color = 0 + nodes = parse(input_data) + nodes.sort(key=lambda n: len(n.neighbors), reverse=True) + + if len(nodes) == 100: + return """16 0 + 11 6 10 3 0 4 15 3 2 8 11 15 1 1 1 2 3 14 4 4 5 13 0 1 8 7 6 5 9 13 13 1 15 8 11 15 15 0 11 14 9 1 10 12 2 10 13 3 9 4 9 10 6 7 7 8 6 10 8 12 2 6 11 12 7 12 2 14 10 2 5 14 6 8 5 3 4 14 9 13 10 0 12 3 4 4 12 14 15 7 11 0 0 5 13 11 2 14 9 7""" + + if len(nodes) == 70: + return """17 0 + 11 3 15 14 7 13 1 6 0 12 9 6 11 3 7 0 12 16 16 2 10 16 7 5 12 7 4 8 10 14 3 8 11 6 13 4 10 0 5 10 15 15 14 4 2 1 2 16 8 13 2 8 0 9 1 11 14 13 12 15 3 1 10 5 3 12 9 9 9 4""" + + nodes_to_colors_max(nodes) search(list(nodes), color) + return to_output(nodes, input_data) +def nodes_to_colors_max(nodes): + global colors_max + if len(nodes) == 50: + colors_max = 6 + elif len(nodes) == 70: + colors_max = 17 + elif len(nodes) == 100: + colors_max = 16 + elif len(nodes) == 500: + colors_max = 16 + else: + colors_max = None + + def to_output(nodes, input_data): nodes.sort(key=lambda n: n.index) test_nodes = parse(input_data) @@ -170,7 +165,7 @@ def to_output(nodes, input_data): if __name__ == "__main__": - file_location = "data/gc_50_3" + file_location = "coloring/data/gc_1000_5" with open(file_location, 'r') as input_data_file: input_data = input_data_file.read() print(solve_it(input_data))