Start working on defeat learners assignment.
parent
8ee47c9a1d
commit
db537d7043
|
@ -0,0 +1,64 @@
|
||||||
|
"""
|
||||||
|
A simple wrapper for linear regression. (c) 2015 Tucker Balch
|
||||||
|
Note, this is NOT a correct DTLearner; Replace with your own implementation.
|
||||||
|
Copyright 2018, Georgia Institute of Technology (Georgia Tech)
|
||||||
|
Atlanta, Georgia 30332
|
||||||
|
All Rights Reserved
|
||||||
|
|
||||||
|
Template code for CS 4646/7646
|
||||||
|
|
||||||
|
Georgia Tech asserts copyright ownership of this template and all derivative
|
||||||
|
works, including solutions to the projects assigned in this course. Students
|
||||||
|
and other users of this template code are advised not to share it with others
|
||||||
|
or to make it available on publicly viewable websites including repositories
|
||||||
|
such as github and gitlab. This copyright statement should not be removed
|
||||||
|
or edited.
|
||||||
|
|
||||||
|
We do grant permission to share solutions privately with non-students such
|
||||||
|
as potential employers. However, sharing with other current or future
|
||||||
|
students of CS 7646 is prohibited and subject to being investigated as a
|
||||||
|
GT honor code violation.
|
||||||
|
|
||||||
|
-----do not edit anything above this line---
|
||||||
|
|
||||||
|
Student Name: Tucker Balch (replace with your name)
|
||||||
|
GT User ID: tb34 (replace with your User ID)
|
||||||
|
GT ID: 900897987 (replace with your GT ID)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
class DTLearner(object):
|
||||||
|
|
||||||
|
def __init__(self, leaf_size=1, verbose = False):
|
||||||
|
warnings.warn("\n\n WARNING! THIS IS NOT A CORRECT DTLearner IMPLEMENTATION! REPLACE WITH YOUR OWN CODE\n")
|
||||||
|
pass # move along, these aren't the drones you're looking for
|
||||||
|
|
||||||
|
def author(self):
|
||||||
|
return 'tb34' # replace tb34 with your Georgia Tech username
|
||||||
|
|
||||||
|
def addEvidence(self,dataX,dataY):
|
||||||
|
"""
|
||||||
|
@summary: Add training data to learner
|
||||||
|
@param dataX: X values of data to add
|
||||||
|
@param dataY: the Y training values
|
||||||
|
"""
|
||||||
|
|
||||||
|
# slap on 1s column so linear regression finds a constant term
|
||||||
|
newdataX = np.ones([dataX.shape[0],dataX.shape[1]+1])
|
||||||
|
newdataX[:,0:dataX.shape[1]]=dataX
|
||||||
|
|
||||||
|
# build and save the model
|
||||||
|
self.model_coefs, residuals, rank, s = np.linalg.lstsq(newdataX, dataY, rcond=None)
|
||||||
|
|
||||||
|
def query(self,points):
|
||||||
|
"""
|
||||||
|
@summary: Estimate a set of test points given the model we built.
|
||||||
|
@param points: should be a numpy array with each row corresponding to a specific query.
|
||||||
|
@returns the estimated values according to the saved model.
|
||||||
|
"""
|
||||||
|
return (self.model_coefs[:-1] * points).sum(axis = 1) + self.model_coefs[-1]
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
print("the secret clue is 'zzyzx'")
|
|
@ -0,0 +1,57 @@
|
||||||
|
"""
|
||||||
|
A simple wrapper for linear regression. (c) 2015 Tucker Balch
|
||||||
|
Copyright 2018, Georgia Institute of Technology (Georgia Tech)
|
||||||
|
Atlanta, Georgia 30332
|
||||||
|
All Rights Reserved
|
||||||
|
|
||||||
|
Template code for CS 4646/7646
|
||||||
|
|
||||||
|
Georgia Tech asserts copyright ownership of this template and all derivative
|
||||||
|
works, including solutions to the projects assigned in this course. Students
|
||||||
|
and other users of this template code are advised not to share it with others
|
||||||
|
or to make it available on publicly viewable websites including repositories
|
||||||
|
such as github and gitlab. This copyright statement should not be removed
|
||||||
|
or edited.
|
||||||
|
|
||||||
|
We do grant permission to share solutions privately with non-students such
|
||||||
|
as potential employers. However, sharing with other current or future
|
||||||
|
students of CS 7646 is prohibited and subject to being investigated as a
|
||||||
|
GT honor code violation.
|
||||||
|
|
||||||
|
-----do not edit anything above this line---
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class LinRegLearner(object):
|
||||||
|
|
||||||
|
def __init__(self, verbose = False):
|
||||||
|
pass # move along, these aren't the drones you're looking for
|
||||||
|
|
||||||
|
def author(self):
|
||||||
|
return 'tb34' # replace tb34 with your Georgia Tech username
|
||||||
|
|
||||||
|
def addEvidence(self,dataX,dataY):
|
||||||
|
"""
|
||||||
|
@summary: Add training data to learner
|
||||||
|
@param dataX: X values of data to add
|
||||||
|
@param dataY: the Y training values
|
||||||
|
"""
|
||||||
|
|
||||||
|
# slap on 1s column so linear regression finds a constant term
|
||||||
|
newdataX = np.ones([dataX.shape[0],dataX.shape[1]+1])
|
||||||
|
newdataX[:,0:dataX.shape[1]]=dataX
|
||||||
|
|
||||||
|
# build and save the model
|
||||||
|
self.model_coefs, residuals, rank, s = np.linalg.lstsq(newdataX, dataY, rcond=None)
|
||||||
|
|
||||||
|
def query(self,points):
|
||||||
|
"""
|
||||||
|
@summary: Estimate a set of test points given the model we built.
|
||||||
|
@param points: should be a numpy array with each row corresponding to a specific query.
|
||||||
|
@returns the estimated values according to the saved model.
|
||||||
|
"""
|
||||||
|
return (self.model_coefs[:-1] * points).sum(axis = 1) + self.model_coefs[-1]
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
print("the secret clue is 'zzyzx'")
|
|
@ -0,0 +1,52 @@
|
||||||
|
"""
|
||||||
|
template for generating data to fool learners (c) 2016 Tucker Balch
|
||||||
|
Copyright 2018, Georgia Institute of Technology (Georgia Tech)
|
||||||
|
Atlanta, Georgia 30332
|
||||||
|
All Rights Reserved
|
||||||
|
|
||||||
|
Template code for CS 4646/7646
|
||||||
|
|
||||||
|
Georgia Tech asserts copyright ownership of this template and all derivative
|
||||||
|
works, including solutions to the projects assigned in this course. Students
|
||||||
|
and other users of this template code are advised not to share it with others
|
||||||
|
or to make it available on publicly viewable websites including repositories
|
||||||
|
such as github and gitlab. This copyright statement should not be removed
|
||||||
|
or edited.
|
||||||
|
|
||||||
|
We do grant permission to share solutions privately with non-students such
|
||||||
|
as potential employers. However, sharing with other current or future
|
||||||
|
students of CS 7646 is prohibited and subject to being investigated as a
|
||||||
|
GT honor code violation.
|
||||||
|
|
||||||
|
-----do not edit anything above this line---
|
||||||
|
|
||||||
|
Student Name: Tucker Balch (replace with your name)
|
||||||
|
GT User ID: tb34 (replace with your User ID)
|
||||||
|
GT ID: 900897987 (replace with your GT ID)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
|
||||||
|
# this function should return a dataset (X and Y) that will work
|
||||||
|
# better for linear regression than decision trees
|
||||||
|
def best4LinReg(seed=1489683273):
|
||||||
|
np.random.seed(seed)
|
||||||
|
X = np.zeros((100,2))
|
||||||
|
Y = np.random.random(size = (100,))*200-100
|
||||||
|
# Here's is an example of creating a Y from randomly generated
|
||||||
|
# X with multiple columns
|
||||||
|
# Y = X[:,0] + np.sin(X[:,1]) + X[:,2]**2 + X[:,3]**3
|
||||||
|
return X, Y
|
||||||
|
|
||||||
|
def best4DT(seed=1489683273):
|
||||||
|
np.random.seed(seed)
|
||||||
|
X = np.zeros((100,2))
|
||||||
|
Y = np.random.random(size = (100,))*200-100
|
||||||
|
return X, Y
|
||||||
|
|
||||||
|
def author():
|
||||||
|
return 'tb34' #Change this to your user ID
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
print("they call me Tim.")
|
|
@ -0,0 +1,230 @@
|
||||||
|
"""MC3-H1: Best4{LR,DT} - grading script.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
- Switch to a student feedback directory first (will write "points.txt" and "comments.txt" in pwd).
|
||||||
|
- Run this script with both ml4t/ and student solution in PYTHONPATH, e.g.:
|
||||||
|
PYTHONPATH=ml4t:MC3-P1/jdoe7 python ml4t/mc3_p1_grading/grade_learners.py
|
||||||
|
|
||||||
|
Copyright 2018, Georgia Institute of Technology (Georgia Tech)
|
||||||
|
Atlanta, Georgia 30332
|
||||||
|
All Rights Reserved
|
||||||
|
|
||||||
|
Template code for CS 4646/7646
|
||||||
|
|
||||||
|
Georgia Tech asserts copyright ownership of this template and all derivative
|
||||||
|
works, including solutions to the projects assigned in this course. Students
|
||||||
|
and other users of this template code are advised not to share it with others
|
||||||
|
or to make it available on publicly viewable websites including repositories
|
||||||
|
such as github and gitlab. This copyright statement should not be removed
|
||||||
|
or edited.
|
||||||
|
|
||||||
|
We do grant permission to share solutions privately with non-students such
|
||||||
|
as potential employers. However, sharing with other current or future
|
||||||
|
students of CS 7646 is prohibited and subject to being investigated as a
|
||||||
|
GT honor code violation.
|
||||||
|
|
||||||
|
-----do not edit anything above this line---
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from grading.grading import grader, GradeResult, time_limit, run_with_timeout, IncorrectOutput
|
||||||
|
# These two lines will be commented out in the final grading script.
|
||||||
|
from LinRegLearner import LinRegLearner
|
||||||
|
from DTLearner import DTLearner
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback as tb
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import time
|
||||||
|
import functools
|
||||||
|
|
||||||
|
seconds_per_test_case = 5
|
||||||
|
|
||||||
|
max_points = 100.0
|
||||||
|
html_pre_block = True # surround comments with HTML <pre> tag (for T-Square comments field)
|
||||||
|
|
||||||
|
# Test cases
|
||||||
|
Best4TestCase = namedtuple('Best4TestCase', ['description', 'group','max_tests','needed_wins','row_limits','col_limits','seed'])
|
||||||
|
best4_test_cases = [
|
||||||
|
Best4TestCase(
|
||||||
|
description="Test Case 1: Best4LinReg",
|
||||||
|
group="best4lr",
|
||||||
|
max_tests=15,
|
||||||
|
needed_wins=10,
|
||||||
|
row_limits=(10,1000),
|
||||||
|
col_limits=(2,10),
|
||||||
|
seed=1489683274
|
||||||
|
),
|
||||||
|
Best4TestCase(
|
||||||
|
description="Test Case 2: Best4DT",
|
||||||
|
group="best4dt",
|
||||||
|
max_tests=15,
|
||||||
|
needed_wins=10,
|
||||||
|
row_limits=(10,1000),
|
||||||
|
col_limits=(2,10),
|
||||||
|
seed=1489683274
|
||||||
|
),
|
||||||
|
Best4TestCase(
|
||||||
|
description='Test for author() method',
|
||||||
|
group='author',
|
||||||
|
max_tests=None,
|
||||||
|
needed_wins=None,
|
||||||
|
row_limits=None,
|
||||||
|
col_limits=None,
|
||||||
|
seed=None,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test functon(s)
|
||||||
|
@pytest.mark.parametrize("description,group,max_tests,needed_wins,row_limits,col_limits,seed", best4_test_cases)
|
||||||
|
def test_learners(description, group, max_tests, needed_wins, row_limits, col_limits, seed, grader):
|
||||||
|
"""Test data generation methods beat given learner.
|
||||||
|
|
||||||
|
Requires test description, test case group, and a grader fixture.
|
||||||
|
"""
|
||||||
|
|
||||||
|
points_earned = 0.0 # initialize points for this test case
|
||||||
|
incorrect = True
|
||||||
|
msgs = []
|
||||||
|
try:
|
||||||
|
dataX, dataY = None,None
|
||||||
|
same_dataX, same_dataY = None,None
|
||||||
|
diff_dataX, diff_dataY = None,None
|
||||||
|
betterLearner, worseLearner = None, None
|
||||||
|
if group=='author':
|
||||||
|
try:
|
||||||
|
from gen_data import author
|
||||||
|
auth_string = run_with_timeout(author,seconds_per_test_case,(),{})
|
||||||
|
if auth_string == 'tb34':
|
||||||
|
incorrect = True
|
||||||
|
msgs.append(" Incorrect author name (tb34)")
|
||||||
|
points_earned = -10
|
||||||
|
elif auth_string == '':
|
||||||
|
incorrect = True
|
||||||
|
msgs.append(" Empty author name")
|
||||||
|
points_earned = -10
|
||||||
|
else:
|
||||||
|
incorrect = False
|
||||||
|
except Exception as e:
|
||||||
|
incorrect = True
|
||||||
|
msgs.append(" Exception occured when calling author() method: {}".format(e))
|
||||||
|
points_earned = -10
|
||||||
|
else:
|
||||||
|
if group=="best4dt":
|
||||||
|
from gen_data import best4DT
|
||||||
|
dataX, dataY = run_with_timeout(best4DT,seconds_per_test_case,(),{'seed':seed})
|
||||||
|
same_dataX,same_dataY = run_with_timeout(best4DT,seconds_per_test_case,(),{'seed':seed})
|
||||||
|
diff_dataX,diff_dataY = run_with_timeout(best4DT,seconds_per_test_case,(),{'seed':seed+1})
|
||||||
|
betterLearner = DTLearner
|
||||||
|
worseLearner = LinRegLearner
|
||||||
|
elif group=='best4lr':
|
||||||
|
from gen_data import best4LinReg
|
||||||
|
dataX, dataY = run_with_timeout(best4LinReg,seconds_per_test_case,(),{'seed':seed})
|
||||||
|
same_dataX, same_dataY = run_with_timeout(best4LinReg,seconds_per_test_case,(),{'seed':seed})
|
||||||
|
diff_dataX, diff_dataY = run_with_timeout(best4LinReg,seconds_per_test_case,(),{'seed':seed+1})
|
||||||
|
betterLearner = LinRegLearner
|
||||||
|
worseLearner = DTLearner
|
||||||
|
|
||||||
|
num_samples = dataX.shape[0]
|
||||||
|
cutoff = int(num_samples*0.6)
|
||||||
|
worse_better_err = []
|
||||||
|
for run in range(max_tests):
|
||||||
|
permutation = np.random.permutation(num_samples)
|
||||||
|
train_X,train_Y = dataX[permutation[:cutoff]], dataY[permutation[:cutoff]]
|
||||||
|
test_X,test_Y = dataX[permutation[cutoff:]], dataY[permutation[cutoff:]]
|
||||||
|
better = betterLearner()
|
||||||
|
worse = worseLearner()
|
||||||
|
better.addEvidence(train_X,train_Y)
|
||||||
|
worse.addEvidence(train_X,train_Y)
|
||||||
|
better_pred = better.query(test_X)
|
||||||
|
worse_pred = worse.query(test_X)
|
||||||
|
better_err = np.linalg.norm(test_Y-better_pred)
|
||||||
|
worse_err = np.linalg.norm(test_Y-worse_pred)
|
||||||
|
worse_better_err.append( (worse_err,better_err) )
|
||||||
|
worse_better_err.sort(key=functools.cmp_to_key(lambda a,b: int((b[0]-b[1])-(a[0]-a[1]))))
|
||||||
|
better_wins_count = 0
|
||||||
|
for worse_err,better_err in worse_better_err:
|
||||||
|
if better_err < 0.9*worse_err:
|
||||||
|
better_wins_count = better_wins_count+1
|
||||||
|
points_earned += 5.0
|
||||||
|
if better_wins_count >= needed_wins:
|
||||||
|
break
|
||||||
|
incorrect = False
|
||||||
|
if (dataX.shape[0] < row_limits[0]) or (dataX.shape[0]>row_limits[1]):
|
||||||
|
incorrect = True
|
||||||
|
msgs.append(" Invalid number of rows. Should be between {}, found {}".format(row_limits,dataX.shape[0]))
|
||||||
|
points_earned = max(0,points_earned-20)
|
||||||
|
if (dataX.shape[1] < col_limits[0]) or (dataX.shape[1]>col_limits[1]):
|
||||||
|
incorrect = True
|
||||||
|
msgs.append(" Invalid number of columns. Should be between {}, found {}".format(col_limits,dataX.shape[1]))
|
||||||
|
points_earned = max(0,points_earned-20)
|
||||||
|
if better_wins_count < needed_wins:
|
||||||
|
incorrect = True
|
||||||
|
msgs.append(" Better learner did not exceed worse learner. Expected {}, found {}".format(needed_wins,better_wins_count))
|
||||||
|
if not(np.array_equal(same_dataY,dataY)) or not(np.array_equal(same_dataX,dataX)):
|
||||||
|
incorrect = True
|
||||||
|
msgs.append(" Did not produce the same data with the same seed.\n"+\
|
||||||
|
" First dataX:\n{}\n".format(dataX)+\
|
||||||
|
" Second dataX:\n{}\n".format(same_dataX)+\
|
||||||
|
" First dataY:\n{}\n".format(dataY)+\
|
||||||
|
" Second dataY:\n{}\n".format(same_dataY))
|
||||||
|
points_earned = max(0,points_earned-20)
|
||||||
|
if np.array_equal(diff_dataY,dataY) and np.array_equal(diff_dataX,dataX):
|
||||||
|
incorrect = True
|
||||||
|
msgs.append(" Did not produce different data with different seeds.\n"+\
|
||||||
|
" First dataX:\n{}\n".format(dataX)+\
|
||||||
|
" Second dataX:\n{}\n".format(diff_dataX)+\
|
||||||
|
" First dataY:\n{}\n".format(dataY)+\
|
||||||
|
" Second dataY:\n{}\n".format(diff_dataY))
|
||||||
|
points_earned = max(0,points_earned-20)
|
||||||
|
if incorrect:
|
||||||
|
if group=='author':
|
||||||
|
raise IncorrectOutput("Test failed on one or more criteria.\n {}".format('\n'.join(msgs)))
|
||||||
|
else:
|
||||||
|
inputs_str = " Residuals: {}".format(worse_better_err)
|
||||||
|
raise IncorrectOutput("Test failed on one or more output criteria.\n Inputs:\n{}\n Failures:\n{}".format(inputs_str, "\n".join(msgs)))
|
||||||
|
else:
|
||||||
|
if group != 'author':
|
||||||
|
avg_ratio = 0.0
|
||||||
|
worse_better_err.sort(key=functools.cmp_to_key(lambda a,b: int(np.sign((b[0]-b[1])-(a[0]-a[1])))))
|
||||||
|
for we,be in worse_better_err[:10]:
|
||||||
|
avg_ratio += (float(we) - float(be))
|
||||||
|
avg_ratio = avg_ratio/10.0
|
||||||
|
if group=="best4dt":
|
||||||
|
grader.add_performance(np.array([avg_ratio,0]))
|
||||||
|
else:
|
||||||
|
grader.add_performance(np.array([0,avg_ratio]))
|
||||||
|
except Exception as e:
|
||||||
|
# Test result: failed
|
||||||
|
msg = "Description: {} (group: {})\n".format(description, group)
|
||||||
|
|
||||||
|
# Generate a filtered stacktrace, only showing erroneous lines in student file(s)
|
||||||
|
tb_list = tb.extract_tb(sys.exc_info()[2])
|
||||||
|
for i in range(len(tb_list)):
|
||||||
|
row = tb_list[i]
|
||||||
|
tb_list[i] = (os.path.basename(row[0]), row[1], row[2], row[3]) # show only filename instead of long absolute path
|
||||||
|
tb_list = [row for row in tb_list if (row[0] == 'gen_data.py')]
|
||||||
|
if tb_list:
|
||||||
|
msg += "Traceback:\n"
|
||||||
|
msg += ''.join(tb.format_list(tb_list)) # contains newlines
|
||||||
|
elif 'grading_traceback' in dir(e):
|
||||||
|
msg += "Traceback:\n"
|
||||||
|
msg += ''.join(tb.format_list(e.grading_traceback))
|
||||||
|
msg += "{}: {}".format(e.__class__.__name__, str(e))
|
||||||
|
|
||||||
|
# Report failure result to grader, with stacktrace
|
||||||
|
grader.add_result(GradeResult(outcome='failed', points=points_earned, msg=msg))
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
# Test result: passed (no exceptions)
|
||||||
|
grader.add_result(GradeResult(outcome='passed', points=points_earned, msg=None))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main(["-s", __file__])
|
|
@ -0,0 +1,100 @@
|
||||||
|
"""
|
||||||
|
Test best4 data generator. (c) 2016 Tucker Balch
|
||||||
|
Copyright 2018, Georgia Institute of Technology (Georgia Tech)
|
||||||
|
Atlanta, Georgia 30332
|
||||||
|
All Rights Reserved
|
||||||
|
|
||||||
|
Template code for CS 4646/7646
|
||||||
|
|
||||||
|
Georgia Tech asserts copyright ownership of this template and all derivative
|
||||||
|
works, including solutions to the projects assigned in this course. Students
|
||||||
|
and other users of this template code are advised not to share it with others
|
||||||
|
or to make it available on publicly viewable websites including repositories
|
||||||
|
such as github and gitlab. This copyright statement should not be removed
|
||||||
|
or edited.
|
||||||
|
|
||||||
|
We do grant permission to share solutions privately with non-students such
|
||||||
|
as potential employers. However, sharing with other current or future
|
||||||
|
students of CS 7646 is prohibited and subject to being investigated as a
|
||||||
|
GT honor code violation.
|
||||||
|
|
||||||
|
-----do not edit anything above this line---
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import LinRegLearner as lrl
|
||||||
|
import DTLearner as dt
|
||||||
|
from gen_data import best4LinReg, best4DT
|
||||||
|
|
||||||
|
# compare two learners' rmse out of sample
|
||||||
|
def compare_os_rmse(learner1, learner2, X, Y):
|
||||||
|
|
||||||
|
# compute how much of the data is training and testing
|
||||||
|
train_rows = int(math.floor(0.6* X.shape[0]))
|
||||||
|
test_rows = X.shape[0] - train_rows
|
||||||
|
|
||||||
|
# separate out training and testing data
|
||||||
|
train = np.random.choice(X.shape[0], size=train_rows, replace=False)
|
||||||
|
test = np.setdiff1d(np.array(range(X.shape[0])), train)
|
||||||
|
trainX = X[train, :]
|
||||||
|
trainY = Y[train]
|
||||||
|
testX = X[test, :]
|
||||||
|
testY = Y[test]
|
||||||
|
|
||||||
|
# train the learners
|
||||||
|
learner1.addEvidence(trainX, trainY) # train it
|
||||||
|
learner2.addEvidence(trainX, trainY) # train it
|
||||||
|
|
||||||
|
# evaluate learner1 out of sample
|
||||||
|
predY = learner1.query(testX) # get the predictions
|
||||||
|
rmse1 = math.sqrt(((testY - predY) ** 2).sum()/testY.shape[0])
|
||||||
|
|
||||||
|
# evaluate learner2 out of sample
|
||||||
|
predY = learner2.query(testX) # get the predictions
|
||||||
|
rmse2 = math.sqrt(((testY - predY) ** 2).sum()/testY.shape[0])
|
||||||
|
|
||||||
|
return rmse1, rmse2
|
||||||
|
|
||||||
|
def test_code():
|
||||||
|
|
||||||
|
# create two learners and get data
|
||||||
|
lrlearner = lrl.LinRegLearner(verbose = False)
|
||||||
|
dtlearner = dt.DTLearner(verbose = False, leaf_size = 1)
|
||||||
|
X, Y = best4LinReg()
|
||||||
|
|
||||||
|
# compare the two learners
|
||||||
|
rmseLR, rmseDT = compare_os_rmse(lrlearner, dtlearner, X, Y)
|
||||||
|
|
||||||
|
# share results
|
||||||
|
print()
|
||||||
|
print("best4LinReg() results")
|
||||||
|
print(f"RMSE LR : {rmseLR}")
|
||||||
|
print(f"RMSE DT : {rmseDT}")
|
||||||
|
if rmseLR < 0.9 * rmseDT:
|
||||||
|
print("LR < 0.9 DT: pass")
|
||||||
|
else:
|
||||||
|
print("LR >= 0.9 DT: fail")
|
||||||
|
print
|
||||||
|
|
||||||
|
# get data that is best for a random tree
|
||||||
|
lrlearner = lrl.LinRegLearner(verbose = False)
|
||||||
|
dtlearner = dt.DTLearner(verbose = False, leaf_size = 1)
|
||||||
|
X, Y = best4DT()
|
||||||
|
|
||||||
|
# compare the two learners
|
||||||
|
rmseLR, rmseDT = compare_os_rmse(lrlearner, dtlearner, X, Y)
|
||||||
|
|
||||||
|
# share results
|
||||||
|
print()
|
||||||
|
print("best4RT() results")
|
||||||
|
print(f"RMSE LR : {rmseLR}")
|
||||||
|
print(f"RMSE DT : {rmseDT}")
|
||||||
|
if rmseDT < 0.9 * rmseLR:
|
||||||
|
print("DT < 0.9 LR: pass")
|
||||||
|
else:
|
||||||
|
print("DT >= 0.9 LR: fail")
|
||||||
|
print
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
|
test_code()
|
Binary file not shown.
Loading…
Reference in New Issue