pyplot:绘制热图速度很慢
我有一个循环,大约执行200次。在每次循环中,它会进行复杂的计算,然后为了调试,我想生成一个NxM矩阵的热图。但是,生成这个热图的速度太慢了,严重拖慢了本来就慢的算法。
我的代码大致是这样的:
import numpy
import matplotlib.pyplot as plt
for i in range(200):
matrix = complex_calculation()
plt.set_cmap("gray")
plt.imshow(matrix)
plt.savefig("frame{0}.png".format(i))
这个矩阵来自numpy,大小并不大——300 x 600的双精度浮点数。即使我不保存图像,而是更新屏幕上的图表,速度也更慢。
我肯定是用错了pyplot。(Matlab可以做到这一点,没问题。)我该如何加快这个过程呢?
2 个回答
3
我觉得这样会快一点:
import matplotlib.pyplot as plt
from matplotlib import cm
fig = plt.figure()
ax = fig.add_axes([0.1,0.1,0.8,0.8])
for i in range(200):
matrix = complex_calculation()
ax.imshow(matrix, cmap=cm.gray)
fig.savefig("frame{0}.png".format(i))
plt.imshow
这个函数会调用 gca
,然后 gca
又会调用 gcf
,接着 gcf
会检查是否已经有一个图形存在;如果没有,它就会创建一个新的图形。通过先手动创建图形,你就不需要执行这些步骤了。
5
试着在循环里加上 plt.clf()
,这样可以清空当前的图像:
for i in range(200):
matrix = complex_calculation()
plt.set_cmap("gray")
plt.imshow(matrix)
plt.savefig("frame{0}.png".format(i))
plt.clf()
如果不这样做,循环会变得很慢,因为计算机要不断分配更多的内存来处理图像。