计算符合模式的元组中元素的数量
我有一个矩阵 m
,我想计算里面有多少个零。
m=((2,0,2,2),(4,4,5,4),(0,9,4,8),(2,2,0,0))
我现在的代码是这样的:
def zeroCount(M):
return [item for row in M for item in row].count(0)
# list of lists is flattened to form single list, and number of 0 are counted
有没有什么方法可以更快地做到这一点?目前,我在4x4的矩阵上执行这个函数20,000次需要0.4秒,而这些矩阵里零和非零的概率是一样的。
我尝试过一些可能的方向(但我没能让它们比我的代码更快),比如这些其他问题:计算numpy数组中非零元素的数量,查找非零元素的索引,还有计算可迭代对象中非零元素的数量。
6 个回答
使用 numpy:
import numpy
m=((2,0,2,2),(4,4,5,4),(0,9,4,8),(2,2,0,0))
numpy_m = numpy.array(m)
print numpy.sum(numpy_m == 0)
上面的代码是怎么回事呢?首先,你的“矩阵”会被转换成一个numpy数组(numpy.array(m)
)。接着,代码会检查每个元素是否等于零(numpy_m == 0
)。这样就会得到一个只有0和1的数组,0表示原数组中的元素不是零,1表示是零。然后,把这个只有0和1的数组加起来,就能算出原数组中有多少个零元素。
需要注意的是,对于更大的矩阵,numpy的效率会明显更高。像4x4这样的小矩阵,可能看不出和普通Python代码之间的性能差别,特别是当你像上面那样初始化一个Python“矩阵”时。
看看这个:
from itertools import chain, filterfalse # ifilterfalse for Python 2
def zeroCount(m):
total = 0
for x in filterfalse(bool, chain(*m)):
total += 1
return total
在Python 3.3.3上的性能测试:
from timeit import timeit
from itertools import chain, filterfalse
import functools
m = ((2,0,2,2),(4,4,5,4),(0,9,4,8),(2,2,0,0))
def zeroCountOP():
return [item for row in m for item in row].count(0)
def zeroCountTFE():
return len([item for row in m for item in row if item == 0])
def zeroCountJFS():
return sum(row.count(0) for row in m)
def zeroCountuser2931409():
# `reduce` is in `functools` in Py3k
return functools.reduce(lambda a, b: a + b, m).count(0)
def zeroCount():
total = 0
for x in filterfalse(bool, chain(*m)):
total += 1
return total
print('Original code ', timeit(zeroCountOP, number=100000))
print('@J.F.Sebastian ', timeit(zeroCountJFS, number=100000))
print('@thefourtheye ', timeit(zeroCountTFE, number=100000))
print('@user2931409 ', timeit(zeroCountuser2931409, number=100000))
print('@frostnational ', timeit(zeroCount, number=100000))
上面的测试给了我这些结果:
Original code 0.244224319984056
@thefourtheye 0.22169152169497108
@user2931409 0.19247795242092186
@frostnational 0.18846473728790825
@J.F.Sebastian 0.1439318853410907
@J.F.Sebastian的解决方案是最好的,我的方案是第二名(大约慢20%)。
适用于Python 2和Python 3的全面解决方案:
import sys
import itertools
if sys.version_info < (3, 0, 0):
filterfalse = getattr(itertools, 'ifilterfalse')
else:
filterfalse = getattr(itertools, 'filterfalse')
def countzeros(matrix):
''' Make a good use of `itertools.filterfalse`
(`itertools.ifilterfalse` in case of Python 2) to count
all 0s in `matrix`. '''
counter = 0
for _ in filterfalse(bool, itertools.chain(*matrix)):
counter += 1
return counter
if __name__ == '__main__':
# Benchmark
from timeit import repeat
print(repeat('countzeros(((2,0,2,2),(4,4,5,4),(0,9,4,8),(2,2,0,0)))',
'from __main__ import countzeros',
repeat=10,
number=100000))
你这个解决方案的问题在于,你需要再次遍历列表来计算数量,这样的时间复杂度是O(N)。而使用len
函数可以在O(1)的时间内直接获取数量。
你可以用下面的方法让这个过程快很多:
def zeroCount(M):
return len([item for row in M for item in row if item == 0])
这是我的回答。
reduce(lambda a, b: a + b, m).count(0)
时间:
%timeit count_zeros(m) #@J.F. Sebastian
1000000 loops, best of 3: 813 ns per loop
%timeit len([item for row in m for item in row if item == 0]) #@thefourtheye
1000000 loops, best of 3: 974 ns per loop
%timeit reduce(lambda a, b: a + b, m).count(0) #Mine
1000000 loops, best of 3: 1.02 us per loop
%timeit countzeros(m) #@frostnational
1000000 loops, best of 3: 1.07 us per loop
%timeit sum(row.count(0) for row in m) #@J.F. Sebastian
1000000 loops, best of 3: 1.28 us per loop
%timeit [item for row in m for item in row].count(0) #OP
1000000 loops, best of 3: 1.53 us per loop
@thefourtheye 的方法是最快的。这是因为它调用的函数比较少。
在我的环境中,@J.F. Sebastian 的方法是最快的。我也不知道为什么...
到目前为止,最快的方式是:
def count_zeros(matrix):
total = 0
for row in matrix:
total += row.count(0)
return total
对于二维元组,你可以 使用生成器表达式:
def count_zeros_gen(matrix):
return sum(row.count(0) for row in matrix)
时间比较:
%timeit [item for row in m for item in row].count(0) # OP
1000000 loops, best of 3: 1.15 µs per loop
%timeit len([item for row in m for item in row if item == 0]) # @thefourtheye
1000000 loops, best of 3: 913 ns per loop
%timeit sum(row.count(0) for row in m)
1000000 loops, best of 3: 1 µs per loop
%timeit count_zeros(m)
1000000 loops, best of 3: 775 ns per loop
作为基准:
def f(m): pass
%timeit f(m)
10000000 loops, best of 3: 110 ns per loop