Answer project 5 question 3 MIRA.
This commit is contained in:
@@ -60,8 +60,47 @@ class MiraClassifier:
|
|||||||
datum is a counter from features to values for those features
|
datum is a counter from features to values for those features
|
||||||
representing a vector of values.
|
representing a vector of values.
|
||||||
"""
|
"""
|
||||||
"*** YOUR CODE HERE ***"
|
|
||||||
util.raiseNotDefined()
|
def scaleCounter(counter, scalar):
|
||||||
|
counter = counter.copy()
|
||||||
|
scalar = float(scalar)
|
||||||
|
for key in counter:
|
||||||
|
counter[key] *= scalar
|
||||||
|
return counter
|
||||||
|
|
||||||
|
def updateWeights(weights, expectedLabel, guessedLabel, datum, c):
|
||||||
|
weightExpected = weights[expectedLabel]
|
||||||
|
weightGuessed = weights[guessedLabel]
|
||||||
|
tau = ((weightExpected - weightGuessed) * datum + 1.0) / ((datum * datum) * 2.0)
|
||||||
|
tau = min(c, tau)
|
||||||
|
weights[expectedLabel] = weights[expectedLabel] + scaleCounter(datum, tau)
|
||||||
|
weights[guessedLabel] = weights[guessedLabel] - scaleCounter(datum, tau)
|
||||||
|
|
||||||
|
def evaluateWeights(weights):
|
||||||
|
correct = 0
|
||||||
|
for datum, expectedLabel in zip(validationData, validationLabels):
|
||||||
|
guessedLabel = guessLabel(weights, datum)
|
||||||
|
if guessedLabel != expectedLabel:
|
||||||
|
correct += 1
|
||||||
|
return correct / float(len(validationData))
|
||||||
|
|
||||||
|
def guessLabel(weights, datum):
|
||||||
|
vectors = util.Counter()
|
||||||
|
for l in self.legalLabels:
|
||||||
|
vectors[l] = weights[l] * datum
|
||||||
|
return vectors.argMax()
|
||||||
|
|
||||||
|
allWeights = []
|
||||||
|
for c in Cgrid:
|
||||||
|
weights = self.weights.copy()
|
||||||
|
for iteration in range(self.max_iterations):
|
||||||
|
for datum, expectedLabel in zip(trainingData, trainingLabels):
|
||||||
|
guessedLabel = guessLabel(weights, datum)
|
||||||
|
if guessedLabel != expectedLabel:
|
||||||
|
updateWeights(weights, expectedLabel, guessedLabel, datum, c)
|
||||||
|
accuracy = evaluateWeights(weights)
|
||||||
|
allWeights.append((accuracy, weights))
|
||||||
|
self.weights = max(allWeights)[1]
|
||||||
|
|
||||||
def classify(self, data ):
|
def classify(self, data ):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user