如何利用NumPy的功能修复和优化这段简单的“生命游戏”代码?
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import animation
from random import randint
arraySize = 50
Z = np.array([[randint(0, 1) for x in range(arraySize)] for y in range(arraySize)])
def computeNeighbours(Z):
rows, cols = len(Z), len(Z[0])
N = np.zeros(np.shape(Z))
for x in range(rows):
for y in range(cols):
Q = [q for q in [x-1, x, x+1] if ((q >= 0) and (q < cols))]
R = [r for r in [y-1, y, y+1] if ((r >= 0) and (r < rows))]
S = [Z[q][r] for q in Q for r in R if (q, r) != (x, y)]
N[x][y] = sum(S)
return N
def iterate(Z):
rows, cols = len(Z), len(Z[0])
N = computeNeighbours(Z)
for x in range(rows):
for y in range(cols):
if Z[x][y] == 1:
if (N[x][y] < 2) or (N[x][y] > 3):
Z[x][y] = 0
else:
if (N[x][y] == 3):
Z[x][y] = 1
return Z
fig = plt.figure()
Zs = [Z]
ims = []
for i in range(0, 100):
im = plt.imshow(Zs[len(Zs)-1], interpolation = 'nearest', cmap='binary')
ims.append([im])
Zs.append(iterate(Z))
ani = animation.ArtistAnimation(fig, ims, interval=250, blit=True)
plt.show()
一开始,我用标准的Python工具写了一个简单的“生命游戏”实现。我把它画出来,运行得很正常,动画效果也很好。
接下来,我尝试把数组转换成NumPy数组,这就是现在代码的样子。不过,动画似乎不再工作了,我还没搞清楚原因。(更新:这个bug已经修复了!)
接下来,我想利用NumPy来优化我的代码。目前为止,我已经把之前用的Python数组转换成了NumPy数组。虽然我相信性能上有一些提升,但并不明显。
我想知道在这种应用中应该进行什么样的优化步骤,这样我就能更好地利用NumPy的强大功能来处理我现在的项目,这个项目是一个(可能是)三维的细胞自动机,包含很多规则。
以下对代码的修改修复了动画错误:
1) 修改iterate
,让它创建Z的深拷贝,然后对这个深拷贝进行修改。新的iterate
:
def iterate(Z):
Zprime = Z.copy()
rows, cols = len(Zprime), len(Zprime[0])
N = computeNeighbours(Zprime)
for x in range(rows):
for y in range(cols):
if Zprime[x][y] == 1:
if (N[x][y] < 2) or (N[x][y] > 3):
Zprime[x][y] = 0
else:
if (N[x][y] == 3):
Zprime[x][y] = 1
return Zprime
2) 由于第1点的原因,修改这段代码:
for i in range(0, 100):
im = plt.imshow(Zs[len(Zs)-1], interpolation = 'nearest', cmap='binary')
ims.append([im])
Zs.append(iterate(Z))
为:
for i in range(0, 100):
im = plt.imshow(Zs[len(Zs)-1], interpolation = 'nearest', cmap='binary')
ims.append([im])
Zs.append(iterate(Zs[len(Zs)-1]))
1 个回答
2
以下内容:
for x in range(rows):
for y in range(cols):
if Z[x][y] == 1:
if (N[x][y] < 2) or (N[x][y] > 3):
Z[x][y] = 0
else:
if (N[x][y] == 3):
Z[x][y] = 1
可以用下面的代码替代:
set_zero_idxs = (Z==1) & ((N<2) | (N>3))
set_one_idxs = (Z!=1) & (N==3)
Z[set_zero_idxs] = 0
Z[set_one_idxs] = 1
虽然这样做的操作次数会比你原来的循环多,但我预计它会更快。
编辑:
我刚刚对这两种方案进行了性能测试,结果并不意外,使用numpy的版本快了180倍:
In [49]: %timeit no_loop(z,n)
1000 loops, best of 3: 177 us per loop
In [50]: %timeit loop(z,n)
10 loops, best of 3: 31.2 ms per loop
编辑2:
我认为这个循环:
for x in range(rows):
for y in range(cols):
Q = [q for q in [x-1, x, x+1] if ((q >= 0) and (q < cols))]
R = [r for r in [y-1, y, y+1] if ((r >= 0) and (r < rows))]
S = [Z[q][r] for q in Q for r in R if (q, r) != (x, y)]
N[x][y] = sum(S)
可以用下面的代码替代:
N = np.roll(Z,1,axis=1) + np.roll(Z,-1,axis=1) + np.roll(Z,1,axis=0) + np.roll(Z,-1,axis=0)
这里有一个隐含的假设,就是数组是没有边界的,并且 x[-1]
是紧挨着 x[0]
的。如果这会造成问题,你可以在数组周围加一圈零作为缓冲,方法是:
shape = Z.shape
new_shape = (shape[0]+2,shape[1]+2)
b_z = np.zeros(new_shape)
b_z[1:-1,1:-1] = Z
b_n = np.roll(b_z,1,axis=1) + np.roll(b_z,-1,axis=1) + np.roll(b_z,1,axis=0) + np.roll(b_z,-1,axis=0)
N = b_n[1:-1,1:-1]
而对于性能测试:
In [4]: %timeit computeNeighbours(z)
10 loops, best of 3: 140 ms per loop
In [5]: %timeit noloop_computeNeighbours(z)
10000 loops, best of 3: 133 us per loop
In [6]: %timeit noloop_with_buffer_computeNeighbours(z)
10000 loops, best of 3: 170 us per loop
结果显示只小幅提升了1052倍。为Numpy欢呼吧!