Implement Q learner
This commit is contained in:
@@ -1,72 +1,115 @@
|
||||
"""
|
||||
Template for implementing QLearner (c) 2015 Tucker Balch
|
||||
|
||||
Copyright 2018, Georgia Institute of Technology (Georgia Tech)
|
||||
Atlanta, Georgia 30332
|
||||
All Rights Reserved
|
||||
|
||||
Template code for CS 4646/7646
|
||||
|
||||
Georgia Tech asserts copyright ownership of this template and all derivative
|
||||
works, including solutions to the projects assigned in this course. Students
|
||||
and other users of this template code are advised not to share it with others
|
||||
or to make it available on publicly viewable websites including repositories
|
||||
such as github and gitlab. This copyright statement should not be removed
|
||||
or edited.
|
||||
|
||||
We do grant permission to share solutions privately with non-students such
|
||||
as potential employers. However, sharing with other current or future
|
||||
students of CS 7646 is prohibited and subject to being investigated as a
|
||||
GT honor code violation.
|
||||
|
||||
-----do not edit anything above this line---
|
||||
|
||||
Student Name: Tucker Balch (replace with your name)
|
||||
GT User ID: tb34 (replace with your User ID)
|
||||
GT ID: 900897987 (replace with your GT ID)
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import random as rand
|
||||
|
||||
class QLearner(object):
|
||||
|
||||
def __init__(self, \
|
||||
num_states=100, \
|
||||
num_actions = 4, \
|
||||
alpha = 0.2, \
|
||||
gamma = 0.9, \
|
||||
rar = 0.5, \
|
||||
radr = 0.99, \
|
||||
dyna = 0, \
|
||||
verbose = False):
|
||||
|
||||
self.verbose = verbose
|
||||
self.num_actions = num_actions
|
||||
self.s = 0
|
||||
self.a = 0
|
||||
|
||||
def querysetstate(self, s):
|
||||
"""
|
||||
@summary: Update the state without updating the Q-table
|
||||
@param s: The new state
|
||||
@returns: The selected action
|
||||
"""
|
||||
self.s = s
|
||||
action = rand.randint(0, self.num_actions-1)
|
||||
if self.verbose: print(f"s = {s}, a = {action}")
|
||||
return action
|
||||
|
||||
def query(self,s_prime,r):
|
||||
"""
|
||||
@summary: Update the Q table and return an action
|
||||
@param s_prime: The new state
|
||||
@param r: The reward
|
||||
@returns: The selected action
|
||||
"""
|
||||
action = rand.randint(0, self.num_actions-1)
|
||||
if self.verbose: print(f"s = {s_prime}, a = {action}, r={r}")
|
||||
return action
|
||||
|
||||
if __name__=="__main__":
|
||||
print("Remember Q from Star Trek? Well, this isn't him")
|
||||
"""
|
||||
Template for implementing QLearner (c) 2015 Tucker Balch
|
||||
|
||||
Copyright 2018, Georgia Institute of Technology (Georgia Tech)
|
||||
Atlanta, Georgia 30332
|
||||
All Rights Reserved
|
||||
|
||||
Template code for CS 4646/7646
|
||||
|
||||
Georgia Tech asserts copyright ownership of this template and all derivative
|
||||
works, including solutions to the projects assigned in this course. Students
|
||||
and other users of this template code are advised not to share it with others
|
||||
or to make it available on publicly viewable websites including repositories
|
||||
such as github and gitlab. This copyright statement should not be removed
|
||||
or edited.
|
||||
|
||||
We do grant permission to share solutions privately with non-students such
|
||||
as potential employers. However, sharing with other current or future
|
||||
students of CS 7646 is prohibited and subject to being investigated as a
|
||||
GT honor code violation.
|
||||
|
||||
-----do not edit anything above this line---
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import random as rand
|
||||
|
||||
|
||||
class QLearner(object):
|
||||
|
||||
def __init__(self,
|
||||
num_states=100,
|
||||
num_actions=4,
|
||||
alpha=0.2,
|
||||
gamma=0.9,
|
||||
rar=0.5,
|
||||
radr=0.99,
|
||||
dyna=0,
|
||||
verbose=False):
|
||||
|
||||
self.verbose = verbose
|
||||
self.num_actions = num_actions
|
||||
self.num_states = num_states
|
||||
self.s = 0
|
||||
self.a = 0
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.rar = rar
|
||||
self.radr = radr
|
||||
self.dyna = dyna
|
||||
|
||||
# self.q = np.random.random((num_states, num_actions))
|
||||
self.q = np.zeros((num_states, num_actions))
|
||||
|
||||
def _get_a(self, s):
|
||||
"""Get best action for state. Considers rar."""
|
||||
if rand.random() < self.rar:
|
||||
a = rand.randint(0, self.num_actions - 1)
|
||||
else:
|
||||
a = np.argmax(self.q[s])
|
||||
return a
|
||||
|
||||
def _update_q(self, s, a, s_prime, r):
|
||||
"""Updates the Q table."""
|
||||
q_old = self.q[s][a]
|
||||
|
||||
# estimate optimal future value
|
||||
a_max = np.argmax(self.q[s_prime])
|
||||
q_future = self.q[s_prime][a_max]
|
||||
|
||||
# calculate new value and update table
|
||||
q_new = q_old + self.alpha * (r + self.gamma * q_future - q_old)
|
||||
self.q[s][a] = q_new
|
||||
|
||||
if self.verbose:
|
||||
print(f"{q_old=} {q_future=} {q_new=}")
|
||||
|
||||
def querysetstate(self, s):
|
||||
"""
|
||||
@summary: Update the state without updating the Q-table
|
||||
@param s: The new state
|
||||
@returns: The selected action
|
||||
"""
|
||||
a = self._get_a(s)
|
||||
if self.verbose:
|
||||
print(f"s = {s}, a = {a}")
|
||||
self.s = s
|
||||
self.a = a
|
||||
return self.a
|
||||
|
||||
def query(self, s_prime, r):
|
||||
"""
|
||||
@summary: Update the Q table and return an action
|
||||
@param s_prime: The new state
|
||||
@param r: The reward
|
||||
@returns: The selected action
|
||||
"""
|
||||
self._update_q(self.s, self.a, s_prime, r)
|
||||
self.a = self._get_a(s_prime)
|
||||
self.s = s_prime
|
||||
if self.verbose:
|
||||
print(f"s = {s_prime}, a = {self.a}, r={r}")
|
||||
# Update random action rate
|
||||
self.rar = self.rar * self.radr
|
||||
return self.a
|
||||
|
||||
def author(self):
|
||||
return 'felixm'
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
q = QLearner(verbose=True)
|
||||
print(q.querysetstate(2))
|
||||
q.query(15, 1.00)
|
||||
print(q.querysetstate(15))
|
||||
|
||||
Reference in New Issue
Block a user