diff --git a/p3_rl/valueIterationAgents.py b/p3_rl/valueIterationAgents.py index eb17ea6..c4b080e 100644 --- a/p3_rl/valueIterationAgents.py +++ b/p3_rl/valueIterationAgents.py @@ -43,8 +43,18 @@ class ValueIterationAgent(ValueEstimationAgent): self.iterations = iterations self.values = util.Counter() # A Counter is a dict with default 0 - # Write value iteration code here - "*** YOUR CODE HERE ***" + for _ in range(iterations): + # update each state once + values_k1 = self.values.copy() + for state in mdp.getStates(): + actions = mdp.getPossibleActions(state) + if not actions: + max_action_value = self.values[state] + else: + max_action_value = max([self.computeQValueFromValues(state, action) + for action in actions]) + values_k1[state] = max_action_value + self.values = values_k1 def getValue(self, state): @@ -53,14 +63,16 @@ class ValueIterationAgent(ValueEstimationAgent): """ return self.values[state] - def computeQValueFromValues(self, state, action): """ Compute the Q-value of action in state from the value function stored in self.values. """ - "*** YOUR CODE HERE ***" - util.raiseNotDefined() + value = 0 + for next_state, prob in self.mdp.getTransitionStatesAndProbs(state, action): + reward = self.mdp.getReward(state, action, next_state) + value += prob * (reward + self.discount * self.values[next_state]) + return value def computeActionFromValues(self, state): """ @@ -71,8 +83,12 @@ class ValueIterationAgent(ValueEstimationAgent): there are no legal actions, which is the case at the terminal state, you should return None. """ - "*** YOUR CODE HERE ***" - util.raiseNotDefined() + actions = self.mdp.getPossibleActions(state) + if not actions: + return None + return max([(self.computeQValueFromValues(state, action), action) + for action in actions])[1] + def getPolicy(self, state): return self.computeActionFromValues(state)