如何在numpy中执行scatter/gather操作
假设我有一些数组:
a = array((1,2,3,4,5))
indices = array((1,1,1,1))
然后我进行了一些操作:
a[indices] += 1
得到的结果是
array([1, 3, 3, 4, 5])
换句话说,indices
中的重复项被忽略了
如果我想让重复项不被忽略,结果变成:
array([1, 6, 3, 4, 5])
我该怎么做呢?
上面的例子有点简单,接下来就是我真正想做的事情:
def inflate(self,pressure):
faceforces = pressure * cross(self.verts[self.faces[:,1]]-self.verts[self.faces[:,0]], self.verts[self.faces[:,2]]-self.verts[self.faces[:,0]])
self.verts[self.faces[:,0]] += faceforces
self.verts[self.faces[:,1]] += faceforces
self.verts[self.faces[:,2]] += faceforces
def constrain_lengths(self):
vectors = self.verts[self.constraints[:,1]] - self.verts[self.constraints[:,0]]
lengths = sqrt(sum(square(vectors), axis=1))
correction = 0.5 * (vectors.T * (1 - (self.restlengths / lengths))).T
self.verts[self.constraints[:,0]] += correction
self.verts[self.constraints[:,1]] -= correction
def compute_normals(self):
facenormals = cross(self.verts[self.faces[:,1]]-self.verts[self.faces[:,0]], self.verts[self.faces[:,2]]-self.verts[self.faces[:,0]])
self.normals.fill(0)
self.normals[self.faces[:,0]] += facenormals
self.normals[self.faces[:,1]] += facenormals
self.normals[self.faces[:,2]] += facenormals
lengths = sqrt(sum(square(self.normals), axis=1))
self.normals = (self.normals.T / lengths).T
由于在我的索引赋值操作中忽略了重复项,我得到了很多很奇怪的结果。
3 个回答
我不知道有什么方法比这个更快:
for face in self.faces[:,0]:
self.verts[face] += faceforces
你也可以把self.faces变成一个包含3个字典的数组,每个字典的键对应一个面,而值则是需要添加的次数。这样你就能得到像这样的代码:
for face in self.faces[0]:
self.verts[face] += self.faces[0][face]*faceforces
这可能会更快。我希望有人能想出更好的方法,因为我今天早些时候在帮别人加速代码时就想做这个。
虽然我来得有点晚,但考虑到这个操作非常常见,而且它似乎还没有成为标准numpy的一部分,我在这里分享我的解决方案,供大家参考:
def scatter(rowidx, vals, target):
"""compute target[rowidx] += vals, allowing for repeated values in rowidx"""
rowidx = np.ravel(rowidx)
vals = np.ravel(vals)
cols = len(vals)
data = np.ones(cols)
colidx = np.arange(cols)
rows = len(target)
from scipy.sparse import coo_matrix
M = coo_matrix((data,(rowidx,colidx)), shape=(rows, cols))
target += M*vals
def gather(idx, vals):
"""for symmetry with scatter"""
return vals[idx]
在numpy中使用自定义的C语言程序可以让这个操作快上两倍,因为它可以省去不必要的分配和与1的乘法,这样一来,性能就会大幅提升,相比于在python中使用循环,效果会好很多。
除了性能方面的考虑,使用散布操作在风格上也更符合其他numpy的向量化代码,而不是在代码中硬塞一些for循环。
编辑:
好了,忘掉上面的内容吧。根据最新的1.8版本,现在numpy直接支持散布操作,并且效率非常高。
def scatter(idx, vals, target):
"""target[idx] += vals, but allowing for repeats in idx"""
np.add.at(target, idx.ravel(), vals.ravel())
numpy
的 histogram
函数是一种分散操作。
a += histogram(indices, bins=a.size, range=(0, a.size))[0]
你需要注意,如果 indices
里包含整数,可能会因为小的四舍五入误差导致某些值放错了地方。为了避免这种情况,可以使用:
a += histogram(indices, bins=a.size, range=(-0.5, a.size-0.5))[0]
这样可以确保每个索引都能准确放到每个区间的中心。
更新:这个方法有效。不过我建议使用 @Eelco Hoogendoorn 的答案,基于 numpy.add.at
的方法。