Compare commits
2 Commits
d5aa22e9dd
...
d112dce5f5
Author | SHA1 | Date | |
---|---|---|---|
d112dce5f5 | |||
22022c3780 |
@ -38,6 +38,6 @@ unzip -n zips/*.zip -d ./
|
||||
- [Report 3](./assess_learners/assess_learners.md)
|
||||
- No reports for projects 4 (defeat learners) and 5 (marketsim)
|
||||
- [Report 6](./manual_strategy/manual_strategy.md)
|
||||
- [Report 7](#)
|
||||
- No report for project 7
|
||||
- [Report 8](#)
|
||||
|
||||
|
@ -22,8 +22,8 @@ GT honor code violation.
|
||||
-----do not edit anything above this line---
|
||||
"""
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
import random as rand
|
||||
|
||||
|
||||
class QLearner(object):
|
||||
@ -49,27 +49,32 @@ class QLearner(object):
|
||||
self.radr = radr
|
||||
self.dyna = dyna
|
||||
|
||||
if self.dyna > 0:
|
||||
self.model = {}
|
||||
self.state_action_list = []
|
||||
|
||||
# self.q = np.random.random((num_states, num_actions))
|
||||
self.q = np.zeros((num_states, num_actions))
|
||||
|
||||
def _get_a(self, s):
|
||||
"""Get best action for state. Considers rar."""
|
||||
if rand.random() < self.rar:
|
||||
a = rand.randint(0, self.num_actions - 1)
|
||||
if random.random() < self.rar:
|
||||
a = random.randint(0, self.num_actions - 1)
|
||||
else:
|
||||
a = np.argmax(self.q[s])
|
||||
return a
|
||||
|
||||
def _update_q(self, s, a, s_prime, r):
|
||||
def _update_q(self, s, a, r, s_prime):
|
||||
"""Updates the Q table."""
|
||||
q_old = self.q[s][a]
|
||||
alpha = self.alpha
|
||||
|
||||
# estimate optimal future value
|
||||
a_max = np.argmax(self.q[s_prime])
|
||||
q_future = self.q[s_prime][a_max]
|
||||
|
||||
# calculate new value and update table
|
||||
q_new = q_old + self.alpha * (r + self.gamma * q_future - q_old)
|
||||
q_new = (1 - alpha) * q_old + alpha * (r + self.gamma * q_future)
|
||||
self.q[s][a] = q_new
|
||||
|
||||
if self.verbose:
|
||||
@ -95,21 +100,39 @@ class QLearner(object):
|
||||
@param r: The reward
|
||||
@returns: The selected action
|
||||
"""
|
||||
self._update_q(self.s, self.a, s_prime, r)
|
||||
self.a = self._get_a(s_prime)
|
||||
self.s = s_prime
|
||||
if self.verbose:
|
||||
print(f"s = {s_prime}, a = {self.a}, r={r}")
|
||||
self._update_q(self.s, self.a, r, s_prime)
|
||||
a = self._get_a(s_prime)
|
||||
|
||||
# Update random action rate
|
||||
self.rar = self.rar * self.radr
|
||||
|
||||
if self.dyna > 0:
|
||||
self._update_model(self.s, self.a, r, s_prime)
|
||||
self._dyna_q()
|
||||
|
||||
self.a = a
|
||||
self.s = s_prime
|
||||
return self.a
|
||||
|
||||
def _update_model(self, s, a, r, s_prime):
|
||||
state_action = (s, a)
|
||||
if not state_action in self.model:
|
||||
self.model[state_action] = (r, s_prime)
|
||||
self.state_action_list.append(state_action)
|
||||
|
||||
def _dyna_q(self):
|
||||
for _ in range(self.dyna):
|
||||
s, a = random.choice(self.state_action_list)
|
||||
r, s_prime = self.model[(s, a)]
|
||||
self._update_q(s, a, r, s_prime)
|
||||
|
||||
def author(self):
|
||||
return 'felixm'
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
q = QLearner(verbose=True)
|
||||
print(q.querysetstate(2))
|
||||
q = QLearner(verbose=True, dyna=2)
|
||||
q.querysetstate(2)
|
||||
q.query(15, 1.00)
|
||||
print(q.querysetstate(15))
|
||||
q.querysetstate(15)
|
||||
q.query(17, 0.10)
|
||||
|
@ -171,7 +171,7 @@ def test(map, epochs, learner, verbose):
|
||||
# run the code to test a learner
|
||||
def test_code():
|
||||
|
||||
verbose = True # print lots of debug stuff if True
|
||||
verbose = False # print lots of debug stuff if True
|
||||
|
||||
# read in the map
|
||||
filename = 'testworlds/world01.csv'
|
||||
|
Loading…
Reference in New Issue
Block a user