Start working on defeat learners assignment.
This commit is contained in:
64
defeat_learners/DTLearner.py
Normal file
64
defeat_learners/DTLearner.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
A simple wrapper for linear regression. (c) 2015 Tucker Balch
|
||||
Note, this is NOT a correct DTLearner; Replace with your own implementation.
|
||||
Copyright 2018, Georgia Institute of Technology (Georgia Tech)
|
||||
Atlanta, Georgia 30332
|
||||
All Rights Reserved
|
||||
|
||||
Template code for CS 4646/7646
|
||||
|
||||
Georgia Tech asserts copyright ownership of this template and all derivative
|
||||
works, including solutions to the projects assigned in this course. Students
|
||||
and other users of this template code are advised not to share it with others
|
||||
or to make it available on publicly viewable websites including repositories
|
||||
such as github and gitlab. This copyright statement should not be removed
|
||||
or edited.
|
||||
|
||||
We do grant permission to share solutions privately with non-students such
|
||||
as potential employers. However, sharing with other current or future
|
||||
students of CS 7646 is prohibited and subject to being investigated as a
|
||||
GT honor code violation.
|
||||
|
||||
-----do not edit anything above this line---
|
||||
|
||||
Student Name: Tucker Balch (replace with your name)
|
||||
GT User ID: tb34 (replace with your User ID)
|
||||
GT ID: 900897987 (replace with your GT ID)
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import warnings
|
||||
|
||||
class DTLearner(object):
|
||||
|
||||
def __init__(self, leaf_size=1, verbose = False):
|
||||
warnings.warn("\n\n WARNING! THIS IS NOT A CORRECT DTLearner IMPLEMENTATION! REPLACE WITH YOUR OWN CODE\n")
|
||||
pass # move along, these aren't the drones you're looking for
|
||||
|
||||
def author(self):
|
||||
return 'tb34' # replace tb34 with your Georgia Tech username
|
||||
|
||||
def addEvidence(self,dataX,dataY):
|
||||
"""
|
||||
@summary: Add training data to learner
|
||||
@param dataX: X values of data to add
|
||||
@param dataY: the Y training values
|
||||
"""
|
||||
|
||||
# slap on 1s column so linear regression finds a constant term
|
||||
newdataX = np.ones([dataX.shape[0],dataX.shape[1]+1])
|
||||
newdataX[:,0:dataX.shape[1]]=dataX
|
||||
|
||||
# build and save the model
|
||||
self.model_coefs, residuals, rank, s = np.linalg.lstsq(newdataX, dataY, rcond=None)
|
||||
|
||||
def query(self,points):
|
||||
"""
|
||||
@summary: Estimate a set of test points given the model we built.
|
||||
@param points: should be a numpy array with each row corresponding to a specific query.
|
||||
@returns the estimated values according to the saved model.
|
||||
"""
|
||||
return (self.model_coefs[:-1] * points).sum(axis = 1) + self.model_coefs[-1]
|
||||
|
||||
if __name__=="__main__":
|
||||
print("the secret clue is 'zzyzx'")
|
||||
Reference in New Issue
Block a user