概率解析器的内存使用情况

4 投票
3 回答
547 浏览
提问于 2025-04-16 14:07

我正在为一个范围连接语法写一个CKY解析器。我想用一个树库作为语法,所以这个语法会很大。我用Python写了一个原型,发现当我模拟几十个句子的树库时,它运行得还不错,但内存使用量太高了。我尝试用C++来写,但到目前为止这让我很沮丧,因为我之前从未使用过C++。这里有一些数据(n是语法所基于的句子数量):

n    mem
9    173M
18   486M
36   836M

这种增长模式是根据最佳优先算法预期的,但我担心的是开销的大小。根据heapy的报告,内存使用量比这些数字小十倍,valgrind也报告了类似的情况。这种差异是怎么造成的?在Python(或Cython)中有什么办法可以解决吗?可能是因为内存碎片化?或者是Python字典的开销?

一些背景信息:两个重要的数据结构是议程(将边映射到概率)和图表(一个字典,将非终结符和位置映射到边)。议程是用heapdict实现的(它内部使用一个字典和一个堆列表),图表是用一个字典将非终结符和位置映射到边。议程经常进行插入和删除,而图表只进行插入和查找。我用元组来表示边,像这样:

(("S", 111), ("NP", 010), ("VP", 100, 001))

这些字符串是来自语法的非终结符标签,位置用位掩码编码。当一个成分是不连续的时候,可以有多个位置。所以这个边可以表示“玛丽快乐吗”的分析,其中“是”和“快乐”都属于VP。图表字典是通过这个边的第一个元素来索引的,在这个例子中是(“S”,111)。在一个新版本中,我尝试转置这个表示,希望能通过重用来节省内存:

