Python:基于交集的简单列表合并

54 投票
19 回答
14274 浏览
提问于 2025-04-17 12:11

假设有一些整数列表,如下所示:

#--------------------------------------
0 [0,1,3]
1 [1,0,3,4,5,10,...]
2 [2,8]
3 [3,1,0,...]
...
n []
#--------------------------------------

问题是要合并那些至少有一个共同元素的列表。所以对于给定的部分,结果将如下:

#--------------------------------------
0 [0,1,3,4,5,10,...]
2 [2,8]
#--------------------------------------

在处理大量数据时,最有效的方法是什么(元素只是数字)? 使用结构是否值得考虑? 我现在的做法是把列表转换成集合,然后进行交集操作,但这太慢了!而且我感觉这太基础了!另外,实施过程中似乎缺少了什么(不明)因为有些列表有时没有合并!话虽如此,如果你打算自己实现,请慷慨一点,提供一个简单的示例代码[显然我最喜欢的是Python :)]或者伪代码。
更新 1: 这是我使用的代码:

#--------------------------------------
lsts = [[0,1,3],
        [1,0,3,4,5,10,11],
        [2,8],
        [3,1,0,16]];
#--------------------------------------

这个函数是(有问题!!):

#--------------------------------------
def merge(lsts):
    sts = [set(l) for l in lsts]
    i = 0
    while i < len(sts):
        j = i+1
        while j < len(sts):
            if len(sts[i].intersection(sts[j])) > 0:
                sts[i] = sts[i].union(sts[j])
                sts.pop(j)
            else: j += 1                        #---corrected
        i += 1
    lst = [list(s) for s in sts]
    return lst
#--------------------------------------

结果是:

#--------------------------------------
>>> merge(lsts)
>>> [0, 1, 3, 4, 5, 10, 11, 16], [8, 2]]
#--------------------------------------

更新 2: 根据我的经验,下面Niklas Baumstark给出的代码在简单情况下稍微快一点。还没有测试“Hooked”给出的方法,因为那是完全不同的思路(顺便说一下,听起来很有趣)。 对这些方法的测试过程可能真的很困难,甚至不可能确保结果。我要使用的真实数据集非常大且复杂,所以仅仅通过重复测试来追踪任何错误是不可能的。因此,在将其作为模块放入大型代码之前,我需要对该方法的可靠性感到100%满意。简单来说,目前Niklas的方法更快,当然简单集合的答案是正确的。
但是我怎么能确保它在真实的大数据集上表现良好呢?因为我无法通过视觉来追踪错误!

更新 3: 请注意,对于这个问题,方法的可靠性比速度更重要。希望我最终能将Python代码转换为Fortran,以获得最佳性能。

更新 4:
这个帖子中有很多有趣的观点和慷慨提供的答案、建设性的评论。我建议大家仔细阅读所有内容。请接受我对问题发展的赞赏,惊人的答案以及建设性的评论和讨论。

19 个回答

7

使用矩阵操作

在开始之前,我想先说一句:

这不是正确的方法。它容易出现数值不稳定的问题,而且比其他方法慢得多,使用时请谨慎。

话虽如此,我还是忍不住从动态的角度来解决这个问题(希望你能从中获得新的视角)。理论上,这种方法应该总是有效,但特征值的计算常常会失败。我们可以把你的列表看作是从行到列的一个流动。如果两行有共同的值,那么它们之间就有连接的流动。如果我们把这些流动想象成水流,我们会发现,当它们之间有连接路径时,这些流动会聚集成小水池。为了简单起见,我将使用一个较小的集合,尽管它也适用于你的数据集:

from numpy import where, newaxis
from scipy import linalg, array, zeros

X = [[0,1,3],[2],[3,1]]

我们需要把数据转换成一个流图。如果第i行流向值j,我们就把它放入矩阵中。这里我们有3行和4个独特的值:

A = zeros((4,len(X)), dtype=float)
for i,row in enumerate(X):
    for val in row: A[val,i] = 1

一般来说,你需要把4改成你独特值的数量。如果这个集合是从0开始的整数列表,就可以直接用最大的数字。接下来,我们进行特征值分解。确切地说,是进行奇异值分解(SVD),因为我们的矩阵不是方阵。

