为什么这两个数学函数返回的结果不同?

0 投票
2 回答
779 浏览
提问于 2025-04-15 16:17

我正在尝试使用一种叫做“花式索引”的方法来代替循环,以加快Numpy中的一个函数。根据我的理解,我已经正确地实现了花式索引的版本。但问题是,这两个函数(循环和花式索引的)返回的结果不一样。我不太明白为什么。值得一提的是,如果使用一个较小的数组(比如20 x 20 x 20),这两个函数的结果是相同的。

下面我提供了重现这个错误所需的所有内容。如果这两个函数的结果相同,那么这一行代码find_maxdiff(data) - find_maxdiff_fancy(data)应该返回一个全是零的数组。

from numpy import *

def rms(data, axis=0):
    return sqrt(mean(data ** 2, axis))

def find_maxdiff(data):
    samples, channels, epochs = shape(data)
    window_size = 50
    maxdiff = zeros(epochs)
    for epoch in xrange(epochs):
        signal = rms(data[:, :, epoch], axis=1)
        for t in xrange(window_size, alen(signal) - window_size):
            amp_a = mean(signal[t-window_size:t], axis=0)
            amp_b = mean(signal[t:t+window_size], axis=0)
            the_diff = abs(amp_b - amp_a)
            if the_diff > maxdiff[epoch]: 
                maxdiff[epoch] = the_diff

    return maxdiff

def find_maxdiff_fancy(data):
    samples, channels, epochs = shape(data)
    window_size = 50
    maxdiff = zeros(epochs)
    signal = rms(data, axis=1)
    for t in xrange(window_size, alen(signal) - window_size):
        amp_a = mean(signal[t-window_size:t], axis=0)
        amp_b = mean(signal[t:t+window_size], axis=0)
        the_diff = abs(amp_b - amp_a)
        maxdiff[the_diff > maxdiff] = the_diff

    return maxdiff

data = random.random((600, 20, 100))
find_maxdiff(data) - find_maxdiff_fancy(data)

data = random.random((20, 20, 20))
find_maxdiff(data) - find_maxdiff_fancy(data)

2 个回答

0

首先,如果我理解得没错,你的信号现在是二维的。所以我觉得明确地给它加上索引会更清楚一些,比如可以写成 amp_a = mean(signal[t-window_size:t,:], axis=0)。同样的,关于 alen(signal),在这两种情况下它应该只是样本数量,所以用这个会更清晰。

t 循环中,只要你真的在做事情,就会出错。当 samples < window_length 时,比如在 20x20x20 的例子中,这个循环根本不会执行。一旦这个循环执行超过一次(也就是说 samples > 2 * window_length + 1),错误就会出现。不过我不太确定为什么会这样——在我看来它们是等价的。

3

问题出在这一行:

maxdiff[the_diff > maxdiff] = the_diff

左边只选择了maxdiff中的一些元素,而右边却包含了the_diff的所有元素。应该这样做:

replaceElements = the_diff > maxdiff
maxdiff[replaceElements] = the_diff[replaceElements]

或者简单点:

maxdiff = maximum(maxdiff, the_diff)

至于为什么20x20x20的大小似乎有效:这是因为你的窗口大小太大了,所以什么都没有执行。

撰写回答