Python神经网络与股票价格:输入该用什么?
我正在使用一个用Python写的反向传播神经网络,具体代码可以在这里找到。这个网络在处理简单的异或(XOR)例子时效果很好。
不过,我想用它做一些更复杂的事情:尝试预测股票价格。我知道神经网络不一定是最好的选择,可能也不太准确,但我还是想试试看。
我第一次尝试是获取某只股票(比如GOOG)的过去10天的收盘价。我希望用这些数据来训练神经网络,然后预测下一天的收盘价,但我意识到一个问题:我只有一个输入值,而在预测时没有任何输入可以提供。这就是我困惑的根源,涉及到输入、隐藏和输出节点的数量。
在一篇论文中提到,可以使用过去d
天股票的最低价、最高价和平均价作为输入。这就有3个输入(或者4个?如果算上d
的话),但为了预测下一天的价格,你却无法提供这些作为输入(也许只有d
可以?)。
在用神经网络进行训练和预测时,如何处理输入数量的变化?我是不是漏掉了神经网络的一些基本概念和用法?谢谢!
3 个回答
你在使用滚动平均的时候,实际上是把很多信息都给省略掉了。还有其他几种方法可以把时间序列数据提供给神经网络,比如滑动窗口的方法。
比如说,你用过去3天的数据来预测第4天的情况。与其把这3天的数据平均起来,不如把每一天的数据单独放到输入节点里。然后把这个3天的数据窗口在你数据的前半部分滑动,用来训练你的模型。测试的时候,就把你想要预测的那天之前的3天价格放进去。比如:
训练集
[[day 1 price, day 2 price, day 3 price], day 4 price]
[[day 2 price, day 3 price, day 4 price], day 5 price]
[[day 3 price, day 4 price, day 5 price], day 6 price]
[[day 4 price, day 5 price, day 6 price], day 7 price]
测试
[day 5 price, day 6 price, day 7 price]
监督式机器学习是一种算法,它会处理一堆案例,这些案例包含了特征(一组数字,作为输入)和一个结果(输出)。
你需要一个训练数据集,比如说几个月的时间序列数据,这些数据的结果是你已经知道的。一旦你的网络训练好了,你就可以用最近几天的股票值(这些是已知的,因为它们已经发生过)来预测明天会发生什么,这样你就知道该买什么了。
最后,d
不是输入,它是一个常量。而输入和输出的数量是独立的(只要你有足够的输入特征就行)。理论上来说,特征越多,预测的准确性就越高,但处理时间会更长,需要更大的训练数据集,并且可能会出现过拟合的问题。
@anana的评论让我明白了神经网络应该如何工作。正如她所说,我可以把过去d天(在我这里是5天)股票的平均值作为输入,来尝试进行预测。
这意味着我的训练输入格式是:
[[rollingAverage, rollingMinimum, rollingMaximum], normalizedClosePrice]
,这是针对过去五天的数据(因为使用了滚动窗口,所以总共分析了9天的数据)。
当我想在训练后进行预测时,我只需要提供最近5天的输入节点,格式是:
[rollingAverage, rollingMinimum, rollingMaximum]
。
下面是所有相关的逻辑,结合了我在最初问题中提到的神经网络:
## ================================================================
def normalizePrice(price, minimum, maximum):
return ((2*price - (maximum + minimum)) / (maximum - minimum))
def denormalizePrice(price, minimum, maximum):
return (((price*(maximum-minimum))/2) + (maximum + minimum))/2
## ================================================================
def rollingWindow(seq, windowSize):
it = iter(seq)
win = [it.next() for cnt in xrange(windowSize)] # First window
yield win
for e in it: # Subsequent windows
win[:-1] = win[1:]
win[-1] = e
yield win
def getMovingAverage(values, windowSize):
movingAverages = []
for w in rollingWindow(values, windowSize):
movingAverages.append(sum(w)/len(w))
return movingAverages
def getMinimums(values, windowSize):
minimums = []
for w in rollingWindow(values, windowSize):
minimums.append(min(w))
return minimums
def getMaximums(values, windowSize):
maximums = []
for w in rollingWindow(values, windowSize):
maximums.append(max(w))
return maximums
## ================================================================
def getTimeSeriesValues(values, window):
movingAverages = getMovingAverage(values, window)
minimums = getMinimums(values, window)
maximums = getMaximums(values, window)
returnData = []
# build items of the form [[average, minimum, maximum], normalized price]
for i in range(0, len(movingAverages)):
inputNode = [movingAverages[i], minimums[i], maximums[i]]
price = normalizePrice(values[len(movingAverages) - (i + 1)], minimums[i], maximums[i])
outputNode = [price]
tempItem = [inputNode, outputNode]
returnData.append(tempItem)
return returnData
## ================================================================
def getHistoricalData(stockSymbol):
historicalPrices = []
# login to API
urllib2.urlopen("http://api.kibot.com/?action=login&user=guest&password=guest")
# get 14 days of data from API (business days only, could be < 10)
url = "http://api.kibot.com/?action=history&symbol=" + stockSymbol + "&interval=daily&period=14&unadjusted=1®ularsession=1"
apiData = urllib2.urlopen(url).read().split("\n")
for line in apiData:
if(len(line) > 0):
tempLine = line.split(',')
price = float(tempLine[1])
historicalPrices.append(price)
return historicalPrices
## ================================================================
def getTrainingData(stockSymbol):
historicalData = getHistoricalData(stockSymbol)
# reverse it so we're using the most recent data first, ensure we only have 9 data points
historicalData.reverse()
del historicalData[9:]
# get five 5-day moving averages, 5-day lows, and 5-day highs, associated with the closing price
trainingData = getTimeSeriesValues(historicalData, 5)
return trainingData