S  = linalg.svd(A)

我们只想保留这个结果的3x3部分,因为它将代表水池的流动。实际上,我们只关心这个矩阵的绝对值;我们只在意这个聚集空间中是否有流动。

M  = abs(S[2])

我们可以把这个矩阵M看作是一个马尔可夫矩阵,并通过行归一化使其更明确。一旦我们有了这个,就可以计算这个矩阵的(左)特征值分解。

M /=  M.sum(axis=1)[:,newaxis]
U,V = linalg.eig(M,left=True, right=False)
V = abs(V)

现在,一个不连通(非遍历)的马尔可夫矩阵有一个很好的特性:对于每个不连通的聚类,都有一个特征值为1。与这些1值相关的特征向量就是我们想要的:

idx = where(U > .999)[0]
C = V.T[idx] > 0

由于之前提到的数值不稳定性,我必须使用0.999。在这一点上,我们完成了!每个独立的聚类现在可以提取相应的行:

for cluster in C:
    print where(A[:,cluster].sum(axis=1))[0]

这就得到了我们想要的结果:

[0 1 3]
[2]

X改成你的lst,你会得到:[ 0 1 3 4 5 10 11 16] [2 8]


附录

这有什么用呢?我不知道你的基础数据来自哪里,但当连接不是绝对的时候会发生什么?比如说第1行的值3有80%的概率出现——你会如何推广这个问题?上面的流动方法依然有效,并且完全由那个0.999值来参数化,离1越远,关联就越松散。


可视化表示

既然一张图片胜过千言万语,这里是我示例和你的lst的矩阵A和V的图示。注意V是如何分成两个聚类的(经过排列后,它是一个块对角矩阵,有两个块),因为每个示例只有两个独特的列表!

我的示例 你的样本数据


更快的实现

回想起来,我意识到你可以跳过SVD步骤,只计算一个单一的分解:

M = dot(A.T,A)
M /=  M.sum(axis=1)[:,newaxis]
U,V = linalg.eig(M,left=True, right=False)

这种方法的优点(除了速度)是M现在是对称的,因此计算可以更快、更准确(不需要担心虚数值)。

16

我试着把关于这个话题在这个问题和重复问题中说过的所有内容总结一下。

我还尝试了测试计时每个解决方案(所有代码在这里)。

测试

这是来自测试模块的TestCase

class MergeTestCase(unittest.TestCase):

    def setUp(self):
        with open('./lists/test_list.txt') as f:
            self.lsts = json.loads(f.read())
        self.merged = self.merge_func(deepcopy(self.lsts))

    def test_disjoint(self):
        """Check disjoint-ness of merged results"""
        from itertools import combinations
        for a,b in combinations(self.merged, 2):
            self.assertTrue(a.isdisjoint(b))

    def test_coverage(self):    # Credit to katrielalex
        """Check coverage original data"""
        merged_flat = set()
        for s in self.merged:
            merged_flat |= s

        original_flat = set()
        for lst in self.lsts:
            original_flat |= set(lst)

        self.assertTrue(merged_flat == original_flat)

    def test_subset(self):      # Credit to WolframH
        """Check that every original data is a subset"""
        for lst in self.lsts:
            self.assertTrue(any(set(lst) <= e for e in self.merged))

这个测试假设结果是一个集合的列表,所以我没法测试一些只适用于列表的解决方案。

我无法测试以下内容:

katrielalex
steabert

在我能测试的方案中,有两个失败了

  -- Going to test: agf (optimized) --
Check disjoint-ness of merged results ... FAIL

  -- Going to test: robert king --
Check disjoint-ness of merged results ... FAIL

计时

性能与所用的测试数据密切相关。

