从列表末尾移除连续重复项 Python
我有一个numpy数组:
ar = np.array([True, False, True, True, True])
如果最后一个元素是True,我想把数组末尾所有连续的True元素去掉。比如说:
magic_func(ar) => [True, False]
如果 ar = [True, False, True, False, True]
,那么结果会是:
magic_func(ar) => [True, False, True, False]
如果 ar = [True, False, False]
,这个函数就什么都不做,因为最后一个元素是False。
有没有什么简单的一行代码可以用python来实现这个?可以使用numpy库或者其他的东西。
6 个回答
编辑:更新了实现方式,以解决numpy数组中没有.pop()的问题
def chop_array(ar, condition):
i = len(ar)
while (ar[i - 1] == condition and i > 0):
i = i - 1
return ar[0:i]
chop_array([True, False, True, False, True], True)
这段代码有点神奇,但看起来确实有效:
>>> ar = np.array([True, False, True, True, True])
>>> ar[np.bitwise_or.accumulate(~ar[::-1])[::-1]]
array([ True, False], dtype=bool)
要理解这个过程,首先我们对数组进行取反,把所有的 True
变成 False
,反之亦然。接着,我们把顺序反转,然后对数组进行“或”运算的累积:在遇到第一个 True
之前,结果会是 False
,之后就会一直是 True
。最后,再把这个数组反转,就得到了一个布尔索引数组,可以去掉所有末尾的 True
。
这个一行的函数应该可以用,但看起来很复杂,而且可能效率不高,哈哈。基本上,这个函数的想法是找到最右边的一个 False
,然后返回在这个 False
之前的所有值。
def magic_func(a):
return a[:len(a)-np.where(a[::-1]==False)[0][0]] if np.where(a[::-1]==False)[0].size>0 else a[:0]
>>> a = np.array([False, True, True, True, True])
>>> magic_func(a)
array([False], dtype=bool)
写一个高效的单行代码并不简单,不过这里有一段代码,来源于一个被删除的回答,作者是askewchan:
argmin = ar[::-1].argmin()
result = np.array([], dtype=bool) if ar[argmin] else ar[:len(ar)-argmin]
对于这个情况,速度比Jaime的基于Numpy的解决方案快了两倍多,这里的代码是针对ar = np.full(1000000, True, dtype=bool)
,然后可以选择:
ar[-10] = False
,- 或者
ar[10] = False
(这两种情况分别代表了最好和最坏的情况)。不过,就像Jaime的解决方案一样,找到最后一个False
时,NumPy(1.8.1)需要遍历整个数组,这样效率就不高了。不过,原则上,NumPy在使用argmin()
时并不需要这样,因为它可以在遇到第一个False
时就停止。
在我看来,这个解决方案是利用了NumPy的最佳性能来进行切割,从而得到最终的数组。原则上,NumPy在找到切割位置时也可以非常高效。
使用 itertools.dropwhile
和 np.fromiter
可以这样做。
from itertools import dropwhile
np.fromiter(dropwhile(lambda x: x, ar[::-1]), dtype=bool)[::-1]
编辑
这是一种更快的方法。(只需使用 itertools.takewhile
)
from itertools import takewhile
ar[:-sum(1 for i in takewhile(lambda x: x, reversed(ar)))]
时间:
ar = np.array([True, False, True, True, True])
#mine
%timeit ar[:-sum(1 for i in takewhile(lambda x: x, reversed(ar)))]
1000000 loops, best of 3: 1.84 us per loop
#mine
%timeit np.fromiter(dropwhile(lambda x: x, ar[::-1]), dtype=bool)[::-1]
100000 loops, best of 3: 2.93 us per loop
#@Jaime
%timeit ar[np.bitwise_or.accumulate(~ar[::-1])[::-1]]
100000 loops, best of 3: 3.63 us per loop
#@askewchan
%timeit ar[:len(ar)-np.argmin(ar[::-1])]
100000 loops, best of 3: 6.24 us per loop
#@xbb
%timeit ar[:len(ar)-np.where(ar[::-1]==False)[0][0]] if np.where(ar[::-1]==False)[0].size>0 else ar[:0]
100000 loops, best of 3: 7.61 us per loop
附言:
这不是魔法函数。
def no_magic_func(ar):
for i in xrange(ar.size-1, -1, -1):
if not ar[i]:
return ar[:i+1]
return ar[0:0]
时间:
ar = np.array([True, False, True, True, True])
%timeit no_magic_func(ar)
1000000 loops, best of 3: 954 ns per loop