45 lines
1.8 KiB
Python
45 lines
1.8 KiB
Python
import numpy as np
|
|
from AbstractTreeLearner import AbstractTreeLearner
|
|
|
|
|
|
class DTLearner(AbstractTreeLearner):
|
|
|
|
def __init__(self, leaf_size = 1, verbose = False):
|
|
self.leaf_size = leaf_size
|
|
self.verbose = verbose
|
|
|
|
def get_correlations(self, xs, y):
|
|
""" Return a list of sorted 2-tuples where the first element
|
|
is the correlation and the second element is the index. Sorted by
|
|
highest correlation first. """
|
|
# a = np.argmax([abs(np.corrcoef(xs[:,i], y)[0, 1])
|
|
# for i in range(xs.shape[1])])
|
|
correlations = []
|
|
for i in range(xs.shape[1]):
|
|
c = abs(np.corrcoef(xs[:, i], y=y)[0, 1])
|
|
correlations.append((c, i))
|
|
correlations.sort(reverse=True)
|
|
return correlations
|
|
|
|
def get_i_and_split_value(self, xs, y):
|
|
# If all elements are true we would get one sub-tree with zero
|
|
# elements, but we need at least one element in both trees. We avoid
|
|
# zero-trees in two steps. First we take the average between the median
|
|
# value and a smaller value an use that as the new split value. If that
|
|
# doesn't work (when all values are the same) we choose the X with the
|
|
# next smaller correlation. We assert that not all values are
|
|
# smaller/equal to the split value at the end.
|
|
for _, i in self.get_correlations(xs, y):
|
|
split_value = np.median(xs[:,i])
|
|
select = xs[:, i] <= split_value
|
|
if select.all():
|
|
for value in xs[:, i]:
|
|
if value < split_value:
|
|
split_value = (value + split_value) / 2.0
|
|
select = xs[:, i] <= split_value
|
|
if not select.all():
|
|
break
|
|
assert(not select.all())
|
|
return i, split_value
|
|
|