Answer project 5 question 3 MIRA.
This commit is contained in:
@@ -60,8 +60,47 @@ class MiraClassifier:
|
||||
datum is a counter from features to values for those features
|
||||
representing a vector of values.
|
||||
"""
|
||||
"*** YOUR CODE HERE ***"
|
||||
util.raiseNotDefined()
|
||||
|
||||
def scaleCounter(counter, scalar):
|
||||
counter = counter.copy()
|
||||
scalar = float(scalar)
|
||||
for key in counter:
|
||||
counter[key] *= scalar
|
||||
return counter
|
||||
|
||||
def updateWeights(weights, expectedLabel, guessedLabel, datum, c):
|
||||
weightExpected = weights[expectedLabel]
|
||||
weightGuessed = weights[guessedLabel]
|
||||
tau = ((weightExpected - weightGuessed) * datum + 1.0) / ((datum * datum) * 2.0)
|
||||
tau = min(c, tau)
|
||||
weights[expectedLabel] = weights[expectedLabel] + scaleCounter(datum, tau)
|
||||
weights[guessedLabel] = weights[guessedLabel] - scaleCounter(datum, tau)
|
||||
|
||||
def evaluateWeights(weights):
|
||||
correct = 0
|
||||
for datum, expectedLabel in zip(validationData, validationLabels):
|
||||
guessedLabel = guessLabel(weights, datum)
|
||||
if guessedLabel != expectedLabel:
|
||||
correct += 1
|
||||
return correct / float(len(validationData))
|
||||
|
||||
def guessLabel(weights, datum):
|
||||
vectors = util.Counter()
|
||||
for l in self.legalLabels:
|
||||
vectors[l] = weights[l] * datum
|
||||
return vectors.argMax()
|
||||
|
||||
allWeights = []
|
||||
for c in Cgrid:
|
||||
weights = self.weights.copy()
|
||||
for iteration in range(self.max_iterations):
|
||||
for datum, expectedLabel in zip(trainingData, trainingLabels):
|
||||
guessedLabel = guessLabel(weights, datum)
|
||||
if guessedLabel != expectedLabel:
|
||||
updateWeights(weights, expectedLabel, guessedLabel, datum, c)
|
||||
accuracy = evaluateWeights(weights)
|
||||
allWeights.append((accuracy, weights))
|
||||
self.weights = max(allWeights)[1]
|
||||
|
||||
def classify(self, data ):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user