Resolve split_value issue in DTLearner and pass all tests.
This commit is contained in:
@@ -8,9 +8,6 @@ class DTLearner(AbstractTreeLearner):
|
||||
self.leaf_size = leaf_size
|
||||
self.verbose = verbose
|
||||
|
||||
def author(self):
|
||||
return 'felixm' # replace tb34 with your Georgia Tech username
|
||||
|
||||
def get_correlations(self, xs, y):
|
||||
""" Return a list of sorted 2-tuples where the first element
|
||||
is the correlation and the second element is the index. Sorted by
|
||||
@@ -25,14 +22,23 @@ class DTLearner(AbstractTreeLearner):
|
||||
return correlations
|
||||
|
||||
def get_i_and_split_value(self, xs, y):
|
||||
# If all elements are true we would get one sub-tree with zero
|
||||
# elements, but we need at least one element in both trees. We avoid
|
||||
# zero-trees in two steps. First we take the average between the median
|
||||
# value and a smaller value an use that as the new split value. If that
|
||||
# doesn't work (when all values are the same) we choose the X with the
|
||||
# next smaller correlation. We assert that not all values are
|
||||
# smaller/equal to the split value at the end.
|
||||
for _, i in self.get_correlations(xs, y):
|
||||
split_value = np.median(xs[:,i])
|
||||
select = xs[:, i] <= split_value
|
||||
# If all elements are true we would get one sub-tree with zero
|
||||
# elements, but we need at least one element. Therefore, we only
|
||||
# choose the index if not all elements are true. If they are we go
|
||||
# to the next smaller correlation.
|
||||
if select.all():
|
||||
for value in xs[:, i]:
|
||||
if value < split_value:
|
||||
split_value = (value + split_value) / 2.0
|
||||
select = xs[:, i] <= split_value
|
||||
if not select.all():
|
||||
break
|
||||
assert(not select.all())
|
||||
return i, split_value
|
||||
|
||||
|
||||
Reference in New Issue
Block a user