This commit is contained in:
2019-12-22 21:03:24 -05:00
parent 233c763805
commit 7481ee3f0a
3 changed files with 250 additions and 38 deletions

175
tsp/tsp.py Normal file
View File

@@ -0,0 +1,175 @@
import math
from functools import lru_cache
from collections import namedtuple
from geometry import do_intersect
Point = namedtuple("Point", ['index', 'x', 'y'])
def parse_input_data(input_data):
lines = input_data.split('\n')
node_count = int(lines[0])
return [Point(i, *map(float, lines[i].split()))
for i in range(1, node_count + 1)]
def plot_solution(solution, points):
import matplotlib.pyplot as plt
def plot_arrows():
for i in range(len(solution)):
p_1_index = solution[i - 1]
p_2_index = solution[i]
plot_arrow(p_1_index, p_2_index)
def plot_arrow(p_1_index, p_2_index):
p_1 = points[p_1_index]
p_2 = points[p_2_index]
x = p_1.x
y = p_1.y
dx = p_2.x - x
dy = p_2.y - y
plt.arrow(x, y, dx, dy,
head_width=0.5, head_length=0.5)
def plot_points():
for i, p in enumerate(points):
plt.plot(p.x, p.y, '')
plt.text(p.x, p.y, ' ' + str(i))
plot_points()
plot_arrows()
plt.show()
def solve_it(input_data):
@lru_cache(maxsize=1000000)
def length(point_1_index, point_2_index):
p_1 = points[point_1_index]
p_2 = points[point_2_index]
return math.sqrt((p_1.x - p_2.x)**2 + (p_1.y - p_2.y)**2)
def prepare_output_data(solution):
obj = calculate_length(solution)
output_data = '%.2f' % obj + ' ' + str(0) + '\n'
output_data += ' '.join(map(str, solution))
return output_data
def is_valid(solution):
points = set(range(len(solution)))
assert(set(solution) == points)
return True
def calculate_length(solution):
obj = 0
for i in range(0, len(solution)):
point_1_index = solution[i - 1]
point_2_index = solution[i]
obj += length(point_1_index, point_2_index)
return obj
def initial_solution_naiv():
return list(range(0, node_count))
def does_edge_cause_intersection(edge, edges):
p1 = points[edge[0]]
p2 = points[edge[1]]
for existing_edge in edges:
q1 = points[existing_edge[0]]
q2 = points[existing_edge[1]]
print(p1, p2, q1, q2)
if do_intersect(p1, q1, p2, q2):
return True
return False
def get_dimensions(point_indices):
p = points[0]
x_min_p = p
x_max_p = p
y_min_p = p
y_max_p = p
x_min = p.x
x_max = p.x
y_min = p.y
y_max = p.y
for p in points:
if p.x < x_min:
x_min = p.x
x_min_p = p
if p.y < y_min:
y_min = p.y
y_min_p = p
if p.x > x_max:
x_max = p.x
x_max_p = p
if p.y > y_max:
y_max = p.y
y_max_p = p
return (x_min_p, x_max_p, y_min_p, y_max_p)
def initial_solution_greedy(point_indices):
current_point = get_dimensions(point_indices)[0].index
xs = [current_point]
points = set(point_indices)
points.remove(current_point)
while points:
min_length = 999999
min_point = None
for next_point in points:
new_length = length(current_point, next_point)
if new_length < min_length:
min_length = new_length
min_point = next_point
xs.append(min_point)
points.remove(min_point)
current_point = min_point
return xs
def local_search(solution):
# Find longest edges to swap
max_len_1 = 0
max_len_2 = 0
edge_1 = None
edge_2 = None
for i in range(node_count):
new_len = length(solution[i - 1], solution[i])
if new_len > max_len_1:
max_len_1 = new_len
edge_1 = (i - 1, i)
for i in range(node_count):
new_len = length(solution[i - 1], solution[i])
if new_len > max_len_2 and new_len != max_len_1:
max_len_2 = new_len
edge_2 = (i - 1, i)
a_1, a_2 = edge_1
print(edge_1, edge_2)
n_a_1, n_a_2 = [solution[a_1], solution[a_2]]
print(n_a_1, n_a_2)
b_1, b_2 = edge_2
n_b_1, n_b_2 = [solution[b_1], solution[b_2]]
print(n_b_1, n_b_2)
solution[a_2] = n_b_2
solution[b_2] = n_a_2
return solution
points = parse_input_data(input_data)
node_count = len(points)
solution = initial_solution_greedy(list(range(node_count)))
local_search(solution)
# solution = initial_solution_naiv()
plot_solution(solution, points)
is_valid(solution)
return prepare_output_data(solution)