Implement p3q1 value iteration.

This commit is contained in:
Felix Martin 2021-11-28 12:41:36 -05:00
parent 524362c5c5
commit d1a4735c5a

View File

@ -43,8 +43,18 @@ class ValueIterationAgent(ValueEstimationAgent):
self.iterations = iterations
self.values = util.Counter() # A Counter is a dict with default 0
# Write value iteration code here
"*** YOUR CODE HERE ***"
for _ in range(iterations):
# update each state once
values_k1 = self.values.copy()
for state in mdp.getStates():
actions = mdp.getPossibleActions(state)
if not actions:
max_action_value = self.values[state]
else:
max_action_value = max([self.computeQValueFromValues(state, action)
for action in actions])
values_k1[state] = max_action_value
self.values = values_k1
def getValue(self, state):
@ -53,14 +63,16 @@ class ValueIterationAgent(ValueEstimationAgent):
"""
return self.values[state]
def computeQValueFromValues(self, state, action):
"""
Compute the Q-value of action in state from the
value function stored in self.values.
"""
"*** YOUR CODE HERE ***"
util.raiseNotDefined()
value = 0
for next_state, prob in self.mdp.getTransitionStatesAndProbs(state, action):
reward = self.mdp.getReward(state, action, next_state)
value += prob * (reward + self.discount * self.values[next_state])
return value
def computeActionFromValues(self, state):
"""
@ -71,8 +83,12 @@ class ValueIterationAgent(ValueEstimationAgent):
there are no legal actions, which is the case at the
terminal state, you should return None.
"""
"*** YOUR CODE HERE ***"
util.raiseNotDefined()
actions = self.mdp.getPossibleActions(state)
if not actions:
return None
return max([(self.computeQValueFromValues(state, action), action)
for action in actions])[1]
def getPolicy(self, state):
return self.computeActionFromValues(state)