到目前为止,有三个答案尝试对他们的解决方案和其他解决方案进行计时。由于他们使用了不同的测试数据,所以结果也不同。

  1. Niklas的基准测试非常灵活。通过他的基准测试,可以通过更改一些参数来进行不同的测试。

    我使用了他在自己答案中使用的三组参数,并把它们放在三个不同的文件中:

    filename = './lists/timing_1.txt'
    class_count = 50,
    class_size = 1000,
    list_count_per_class = 100,
    large_list_sizes = (100, 1000),
    small_list_sizes = (0, 100),
    large_list_probability = 0.5,
    
    filename = './lists/timing_2.txt'
    class_count = 15,
    class_size = 1000,
    list_count_per_class = 300,
    large_list_sizes = (100, 1000),
    small_list_sizes = (0, 100),
    large_list_probability = 0.5,
    
    filename = './lists/timing_3.txt'
    class_count = 15,
    class_size = 1000,
    list_count_per_class = 300,
    large_list_sizes = (100, 1000),
    small_list_sizes = (0, 100),
    large_list_probability = 0.1,
    

    这是我得到的结果:

    来自文件:timing_1.txt

    Timing with: >> Niklas << Benchmark
    Info: 5000 lists, average size 305, max size 999
    
    Timing Results:
    10.434  -- alexis
    11.476  -- agf
    11.555  -- Niklas B.
    13.622  -- Rik. Poggi
    14.016  -- agf (optimized)
    14.057  -- ChessMaster
    20.208  -- katrielalex
    21.697  -- steabert
    25.101  -- robert king
    76.870  -- Sven Marnach
    133.399  -- hochl
    

    来自文件:timing_2.txt

    Timing with: >> Niklas << Benchmark
    Info: 4500 lists, average size 305, max size 999
    
    Timing Results:
    8.247  -- Niklas B.
    8.286  -- agf
    8.637  -- Rik. Poggi
    8.967  -- alexis
    9.090  -- ChessMaster
    9.091  -- agf (optimized)
    18.186  -- katrielalex
    19.543  -- steabert
    22.852  -- robert king
    70.486  -- Sven Marnach
    104.405  -- hochl
    

    来自文件:timing_3.txt

    Timing with: >> Niklas << Benchmark
    Info: 4500 lists, average size 98, max size 999
    
    Timing Results:
    2.746  -- agf
    2.850  -- Niklas B.
    2.887  -- Rik. Poggi
    2.972  -- alexis
    3.077  -- ChessMaster
    3.174  -- agf (optimized)
    5.811  -- katrielalex
    7.208  -- robert king
    9.193  -- steabert
    23.536  -- Sven Marnach
    37.436  -- hochl
    
  2. 使用Sven的测试数据,我得到了以下结果:

    Timing with: >> Sven << Benchmark
    Info: 200 lists, average size 10, max size 10
    
    Timing Results:
    2.053  -- alexis
    2.199  -- ChessMaster
    2.410  -- agf (optimized)
    3.394  -- agf
    3.398  -- Rik. Poggi
    3.640  -- robert king
    3.719  -- steabert
    3.776  -- Niklas B.
    3.888  -- hochl
    4.610  -- Sven Marnach
    5.018  -- katrielalex
    
  3. 最后,使用Agf的基准测试,我得到了:

    Timing with: >> Agf << Benchmark
    Info: 2000 lists, average size 246, max size 500
    
    Timing Results:
    3.446  -- Rik. Poggi
    3.500  -- ChessMaster
    3.520  -- agf (optimized)
    3.527  -- Niklas B.
    3.527  -- agf
    3.902  -- hochl
    5.080  -- alexis
    15.997  -- steabert
    16.422  -- katrielalex
    18.317  -- robert king
    1257.152  -- Sven Marnach
    

正如我一开始所说,所有代码都可以在这个git仓库找到。所有合并函数都在一个名为core.py的文件中,文件中每个函数的名称都以_merge结尾,这些函数在测试时会自动加载,所以添加、测试或改进你自己的解决方案应该不难。

如果有什么问题,请告诉我,写了很多代码,我需要一些新鲜的视角来看看 :)

30

我尝试的代码:

def merge(lsts):
    sets = [set(lst) for lst in lsts if lst]
    merged = True
    while merged:
        merged = False
        results = []
        while sets:
            common, rest = sets[0], sets[1:]
            sets = []
            for x in rest:
                if x.isdisjoint(common):
                    sets.append(x)
                else:
                    merged = True
                    common |= x
            results.append(common)
        sets = results
    return sets

