Files
ML4T/assess_learners/DTLearner.py

39 lines
1.4 KiB
Python

import numpy as np
from AbstractTreeLearner import AbstractTreeLearner
class DTLearner(AbstractTreeLearner):
def __init__(self, leaf_size = 1, verbose = False):
self.leaf_size = leaf_size
self.verbose = verbose
def author(self):
return 'felixm' # replace tb34 with your Georgia Tech username
def get_correlations(self, xs, y):
""" Return a list of sorted 2-tuples where the first element
is the correlation and the second element is the index. Sorted by
highest correlation first. """
# a = np.argmax([abs(np.corrcoef(xs[:,i], y)[0, 1])
# for i in range(xs.shape[1])])
correlations = []
for i in range(xs.shape[1]):
c = abs(np.corrcoef(xs[:, i], y=y)[0, 1])
correlations.append((c, i))
correlations.sort(reverse=True)
return correlations
def get_i_and_split_value(self, xs, y):
for _, i in self.get_correlations(xs, y):
split_value = np.median(xs[:,i])
select = xs[:, i] <= split_value
# If all elements are true we would get one sub-tree with zero
# elements, but we need at least one element. Therefore, we only
# choose the index if not all elements are true. If they are we go
# to the next smaller correlation.
if not select.all():
break
return i, split_value