(("S", "NP", "VP), (111, 100, 011))

我想如果不同位置组合中出现的第一部分只存储一次,那Python应该会这样做,尽管我并不确定这是否真的成立。无论如何,这似乎没有什么区别。

所以我基本上想知道,继续追求我的Python实现是否值得,包括用Cython和不同的数据结构,还是从头开始用C++编写是唯一可行的选择。

更新:经过一些改进,我不再有内存使用的问题。我正在开发一个优化的Cython版本。我会把赏金给对提高代码效率最有帮助的建议。这里有一个带注释的版本:http://student.science.uva.nl/~acranenb/plcfrs_cython.html

1 https://github.com/andreasvc/disco-dop/ -- 运行test.py来解析一些句子。需要Python 2.6,nltkheapdict

3 个回答

1

在这种情况下,首先要做的就是进行性能分析:

15147/297    0.032    0.000    0.041    0.000 tree.py:102(__eq__)
15400/200    0.031    0.000    0.106    0.001 tree.py:399(convert)
        1    0.023    0.023    0.129    0.129 plcfrs_cython.pyx:52(parse)
6701/1143    0.022    0.000    0.043    0.000 heapdict.py:45(_min_heapify)
    18212    0.017    0.000    0.023    0.000 plcfrs_cython.pyx:38(__richcmp__)
10975/10875    0.017    0.000    0.035    0.000 tree.py:75(__init__)
     5772    0.016    0.000    0.050    0.000 tree.py:665(__init__)
      960    0.016    0.000    0.025    0.000 plcfrs_cython.pyx:118(deduced_from)
    46938    0.014    0.000    0.014    0.000 tree.py:708(_get_node)
25220/2190    0.014    0.000    0.016    0.000 tree.py:231(subtrees)
    10975    0.013    0.000    0.023    0.000 tree.py:60(__new__)
    49441    0.013    0.000    0.013    0.000 {isinstance}
    16748    0.008    0.000    0.015    0.000 {hasattr}

我注意到的第一件事是,来自cython模块的函数非常少。大多数函数来自tree.py模块,可能这就是性能瓶颈所在。

专注于cython部分,我发现了richcmp函数:

我们可以通过在方法声明中添加值的类型来简单优化它。

def __richcmp__(ChartItem self, ChartItem other, int op):
        ....

这样可以降低值的计算时间。

ncalls  tottime  percall  cumtime  percall filename:lineno(function)
....
18212    0.011    0.000    0.015    0.000 plcfrs_cython.pyx:38(__richcmp__)

用elif语法替代单个if语句,可以启用cython的switch优化

    if op == 0: return self.label < other.label or self.vec < other.vec
    elif op == 1: return self.label <= other.label or self.vec <= other.vec
    elif op == 2: return self.label == other.label and self.vec == other.vec
    elif op == 3: return self.label != other.label or self.vec != other.vec
    elif op == 4: return self.label > other.label or self.vec > other.vec
    elif op == 5: return self.label >= other.label or self.vec >= other.vec

这样可以得到:

17963    0.002    0.000    0.002    0.000 plcfrs_cython.pyx:38(__richcmp__)

在试图找出tree.py:399的convert函数来源时,我发现dopg.py中的这个函数耗时很长。

  def removeids(tree):
""" remove unique IDs introduced by the Goodman reduction """
result = Tree.convert(tree)
for a in result.subtrees(lambda t: '@' in t.node):
    a.node = a.node.rsplit('@', 1)[0]
if isinstance(tree, ImmutableTree): return result.freeze()
return result

现在我不确定树中的每个节点是否都是ChartItem,以及getitem的值是否在其他地方被使用,但添加这些更改:

cdef class ChartItem:
cdef public str label
cdef public str root
cdef public long vec
cdef int _hash
__slots__ = ("label", "vec", "_hash")
def __init__(ChartItem self, label, int vec):
    self.label = intern(label) #.rsplit('@', 1)[0])
    self.root = intern(label.rsplit('@', 1)[0])
    self.vec = vec
    self._hash = hash((self.label, self.vec))
def __hash__(self):
    return self._hash
def __richcmp__(ChartItem self, ChartItem other, int op):
    if op == 0: return self.label < other.label or self.vec < other.vec
    elif op == 1: return self.label <= other.label or self.vec <= other.vec
    elif op == 2: return self.label == other.label and self.vec == other.vec
    elif op == 3: return self.label != other.label or self.vec != other.vec
    elif op == 4: return self.label > other.label or self.vec > other.vec
    elif op == 5: return self.label >= other.label or self.vec >= other.vec
def __getitem__(ChartItem self, int n):
    if n == 0: return self.root
    elif n == 1: return self.vec
def __repr__(self):
    #would need bitlen for proper padding
    return "%s[%s]" % (self.label, bin(self.vec)[2:][::-1]) 

并在mostprobableparse内部:

from libc cimport pow
def mostprobableparse...
            ...
    cdef dict parsetrees = <dict>defaultdict(float)
    cdef float prob
    m = 0
    for n,(a,prob) in enumerate(derivations):
        parsetrees[a] += pow(e,prob)
        m += 1

我得到了:

         189345 function calls (173785 primitive calls) in 0.162 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
6701/1143    0.025    0.000    0.037    0.000 heapdict.py:45(_min_heapify)
        1    0.023    0.023    0.120    0.120 plcfrs_cython.pyx:54(parse)
      960    0.018    0.000    0.030    0.000 plcfrs_cython.pyx:122(deduced_from)
 5190/198    0.011    0.000    0.015    0.000 tree.py:102(__eq__)
     6619    0.006    0.000    0.006    0.000 heapdict.py:67(_swap)
     9678    0.006    0.000    0.008    0.000 plcfrs_cython.pyx:137(concat)

接下来的步骤是优化heapify和deduced_from。

deduce_from可以进一步优化:

cdef inline deduced_from(ChartItem Ih, double x, pyCx, pyunary, pylbinary, pyrbinary, int bitlen):
cdef str I = Ih.label
cdef int Ir = Ih.vec
cdef list result = []
cdef dict Cx = <dict>pyCx
cdef dict unary = <dict>pyunary
cdef dict lbinary = <dict>pylbinary
cdef dict rbinary = <dict>pyrbinary
cdef ChartItem Ilh
cdef double z
cdef double y
cdef ChartItem I1h 
for rule, z in unary[I]:
    result.append((ChartItem(rule[0][0], Ir), ((x+z,z), (Ih,))))
for rule, z in lbinary[I]:
    for I1h, y in Cx[rule[0][2]].items():
        if concat(rule[1], Ir, I1h.vec, bitlen):
            result.append((ChartItem(rule[0][0], Ir ^ I1h.vec), ((x+y+z, z), (Ih, I1h))))
for rule, z in rbinary[I]:
    for I1h, y in Cx[rule[0][1]].items():
        if concat(rule[1], I1h.vec, Ir, bitlen):
            result.append((ChartItem(rule[0][0], I1h.vec ^ Ir), ((x+y+z, z), (I1h, Ih))))
return result

我在这里先停一下,虽然我相信随着对问题的深入了解,我们可以继续优化。

一系列单元测试会很有用,以确保每次优化不会引入任何细微的错误。

另外,建议使用空格而不是制表符。

2

你有没有试过用PyPy来运行你的应用,而不是用CPython呢?

PyPy在识别常见情况和避免不必要的内存占用方面,比CPython聪明得多。

反正试试看也是值得的:http://pypy.org/

2

我以为如果Python在不同的位置组合中只出现一次,那么它只会存储第一次出现的部分。

其实不一定:

>>> ("S", "NP", "VP") is ("S", "NP", "VP")
False

你可能想要把所有指向非终结符的字符串都用intern处理一下,因为你似乎在rcgrules.py中创建了很多这样的字符串。如果你想要对一个元组进行intern,那么首先要把它转换成字符串:

>>> intern("S NP VP") is intern(' '.join('S', 'NP', 'VP'))
True

否则,你就得“复制”这些元组,而不是重新构建它们。

(如果你对C++不太熟悉,那么在C++中重写这样的算法可能不会带来太多内存上的好处。你需要先评估各种哈希表的实现方式,并了解它在容器中的复制行为。我发现boost::unordered_map在处理很多小哈希表时是相当浪费的。)

撰写回答