149 lines
5.1 KiB
Python
149 lines
5.1 KiB
Python
import math
|
|
|
|
|
|
class Map(object):
|
|
# Create Map. Cluster points into regions. Calculate distances only to own
|
|
# and neighbor regions. We can actually cluster in O(n) when we know how
|
|
# high and wide the clusters are. Once we have that working we go from
|
|
# there
|
|
|
|
def __init__(self, n_clusters):
|
|
self.CLUSTERS_X = n_clusters
|
|
|
|
def calc_corners(self, points):
|
|
x_min, x_max = float("inf"), float("-inf")
|
|
y_min, y_max = float("inf"), float("-inf")
|
|
for p in points:
|
|
if p.x < x_min:
|
|
x_min = p.x
|
|
if p.x > x_max:
|
|
x_max = p.x
|
|
if p.y < y_min:
|
|
y_min = p.y
|
|
if p.y > y_max:
|
|
y_max = p.y
|
|
self.x_min = x_min
|
|
self.x_max = x_max
|
|
self.y_min = y_min
|
|
self.y_max = y_max
|
|
|
|
def calc_cluster_dim(self, points):
|
|
# clusters = len(points) // self.CLUSTER_SIZE
|
|
# Calculate number of clusters to have a square
|
|
# self.clusters_x = math.ceil(math.sqrt(clusters))
|
|
# self.clusters_y = self.clusters_x
|
|
self.clusters_x = self.CLUSTERS_X
|
|
self.clusters_y = self.CLUSTERS_X
|
|
self.clusters_total = self.clusters_x ** 2
|
|
self.cluster_x_dim = (self.x_max - self.x_min) / self.clusters_x
|
|
self.cluster_y_dim = (self.y_max - self.y_min) / self.clusters_y
|
|
|
|
def sort_points_into_clusters(self, points):
|
|
self.clusters = [[[]
|
|
for x in range(self.clusters_y)]
|
|
for y in range(self.clusters_y)]
|
|
for p in points:
|
|
cluster_x = int((p.x - self.x_min) // self.cluster_x_dim)
|
|
cluster_y = int((p.y - self.y_min) // self.cluster_y_dim)
|
|
|
|
# If the point is on the outer edge of the highest cluster
|
|
# the index will be outside the correct range. We put it
|
|
# into the closes cluster.
|
|
if cluster_x == self.clusters_x:
|
|
cluster_x -= 1
|
|
if cluster_y == self.clusters_y:
|
|
cluster_y -= 1
|
|
|
|
self.clusters[cluster_x][cluster_y].append(p)
|
|
p.cluster_x = cluster_x
|
|
p.cluster_y = cluster_y
|
|
|
|
def add_neighbors_to_points(self, points):
|
|
""" Add all points from the surrounding clusters to each point. """
|
|
for p in points:
|
|
clusters_x = [p.cluster_x]
|
|
clusters_y = [p.cluster_y]
|
|
|
|
if p.cluster_x - 1 >= 0:
|
|
clusters_x.append(p.cluster_x - 1)
|
|
if p.cluster_x + 1 < self.clusters_x:
|
|
clusters_x.append(p.cluster_x + 1)
|
|
if p.cluster_y - 1 >= 0:
|
|
clusters_y.append(p.cluster_y - 1)
|
|
if p.cluster_y + 1 < self.clusters_y:
|
|
clusters_y.append(p.cluster_y + 1)
|
|
|
|
clusters = [(x, y)
|
|
for x in clusters_x
|
|
for y in clusters_y]
|
|
neighbors = []
|
|
for x, y in clusters:
|
|
for p2 in self.clusters[x][y]:
|
|
if p is not p2:
|
|
neighbors.append(p2)
|
|
p.add_neighbors(neighbors)
|
|
|
|
def cluster(self, points):
|
|
""" Splits the map into clusters of a size so
|
|
that each cluster contains CLUSTER_SIZE points on
|
|
average. Adds all points from the current cluster
|
|
and the adjacent clusters to each point. """
|
|
self.calc_corners(points)
|
|
self.calc_cluster_dim(points)
|
|
self.sort_points_into_clusters(points)
|
|
self.add_neighbors_to_points(points)
|
|
return points
|
|
|
|
def plot_grid(self, plt):
|
|
if plt is None:
|
|
return
|
|
for x_i in range(self.clusters_x + 1):
|
|
x_1 = self.x_min + x_i * self.cluster_x_dim
|
|
x_2 = x_1
|
|
y_1 = self.y_min
|
|
y_2 = self.y_max
|
|
plt.plot([x_1, x_2], [y_1, y_2], 'y:', linewidth=0.1)
|
|
for y_i in range(self.clusters_y + 1):
|
|
x_1 = self.x_min
|
|
x_2 = self.x_max
|
|
y_1 = self.y_min + y_i * self.cluster_y_dim
|
|
y_2 = y_1
|
|
plt.plot([x_1, x_2], [y_1, y_2], 'y:', linewidth=0.1)
|
|
|
|
def plot(self, points, plt):
|
|
if plt is None:
|
|
return
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
except ModuleNotFoundError:
|
|
return
|
|
|
|
def plot_arrows():
|
|
for i in range(len_points):
|
|
p1 = points[i - 1]
|
|
p2 = points[i]
|
|
plot_arrow(p1, p2)
|
|
|
|
def plot_arrow(p1, p2):
|
|
x = p1.x
|
|
y = p1.y
|
|
dx = p2.x - x
|
|
dy = p2.y - y
|
|
opt = {'head_width': 0.4, 'head_length': 0.4, 'width': 0.05,
|
|
'length_includes_head': True}
|
|
plt.arrow(x, y, dx, dy, **opt)
|
|
|
|
def plot_points():
|
|
for i, p in enumerate(points):
|
|
plt.plot(p.x, p.y, '')
|
|
# plt.text(p.x, p.y, ' ' + str(p))
|
|
for nb, _ in p.neighbors:
|
|
# plt.plot([p.x, nb.x], [p.y, nb.y], 'r--')
|
|
pass
|
|
|
|
len_points = len(points)
|
|
plot_points()
|
|
self.plot_grid(plt)
|
|
plot_arrows()
|
|
plt.show()
|