检查列表中所有元素是否相同
我需要一个函数,这个函数接收一个 list
(列表),如果列表中的所有元素都相等,就返回 True
,否则返回 False
。
我觉得最好的方法是遍历这个列表,比较相邻的元素,然后把所有得到的布尔值用 AND
连接起来。但我不太确定用Python的方式怎么做最合适。
33 个回答
[编辑:这个回答是针对目前投票最高的 itertools.groupby
答案的,那个答案也很好。]
在不重写程序的情况下,最有效率且最易读的方法如下:
all(x==myList[0] for x in myList)
(是的,这个方法即使在空列表的情况下也能工作!这是因为在这种情况下,Python 采用了懒惰求值的方式。)
这个方法会在最早的时刻就失败,所以它在理论上是最优的(预期时间大约是 O(#uniques),而不是 O(N),但最坏情况下的时间仍然是 O(N))。这假设你之前没有见过这些数据……
(如果你在意性能,但又不是特别在意,你可以先做一些常规的优化,比如把 myList[0]
这个常量提到循环外面,并为边界情况添加一些复杂的逻辑,虽然这可能是 Python 编译器最终会学会做的事情,所以除非绝对必要,不然不建议这样做,因为这样会降低可读性,收益却很小。)
如果你对性能稍微在意一点,这个方法比上面的快一倍,但稍微复杂一些:
def allEqual(iterable):
iterator = iter(iterable)
try:
firstItem = next(iterator)
except StopIteration:
return True
for x in iterator:
if x!=firstItem:
return False
return True
如果你对性能要求更高(但又不想重写程序),可以使用目前投票最高的 itertools.groupby
答案,这个方法比 allEqual
快一倍,因为它可能是经过优化的 C 代码。(根据文档,它应该(和这个答案类似)没有任何内存开销,因为懒惰生成器从未被评估为列表……这可能让人担心,但伪代码显示分组的“列表”实际上是懒惰生成器。)
如果你对性能要求更高,请继续阅读……
关于性能的附注,因为其他答案出于某种未知原因在讨论这个:
……如果你之前见过这些数据,并且可能使用某种集合数据结构,并且你真的很在意性能,你可以通过在你的结构中增加一个 Counter
来免费获得 .isAllEqual()
,这个 Counter
会在每次插入/删除等操作时更新,只需检查它是否是 {something:someCount}
的形式,即 len(counter.keys())==1
;或者你可以在一个单独的变量中保持一个 Counter。这证明比其他任何方法都要好,最多只差一个常数因子。 也许你还可以使用 Python 的 FFI 和 ctypes
,结合你选择的方法,或许还可以用一些启发式的方法(比如如果是一个有 getitem 的序列,先检查第一个元素、最后一个元素,然后再按顺序检查元素)。
当然,可读性也是很重要的。
一种比使用set()更快的方法,适用于序列(而不是可迭代对象),就是简单地统计第一个元素的数量。这假设列表不是空的(不过这很容易检查,你可以自己决定如果列表是空的该怎么处理)。
x.count(x[0]) == len(x)
一些简单的基准测试:
>>> timeit.timeit('len(set(s1))<=1', 's1=[1]*5000', number=10000)
1.4383411407470703
>>> timeit.timeit('len(set(s1))<=1', 's1=[1]*4999+[2]', number=10000)
1.4765670299530029
>>> timeit.timeit('s1.count(s1[0])==len(s1)', 's1=[1]*5000', number=10000)
0.26274609565734863
>>> timeit.timeit('s1.count(s1[0])==len(s1)', 's1=[1]*4999+[2]', number=10000)
0.25654196739196777
使用 itertools.groupby
(可以参考 这个 itertools
的食谱):
from itertools import groupby
def all_equal(iterable):
g = groupby(iterable)
return next(g, True) and not next(g, False)
或者不使用 groupby
:
def all_equal(iterator):
iterator = iter(iterator)
try:
first = next(iterator)
except StopIteration:
return True
return all(first == x for x in iterator)
你可以考虑一些其他的简单方法:
把输入转换成一个集合,然后检查这个集合里只有一个或零个(如果输入是空的)元素
def all_equal2(iterator): return len(set(iterator)) <= 1
把输入列表的第一个元素去掉后进行比较
def all_equal3(lst): return lst[:-1] == lst[1:]
-
def all_equal_ivo(lst): return not lst or lst.count(lst[0]) == len(lst)
-
def all_equal_6502(lst): return not lst or [lst[0]]*len(lst) == lst
不过这些方法也有一些缺点,具体来说:
all_equal
和all_equal2
可以使用任何迭代器,但其他方法必须输入一个序列,通常是像列表或元组这样的具体容器。all_equal
和all_equal3
一旦发现不同就会停止(这叫做“短路”),而其他方法需要遍历整个列表,即使你只看前两个元素就能判断答案是False
。- 在
all_equal2
中,内容必须是 可哈希的。比如说,列表里面再放列表会引发TypeError
。 all_equal2
(在最坏的情况下)和all_equal_6502
会创建列表的副本,这意味着你需要使用双倍的内存。
在 Python 3.9 中,使用 perfplot
,我们得到了这些时间(运行时间越低越好):