lst = [[65, 17, 5, 30, 79, 56, 48, 62],
       [6, 97, 32, 93, 55, 14, 70, 32],
       [75, 37, 83, 34, 9, 19, 14, 64],
       [43, 71],
       [],
       [89, 49, 1, 30, 28, 3, 63],
       [35, 21, 68, 94, 57, 94, 9, 3],
       [16],
       [29, 9, 97, 43],
       [17, 63, 24]]
print merge(lst)

基准测试:

import random

# adapt parameters to your own usage scenario
class_count = 50
class_size = 1000
list_count_per_class = 100
large_list_sizes = list(range(100, 1000))
small_list_sizes = list(range(0, 100))
large_list_probability = 0.5

if False:  # change to true to generate the test data file (takes a while)
    with open("/tmp/test.txt", "w") as f:
        lists = []
        classes = [
            range(class_size * i, class_size * (i + 1)) for i in range(class_count)
        ]
        for c in classes:
            # distribute each class across ~300 lists
            for i in xrange(list_count_per_class):
                lst = []
                if random.random() < large_list_probability:
                    size = random.choice(large_list_sizes)
                else:
                    size = random.choice(small_list_sizes)
                nums = set(c)
                for j in xrange(size):
                    x = random.choice(list(nums))
                    lst.append(x)
                    nums.remove(x)
                random.shuffle(lst)
                lists.append(lst)
        random.shuffle(lists)
        for lst in lists:
            f.write(" ".join(str(x) for x in lst) + "\n")

setup = """
# Niklas'
def merge_niklas(lsts):
    sets = [set(lst) for lst in lsts if lst]
    merged = 1
    while merged:
        merged = 0
        results = []
        while sets:
            common, rest = sets[0], sets[1:]
            sets = []
            for x in rest:
                if x.isdisjoint(common):
                    sets.append(x)
                else:
                    merged = 1
                    common |= x
            results.append(common)
        sets = results
    return sets

# Rik's
def merge_rik(data):
    sets = (set(e) for e in data if e)
    results = [next(sets)]
    for e_set in sets:
        to_update = []
        for i, res in enumerate(results):
            if not e_set.isdisjoint(res):
                to_update.insert(0, i)

        if not to_update:
            results.append(e_set)
        else:
            last = results[to_update.pop(-1)]
            for i in to_update:
                last |= results[i]
                del results[i]
            last |= e_set
    return results

# katrielalex's
def pairs(lst):
    i = iter(lst)
    first = prev = item = i.next()
    for item in i:
        yield prev, item
        prev = item
    yield item, first

import networkx

def merge_katrielalex(lsts):
    g = networkx.Graph()
    for lst in lsts:
        for edge in pairs(lst):
            g.add_edge(*edge)
    return networkx.connected_components(g)

# agf's (optimized)
from collections import deque

def merge_agf_optimized(lists):
    sets = deque(set(lst) for lst in lists if lst)
    results = []
    disjoint = 0
    current = sets.pop()
    while True:
        merged = False
        newsets = deque()
        for _ in xrange(disjoint, len(sets)):
            this = sets.pop()
            if not current.isdisjoint(this):
                current.update(this)
                merged = True
                disjoint = 0
            else:
                newsets.append(this)
                disjoint += 1
        if sets:
            newsets.extendleft(sets)
        if not merged:
            results.append(current)
            try:
                current = newsets.pop()
            except IndexError:
                break
            disjoint = 0
        sets = newsets
    return results

# agf's (simple)
def merge_agf_simple(lists):
    newsets, sets = [set(lst) for lst in lists if lst], []
    while len(sets) != len(newsets):
        sets, newsets = newsets, []
        for aset in sets:
            for eachset in newsets:
                if not aset.isdisjoint(eachset):
                    eachset.update(aset)
                    break
            else:
                newsets.append(aset)
    return newsets

# alexis'
def merge_alexis(data):
    bins = range(len(data))  # Initialize each bin[n] == n
    nums = dict()

    data = [set(m) for m in data]  # Convert to sets
    for r, row in enumerate(data):
        for num in row:
            if num not in nums:
                # New number: tag it with a pointer to this row's bin
                nums[num] = r
                continue
            else:
                dest = locatebin(bins, nums[num])
                if dest == r:
                    continue  # already in the same bin

                if dest > r:
                    dest, r = r, dest  # always merge into the smallest bin

                data[dest].update(data[r])
                data[r] = None
                # Update our indices to reflect the move
                bins[r] = dest
                r = dest

    # Filter out the empty bins
    have = [m for m in data if m]
    return have

def locatebin(bins, n):
    while bins[n] != n:
        n = bins[n]
    return n

lsts = []
size = 0
num = 0
max = 0
for line in open("/tmp/test.txt", "r"):
    lst = [int(x) for x in line.split()]
    size += len(lst)
    if len(lst) > max:
        max = len(lst)
    num += 1
    lsts.append(lst)
"""

