139 lines
3.9 KiB
Python
139 lines
3.9 KiB
Python
"""
|
|
Template for implementing QLearner (c) 2015 Tucker Balch
|
|
|
|
Copyright 2018, Georgia Institute of Technology (Georgia Tech)
|
|
Atlanta, Georgia 30332
|
|
All Rights Reserved
|
|
|
|
Template code for CS 4646/7646
|
|
|
|
Georgia Tech asserts copyright ownership of this template and all derivative
|
|
works, including solutions to the projects assigned in this course. Students
|
|
and other users of this template code are advised not to share it with others
|
|
or to make it available on publicly viewable websites including repositories
|
|
such as github and gitlab. This copyright statement should not be removed
|
|
or edited.
|
|
|
|
We do grant permission to share solutions privately with non-students such
|
|
as potential employers. However, sharing with other current or future
|
|
students of CS 7646 is prohibited and subject to being investigated as a
|
|
GT honor code violation.
|
|
|
|
-----do not edit anything above this line---
|
|
"""
|
|
|
|
import random
|
|
import numpy as np
|
|
|
|
|
|
class QLearner(object):
|
|
|
|
def __init__(self,
|
|
num_states=100,
|
|
num_actions=4,
|
|
alpha=0.2,
|
|
gamma=0.9,
|
|
rar=0.5,
|
|
radr=0.99,
|
|
dyna=0,
|
|
verbose=False):
|
|
|
|
self.verbose = verbose
|
|
self.num_actions = num_actions
|
|
self.num_states = num_states
|
|
self.s = 0
|
|
self.a = 0
|
|
self.alpha = alpha
|
|
self.gamma = gamma
|
|
self.rar = rar
|
|
self.radr = radr
|
|
self.dyna = dyna
|
|
|
|
if self.dyna > 0:
|
|
self.model = {}
|
|
self.state_action_list = []
|
|
|
|
# self.q = np.random.random((num_states, num_actions))
|
|
self.q = np.zeros((num_states, num_actions))
|
|
|
|
def _get_a(self, s):
|
|
"""Get best action for state. Considers rar."""
|
|
if random.random() < self.rar:
|
|
a = random.randint(0, self.num_actions - 1)
|
|
else:
|
|
a = np.argmax(self.q[s])
|
|
return a
|
|
|
|
def _update_q(self, s, a, r, s_prime):
|
|
"""Updates the Q table."""
|
|
q_old = self.q[s][a]
|
|
alpha = self.alpha
|
|
|
|
# estimate optimal future value
|
|
a_max = np.argmax(self.q[s_prime])
|
|
q_future = self.q[s_prime][a_max]
|
|
|
|
# calculate new value and update table
|
|
q_new = (1 - alpha) * q_old + alpha * (r + self.gamma * q_future)
|
|
self.q[s][a] = q_new
|
|
|
|
if self.verbose:
|
|
print(f"{q_old=} {q_future=} {q_new=}")
|
|
|
|
def querysetstate(self, s):
|
|
"""
|
|
@summary: Update the state without updating the Q-table
|
|
@param s: The new state
|
|
@returns: The selected action
|
|
"""
|
|
a = self._get_a(s)
|
|
if self.verbose:
|
|
print(f"s = {s}, a = {a}")
|
|
self.s = s
|
|
self.a = a
|
|
return self.a
|
|
|
|
def query(self, s_prime, r):
|
|
"""
|
|
@summary: Update the Q table and return an action
|
|
@param s_prime: The new state
|
|
@param r: The reward
|
|
@returns: The selected action
|
|
"""
|
|
self._update_q(self.s, self.a, r, s_prime)
|
|
a = self._get_a(s_prime)
|
|
|
|
# Update random action rate
|
|
self.rar = self.rar * self.radr
|
|
|
|
if self.dyna > 0:
|
|
self._update_model(self.s, self.a, r, s_prime)
|
|
self._dyna_q()
|
|
|
|
self.a = a
|
|
self.s = s_prime
|
|
return self.a
|
|
|
|
def _update_model(self, s, a, r, s_prime):
|
|
state_action = (s, a)
|
|
if not state_action in self.model:
|
|
self.model[state_action] = (r, s_prime)
|
|
self.state_action_list.append(state_action)
|
|
|
|
def _dyna_q(self):
|
|
for _ in range(self.dyna):
|
|
s, a = random.choice(self.state_action_list)
|
|
r, s_prime = self.model[(s, a)]
|
|
self._update_q(s, a, r, s_prime)
|
|
|
|
def author(self):
|
|
return 'felixm'
|
|
|
|
|
|
if __name__ == "__main__":
|
|
q = QLearner(verbose=True, dyna=2)
|
|
q.querysetstate(2)
|
|
q.query(15, 1.00)
|
|
q.querysetstate(15)
|
|
q.query(17, 0.10)
|