如何加速 Matplotlib?
我在这里看到,matplotlib在处理大数据集方面表现不错。但是我正在写一个数据处理应用,把matplotlib的图嵌入到wx中,结果发现matplotlib在处理大量数据时表现得非常糟糕,既慢又占内存。有没有人知道除了降低输入数据的采样率之外,还有什么方法可以加快matplotlib的速度(减少内存占用)?
为了说明matplotlib在内存方面的糟糕表现,看看这段代码:
import pylab
import numpy
a = numpy.arange(int(1e7)) # only 10,000,000 32-bit integers (~40 Mb in memory)
# watch your system memory now...
pylab.plot(a) # this uses over 230 ADDITIONAL Mb of memory
3 个回答
2
我想保留一个日志采样图的一侧,所以我想出了这个方法:
(downsample是我第一次尝试的结果)
def downsample(x, y, target_length=1000, preserve_ends=0):
assert len(x.shape) == 1
assert len(y.shape) == 1
data = np.vstack((x, y))
if preserve_ends > 0:
l, data, r = np.split(data, (preserve_ends, -preserve_ends), axis=1)
interval = int(data.shape[1] / target_length) + 1
data = data[:, ::interval]
if preserve_ends > 0:
data = np.concatenate([l, data, r], axis=1)
return data[0, :], data[1, :]
def geom_ind(stop, num=50):
geo_num = num
ind = np.geomspace(1, stop, dtype=int, num=geo_num)
while len(set(ind)) < num - 1:
geo_num += 1
ind = np.geomspace(1, stop, dtype=int, num=geo_num)
return np.sort(list(set(ind) | {0}))
def log_downsample(x, y, target_length=1000, flip=False):
assert len(x.shape) == 1
assert len(y.shape) == 1
data = np.vstack((x, y))
if flip:
data = np.fliplr(data)
data = data[:, geom_ind(data.shape[1], num=target_length)]
if flip:
data = np.fliplr(data)
return data[0, :], data[1, :]
这个方法让我更好地保留了图的一侧:
newx, newy = downsample(x, y, target_length=1000, preserve_ends=50)
newlogx, newlogy = log_downsample(x, y, target_length=1000)
f = plt.figure()
plt.gca().set_yscale("log")
plt.step(x, y, label="original")
plt.step(newx, newy, label="downsample")
plt.step(newlogx, newlogy, label="log_downsample")
plt.legend()
2
我也经常对极端值感兴趣,所以在绘制大量数据之前,我会这样做:
import numpy as np
s = np.random.normal(size=(1e7,))
decimation_factor = 10
s = np.max(s.reshape(-1,decimation_factor),axis=1)
# To check the final size
s.shape
当然,np.max
只是一个极端值计算的例子。
附注:使用 numpy
的“步幅技巧”,应该可以在调整数据形状时避免复制数据。
7
降采样在这里是个不错的解决办法——在matplotlib中绘制1000万点会消耗大量的内存和时间。如果你知道自己能接受多少内存,那么就可以根据这个量来进行降采样。例如,假设100万点需要额外的23MB内存,而你觉得这个内存占用在可接受范围内,那么你就应该降采样,确保点数始终低于100万:
if(len(a) > 1M):
a = scipy.signal.decimate(a, int(len(a)/1M)+1)
pylab.plot(a)
或者像上面的代码片段那样(上面的代码可能降采样得有点过于激进,不太符合你的口味)。