Python、字典和卡方列联表
这是我思考了很久的问题,任何帮助都会很有用。我有一个文件,里面有几行数据,格式是这样的:一个单词、这个单词出现的时间,以及在这个时间段内包含这个单词的文档数量。下面是这个输入文件的一个例子。
#inputfile
<word, time, frequency>
apple, 1, 3
banana, 1, 2
apple, 2, 1
banana, 2, 4
orange, 3, 1
我有一个Python类,用来创建二维字典,存储上面提到的文件,使用单词作为键,频率作为值:
class Ddict(dict):
'''
2D dictionary class
'''
def __init__(self, default=None):
self.default = default
def __getitem__(self, key):
if not self.has_key(key):
self[key] = self.default()
return dict.__getitem__(self, key)
wordtime=Ddict(dict) # Store each inputfile entry with a <word,time> key
timeword=Ddict(dict) # Store each inputfile entry with a <time,word> key
# Loop over every line of the inputfile
for line in open('inputfile'):
word,time,count=line.split(',')
# If <word,time> already a key, increment count
try:
wordtime[word][time]+=count
# Otherwise, create the key
except KeyError:
wordtime[word][time]=count
# If <time,word> already a key, increment count
try:
timeword[time][word]+=count
# Otherwise, create the key
except KeyError:
timeword[time][word]=count
我想问的问题是,在遍历这个二维字典的条目时,如何计算一些特定的内容。对于每个单词'w'和每个时间't',需要计算:
- 在时间't'内,包含单词'w'的文档数量。(a)
- 在时间't'内,不包含单词'w'的文档数量。(b)
- 在时间't'外,包含单词'w'的文档数量。(c)
- 在时间't'外,不包含单词'w'的文档数量。(d)
以上每一项都代表了每个单词和时间的卡方列联表中的一个单元格。所有这些计算能在一个循环中完成吗,还是需要逐个计算?
理想情况下,我希望输出的结果如下,其中a、b、c、d都是上面计算的结果:
print "%s, %s, %s, %s" %(a,b,c,d)
对于上面的输入文件,尝试找到单词'apple'在时间'1'的列联表的结果是(3,2,1,6)
。我来解释一下每个单元格是如何计算的:
- '3'个文档在时间'1'内包含'apple'。
- 在时间'1'内,有'2'个文档不包含'apple'。
- 在时间'1'外,有'1'个文档包含'apple'。
- 在时间'1'外,有'6'个文档不包含'apple'(1+4+1)。
1 个回答
2
你的4个关于“苹果/1”的数字加起来是12,这个数字比总的观察次数(11)还要多!实际上,只有5个文档在时间'1'之外,并且不包含“苹果”这个词。
你需要把观察结果分成4个不重叠的部分:
a: 包含“苹果”和“1”的部分 => 3
b: 不包含“苹果”和“1”的部分 => 2
c: 包含“苹果”但不包含“1”的部分 => 1
d: 不包含“苹果”也不包含“1”的部分 => 5
这里有一段代码,展示了一种实现方法:
from collections import defaultdict
class Crosstab(object):
def __init__(self):
self.count = defaultdict(lambda: defaultdict(int))
self.row_tot = defaultdict(int)
self.col_tot = defaultdict(int)
self.grand_tot = 0
def add(self, r, c, n):
self.count[r][c] += n
self.row_tot[r] += n
self.col_tot[c] += n
self.grand_tot += n
def load_data(line_iterator, conv_funcs):
ct = Crosstab()
for line in line_iterator:
r, c, n = [func(s) for func, s in zip(conv_funcs, line.split(','))]
ct.add(r, c, n)
return ct
def display_all_2x2_tables(crosstab):
for rx in crosstab.row_tot:
for cx in crosstab.col_tot:
a = crosstab.count[rx][cx]
b = crosstab.col_tot[cx] - a
c = crosstab.row_tot[rx] - a
d = crosstab.grand_tot - a - b - c
assert all(x >= 0 for x in (a, b, c, d))
print ",".join(str(x) for x in (rx, cx, a, b, c, d))
if __name__ == "__main__":
# inputfile
# <word, time, frequency>
lines = """\
apple, 1, 3
banana, 1, 2
apple, 2, 1
banana, 2, 4
orange, 3, 1""".splitlines()
ct = load_data(lines, (str.strip, int, int))
display_all_2x2_tables(ct)
这是输出结果:
orange,1,0,5,1,5
orange,2,0,5,1,5
orange,3,1,0,0,10
apple,1,3,2,1,5
apple,2,1,4,3,3
apple,3,0,1,4,6
banana,1,2,3,4,2
banana,2,4,1,2,4
banana,3,0,1,6,4