概率解析器的内存使用情况
我正在为一个范围连接语法写一个CKY解析器。我想用一个树库作为语法,所以这个语法会很大。我用Python写了一个原型,发现当我模拟几十个句子的树库时,它运行得还不错,但内存使用量太高了。我尝试用C++来写,但到目前为止这让我很沮丧,因为我之前从未使用过C++。这里有一些数据(n是语法所基于的句子数量):
n mem
9 173M
18 486M
36 836M
这种增长模式是根据最佳优先算法预期的,但我担心的是开销的大小。根据heapy的报告,内存使用量比这些数字小十倍,valgrind也报告了类似的情况。这种差异是怎么造成的?在Python(或Cython)中有什么办法可以解决吗?可能是因为内存碎片化?或者是Python字典的开销?
一些背景信息:两个重要的数据结构是议程(将边映射到概率)和图表(一个字典,将非终结符和位置映射到边)。议程是用heapdict实现的(它内部使用一个字典和一个堆列表),图表是用一个字典将非终结符和位置映射到边。议程经常进行插入和删除,而图表只进行插入和查找。我用元组来表示边,像这样:
(("S", 111), ("NP", 010), ("VP", 100, 001))
这些字符串是来自语法的非终结符标签,位置用位掩码编码。当一个成分是不连续的时候,可以有多个位置。所以这个边可以表示“玛丽快乐吗”的分析,其中“是”和“快乐”都属于VP。图表字典是通过这个边的第一个元素来索引的,在这个例子中是(“S”,111)。在一个新版本中,我尝试转置这个表示,希望能通过重用来节省内存:
(("S", "NP", "VP), (111, 100, 011))
我想如果不同位置组合中出现的第一部分只存储一次,那Python应该会这样做,尽管我并不确定这是否真的成立。无论如何,这似乎没有什么区别。
所以我基本上想知道,继续追求我的Python实现是否值得,包括用Cython和不同的数据结构,还是从头开始用C++编写是唯一可行的选择。
更新:经过一些改进,我不再有内存使用的问题。我正在开发一个优化的Cython版本。我会把赏金给对提高代码效率最有帮助的建议。这里有一个带注释的版本:http://student.science.uva.nl/~acranenb/plcfrs_cython.html
1 https://github.com/andreasvc/disco-dop/ -- 运行test.py来解析一些句子。需要Python 2.6,nltk和heapdict
3 个回答
在这种情况下,首先要做的就是进行性能分析:
15147/297 0.032 0.000 0.041 0.000 tree.py:102(__eq__)
15400/200 0.031 0.000 0.106 0.001 tree.py:399(convert)
1 0.023 0.023 0.129 0.129 plcfrs_cython.pyx:52(parse)
6701/1143 0.022 0.000 0.043 0.000 heapdict.py:45(_min_heapify)
18212 0.017 0.000 0.023 0.000 plcfrs_cython.pyx:38(__richcmp__)
10975/10875 0.017 0.000 0.035 0.000 tree.py:75(__init__)
5772 0.016 0.000 0.050 0.000 tree.py:665(__init__)
960 0.016 0.000 0.025 0.000 plcfrs_cython.pyx:118(deduced_from)
46938 0.014 0.000 0.014 0.000 tree.py:708(_get_node)
25220/2190 0.014 0.000 0.016 0.000 tree.py:231(subtrees)
10975 0.013 0.000 0.023 0.000 tree.py:60(__new__)
49441 0.013 0.000 0.013 0.000 {isinstance}
16748 0.008 0.000 0.015 0.000 {hasattr}
我注意到的第一件事是,来自cython模块的函数非常少。大多数函数来自tree.py模块,可能这就是性能瓶颈所在。
专注于cython部分,我发现了richcmp函数:
我们可以通过在方法声明中添加值的类型来简单优化它。
def __richcmp__(ChartItem self, ChartItem other, int op):
....
这样可以降低值的计算时间。
ncalls tottime percall cumtime percall filename:lineno(function)
....
18212 0.011 0.000 0.015 0.000 plcfrs_cython.pyx:38(__richcmp__)
用elif语法替代单个if语句,可以启用cython的switch优化。
if op == 0: return self.label < other.label or self.vec < other.vec
elif op == 1: return self.label <= other.label or self.vec <= other.vec
elif op == 2: return self.label == other.label and self.vec == other.vec
elif op == 3: return self.label != other.label or self.vec != other.vec
elif op == 4: return self.label > other.label or self.vec > other.vec
elif op == 5: return self.label >= other.label or self.vec >= other.vec
这样可以得到:
17963 0.002 0.000 0.002 0.000 plcfrs_cython.pyx:38(__richcmp__)
在试图找出tree.py:399的convert函数来源时,我发现dopg.py中的这个函数耗时很长。
def removeids(tree):
""" remove unique IDs introduced by the Goodman reduction """
result = Tree.convert(tree)
for a in result.subtrees(lambda t: '@' in t.node):
a.node = a.node.rsplit('@', 1)[0]
if isinstance(tree, ImmutableTree): return result.freeze()
return result
现在我不确定树中的每个节点是否都是ChartItem,以及getitem的值是否在其他地方被使用,但添加这些更改:
cdef class ChartItem:
cdef public str label
cdef public str root
cdef public long vec
cdef int _hash
__slots__ = ("label", "vec", "_hash")
def __init__(ChartItem self, label, int vec):
self.label = intern(label) #.rsplit('@', 1)[0])
self.root = intern(label.rsplit('@', 1)[0])
self.vec = vec
self._hash = hash((self.label, self.vec))
def __hash__(self):
return self._hash
def __richcmp__(ChartItem self, ChartItem other, int op):
if op == 0: return self.label < other.label or self.vec < other.vec
elif op == 1: return self.label <= other.label or self.vec <= other.vec
elif op == 2: return self.label == other.label and self.vec == other.vec
elif op == 3: return self.label != other.label or self.vec != other.vec
elif op == 4: return self.label > other.label or self.vec > other.vec
elif op == 5: return self.label >= other.label or self.vec >= other.vec
def __getitem__(ChartItem self, int n):
if n == 0: return self.root
elif n == 1: return self.vec
def __repr__(self):
#would need bitlen for proper padding
return "%s[%s]" % (self.label, bin(self.vec)[2:][::-1])
并在mostprobableparse内部:
from libc cimport pow
def mostprobableparse...
...
cdef dict parsetrees = <dict>defaultdict(float)
cdef float prob
m = 0
for n,(a,prob) in enumerate(derivations):
parsetrees[a] += pow(e,prob)
m += 1
我得到了:
189345 function calls (173785 primitive calls) in 0.162 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
6701/1143 0.025 0.000 0.037 0.000 heapdict.py:45(_min_heapify)
1 0.023 0.023 0.120 0.120 plcfrs_cython.pyx:54(parse)
960 0.018 0.000 0.030 0.000 plcfrs_cython.pyx:122(deduced_from)
5190/198 0.011 0.000 0.015 0.000 tree.py:102(__eq__)
6619 0.006 0.000 0.006 0.000 heapdict.py:67(_swap)
9678 0.006 0.000 0.008 0.000 plcfrs_cython.pyx:137(concat)
接下来的步骤是优化heapify和deduced_from。
deduce_from可以进一步优化:
cdef inline deduced_from(ChartItem Ih, double x, pyCx, pyunary, pylbinary, pyrbinary, int bitlen):
cdef str I = Ih.label
cdef int Ir = Ih.vec
cdef list result = []
cdef dict Cx = <dict>pyCx
cdef dict unary = <dict>pyunary
cdef dict lbinary = <dict>pylbinary
cdef dict rbinary = <dict>pyrbinary
cdef ChartItem Ilh
cdef double z
cdef double y
cdef ChartItem I1h
for rule, z in unary[I]:
result.append((ChartItem(rule[0][0], Ir), ((x+z,z), (Ih,))))
for rule, z in lbinary[I]:
for I1h, y in Cx[rule[0][2]].items():
if concat(rule[1], Ir, I1h.vec, bitlen):
result.append((ChartItem(rule[0][0], Ir ^ I1h.vec), ((x+y+z, z), (Ih, I1h))))
for rule, z in rbinary[I]:
for I1h, y in Cx[rule[0][1]].items():
if concat(rule[1], I1h.vec, Ir, bitlen):
result.append((ChartItem(rule[0][0], I1h.vec ^ Ir), ((x+y+z, z), (I1h, Ih))))
return result
我在这里先停一下,虽然我相信随着对问题的深入了解,我们可以继续优化。
一系列单元测试会很有用,以确保每次优化不会引入任何细微的错误。
另外,建议使用空格而不是制表符。
我以为如果Python在不同的位置组合中只出现一次,那么它只会存储第一次出现的部分。
其实不一定:
>>> ("S", "NP", "VP") is ("S", "NP", "VP")
False
你可能想要把所有指向非终结符的字符串都用intern
处理一下,因为你似乎在rcgrules.py
中创建了很多这样的字符串。如果你想要对一个元组进行intern
,那么首先要把它转换成字符串:
>>> intern("S NP VP") is intern(' '.join('S', 'NP', 'VP'))
True
否则,你就得“复制”这些元组,而不是重新构建它们。
(如果你对C++不太熟悉,那么在C++中重写这样的算法可能不会带来太多内存上的好处。你需要先评估各种哈希表的实现方式,并了解它在容器中的复制行为。我发现boost::unordered_map
在处理很多小哈希表时是相当浪费的。)