Compare commits

..

2 Commits

Author SHA1 Message Date
d112dce5f5 Implement dyna-q to finish project 7 2020-10-19 08:56:24 -04:00
22022c3780 Update readme for project 7 2020-10-18 14:48:15 -04:00
3 changed files with 38 additions and 15 deletions

View File

@@ -38,6 +38,6 @@ unzip -n zips/*.zip -d ./
- [Report 3](./assess_learners/assess_learners.md) - [Report 3](./assess_learners/assess_learners.md)
- No reports for projects 4 (defeat learners) and 5 (marketsim) - No reports for projects 4 (defeat learners) and 5 (marketsim)
- [Report 6](./manual_strategy/manual_strategy.md) - [Report 6](./manual_strategy/manual_strategy.md)
- [Report 7](#) - No report for project 7
- [Report 8](#) - [Report 8](#)

View File

@@ -22,8 +22,8 @@ GT honor code violation.
-----do not edit anything above this line--- -----do not edit anything above this line---
""" """
import random
import numpy as np import numpy as np
import random as rand
class QLearner(object): class QLearner(object):
@@ -49,27 +49,32 @@ class QLearner(object):
self.radr = radr self.radr = radr
self.dyna = dyna self.dyna = dyna
if self.dyna > 0:
self.model = {}
self.state_action_list = []
# self.q = np.random.random((num_states, num_actions)) # self.q = np.random.random((num_states, num_actions))
self.q = np.zeros((num_states, num_actions)) self.q = np.zeros((num_states, num_actions))
def _get_a(self, s): def _get_a(self, s):
"""Get best action for state. Considers rar.""" """Get best action for state. Considers rar."""
if rand.random() < self.rar: if random.random() < self.rar:
a = rand.randint(0, self.num_actions - 1) a = random.randint(0, self.num_actions - 1)
else: else:
a = np.argmax(self.q[s]) a = np.argmax(self.q[s])
return a return a
def _update_q(self, s, a, s_prime, r): def _update_q(self, s, a, r, s_prime):
"""Updates the Q table.""" """Updates the Q table."""
q_old = self.q[s][a] q_old = self.q[s][a]
alpha = self.alpha
# estimate optimal future value # estimate optimal future value
a_max = np.argmax(self.q[s_prime]) a_max = np.argmax(self.q[s_prime])
q_future = self.q[s_prime][a_max] q_future = self.q[s_prime][a_max]
# calculate new value and update table # calculate new value and update table
q_new = q_old + self.alpha * (r + self.gamma * q_future - q_old) q_new = (1 - alpha) * q_old + alpha * (r + self.gamma * q_future)
self.q[s][a] = q_new self.q[s][a] = q_new
if self.verbose: if self.verbose:
@@ -95,21 +100,39 @@ class QLearner(object):
@param r: The reward @param r: The reward
@returns: The selected action @returns: The selected action
""" """
self._update_q(self.s, self.a, s_prime, r) self._update_q(self.s, self.a, r, s_prime)
self.a = self._get_a(s_prime) a = self._get_a(s_prime)
self.s = s_prime
if self.verbose:
print(f"s = {s_prime}, a = {self.a}, r={r}")
# Update random action rate # Update random action rate
self.rar = self.rar * self.radr self.rar = self.rar * self.radr
if self.dyna > 0:
self._update_model(self.s, self.a, r, s_prime)
self._dyna_q()
self.a = a
self.s = s_prime
return self.a return self.a
def _update_model(self, s, a, r, s_prime):
state_action = (s, a)
if not state_action in self.model:
self.model[state_action] = (r, s_prime)
self.state_action_list.append(state_action)
def _dyna_q(self):
for _ in range(self.dyna):
s, a = random.choice(self.state_action_list)
r, s_prime = self.model[(s, a)]
self._update_q(s, a, r, s_prime)
def author(self): def author(self):
return 'felixm' return 'felixm'
if __name__ == "__main__": if __name__ == "__main__":
q = QLearner(verbose=True) q = QLearner(verbose=True, dyna=2)
print(q.querysetstate(2)) q.querysetstate(2)
q.query(15, 1.00) q.query(15, 1.00)
print(q.querysetstate(15)) q.querysetstate(15)
q.query(17, 0.10)

View File

@@ -171,7 +171,7 @@ def test(map, epochs, learner, verbose):
# run the code to test a learner # run the code to test a learner
def test_code(): def test_code():
verbose = True # print lots of debug stuff if True verbose = False # print lots of debug stuff if True
# read in the map # read in the map
filename = 'testworlds/world01.csv' filename = 'testworlds/world01.csv'