pyplot:绘制热图速度很慢

5 投票
2 回答
2933 浏览
提问于 2025-04-15 23:33

我有一个循环,大约执行200次。在每次循环中,它会进行复杂的计算,然后为了调试,我想生成一个NxM矩阵的热图。但是,生成这个热图的速度太慢了,严重拖慢了本来就慢的算法。

我的代码大致是这样的:

import numpy
import matplotlib.pyplot as plt
for i in range(200):
    matrix = complex_calculation()
    plt.set_cmap("gray")
    plt.imshow(matrix)
    plt.savefig("frame{0}.png".format(i))

这个矩阵来自numpy,大小并不大——300 x 600的双精度浮点数。即使我不保存图像,而是更新屏幕上的图表,速度也更慢。

我肯定是用错了pyplot。(Matlab可以做到这一点,没问题。)我该如何加快这个过程呢?

2 个回答

3

我觉得这样会快一点:

import matplotlib.pyplot as plt
from matplotlib import cm
fig = plt.figure()
ax = fig.add_axes([0.1,0.1,0.8,0.8])
for i in range(200):
    matrix = complex_calculation()
    ax.imshow(matrix, cmap=cm.gray)
    fig.savefig("frame{0}.png".format(i))

plt.imshow 这个函数会调用 gca,然后 gca 又会调用 gcf,接着 gcf 会检查是否已经有一个图形存在;如果没有,它就会创建一个新的图形。通过先手动创建图形,你就不需要执行这些步骤了。

5

试着在循环里加上 plt.clf(),这样可以清空当前的图像:

for i in range(200):
    matrix = complex_calculation()
    plt.set_cmap("gray")
    plt.imshow(matrix)
    plt.savefig("frame{0}.png".format(i))
    plt.clf()

如果不这样做,循环会变得很慢,因为计算机要不断分配更多的内存来处理图像。

撰写回答