"""MLT: Utility code. Copyright 2017, Georgia Tech Research Corporation Atlanta, Georgia 30332-0415 All Rights Reserved """ import os import pandas as pd def symbol_to_path(symbol, base_dir=None): """Return CSV file path given ticker symbol.""" if base_dir is None: base_dir = os.environ.get("MARKET_DATA_DIR", '../data/') return os.path.join(base_dir, "{}.csv".format(str(symbol))) def get_data(symbols, dates, addSPY=True, colname='Adj Close', datecol='Date'): """Read stock data (adjusted close) for given symbols from CSV files.""" df = pd.DataFrame(index=dates) if addSPY and 'SPY' not in symbols: # add SPY for reference, if absent # handles the case where symbols is np array of 'object' symbols = ['SPY'] + list(symbols) for symbol in symbols: if 'BTC' in symbol or 'ETH' in symbol: colname = 'close' datecol = 'time' elif symbol == 'SPY': colname = 'close' datecol = 'time' else: colname = 'Adj Close' datecol = 'Date' df_temp = pd.read_csv(symbol_to_path(symbol), index_col=datecol, parse_dates=True, usecols=[datecol, colname], na_values=['nan']) df_temp = df_temp.rename(columns={colname: symbol}) if datecol == 'time': df_temp['date'] = pd.to_datetime(df_temp.index, unit='s') df_temp['date'] = pd.DatetimeIndex(df_temp['date']).normalize() df_temp.set_index('date', drop=True, inplace=True) df = df.join(df_temp) if symbol == 'SPY': # drop dates SPY did not trade pass # df = df.dropna(subset=["SPY"]) return df def plot_data(df, title="Stock prices", xlabel="Date", ylabel="Price"): import matplotlib.pyplot as plt """Plot stock prices with a custom title and meaningful axis labels.""" ax = df.plot(title=title, fontsize=12) ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) plt.show() def get_orders_data_file(basefilename): return open(os.path.join(os.environ.get("ORDERS_DATA_DIR",'orders/'),basefilename)) def get_learner_data_file(basefilename): return open(os.path.join(os.environ.get("LEARNER_DATA_DIR",'Data/'),basefilename),'r') def get_robot_world_file(basefilename): return open(os.path.join(os.environ.get("ROBOT_WORLDS_DIR",'testworlds/'),basefilename))