intro2ai/p3_rl/textGridworldDisplay.py

325 lines
13 KiB
Python
Raw Normal View History

2021-11-27 16:16:51 +01:00
# textGridworldDisplay.py
# -----------------------
# Licensing Information: You are free to use or extend these projects for
# educational purposes provided that (1) you do not distribute or publish
# solutions, (2) you retain this notice, and (3) you provide clear
# attribution to UC Berkeley, including a link to http://ai.berkeley.edu.
#
# Attribution Information: The Pacman AI projects were developed at UC Berkeley.
# The core projects and autograders were primarily created by John DeNero
# (denero@cs.berkeley.edu) and Dan Klein (klein@cs.berkeley.edu).
# Student side autograding was added by Brad Miller, Nick Hay, and
# Pieter Abbeel (pabbeel@cs.berkeley.edu).
import util
class TextGridworldDisplay:
def __init__(self, gridworld):
self.gridworld = gridworld
def start(self):
pass
def pause(self):
pass
def displayValues(self, agent, currentState = None, message = None):
if message != None:
print message
values = util.Counter()
policy = {}
states = self.gridworld.getStates()
for state in states:
values[state] = agent.getValue(state)
policy[state] = agent.getPolicy(state)
prettyPrintValues(self.gridworld, values, policy, currentState)
def displayNullValues(self, agent, currentState = None, message = None):
if message != None: print message
prettyPrintNullValues(self.gridworld, currentState)
def displayQValues(self, agent, currentState = None, message = None):
if message != None: print message
qValues = util.Counter()
states = self.gridworld.getStates()
for state in states:
for action in self.gridworld.getPossibleActions(state):
qValues[(state, action)] = agent.getQValue(state, action)
prettyPrintQValues(self.gridworld, qValues, currentState)
def prettyPrintValues(gridWorld, values, policy=None, currentState = None):
grid = gridWorld.grid
maxLen = 11
newRows = []
for y in range(grid.height):
newRow = []
for x in range(grid.width):
state = (x, y)
value = values[state]
action = None
if policy != None and state in policy:
action = policy[state]
actions = gridWorld.getPossibleActions(state)
if action not in actions and 'exit' in actions:
action = 'exit'
valString = None
if action == 'exit':
valString = border('%.2f' % value)
else:
valString = '\n\n%.2f\n\n' % value
valString += ' '*maxLen
if grid[x][y] == 'S':
valString = '\n\nS: %.2f\n\n' % value
valString += ' '*maxLen
if grid[x][y] == '#':
valString = '\n#####\n#####\n#####\n'
valString += ' '*maxLen
pieces = [valString]
text = ("\n".join(pieces)).split('\n')
if currentState == state:
l = len(text[1])
if l == 0:
text[1] = '*'
else:
text[1] = "|" + ' ' * int((l-1)/2-1) + '*' + ' ' * int((l)/2-1) + "|"
if action == 'east':
text[2] = ' ' + text[2] + ' >'
elif action == 'west':
text[2] = '< ' + text[2] + ' '
elif action == 'north':
text[0] = ' ' * int(maxLen/2) + '^' +' ' * int(maxLen/2)
elif action == 'south':
text[4] = ' ' * int(maxLen/2) + 'v' +' ' * int(maxLen/2)
newCell = "\n".join(text)
newRow.append(newCell)
newRows.append(newRow)
numCols = grid.width
for rowNum, row in enumerate(newRows):
row.insert(0,"\n\n"+str(rowNum))
newRows.reverse()
colLabels = [str(colNum) for colNum in range(numCols)]
colLabels.insert(0,' ')
finalRows = [colLabels] + newRows
print indent(finalRows,separateRows=True,delim='|', prefix='|',postfix='|', justify='center',hasHeader=True)
def prettyPrintNullValues(gridWorld, currentState = None):
grid = gridWorld.grid
maxLen = 11
newRows = []
for y in range(grid.height):
newRow = []
for x in range(grid.width):
state = (x, y)
# value = values[state]
action = None
# if policy != None and state in policy:
# action = policy[state]
#
actions = gridWorld.getPossibleActions(state)
if action not in actions and 'exit' in actions:
action = 'exit'
valString = None
# if action == 'exit':
# valString = border('%.2f' % value)
# else:
# valString = '\n\n%.2f\n\n' % value
# valString += ' '*maxLen
if grid[x][y] == 'S':
valString = '\n\nS\n\n'
valString += ' '*maxLen
elif grid[x][y] == '#':
valString = '\n#####\n#####\n#####\n'
valString += ' '*maxLen
elif type(grid[x][y]) == float or type(grid[x][y]) == int:
valString = border('%.2f' % float(grid[x][y]))
else: valString = border(' ')
pieces = [valString]
text = ("\n".join(pieces)).split('\n')
if currentState == state:
l = len(text[1])
if l == 0:
text[1] = '*'
else:
text[1] = "|" + ' ' * int((l-1)/2-1) + '*' + ' ' * int((l)/2-1) + "|"
if action == 'east':
text[2] = ' ' + text[2] + ' >'
elif action == 'west':
text[2] = '< ' + text[2] + ' '
elif action == 'north':
text[0] = ' ' * int(maxLen/2) + '^' +' ' * int(maxLen/2)
elif action == 'south':
text[4] = ' ' * int(maxLen/2) + 'v' +' ' * int(maxLen/2)
newCell = "\n".join(text)
newRow.append(newCell)
newRows.append(newRow)
numCols = grid.width
for rowNum, row in enumerate(newRows):
row.insert(0,"\n\n"+str(rowNum))
newRows.reverse()
colLabels = [str(colNum) for colNum in range(numCols)]
colLabels.insert(0,' ')
finalRows = [colLabels] + newRows
print indent(finalRows,separateRows=True,delim='|', prefix='|',postfix='|', justify='center',hasHeader=True)
def prettyPrintQValues(gridWorld, qValues, currentState=None):
grid = gridWorld.grid
maxLen = 11
newRows = []
for y in range(grid.height):
newRow = []
for x in range(grid.width):
state = (x, y)
actions = gridWorld.getPossibleActions(state)
if actions == None or len(actions) == 0:
actions = [None]
bestQ = max([qValues[(state, action)] for action in actions])
bestActions = [action for action in actions if qValues[(state, action)] == bestQ]
# display cell
qStrings = dict([(action, "%.2f" % qValues[(state, action)]) for action in actions])
northString = ('north' in qStrings and qStrings['north']) or ' '
southString = ('south' in qStrings and qStrings['south']) or ' '
eastString = ('east' in qStrings and qStrings['east']) or ' '
westString = ('west' in qStrings and qStrings['west']) or ' '
exitString = ('exit' in qStrings and qStrings['exit']) or ' '
eastLen = len(eastString)
westLen = len(westString)
if eastLen < westLen:
eastString = ' '*(westLen-eastLen)+eastString
if westLen < eastLen:
westString = westString+' '*(eastLen-westLen)
if 'north' in bestActions:
northString = '/'+northString+'\\'
if 'south' in bestActions:
southString = '\\'+southString+'/'
if 'east' in bestActions:
eastString = ''+eastString+'>'
else:
eastString = ''+eastString+' '
if 'west' in bestActions:
westString = '<'+westString+''
else:
westString = ' '+westString+''
if 'exit' in bestActions:
exitString = '[ '+exitString+' ]'
ewString = westString + " " + eastString
if state == currentState:
ewString = westString + " * " + eastString
if state == gridWorld.getStartState():
ewString = westString + " S " + eastString
if state == currentState and state == gridWorld.getStartState():
ewString = westString + " S:* " + eastString
text = [northString, "\n"+exitString, ewString, ' '*maxLen+"\n", southString]
if grid[x][y] == '#':
text = ['', '\n#####\n#####\n#####', '']
newCell = "\n".join(text)
newRow.append(newCell)
newRows.append(newRow)
numCols = grid.width
for rowNum, row in enumerate(newRows):
row.insert(0,"\n\n\n"+str(rowNum))
newRows.reverse()
colLabels = [str(colNum) for colNum in range(numCols)]
colLabels.insert(0,' ')
finalRows = [colLabels] + newRows
print indent(finalRows,separateRows=True,delim='|',prefix='|',postfix='|', justify='center',hasHeader=True)
def border(text):
length = len(text)
pieces = ['-' * (length+2), '|'+' ' * (length+2)+'|', ' | '+text+' | ', '|'+' ' * (length+2)+'|','-' * (length+2)]
return '\n'.join(pieces)
# INDENTING CODE
# Indenting code based on a post from George Sakkis
# (http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/267662)
import cStringIO,operator
def indent(rows, hasHeader=False, headerChar='-', delim=' | ', justify='left',
separateRows=False, prefix='', postfix='', wrapfunc=lambda x:x):
"""Indents a table by column.
- rows: A sequence of sequences of items, one sequence per row.
- hasHeader: True if the first row consists of the columns' names.
- headerChar: Character to be used for the row separator line
(if hasHeader==True or separateRows==True).
- delim: The column delimiter.
- justify: Determines how are data justified in their column.
Valid values are 'left','right' and 'center'.
- separateRows: True if rows are to be separated by a line
of 'headerChar's.
- prefix: A string prepended to each printed row.
- postfix: A string appended to each printed row.
- wrapfunc: A function f(text) for wrapping text; each element in
the table is first wrapped by this function."""
# closure for breaking logical rows to physical, using wrapfunc
def rowWrapper(row):
newRows = [wrapfunc(item).split('\n') for item in row]
return [[substr or '' for substr in item] for item in map(None,*newRows)]
# break each logical row into one or more physical ones
logicalRows = [rowWrapper(row) for row in rows]
# columns of physical rows
columns = map(None,*reduce(operator.add,logicalRows))
# get the maximum of each column by the string length of its items
maxWidths = [max([len(str(item)) for item in column]) for column in columns]
rowSeparator = headerChar * (len(prefix) + len(postfix) + sum(maxWidths) + \
len(delim)*(len(maxWidths)-1))
# select the appropriate justify method
justify = {'center':str.center, 'right':str.rjust, 'left':str.ljust}[justify.lower()]
output=cStringIO.StringIO()
if separateRows: print >> output, rowSeparator
for physicalRows in logicalRows:
for row in physicalRows:
print >> output, \
prefix \
+ delim.join([justify(str(item),width) for (item,width) in zip(row,maxWidths)]) \
+ postfix
if separateRows or hasHeader: print >> output, rowSeparator; hasHeader=False
return output.getvalue()
import math
def wrap_always(text, width):
"""A simple word-wrap function that wraps text on exactly width characters.
It doesn't split the text in words."""
return '\n'.join([ text[width*i:width*(i+1)] \
for i in xrange(int(math.ceil(1.*len(text)/width))) ])
# TEST OF DISPLAY CODE
if __name__ == '__main__':
import gridworld, util
grid = gridworld.getCliffGrid3()
print grid.getStates()
policy = dict([(state,'east') for state in grid.getStates()])
values = util.Counter(dict([(state,1000.23) for state in grid.getStates()]))
prettyPrintValues(grid, values, policy, currentState = (0,0))
stateCrossActions = [[(state, action) for action in grid.getPossibleActions(state)] for state in grid.getStates()]
qStates = reduce(lambda x,y: x+y, stateCrossActions, [])
qValues = util.Counter(dict([((state, action), 10.5) for state, action in qStates]))
qValues = util.Counter(dict([((state, action), 10.5) for state, action in reduce(lambda x,y: x+y, stateCrossActions, [])]))
prettyPrintQValues(grid, qValues, currentState = (0,0))