如何利用NumPy的功能修复和优化这段简单的“生命游戏”代码?

4 投票
1 回答
552 浏览
提问于 2025-04-18 00:58
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import animation
from random import randint

arraySize = 50
Z = np.array([[randint(0, 1) for x in range(arraySize)] for y in range(arraySize)])


def computeNeighbours(Z):
    rows, cols = len(Z), len(Z[0])
    N = np.zeros(np.shape(Z))

    for x in range(rows):
        for y in range(cols):
            Q = [q for q in [x-1, x, x+1] if ((q >= 0) and (q < cols))]
            R = [r for r in [y-1, y, y+1] if ((r >= 0) and (r < rows))]
            S = [Z[q][r] for q in Q for r in R if (q, r) != (x, y)]
            N[x][y] = sum(S)

    return N

def iterate(Z):
    rows, cols = len(Z), len(Z[0])
    N = computeNeighbours(Z)

    for x in range(rows):
        for y in range(cols):
            if Z[x][y] == 1:
                if (N[x][y] < 2) or (N[x][y] > 3):
                    Z[x][y] = 0
            else:
                if (N[x][y] == 3):
                    Z[x][y] = 1

    return Z

fig = plt.figure()

Zs = [Z]
ims = []

for i in range(0, 100):
    im = plt.imshow(Zs[len(Zs)-1], interpolation = 'nearest', cmap='binary')
    ims.append([im])
    Zs.append(iterate(Z))

ani = animation.ArtistAnimation(fig, ims, interval=250, blit=True)
plt.show()

一开始,我用标准的Python工具写了一个简单的“生命游戏”实现。我把它画出来,运行得很正常,动画效果也很好。

接下来,我尝试把数组转换成NumPy数组,这就是现在代码的样子。不过,动画似乎不再工作了,我还没搞清楚原因。更新:这个bug已经修复了!)

接下来,我想利用NumPy来优化我的代码。目前为止,我已经把之前用的Python数组转换成了NumPy数组。虽然我相信性能上有一些提升,但并不明显。

我想知道在这种应用中应该进行什么样的优化步骤,这样我就能更好地利用NumPy的强大功能来处理我现在的项目,这个项目是一个(可能是)三维的细胞自动机,包含很多规则。


以下对代码的修改修复了动画错误:

1) 修改iterate,让它创建Z的深拷贝,然后对这个深拷贝进行修改。新的iterate

def iterate(Z):
    Zprime = Z.copy()
    rows, cols = len(Zprime), len(Zprime[0])
    N = computeNeighbours(Zprime)

    for x in range(rows):
        for y in range(cols):
            if Zprime[x][y] == 1:
                if (N[x][y] < 2) or (N[x][y] > 3):
                    Zprime[x][y] = 0
            else:
                if (N[x][y] == 3):
                    Zprime[x][y] = 1

    return Zprime

2) 由于第1点的原因,修改这段代码:

for i in range(0, 100):
    im = plt.imshow(Zs[len(Zs)-1], interpolation = 'nearest', cmap='binary')
    ims.append([im])
    Zs.append(iterate(Z))

为:

for i in range(0, 100):
    im = plt.imshow(Zs[len(Zs)-1], interpolation = 'nearest', cmap='binary')
    ims.append([im])
    Zs.append(iterate(Zs[len(Zs)-1]))

1 个回答

2

以下内容:

for x in range(rows):
        for y in range(cols):
            if Z[x][y] == 1:
                if (N[x][y] < 2) or (N[x][y] > 3):
                    Z[x][y] = 0
            else:
                if (N[x][y] == 3):
                    Z[x][y] = 1

可以用下面的代码替代:

set_zero_idxs = (Z==1) & ((N<2) | (N>3))
set_one_idxs = (Z!=1) & (N==3)
Z[set_zero_idxs] = 0
Z[set_one_idxs] = 1

虽然这样做的操作次数会比你原来的循环多,但我预计它会更快。

编辑:

我刚刚对这两种方案进行了性能测试,结果并不意外,使用numpy的版本快了180倍:

In [49]: %timeit no_loop(z,n)
1000 loops, best of 3: 177 us per loop

In [50]: %timeit loop(z,n)
10 loops, best of 3: 31.2 ms per loop

编辑2:

我认为这个循环:

for x in range(rows):
        for y in range(cols):
            Q = [q for q in [x-1, x, x+1] if ((q >= 0) and (q < cols))]
            R = [r for r in [y-1, y, y+1] if ((r >= 0) and (r < rows))]
            S = [Z[q][r] for q in Q for r in R if (q, r) != (x, y)]
            N[x][y] = sum(S)

可以用下面的代码替代:

N = np.roll(Z,1,axis=1) + np.roll(Z,-1,axis=1) + np.roll(Z,1,axis=0) + np.roll(Z,-1,axis=0)

这里有一个隐含的假设,就是数组是没有边界的,并且 x[-1] 是紧挨着 x[0] 的。如果这会造成问题,你可以在数组周围加一圈零作为缓冲,方法是:

shape = Z.shape
new_shape = (shape[0]+2,shape[1]+2)
b_z = np.zeros(new_shape)
b_z[1:-1,1:-1] = Z
b_n = np.roll(b_z,1,axis=1) + np.roll(b_z,-1,axis=1) + np.roll(b_z,1,axis=0) + np.roll(b_z,-1,axis=0)
N = b_n[1:-1,1:-1]

而对于性能测试:

In [4]: %timeit computeNeighbours(z)
10 loops, best of 3: 140 ms per loop 

In [5]: %timeit noloop_computeNeighbours(z)
10000 loops, best of 3: 133 us per loop

In [6]: %timeit noloop_with_buffer_computeNeighbours(z)
10000 loops, best of 3: 170 us per loop

结果显示只小幅提升了1052倍。为Numpy欢呼吧!

撰写回答