Compare commits

...

2 Commits

Author SHA1 Message Date
d112dce5f5 Implement dyna-q to finish project 7 2020-10-19 08:56:24 -04:00
22022c3780 Update readme for project 7 2020-10-18 14:48:15 -04:00
3 changed files with 38 additions and 15 deletions

View File

@ -38,6 +38,6 @@ unzip -n zips/*.zip -d ./
- [Report 3](./assess_learners/assess_learners.md)
- No reports for projects 4 (defeat learners) and 5 (marketsim)
- [Report 6](./manual_strategy/manual_strategy.md)
- [Report 7](#)
- No report for project 7
- [Report 8](#)

View File

@ -22,8 +22,8 @@ GT honor code violation.
-----do not edit anything above this line---
"""
import random
import numpy as np
import random as rand
class QLearner(object):
@ -49,27 +49,32 @@ class QLearner(object):
self.radr = radr
self.dyna = dyna
if self.dyna > 0:
self.model = {}
self.state_action_list = []
# self.q = np.random.random((num_states, num_actions))
self.q = np.zeros((num_states, num_actions))
def _get_a(self, s):
"""Get best action for state. Considers rar."""
if rand.random() < self.rar:
a = rand.randint(0, self.num_actions - 1)
if random.random() < self.rar:
a = random.randint(0, self.num_actions - 1)
else:
a = np.argmax(self.q[s])
return a
def _update_q(self, s, a, s_prime, r):
def _update_q(self, s, a, r, s_prime):
"""Updates the Q table."""
q_old = self.q[s][a]
alpha = self.alpha
# estimate optimal future value
a_max = np.argmax(self.q[s_prime])
q_future = self.q[s_prime][a_max]
# calculate new value and update table
q_new = q_old + self.alpha * (r + self.gamma * q_future - q_old)
q_new = (1 - alpha) * q_old + alpha * (r + self.gamma * q_future)
self.q[s][a] = q_new
if self.verbose:
@ -95,21 +100,39 @@ class QLearner(object):
@param r: The reward
@returns: The selected action
"""
self._update_q(self.s, self.a, s_prime, r)
self.a = self._get_a(s_prime)
self.s = s_prime
if self.verbose:
print(f"s = {s_prime}, a = {self.a}, r={r}")
self._update_q(self.s, self.a, r, s_prime)
a = self._get_a(s_prime)
# Update random action rate
self.rar = self.rar * self.radr
if self.dyna > 0:
self._update_model(self.s, self.a, r, s_prime)
self._dyna_q()
self.a = a
self.s = s_prime
return self.a
def _update_model(self, s, a, r, s_prime):
state_action = (s, a)
if not state_action in self.model:
self.model[state_action] = (r, s_prime)
self.state_action_list.append(state_action)
def _dyna_q(self):
for _ in range(self.dyna):
s, a = random.choice(self.state_action_list)
r, s_prime = self.model[(s, a)]
self._update_q(s, a, r, s_prime)
def author(self):
return 'felixm'
if __name__ == "__main__":
q = QLearner(verbose=True)
print(q.querysetstate(2))
q = QLearner(verbose=True, dyna=2)
q.querysetstate(2)
q.query(15, 1.00)
print(q.querysetstate(15))
q.querysetstate(15)
q.query(17, 0.10)

View File

@ -171,7 +171,7 @@ def test(map, epochs, learner, verbose):
# run the code to test a learner
def test_code():
verbose = True # print lots of debug stuff if True
verbose = False # print lots of debug stuff if True
# read in the map
filename = 'testworlds/world01.csv'