setup += """
print "%i lists, {class_count} equally distributed classes, average size %i, max size %i" % (num, size/num, max)
""".format(class_count=class_count)

import timeit
print "niklas"
print timeit.timeit("merge_niklas(lsts)", setup=setup, number=3)
print "rik"
print timeit.timeit("merge_rik(lsts)", setup=setup, number=3)
print "katrielalex"
print timeit.timeit("merge_katrielalex(lsts)", setup=setup, number=3)
print "agf (1)"
print timeit.timeit("merge_agf_optimized(lsts)", setup=setup, number=3)
print "agf (2)"
print timeit.timeit("merge_agf_simple(lsts)", setup=setup, number=3)
print "alexis"
print timeit.timeit("merge_alexis(lsts)", setup=setup, number=3)

这些时间显然取决于基准测试的具体参数,比如类的数量、列表的数量、列表的大小等等。根据你的需求调整这些参数,以获得更有用的结果。

下面是我机器上针对不同参数的一些示例输出。它们显示了所有算法都有各自的优缺点,这取决于它们接收到的输入类型:

=====================
# many disjoint classes, large lists
class_count = 50
class_size = 1000
list_count_per_class = 100
large_list_sizes = list(range(100, 1000))
small_list_sizes = list(range(0, 100))
large_list_probability = 0.5
=====================

niklas
5000 lists, 50 equally distributed classes, average size 298, max size 999
4.80084705353
rik
5000 lists, 50 equally distributed classes, average size 298, max size 999
9.49251699448
katrielalex
5000 lists, 50 equally distributed classes, average size 298, max size 999
21.5317108631
agf (1)
5000 lists, 50 equally distributed classes, average size 298, max size 999
8.61671280861
agf (2)
5000 lists, 50 equally distributed classes, average size 298, max size 999
5.18117713928
=> alexis
=> 5000 lists, 50 equally distributed classes, average size 298, max size 999
=> 3.73504281044

===================
# less number of classes, large lists
class_count = 15
class_size = 1000
list_count_per_class = 300
large_list_sizes = list(range(100, 1000))
small_list_sizes = list(range(0, 100))
large_list_probability = 0.5
===================

niklas
4500 lists, 15 equally distributed classes, average size 296, max size 999
1.79993700981
rik
4500 lists, 15 equally distributed classes, average size 296, max size 999
2.58237695694
katrielalex
4500 lists, 15 equally distributed classes, average size 296, max size 999
19.5465381145
agf (1)
4500 lists, 15 equally distributed classes, average size 296, max size 999
2.75445604324
=> agf (2)
=> 4500 lists, 15 equally distributed classes, average size 296, max size 999
=> 1.77850699425
alexis
4500 lists, 15 equally distributed classes, average size 296, max size 999
3.23530197144

===================
# less number of classes, smaller lists
class_count = 15
class_size = 1000
list_count_per_class = 300
large_list_sizes = list(range(100, 1000))
small_list_sizes = list(range(0, 100))
large_list_probability = 0.1
===================

niklas
4500 lists, 15 equally distributed classes, average size 95, max size 997
0.773697137833
rik
4500 lists, 15 equally distributed classes, average size 95, max size 997
1.0523750782
katrielalex
4500 lists, 15 equally distributed classes, average size 95, max size 997
6.04466891289
agf (1)
4500 lists, 15 equally distributed classes, average size 95, max size 997
1.20285701752
=> agf (2)
=> 4500 lists, 15 equally distributed classes, average size 95, max size 997
=> 0.714507102966
alexis
4500 lists, 15 equally distributed classes, average size 95, max size 997
1.1286110878

撰写回答