1. 程式人生 > >有趣的應用 | 使用RNN預測股票價格系列一

有趣的應用 | 使用RNN預測股票價格系列一

import os
import pandas as pd
import pprint
import tensorflow as tf
import tensorflow.contrib.slim as slim
 from data_model import StockDataSet
from model_rnn import LstmRNN flags = tf.app.flags flags.DEFINE_integer("stock_count", 100, "Stock count [100]") flags.DEFINE_integer("input_size", 5, "Input size [5]") flags.DEFINE_integer("num_steps", 30, "Num of steps [30]") flags.DEFINE_integer("num_layers", 1, "Num of layer [1]") flags.DEFINE_integer("lstm_size", 128, "Size of one LSTM cell [128]") flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]") flags.DEFINE_float("keep_prob", 0.8, "Keep probability of dropout layer. [0.8]") flags.DEFINE_float("init_learning_rate", 0.001, "Initial learning rate at early stage. [0.001]")
flags.DEFINE_float("learning_rate_decay", 0.99, "Decay rate of learning rate. [0.99]") flags.DEFINE_integer("init_epoch", 5, "Num. of epoches considered as early stage. [5]")
flags.DEFINE_integer("max_epoch", 50, "Total training epoches. [50]") flags.DEFINE_integer("embed_size", None, "If provided, use embedding vector of this size. [None]")
flags.DEFINE_string("stock_symbol", None, "Target stock symbol [None]") flags.DEFINE_integer("sample_size", 4, "Number of stocks to plot during training. [4]") flags.DEFINE_boolean("train", False, "True for training, False for testing [False]") FLAGS = flags.FLAGS pp = pprint.PrettyPrinter()
 if not os.path.exists("logs"):  
os.mkdir("logs")
def show_all_variables():  
model_vars = tf.trainable_variables()  
slim.model_analyzer.analyze_vars(model_vars, print_info=True)
 def load_sp500(input_size, num_steps, k=None, target_symbol=None, test_ratio=0.05):  
if target_symbol is not None:      
return [          
StockDataSet(              
target_symbol,              
input_size=input_size,              
num_steps=num_steps,              
test_ratio=test_ratio)       ]  
# Load metadata of s & p 500 stocks  
info = pd.read_csv("data/constituents-financials.csv")  
info = info.rename(columns={col: col.lower().replace(' ', '_') for col in info.columns})   info['file_exists'] = info['symbol'].map(lambda x: os.path.exists("data/{}.csv".format(x)))   print info['file_exists'].value_counts().to_dict()  
info = info[info['file_exists'] == True].reset_index(drop=True)  
 info = info.sort('market_cap', ascending=False).reset_index(drop=True)  
if k is not None:      
info = info.head(k)  
print "Head of S&P 500 info:\n", info.head()  
# Generate embedding meta file  
info[['symbol', 'sector']].to_csv(os.path.join("logs/metadata.tsv"), sep='\t', index=False)   return [       StockDataSet(row['symbol'],                    
input_size=input_size,                    
num_steps=num_steps,                    
test_ratio=0.05)      
for _, row in info.iterrows()] def main(_):   pp.pprint(flags.FLAGS.__flags)  
# gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)   run_config = tf.ConfigProto()  
run_config.gpu_options.allow_growth = True  
with tf.Session(config=run_config) as sess:      
rnn_model = LstmRNN(