Implement dyna-q to finish project 7
This commit is contained in:
@@ -22,8 +22,8 @@ GT honor code violation.
|
|||||||
-----do not edit anything above this line---
|
-----do not edit anything above this line---
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random as rand
|
|
||||||
|
|
||||||
|
|
||||||
class QLearner(object):
|
class QLearner(object):
|
||||||
@@ -49,27 +49,32 @@ class QLearner(object):
|
|||||||
self.radr = radr
|
self.radr = radr
|
||||||
self.dyna = dyna
|
self.dyna = dyna
|
||||||
|
|
||||||
|
if self.dyna > 0:
|
||||||
|
self.model = {}
|
||||||
|
self.state_action_list = []
|
||||||
|
|
||||||
# self.q = np.random.random((num_states, num_actions))
|
# self.q = np.random.random((num_states, num_actions))
|
||||||
self.q = np.zeros((num_states, num_actions))
|
self.q = np.zeros((num_states, num_actions))
|
||||||
|
|
||||||
def _get_a(self, s):
|
def _get_a(self, s):
|
||||||
"""Get best action for state. Considers rar."""
|
"""Get best action for state. Considers rar."""
|
||||||
if rand.random() < self.rar:
|
if random.random() < self.rar:
|
||||||
a = rand.randint(0, self.num_actions - 1)
|
a = random.randint(0, self.num_actions - 1)
|
||||||
else:
|
else:
|
||||||
a = np.argmax(self.q[s])
|
a = np.argmax(self.q[s])
|
||||||
return a
|
return a
|
||||||
|
|
||||||
def _update_q(self, s, a, s_prime, r):
|
def _update_q(self, s, a, r, s_prime):
|
||||||
"""Updates the Q table."""
|
"""Updates the Q table."""
|
||||||
q_old = self.q[s][a]
|
q_old = self.q[s][a]
|
||||||
|
alpha = self.alpha
|
||||||
|
|
||||||
# estimate optimal future value
|
# estimate optimal future value
|
||||||
a_max = np.argmax(self.q[s_prime])
|
a_max = np.argmax(self.q[s_prime])
|
||||||
q_future = self.q[s_prime][a_max]
|
q_future = self.q[s_prime][a_max]
|
||||||
|
|
||||||
# calculate new value and update table
|
# calculate new value and update table
|
||||||
q_new = q_old + self.alpha * (r + self.gamma * q_future - q_old)
|
q_new = (1 - alpha) * q_old + alpha * (r + self.gamma * q_future)
|
||||||
self.q[s][a] = q_new
|
self.q[s][a] = q_new
|
||||||
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
@@ -95,21 +100,39 @@ class QLearner(object):
|
|||||||
@param r: The reward
|
@param r: The reward
|
||||||
@returns: The selected action
|
@returns: The selected action
|
||||||
"""
|
"""
|
||||||
self._update_q(self.s, self.a, s_prime, r)
|
self._update_q(self.s, self.a, r, s_prime)
|
||||||
self.a = self._get_a(s_prime)
|
a = self._get_a(s_prime)
|
||||||
self.s = s_prime
|
|
||||||
if self.verbose:
|
|
||||||
print(f"s = {s_prime}, a = {self.a}, r={r}")
|
|
||||||
# Update random action rate
|
# Update random action rate
|
||||||
self.rar = self.rar * self.radr
|
self.rar = self.rar * self.radr
|
||||||
|
|
||||||
|
if self.dyna > 0:
|
||||||
|
self._update_model(self.s, self.a, r, s_prime)
|
||||||
|
self._dyna_q()
|
||||||
|
|
||||||
|
self.a = a
|
||||||
|
self.s = s_prime
|
||||||
return self.a
|
return self.a
|
||||||
|
|
||||||
|
def _update_model(self, s, a, r, s_prime):
|
||||||
|
state_action = (s, a)
|
||||||
|
if not state_action in self.model:
|
||||||
|
self.model[state_action] = (r, s_prime)
|
||||||
|
self.state_action_list.append(state_action)
|
||||||
|
|
||||||
|
def _dyna_q(self):
|
||||||
|
for _ in range(self.dyna):
|
||||||
|
s, a = random.choice(self.state_action_list)
|
||||||
|
r, s_prime = self.model[(s, a)]
|
||||||
|
self._update_q(s, a, r, s_prime)
|
||||||
|
|
||||||
def author(self):
|
def author(self):
|
||||||
return 'felixm'
|
return 'felixm'
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
q = QLearner(verbose=True)
|
q = QLearner(verbose=True, dyna=2)
|
||||||
print(q.querysetstate(2))
|
q.querysetstate(2)
|
||||||
q.query(15, 1.00)
|
q.query(15, 1.00)
|
||||||
print(q.querysetstate(15))
|
q.querysetstate(15)
|
||||||
|
q.query(17, 0.10)
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ def test(map, epochs, learner, verbose):
|
|||||||
# run the code to test a learner
|
# run the code to test a learner
|
||||||
def test_code():
|
def test_code():
|
||||||
|
|
||||||
verbose = True # print lots of debug stuff if True
|
verbose = False # print lots of debug stuff if True
|
||||||
|
|
||||||
# read in the map
|
# read in the map
|
||||||
filename = 'testworlds/world01.csv'
|
filename = 'testworlds/world01.csv'
|
||||||
|
|||||||
Reference in New Issue
Block a user