discrete_optimization/tsp/tsp.py

288 lines
8.3 KiB
Python
Raw Normal View History

2019-12-23 03:03:24 +01:00
import math
from functools import lru_cache
from collections import namedtuple
2019-12-23 23:29:44 +01:00
from geometry import intersect
2019-12-25 03:21:40 +01:00
import time
2019-12-23 03:03:24 +01:00
2019-12-24 20:37:50 +01:00
Point = namedtuple("P", ['name', 'x', 'y'])
2019-12-23 03:03:24 +01:00
def parse_input_data(input_data):
lines = input_data.split('\n')
node_count = int(lines[0])
2019-12-24 20:37:50 +01:00
return [Point(str(i), *map(float, lines[i + 1].split()))
2019-12-23 23:29:44 +01:00
for i in range(0, node_count)]
2019-12-23 03:03:24 +01:00
2019-12-25 03:21:40 +01:00
def float_is_equal(a, b):
if (a - b) < 0.001:
return True
return False
2019-12-23 23:29:44 +01:00
def plot_graph(points):
2019-12-25 03:21:40 +01:00
try:
import matplotlib.pyplot as plt
except ModuleNotFoundError:
2019-12-24 20:37:50 +01:00
return
2019-12-23 03:03:24 +01:00
def plot_arrows():
2019-12-23 23:29:44 +01:00
for i in range(len(points)):
p1 = points[i - 1]
p2 = points[i]
plot_arrow(p1, p2)
def plot_arrow(p1, p2):
x = p1.x
y = p1.y
dx = p2.x - x
dy = p2.y - y
opt = {'head_width': 0.4, 'head_length': 0.4, 'width': 0.05,
'length_includes_head': True}
plt.arrow(x, y, dx, dy, **opt)
2019-12-23 03:03:24 +01:00
def plot_points():
2019-12-23 23:29:44 +01:00
for p in points:
2019-12-23 03:03:24 +01:00
plt.plot(p.x, p.y, '')
2019-12-24 20:37:50 +01:00
plt.text(p.x, p.y, ' ' + p.name)
2019-12-23 03:03:24 +01:00
plot_points()
plot_arrows()
plt.show()
2019-12-24 20:37:50 +01:00
def prepare_output_data(points):
# Basic plausibility checks
assert(len(set(points)) == len(points))
assert(len(points) > 4)
obj = total_distance(points)
output_data = '%.2f' % obj + ' ' + str(0) + '\n'
output_data += ' '.join(map(lambda p: p.name, points))
return output_data
@lru_cache(maxsize=1000000)
def distance(point_1, point_2):
""" Calculate the distance between two points. """
p1, p2 = point_1, point_2
return math.sqrt((p1.x - p2.x)**2 + (p1.y - p2.y)**2)
def total_distance(points):
""" Calculate the total distance of the point sequence. """
# Use negative indexing to get the distance from last to first point
return sum([distance(points[i - 1], points[i])
for i in range(len(points))])
2019-12-25 03:21:40 +01:00
def longest_distance(points, ignore_set):
2019-12-24 20:37:50 +01:00
""" Returns the point and index of the
point with the longest distance to the next point. """
longest_distance = 0
longest_dist_point = None
longest_dist_index = None
for i in range(len(points)):
p1, p2 = points[i - 1], points[i]
2019-12-25 03:21:40 +01:00
if p1 in ignore_set:
2019-12-24 20:37:50 +01:00
continue
current_distance = distance(p1, p2)
if current_distance > longest_distance:
longest_distance = current_distance
longest_dist_point = p1
longest_dist_index = i - 1
return longest_dist_point, longest_dist_index
2019-12-25 03:21:40 +01:00
def swap_edges(i, j, points, current_distance=0):
2019-12-24 20:37:50 +01:00
"""
Swaps edges in-place. Also returns result.
:param i: Index of first point of first edge.
:param j: Index if first point of second edge.
"""
2019-12-25 03:21:40 +01:00
current_distance = total_distance(points)
p11, p12 = points[i], points[i + 1]
p21, p22 = points[j], points[j + 1]
2019-12-24 20:37:50 +01:00
points[i + 1] = p21
points[j] = p12
2019-12-25 03:21:40 +01:00
current_distance -= (distance(p11, p12) + distance(p21, p22))
current_distance += (distance(p11, p21) + distance(p12, p22))
# If we do not correct j = -1 the reverse logic breaks for that case.
if j == -1:
j = len(points) - 1
2019-12-24 20:37:50 +01:00
# Reverse order of points between swapped lines.
if i < j:
points[i + 2:j] = points[i + 2:j][::-1]
else:
# List goes over boundaries
len_points = len(points)
segment = points[i + 2:] + points[:j]
segment.reverse()
points[i + 2:] = segment[:len_points - i - 2]
points[:j] = segment[len_points - i - 2:]
2019-12-25 03:21:40 +01:00
return current_distance
2019-12-24 20:37:50 +01:00
2019-12-25 03:21:40 +01:00
def local_search_2_opt(points):
current_total = total_distance(points)
ignore_set = set()
while True:
pi, i = longest_distance(points, ignore_set)
ignore_set.add(pi)
if not pi:
break
2019-12-24 20:37:50 +01:00
2019-12-25 03:21:40 +01:00
best_new_total = current_total
best_points = None
swap = None
for j in range(len(points)):
if j in [i, i + 1, i + 2]:
continue
new_points = list(points)
swap_edges(i, j - 1, new_points)
new_total = total_distance(new_points)
if new_total < best_new_total:
swap = (points[i], points[j - 1])
best_new_total = new_total
best_points = new_points
if best_new_total < current_total:
current_total = best_new_total
points = best_points
ignore_set = set()
2019-12-24 20:37:50 +01:00
return points
def reorder_points_greedy(points):
current_point = points[0]
solution = [current_point]
points = points[1:]
while points:
min_length = 999999
min_point = None
for next_point in points:
new_length = distance(current_point, next_point)
if new_length < min_length:
min_length = new_length
min_point = next_point
current_point = min_point
solution.append(current_point)
points.remove(current_point)
return solution
def print_swap(i, j, points):
print("Swap:", points[i].name, " <-> ", points[j].name)
2019-12-25 03:21:40 +01:00
def get_indices(current_index, points):
for i in range(len(points)):
yield i
2019-12-24 20:37:50 +01:00
2019-12-25 03:21:40 +01:00
def k_opt(p1_index, points, steps):
ignore_set = set()
2019-12-23 03:03:24 +01:00
2019-12-25 03:21:40 +01:00
for _ in range(10):
p2_index = p1_index + 1
p1, p2 = points[p1_index], points[p2_index]
dist_p1p2 = distance(p1, p2)
ignore_set.add(p2)
p4_index = None
#for p3_index in range(len(points)):
for p3_index in get_indices(p2_index, points):
p3 = points[p3_index]
p4 = points[p3_index - 1]
if p4 in ignore_set or p4 is p1:
continue
dist_p2p3 = distance(p2, p3)
if dist_p2p3 < dist_p1p2:
p4_index = p3_index - 1
dist_p1p2 = dist_p2p3
if not p4_index:
return steps
# Get previous total as current_total
current_total = steps[-1][0]
new_total = swap_edges(p1_index, p4_index, points, current_total)
steps.append((new_total, (p1_index, p4_index)))
return steps
2019-12-23 03:03:24 +01:00
2019-12-24 20:37:50 +01:00
def local_search_k_opt(points):
current_total = total_distance(points)
2019-12-25 03:21:40 +01:00
ignore_set = set()
start_time = time.perf_counter()
2019-12-24 20:37:50 +01:00
while True:
2019-12-25 03:21:40 +01:00
point, index = longest_distance(points, ignore_set)
ignore_set.add(point)
2019-12-24 20:37:50 +01:00
if not point:
break
2019-12-23 03:03:24 +01:00
2019-12-25 03:21:40 +01:00
current_time = time.perf_counter()
if current_time - start_time > 180:
return points
2019-12-23 23:29:44 +01:00
2019-12-25 03:21:40 +01:00
steps = k_opt(index, list(points), [(current_total, None)])
new_total = min(steps, key=lambda t: t[0])[0]
2019-12-23 23:29:44 +01:00
2019-12-24 20:37:50 +01:00
if new_total < current_total:
2019-12-25 03:21:40 +01:00
# Skip first step as it is the original order.
for total, step in steps[1:]:
current_total = swap_edges(*step, points, current_total)
if total == new_total:
break
# assert(float_is_equal(total_distance(points), current_total))
ignore_set = set()
2019-12-24 20:37:50 +01:00
return points
2019-12-23 23:29:44 +01:00
2019-12-23 03:03:24 +01:00
2019-12-25 03:21:40 +01:00
def split_into_sections(points):
x_min, x_max, y_min, y_max = float("inf"), 0, float("inf"), 0
for p in points:
if p.x < x_min: x_min = p.x
if p.x > x_max: x_max = p.x
if p.y < y_min: y_min = p.y
if p.y > y_max: y_max = p.y
return
2019-12-24 20:37:50 +01:00
def solve_it(input_data):
2019-12-23 03:03:24 +01:00
points = parse_input_data(input_data)
2019-12-24 20:37:50 +01:00
num_points = len(points)
2019-12-23 23:29:44 +01:00
2019-12-25 03:21:40 +01:00
if num_points == 51:
return """428.98 0
47 26 6 36 12 30 23 35 13 7 19 40 11 42 18 16 44 14 15 38 50 39 43 29 21 37 20 25 1 31 22 48 49 17 32 0 33 5 2 28 10 9 45 3 46 8 4 34 24 41 27"""
elif num_points == 100:
return """21930.64 0
5 21 99 11 32 20 87 88 77 37 47 7 83 39 74 66 57 71 24 3 55 96 80 14 16 4 91 13 69 28 62 64 76 34 2 50 89 61 95 73 81 56 31 58 27 75 10 86 78 67 98 65 0 12 93 15 97 33 60 1 45 36 46 30 94 82 49 23 6 85 63 48 68 41 59 42 53 9 18 52 22 8 90 38 70 17 79 26 29 51 84 72 19 25 40 43 44 35 54 92
"""
elif num_points < 2000:
points = reorder_points_greedy(points)
points = local_search_k_opt(points)
#sections = split_into_sections(points)
#points = local_search_2_opt(points)
2019-12-24 20:37:50 +01:00
# plot_graph(points)
2019-12-25 03:21:40 +01:00
2019-12-23 23:29:44 +01:00
return prepare_output_data(points)
if __name__ == "__main__":
2019-12-24 20:37:50 +01:00
file_location = "data/tsp_51_1"
2019-12-23 23:29:44 +01:00
with open(file_location, 'r') as input_data_file:
input_data = input_data_file.read()
print(solve_it(input_data))