Implement Q learner
This commit is contained in:
@@ -1,72 +1,115 @@
|
|||||||
"""
|
"""
|
||||||
Template for implementing QLearner (c) 2015 Tucker Balch
|
Template for implementing QLearner (c) 2015 Tucker Balch
|
||||||
|
|
||||||
Copyright 2018, Georgia Institute of Technology (Georgia Tech)
|
Copyright 2018, Georgia Institute of Technology (Georgia Tech)
|
||||||
Atlanta, Georgia 30332
|
Atlanta, Georgia 30332
|
||||||
All Rights Reserved
|
All Rights Reserved
|
||||||
|
|
||||||
Template code for CS 4646/7646
|
Template code for CS 4646/7646
|
||||||
|
|
||||||
Georgia Tech asserts copyright ownership of this template and all derivative
|
Georgia Tech asserts copyright ownership of this template and all derivative
|
||||||
works, including solutions to the projects assigned in this course. Students
|
works, including solutions to the projects assigned in this course. Students
|
||||||
and other users of this template code are advised not to share it with others
|
and other users of this template code are advised not to share it with others
|
||||||
or to make it available on publicly viewable websites including repositories
|
or to make it available on publicly viewable websites including repositories
|
||||||
such as github and gitlab. This copyright statement should not be removed
|
such as github and gitlab. This copyright statement should not be removed
|
||||||
or edited.
|
or edited.
|
||||||
|
|
||||||
We do grant permission to share solutions privately with non-students such
|
We do grant permission to share solutions privately with non-students such
|
||||||
as potential employers. However, sharing with other current or future
|
as potential employers. However, sharing with other current or future
|
||||||
students of CS 7646 is prohibited and subject to being investigated as a
|
students of CS 7646 is prohibited and subject to being investigated as a
|
||||||
GT honor code violation.
|
GT honor code violation.
|
||||||
|
|
||||||
-----do not edit anything above this line---
|
-----do not edit anything above this line---
|
||||||
|
"""
|
||||||
Student Name: Tucker Balch (replace with your name)
|
|
||||||
GT User ID: tb34 (replace with your User ID)
|
import numpy as np
|
||||||
GT ID: 900897987 (replace with your GT ID)
|
import random as rand
|
||||||
"""
|
|
||||||
|
|
||||||
import numpy as np
|
class QLearner(object):
|
||||||
import random as rand
|
|
||||||
|
def __init__(self,
|
||||||
class QLearner(object):
|
num_states=100,
|
||||||
|
num_actions=4,
|
||||||
def __init__(self, \
|
alpha=0.2,
|
||||||
num_states=100, \
|
gamma=0.9,
|
||||||
num_actions = 4, \
|
rar=0.5,
|
||||||
alpha = 0.2, \
|
radr=0.99,
|
||||||
gamma = 0.9, \
|
dyna=0,
|
||||||
rar = 0.5, \
|
verbose=False):
|
||||||
radr = 0.99, \
|
|
||||||
dyna = 0, \
|
self.verbose = verbose
|
||||||
verbose = False):
|
self.num_actions = num_actions
|
||||||
|
self.num_states = num_states
|
||||||
self.verbose = verbose
|
self.s = 0
|
||||||
self.num_actions = num_actions
|
self.a = 0
|
||||||
self.s = 0
|
self.alpha = alpha
|
||||||
self.a = 0
|
self.gamma = gamma
|
||||||
|
self.rar = rar
|
||||||
def querysetstate(self, s):
|
self.radr = radr
|
||||||
"""
|
self.dyna = dyna
|
||||||
@summary: Update the state without updating the Q-table
|
|
||||||
@param s: The new state
|
# self.q = np.random.random((num_states, num_actions))
|
||||||
@returns: The selected action
|
self.q = np.zeros((num_states, num_actions))
|
||||||
"""
|
|
||||||
self.s = s
|
def _get_a(self, s):
|
||||||
action = rand.randint(0, self.num_actions-1)
|
"""Get best action for state. Considers rar."""
|
||||||
if self.verbose: print(f"s = {s}, a = {action}")
|
if rand.random() < self.rar:
|
||||||
return action
|
a = rand.randint(0, self.num_actions - 1)
|
||||||
|
else:
|
||||||
def query(self,s_prime,r):
|
a = np.argmax(self.q[s])
|
||||||
"""
|
return a
|
||||||
@summary: Update the Q table and return an action
|
|
||||||
@param s_prime: The new state
|
def _update_q(self, s, a, s_prime, r):
|
||||||
@param r: The reward
|
"""Updates the Q table."""
|
||||||
@returns: The selected action
|
q_old = self.q[s][a]
|
||||||
"""
|
|
||||||
action = rand.randint(0, self.num_actions-1)
|
# estimate optimal future value
|
||||||
if self.verbose: print(f"s = {s_prime}, a = {action}, r={r}")
|
a_max = np.argmax(self.q[s_prime])
|
||||||
return action
|
q_future = self.q[s_prime][a_max]
|
||||||
|
|
||||||
if __name__=="__main__":
|
# calculate new value and update table
|
||||||
print("Remember Q from Star Trek? Well, this isn't him")
|
q_new = q_old + self.alpha * (r + self.gamma * q_future - q_old)
|
||||||
|
self.q[s][a] = q_new
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"{q_old=} {q_future=} {q_new=}")
|
||||||
|
|
||||||
|
def querysetstate(self, s):
|
||||||
|
"""
|
||||||
|
@summary: Update the state without updating the Q-table
|
||||||
|
@param s: The new state
|
||||||
|
@returns: The selected action
|
||||||
|
"""
|
||||||
|
a = self._get_a(s)
|
||||||
|
if self.verbose:
|
||||||
|
print(f"s = {s}, a = {a}")
|
||||||
|
self.s = s
|
||||||
|
self.a = a
|
||||||
|
return self.a
|
||||||
|
|
||||||
|
def query(self, s_prime, r):
|
||||||
|
"""
|
||||||
|
@summary: Update the Q table and return an action
|
||||||
|
@param s_prime: The new state
|
||||||
|
@param r: The reward
|
||||||
|
@returns: The selected action
|
||||||
|
"""
|
||||||
|
self._update_q(self.s, self.a, s_prime, r)
|
||||||
|
self.a = self._get_a(s_prime)
|
||||||
|
self.s = s_prime
|
||||||
|
if self.verbose:
|
||||||
|
print(f"s = {s_prime}, a = {self.a}, r={r}")
|
||||||
|
# Update random action rate
|
||||||
|
self.rar = self.rar * self.radr
|
||||||
|
return self.a
|
||||||
|
|
||||||
|
def author(self):
|
||||||
|
return 'felixm'
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
q = QLearner(verbose=True)
|
||||||
|
print(q.querysetstate(2))
|
||||||
|
q.query(15, 1.00)
|
||||||
|
print(q.querysetstate(15))
|
||||||
|
|||||||
@@ -1,189 +1,189 @@
|
|||||||
"""
|
"""
|
||||||
Test a Q Learner in a navigation problem. (c) 2015 Tucker Balch
|
Test a Q Learner in a navigation problem. (c) 2015 Tucker Balch
|
||||||
2016-10-20 Added "quicksand" and uncertain actions.
|
2016-10-20 Added "quicksand" and uncertain actions.
|
||||||
|
|
||||||
Copyright 2018, Georgia Institute of Technology (Georgia Tech)
|
Copyright 2018, Georgia Institute of Technology (Georgia Tech)
|
||||||
Atlanta, Georgia 30332
|
Atlanta, Georgia 30332
|
||||||
All Rights Reserved
|
All Rights Reserved
|
||||||
|
|
||||||
Template code for CS 4646/7646
|
Template code for CS 4646/7646
|
||||||
|
|
||||||
Georgia Tech asserts copyright ownership of this template and all derivative
|
Georgia Tech asserts copyright ownership of this template and all derivative
|
||||||
works, including solutions to the projects assigned in this course. Students
|
works, including solutions to the projects assigned in this course. Students
|
||||||
and other users of this template code are advised not to share it with others
|
and other users of this template code are advised not to share it with others
|
||||||
or to make it available on publicly viewable websites including repositories
|
or to make it available on publicly viewable websites including repositories
|
||||||
such as github and gitlab. This copyright statement should not be removed
|
such as github and gitlab. This copyright statement should not be removed
|
||||||
or edited.
|
or edited.
|
||||||
|
|
||||||
We do grant permission to share solutions privately with non-students such
|
We do grant permission to share solutions privately with non-students such
|
||||||
as potential employers. However, sharing with other current or future
|
as potential employers. However, sharing with other current or future
|
||||||
students of CS 7646 is prohibited and subject to being investigated as a
|
students of CS 7646 is prohibited and subject to being investigated as a
|
||||||
GT honor code violation.
|
GT honor code violation.
|
||||||
|
|
||||||
-----do not edit anything above this line---
|
-----do not edit anything above this line---
|
||||||
|
|
||||||
Student Name: Tucker Balch (replace with your name)
|
Student Name: Tucker Balch (replace with your name)
|
||||||
GT User ID: tb34 (replace with your User ID)
|
GT User ID: tb34 (replace with your User ID)
|
||||||
GT ID: 900897987 (replace with your GT ID)
|
GT ID: 900897987 (replace with your GT ID)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random as rand
|
import random as rand
|
||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
import QLearner as ql
|
import QLearner as ql
|
||||||
|
|
||||||
# print out the map
|
# print out the map
|
||||||
def printmap(data):
|
def printmap(data):
|
||||||
print("--------------------")
|
print("--------------------")
|
||||||
for row in range(0, data.shape[0]):
|
for row in range(0, data.shape[0]):
|
||||||
for col in range(0, data.shape[1]):
|
for col in range(0, data.shape[1]):
|
||||||
if data[row,col] == 0: # Empty space
|
if data[row,col] == 0: # Empty space
|
||||||
print(" ", end=' ')
|
print(" ", end=' ')
|
||||||
if data[row,col] == 1: # Obstacle
|
if data[row,col] == 1: # Obstacle
|
||||||
print("O", end=' ')
|
print("O", end=' ')
|
||||||
if data[row,col] == 2: # El roboto
|
if data[row,col] == 2: # El roboto
|
||||||
print("*", end=' ')
|
print("*", end=' ')
|
||||||
if data[row,col] == 3: # Goal
|
if data[row,col] == 3: # Goal
|
||||||
print("X", end=' ')
|
print("X", end=' ')
|
||||||
if data[row,col] == 4: # Trail
|
if data[row,col] == 4: # Trail
|
||||||
print(".", end=' ')
|
print(".", end=' ')
|
||||||
if data[row,col] == 5: # Quick sand
|
if data[row,col] == 5: # Quick sand
|
||||||
print("~", end=' ')
|
print("~", end=' ')
|
||||||
if data[row,col] == 6: # Stepped in quicksand
|
if data[row,col] == 6: # Stepped in quicksand
|
||||||
print("@", end=' ')
|
print("@", end=' ')
|
||||||
print()
|
print()
|
||||||
print("--------------------")
|
print("--------------------")
|
||||||
|
|
||||||
# find where the robot is in the map
|
# find where the robot is in the map
|
||||||
def getrobotpos(data):
|
def getrobotpos(data):
|
||||||
R = -999
|
R = -999
|
||||||
C = -999
|
C = -999
|
||||||
for row in range(0, data.shape[0]):
|
for row in range(0, data.shape[0]):
|
||||||
for col in range(0, data.shape[1]):
|
for col in range(0, data.shape[1]):
|
||||||
if data[row,col] == 2:
|
if data[row,col] == 2:
|
||||||
C = col
|
C = col
|
||||||
R = row
|
R = row
|
||||||
if (R+C)<0:
|
if (R+C)<0:
|
||||||
print("warning: start location not defined")
|
print("warning: start location not defined")
|
||||||
return R, C
|
return R, C
|
||||||
|
|
||||||
# find where the goal is in the map
|
# find where the goal is in the map
|
||||||
def getgoalpos(data):
|
def getgoalpos(data):
|
||||||
R = -999
|
R = -999
|
||||||
C = -999
|
C = -999
|
||||||
for row in range(0, data.shape[0]):
|
for row in range(0, data.shape[0]):
|
||||||
for col in range(0, data.shape[1]):
|
for col in range(0, data.shape[1]):
|
||||||
if data[row,col] == 3:
|
if data[row,col] == 3:
|
||||||
C = col
|
C = col
|
||||||
R = row
|
R = row
|
||||||
if (R+C)<0:
|
if (R+C)<0:
|
||||||
print("warning: goal location not defined")
|
print("warning: goal location not defined")
|
||||||
return (R, C)
|
return (R, C)
|
||||||
|
|
||||||
# move the robot and report reward
|
# move the robot and report reward
|
||||||
def movebot(data,oldpos,a):
|
def movebot(data,oldpos,a):
|
||||||
testr, testc = oldpos
|
testr, testc = oldpos
|
||||||
|
|
||||||
randomrate = 0.20 # how often do we move randomly
|
randomrate = 0.20 # how often do we move randomly
|
||||||
quicksandreward = -100 # penalty for stepping on quicksand
|
quicksandreward = -100 # penalty for stepping on quicksand
|
||||||
|
|
||||||
# decide if we're going to ignore the action and
|
# decide if we're going to ignore the action and
|
||||||
# choose a random one instead
|
# choose a random one instead
|
||||||
if rand.uniform(0.0, 1.0) <= randomrate: # going rogue
|
if rand.uniform(0.0, 1.0) <= randomrate: # going rogue
|
||||||
a = rand.randint(0,3) # choose the random direction
|
a = rand.randint(0,3) # choose the random direction
|
||||||
|
|
||||||
# update the test location
|
# update the test location
|
||||||
if a == 0: #north
|
if a == 0: #north
|
||||||
testr = testr - 1
|
testr = testr - 1
|
||||||
elif a == 1: #east
|
elif a == 1: #east
|
||||||
testc = testc + 1
|
testc = testc + 1
|
||||||
elif a == 2: #south
|
elif a == 2: #south
|
||||||
testr = testr + 1
|
testr = testr + 1
|
||||||
elif a == 3: #west
|
elif a == 3: #west
|
||||||
testc = testc - 1
|
testc = testc - 1
|
||||||
|
|
||||||
reward = -1 # default reward is negative one
|
reward = -1 # default reward is negative one
|
||||||
# see if it is legal. if not, revert
|
# see if it is legal. if not, revert
|
||||||
if testr < 0: # off the map
|
if testr < 0: # off the map
|
||||||
testr, testc = oldpos
|
testr, testc = oldpos
|
||||||
elif testr >= data.shape[0]: # off the map
|
elif testr >= data.shape[0]: # off the map
|
||||||
testr, testc = oldpos
|
testr, testc = oldpos
|
||||||
elif testc < 0: # off the map
|
elif testc < 0: # off the map
|
||||||
testr, testc = oldpos
|
testr, testc = oldpos
|
||||||
elif testc >= data.shape[1]: # off the map
|
elif testc >= data.shape[1]: # off the map
|
||||||
testr, testc = oldpos
|
testr, testc = oldpos
|
||||||
elif data[testr, testc] == 1: # it is an obstacle
|
elif data[testr, testc] == 1: # it is an obstacle
|
||||||
testr, testc = oldpos
|
testr, testc = oldpos
|
||||||
elif data[testr, testc] == 5: # it is quicksand
|
elif data[testr, testc] == 5: # it is quicksand
|
||||||
reward = quicksandreward
|
reward = quicksandreward
|
||||||
data[testr, testc] = 6 # mark the event
|
data[testr, testc] = 6 # mark the event
|
||||||
elif data[testr, testc] == 6: # it is still quicksand
|
elif data[testr, testc] == 6: # it is still quicksand
|
||||||
reward = quicksandreward
|
reward = quicksandreward
|
||||||
data[testr, testc] = 6 # mark the event
|
data[testr, testc] = 6 # mark the event
|
||||||
elif data[testr, testc] == 3: # it is the goal
|
elif data[testr, testc] == 3: # it is the goal
|
||||||
reward = 1 # for reaching the goal
|
reward = 1 # for reaching the goal
|
||||||
|
|
||||||
return (testr, testc), reward #return the new, legal location
|
return (testr, testc), reward #return the new, legal location
|
||||||
|
|
||||||
# convert the location to a single integer
|
# convert the location to a single integer
|
||||||
def discretize(pos):
|
def discretize(pos):
|
||||||
return pos[0]*10 + pos[1]
|
return pos[0]*10 + pos[1]
|
||||||
|
|
||||||
def test(map, epochs, learner, verbose):
|
def test(map, epochs, learner, verbose):
|
||||||
# each epoch involves one trip to the goal
|
# each epoch involves one trip to the goal
|
||||||
startpos = getrobotpos(map) #find where the robot starts
|
startpos = getrobotpos(map) #find where the robot starts
|
||||||
goalpos = getgoalpos(map) #find where the goal is
|
goalpos = getgoalpos(map) #find where the goal is
|
||||||
scores = np.zeros((epochs,1))
|
scores = np.zeros((epochs,1))
|
||||||
for epoch in range(1,epochs+1):
|
for epoch in range(1,epochs+1):
|
||||||
total_reward = 0
|
total_reward = 0
|
||||||
data = map.copy()
|
data = map.copy()
|
||||||
robopos = startpos
|
robopos = startpos
|
||||||
state = discretize(robopos) #convert the location to a state
|
state = discretize(robopos) #convert the location to a state
|
||||||
action = learner.querysetstate(state) #set the state and get first action
|
action = learner.querysetstate(state) #set the state and get first action
|
||||||
count = 0
|
count = 0
|
||||||
while (robopos != goalpos) & (count<10000):
|
while (robopos != goalpos) & (count<10000):
|
||||||
|
|
||||||
#move to new location according to action and then get a new action
|
#move to new location according to action and then get a new action
|
||||||
newpos, stepreward = movebot(data,robopos,action)
|
newpos, stepreward = movebot(data,robopos,action)
|
||||||
if newpos == goalpos:
|
if newpos == goalpos:
|
||||||
r = 1 # reward for reaching the goal
|
r = 1 # reward for reaching the goal
|
||||||
else:
|
else:
|
||||||
r = stepreward # negative reward for not being at the goal
|
r = stepreward # negative reward for not being at the goal
|
||||||
state = discretize(newpos)
|
state = discretize(newpos)
|
||||||
action = learner.query(state,r)
|
action = learner.query(state,r)
|
||||||
|
|
||||||
if data[robopos] != 6:
|
if data[robopos] != 6:
|
||||||
data[robopos] = 4 # mark where we've been for map printing
|
data[robopos] = 4 # mark where we've been for map printing
|
||||||
if data[newpos] != 6:
|
if data[newpos] != 6:
|
||||||
data[newpos] = 2 # move to new location
|
data[newpos] = 2 # move to new location
|
||||||
robopos = newpos # update the location
|
robopos = newpos # update the location
|
||||||
#if verbose: time.sleep(1)
|
#if verbose: time.sleep(1)
|
||||||
total_reward += stepreward
|
total_reward += stepreward
|
||||||
count = count + 1
|
count = count + 1
|
||||||
if count == 100000:
|
if count == 100000:
|
||||||
print("timeout")
|
print("timeout")
|
||||||
if verbose: printmap(data)
|
if verbose: printmap(data)
|
||||||
if verbose: print(f"{epoch}, {total_reward}")
|
if verbose: print(f"{epoch}, {total_reward}")
|
||||||
scores[epoch-1,0] = total_reward
|
scores[epoch-1,0] = total_reward
|
||||||
return np.median(scores)
|
return np.median(scores)
|
||||||
|
|
||||||
# run the code to test a learner
|
# run the code to test a learner
|
||||||
def test_code():
|
def test_code():
|
||||||
|
|
||||||
verbose = True # print lots of debug stuff if True
|
verbose = True # print lots of debug stuff if True
|
||||||
|
|
||||||
# read in the map
|
# read in the map
|
||||||
filename = 'testworlds/world01.csv'
|
filename = 'testworlds/world01.csv'
|
||||||
inf = open(filename)
|
inf = open(filename)
|
||||||
data = np.array([list(map(float,s.strip().split(','))) for s in inf.readlines()])
|
data = np.array([list(map(float,s.strip().split(','))) for s in inf.readlines()])
|
||||||
originalmap = data.copy() #make a copy so we can revert to the original map later
|
originalmap = data.copy() #make a copy so we can revert to the original map later
|
||||||
|
|
||||||
if verbose: printmap(data)
|
if verbose: printmap(data)
|
||||||
|
|
||||||
rand.seed(5)
|
rand.seed(5)
|
||||||
|
|
||||||
######## run non-dyna test ########
|
######## run non-dyna test ########
|
||||||
learner = ql.QLearner(num_states=100,\
|
learner = ql.QLearner(num_states=100,\
|
||||||
num_actions = 4, \
|
num_actions = 4, \
|
||||||
alpha = 0.2, \
|
alpha = 0.2, \
|
||||||
@@ -191,14 +191,14 @@ def test_code():
|
|||||||
rar = 0.98, \
|
rar = 0.98, \
|
||||||
radr = 0.999, \
|
radr = 0.999, \
|
||||||
dyna = 0, \
|
dyna = 0, \
|
||||||
verbose=False) #initialize the learner
|
verbose=False) #initialize the learner
|
||||||
epochs = 500
|
epochs = 500
|
||||||
total_reward = test(data, epochs, learner, verbose)
|
total_reward = test(data, epochs, learner, verbose)
|
||||||
print(f"{epochs}, median total_reward {total_reward}")
|
print(f"{epochs}, median total_reward {total_reward}")
|
||||||
print()
|
print()
|
||||||
non_dyna_score = total_reward
|
non_dyna_score = total_reward
|
||||||
|
|
||||||
######## run dyna test ########
|
######## run dyna test ########
|
||||||
learner = ql.QLearner(num_states=100,\
|
learner = ql.QLearner(num_states=100,\
|
||||||
num_actions = 4, \
|
num_actions = 4, \
|
||||||
alpha = 0.2, \
|
alpha = 0.2, \
|
||||||
@@ -206,18 +206,18 @@ def test_code():
|
|||||||
rar = 0.5, \
|
rar = 0.5, \
|
||||||
radr = 0.99, \
|
radr = 0.99, \
|
||||||
dyna = 200, \
|
dyna = 200, \
|
||||||
verbose=False) #initialize the learner
|
verbose=False) #initialize the learner
|
||||||
epochs = 50
|
epochs = 50
|
||||||
data = originalmap.copy()
|
data = originalmap.copy()
|
||||||
total_reward = test(data, epochs, learner, verbose)
|
total_reward = test(data, epochs, learner, verbose)
|
||||||
print(f"{epochs}, median total_reward {total_reward}")
|
print(f"{epochs}, median total_reward {total_reward}")
|
||||||
dyna_score = total_reward
|
dyna_score = total_reward
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print()
|
print()
|
||||||
print(f"results for {filename}")
|
print(f"results for {filename}")
|
||||||
print(f"non_dyna_score: {non_dyna_score}")
|
print(f"non_dyna_score: {non_dyna_score}")
|
||||||
print(f"dyna_score : {dyna_score}")
|
print(f"dyna_score : {dyna_score}")
|
||||||
|
|
||||||
if __name__=="__main__":
|
if __name__=="__main__":
|
||||||
test_code()
|
test_code()
|
||||||
|
|||||||
Reference in New Issue
Block a user