Implement p3q1 value iteration.

This commit is contained in:
Felix Martin 2021-11-28 12:41:36 -05:00
parent 524362c5c5
commit d1a4735c5a

View File

@ -43,8 +43,18 @@ class ValueIterationAgent(ValueEstimationAgent):
self.iterations = iterations self.iterations = iterations
self.values = util.Counter() # A Counter is a dict with default 0 self.values = util.Counter() # A Counter is a dict with default 0
# Write value iteration code here for _ in range(iterations):
"*** YOUR CODE HERE ***" # update each state once
values_k1 = self.values.copy()
for state in mdp.getStates():
actions = mdp.getPossibleActions(state)
if not actions:
max_action_value = self.values[state]
else:
max_action_value = max([self.computeQValueFromValues(state, action)
for action in actions])
values_k1[state] = max_action_value
self.values = values_k1
def getValue(self, state): def getValue(self, state):
@ -53,14 +63,16 @@ class ValueIterationAgent(ValueEstimationAgent):
""" """
return self.values[state] return self.values[state]
def computeQValueFromValues(self, state, action): def computeQValueFromValues(self, state, action):
""" """
Compute the Q-value of action in state from the Compute the Q-value of action in state from the
value function stored in self.values. value function stored in self.values.
""" """
"*** YOUR CODE HERE ***" value = 0
util.raiseNotDefined() for next_state, prob in self.mdp.getTransitionStatesAndProbs(state, action):
reward = self.mdp.getReward(state, action, next_state)
value += prob * (reward + self.discount * self.values[next_state])
return value
def computeActionFromValues(self, state): def computeActionFromValues(self, state):
""" """
@ -71,8 +83,12 @@ class ValueIterationAgent(ValueEstimationAgent):
there are no legal actions, which is the case at the there are no legal actions, which is the case at the
terminal state, you should return None. terminal state, you should return None.
""" """
"*** YOUR CODE HERE ***" actions = self.mdp.getPossibleActions(state)
util.raiseNotDefined() if not actions:
return None
return max([(self.computeQValueFromValues(state, action), action)
for action in actions])[1]
def getPolicy(self, state): def getPolicy(self, state):
return self.computeActionFromValues(state) return self.computeActionFromValues(state)