Python神经网络与股票价格:输入该用什么?

4 投票
3 回答
3961 浏览
提问于 2025-04-18 04:40

我正在使用一个用Python写的反向传播神经网络,具体代码可以在这里找到。这个网络在处理简单的异或(XOR)例子时效果很好。

不过,我想用它做一些更复杂的事情:尝试预测股票价格。我知道神经网络不一定是最好的选择,可能也不太准确,但我还是想试试看。

我第一次尝试是获取某只股票(比如GOOG)的过去10天的收盘价。我希望用这些数据来训练神经网络,然后预测下一天的收盘价,但我意识到一个问题:我只有一个输入值,而在预测时没有任何输入可以提供。这就是我困惑的根源,涉及到输入、隐藏和输出节点的数量。

在一篇论文中提到,可以使用过去d天股票的最低价、最高价和平均价作为输入。这就有3个输入(或者4个?如果算上d的话),但为了预测下一天的价格,你却无法提供这些作为输入(也许只有d可以?)。

在用神经网络进行训练和预测时,如何处理输入数量的变化?我是不是漏掉了神经网络的一些基本概念和用法?谢谢!

3 个回答

0

你在使用滚动平均的时候,实际上是把很多信息都给省略掉了。还有其他几种方法可以把时间序列数据提供给神经网络,比如滑动窗口的方法。

比如说,你用过去3天的数据来预测第4天的情况。与其把这3天的数据平均起来,不如把每一天的数据单独放到输入节点里。然后把这个3天的数据窗口在你数据的前半部分滑动,用来训练你的模型。测试的时候,就把你想要预测的那天之前的3天价格放进去。比如:

训练集

[[day 1 price, day 2 price, day 3 price], day 4 price]
[[day 2 price, day 3 price, day 4 price], day 5 price]
[[day 3 price, day 4 price, day 5 price], day 6 price]
[[day 4 price, day 5 price, day 6 price], day 7 price] 

测试

[day 5 price, day 6 price, day 7 price] 
0

监督式机器学习是一种算法,它会处理一堆案例,这些案例包含了特征(一组数字,作为输入)和一个结果(输出)。

你需要一个训练数据集,比如说几个月的时间序列数据,这些数据的结果是你已经知道的。一旦你的网络训练好了,你就可以用最近几天的股票值(这些是已知的,因为它们已经发生过)来预测明天会发生什么,这样你就知道该买什么了。

最后,d 不是输入,它是一个常量。而输入和输出的数量是独立的(只要你有足够的输入特征就行)。理论上来说,特征越多,预测的准确性就越高,但处理时间会更长,需要更大的训练数据集,并且可能会出现过拟合的问题。

3

@anana的评论让我明白了神经网络应该如何工作。正如她所说,我可以把过去d天(在我这里是5天)股票的平均值作为输入,来尝试进行预测。

这意味着我的训练输入格式是:

[[rollingAverage, rollingMinimum, rollingMaximum], normalizedClosePrice],这是针对过去五天的数据(因为使用了滚动窗口,所以总共分析了9天的数据)。

当我想在训练后进行预测时,我只需要提供最近5天的输入节点,格式是:

[rollingAverage, rollingMinimum, rollingMaximum]

下面是所有相关的逻辑,结合了我在最初问题中提到的神经网络:

## ================================================================

def normalizePrice(price, minimum, maximum):
    return ((2*price - (maximum + minimum)) / (maximum - minimum))

def denormalizePrice(price, minimum, maximum):
    return (((price*(maximum-minimum))/2) + (maximum + minimum))/2

## ================================================================

def rollingWindow(seq, windowSize):
    it = iter(seq)
    win = [it.next() for cnt in xrange(windowSize)] # First window
    yield win
    for e in it: # Subsequent windows
        win[:-1] = win[1:]
        win[-1] = e
        yield win

def getMovingAverage(values, windowSize):
    movingAverages = []

    for w in rollingWindow(values, windowSize):
        movingAverages.append(sum(w)/len(w))

    return movingAverages

def getMinimums(values, windowSize):
    minimums = []

    for w in rollingWindow(values, windowSize):
        minimums.append(min(w))

    return minimums

def getMaximums(values, windowSize):
    maximums = []

    for w in rollingWindow(values, windowSize):
        maximums.append(max(w))

    return maximums

## ================================================================

def getTimeSeriesValues(values, window):
    movingAverages = getMovingAverage(values, window)
    minimums = getMinimums(values, window)
    maximums = getMaximums(values, window)

    returnData = []

    # build items of the form [[average, minimum, maximum], normalized price]
    for i in range(0, len(movingAverages)):
        inputNode = [movingAverages[i], minimums[i], maximums[i]]
        price = normalizePrice(values[len(movingAverages) - (i + 1)], minimums[i], maximums[i])
        outputNode = [price]
        tempItem = [inputNode, outputNode]
        returnData.append(tempItem)

    return returnData

## ================================================================

def getHistoricalData(stockSymbol):
    historicalPrices = []

    # login to API
    urllib2.urlopen("http://api.kibot.com/?action=login&user=guest&password=guest")

    # get 14 days of data from API (business days only, could be < 10)
    url = "http://api.kibot.com/?action=history&symbol=" + stockSymbol + "&interval=daily&period=14&unadjusted=1&regularsession=1"
    apiData = urllib2.urlopen(url).read().split("\n")
    for line in apiData:
        if(len(line) > 0):
            tempLine = line.split(',')
            price = float(tempLine[1])
            historicalPrices.append(price)

    return historicalPrices

## ================================================================

def getTrainingData(stockSymbol):
    historicalData = getHistoricalData(stockSymbol)

    # reverse it so we're using the most recent data first, ensure we only have 9 data points
    historicalData.reverse()
    del historicalData[9:]

    # get five 5-day moving averages, 5-day lows, and 5-day highs, associated with the closing price
    trainingData = getTimeSeriesValues(historicalData, 5)

    return trainingData

撰写回答