31 lines
982 B
Python
31 lines
982 B
Python
import numpy as np
|
|
from AbstractTreeLearner import AbstractTreeLearner
|
|
|
|
|
|
class RTLearner(AbstractTreeLearner):
|
|
|
|
def __init__(self, leaf_size = 1, verbose = False):
|
|
self.leaf_size = leaf_size
|
|
self.verbose = verbose
|
|
|
|
def get_i_and_split_value(self, xs, y):
|
|
"""
|
|
@summary: Pick a random i and split value.
|
|
|
|
Make sure that not all X are the same for i and also pick
|
|
different values to average the split_value from.
|
|
"""
|
|
i = np.random.randint(0, xs.shape[1])
|
|
while np.all(xs[0,i] == xs[:,i]):
|
|
i = np.random.randint(0, xs.shape[1])
|
|
|
|
# I don't know about the performance of this, but at least it
|
|
# terminates reliably. If the two elements are the same something is
|
|
# wrong.
|
|
a = np.array(list(set(xs[:, i])))
|
|
r1, r2 = np.random.choice(a, size = 2, replace = False)
|
|
assert(r1 != r2)
|
|
split_value = (r1 + r2) / 2.0
|
|
return i, split_value
|
|
|