如何加速 Matplotlib?

17 投票
3 回答
8288 浏览
提问于 2025-04-16 11:41

我在这里看到,matplotlib在处理大数据集方面表现不错。但是我正在写一个数据处理应用,把matplotlib的图嵌入到wx中,结果发现matplotlib在处理大量数据时表现得非常糟糕,既慢又占内存。有没有人知道除了降低输入数据的采样率之外,还有什么方法可以加快matplotlib的速度(减少内存占用)?

为了说明matplotlib在内存方面的糟糕表现,看看这段代码:

import pylab
import numpy
a = numpy.arange(int(1e7)) # only 10,000,000 32-bit integers (~40 Mb in memory)
# watch your system memory now...
pylab.plot(a) # this uses over 230 ADDITIONAL Mb of memory

3 个回答

2

我想保留一个日志采样图的一侧,所以我想出了这个方法:
(downsample是我第一次尝试的结果)

def downsample(x, y, target_length=1000, preserve_ends=0):
    assert len(x.shape) == 1
    assert len(y.shape) == 1
    data = np.vstack((x, y))
    if preserve_ends > 0:
        l, data, r = np.split(data, (preserve_ends, -preserve_ends), axis=1)
    interval = int(data.shape[1] / target_length) + 1
    data = data[:, ::interval]
    if preserve_ends > 0:
        data = np.concatenate([l, data, r], axis=1)
    return data[0, :], data[1, :]

def geom_ind(stop, num=50):
    geo_num = num
    ind = np.geomspace(1, stop, dtype=int, num=geo_num)
    while len(set(ind)) < num - 1:
        geo_num += 1
        ind = np.geomspace(1, stop, dtype=int, num=geo_num)
    return np.sort(list(set(ind) | {0}))

def log_downsample(x, y, target_length=1000, flip=False):
    assert len(x.shape) == 1
    assert len(y.shape) == 1
    data = np.vstack((x, y))
    if flip:
        data = np.fliplr(data)
    data = data[:, geom_ind(data.shape[1], num=target_length)]
    if flip:
        data = np.fliplr(data)
    return data[0, :], data[1, :]

这个方法让我更好地保留了图的一侧:

newx, newy = downsample(x, y, target_length=1000, preserve_ends=50)
newlogx, newlogy = log_downsample(x, y, target_length=1000)
f = plt.figure()
plt.gca().set_yscale("log")
plt.step(x, y, label="original")
plt.step(newx, newy, label="downsample")
plt.step(newlogx, newlogy, label="log_downsample")
plt.legend()

test

2

我也经常对极端值感兴趣,所以在绘制大量数据之前,我会这样做:

import numpy as np

s = np.random.normal(size=(1e7,))
decimation_factor = 10 
s = np.max(s.reshape(-1,decimation_factor),axis=1)

# To check the final size
s.shape

当然,np.max 只是一个极端值计算的例子。

附注:使用 numpy 的“步幅技巧”,应该可以在调整数据形状时避免复制数据。

7

降采样在这里是个不错的解决办法——在matplotlib中绘制1000万点会消耗大量的内存和时间。如果你知道自己能接受多少内存,那么就可以根据这个量来进行降采样。例如,假设100万点需要额外的23MB内存,而你觉得这个内存占用在可接受范围内,那么你就应该降采样,确保点数始终低于100万:

if(len(a) > 1M):
   a = scipy.signal.decimate(a, int(len(a)/1M)+1)
pylab.plot(a)

或者像上面的代码片段那样(上面的代码可能降采样得有点过于激进,不太符合你的口味)。

撰写回答