Answer project 5 question 3 MIRA.

This commit is contained in:
2022-01-03 12:46:37 -05:00
parent d6e89e8c00
commit a73b2b35ac

View File

@@ -60,8 +60,47 @@ class MiraClassifier:
datum is a counter from features to values for those features
representing a vector of values.
"""
"*** YOUR CODE HERE ***"
util.raiseNotDefined()
def scaleCounter(counter, scalar):
counter = counter.copy()
scalar = float(scalar)
for key in counter:
counter[key] *= scalar
return counter
def updateWeights(weights, expectedLabel, guessedLabel, datum, c):
weightExpected = weights[expectedLabel]
weightGuessed = weights[guessedLabel]
tau = ((weightExpected - weightGuessed) * datum + 1.0) / ((datum * datum) * 2.0)
tau = min(c, tau)
weights[expectedLabel] = weights[expectedLabel] + scaleCounter(datum, tau)
weights[guessedLabel] = weights[guessedLabel] - scaleCounter(datum, tau)
def evaluateWeights(weights):
correct = 0
for datum, expectedLabel in zip(validationData, validationLabels):
guessedLabel = guessLabel(weights, datum)
if guessedLabel != expectedLabel:
correct += 1
return correct / float(len(validationData))
def guessLabel(weights, datum):
vectors = util.Counter()
for l in self.legalLabels:
vectors[l] = weights[l] * datum
return vectors.argMax()
allWeights = []
for c in Cgrid:
weights = self.weights.copy()
for iteration in range(self.max_iterations):
for datum, expectedLabel in zip(trainingData, trainingLabels):
guessedLabel = guessLabel(weights, datum)
if guessedLabel != expectedLabel:
updateWeights(weights, expectedLabel, guessedLabel, datum, c)
accuracy = evaluateWeights(weights)
allWeights.append((accuracy, weights))
self.weights = max(allWeights)[1]
def classify(self, data ):
"""