计算多个向量的Tanimoto系数的SQL
我觉得用一个例子来解释我的问题会更简单。
我有一个表格,里面列出了食谱的各种原料。我已经写了一个函数,用来计算原料之间的Tanimoto系数。这个计算两个原料之间的系数的速度还不错(需要3个SQL查询),但是当要计算所有可能的原料组合时,就不太好用了。要计算所有原料组合的系数,需要进行N + (N*(N-1))/2个查询,对于1000种原料来说,这就需要500500个查询。有没有更快的方法呢?这是我目前的进展:
class Filtering():
def __init__(self):
self._connection=sqlite.connect('database.db')
def n_recipes(self, ingredient_id):
cursor = self._connection.cursor()
cursor.execute('''select count(recipe_id) from recipe_ingredient
where ingredient_id = ? ''', (ingredient_id, ))
return cursor.fetchone()[0]
def n_recipes_intersection(self, ingredient_a, ingredient_b):
cursor = self._connection.cursor()
cursor.execute('''select count(drink_id) from recipe_ingredient where
ingredient_id = ? and recipe_id in (
select recipe_id from recipe_ingredient
where ingredient_id = ?) ''', (ingredient_a, ingredient_b))
return cursor.fetchone()[0]
def tanimoto(self, ingredient_a, ingredient_b):
n_a, n_b = map(self.n_recipes, (ingredient_a, ingredient_b))
n_ab = self.n_recipes_intersection(ingredient_a, ingredient_b)
return float(n_ab) / (n_a + n_b - n_ab)
4 个回答
1
如果你有1000种食材,那么用1000个查询就可以把每种食材和一组食谱在内存中对应起来。比如说,一种食材通常会出现在大约100个食谱中,那么每组食谱大概只需要几KB的空间,所以整个字典也就只需要几MB的内存——这完全没问题,整个内容都能放在内存里。如果每种食材对应的食谱数量增加十倍,内存也不会出现严重问题。
result = dict()
for ing_id in all_ingredient_ids:
cursor.execute('''select recipe_id from recipe_ingredient
where ingredient_id = ?''', (ing_id,))
result[ing_id] = set(r[0] for r in cursor.fetchall())
return result
在完成这1000个查询后,接下来需要的50万个成对的Tanimoto系数计算显然都是在内存中完成的。你可以提前计算出各种集合长度的平方,这样可以加快速度(并把结果放在另一个字典里),而每对的关键部分“A点乘B”当然就是这两个集合交集的长度。
3
如果有人感兴趣的话,这是我根据Alex和S.Lotts的建议写出来的代码。谢谢你们。
def __init__(self):
self._connection=sqlite.connect('database.db')
self._counts = None
self._intersections = {}
def inc_intersections(self, ingredients):
ingredients.sort()
lenght = len(ingredients)
for i in xrange(1, lenght):
a = ingredients[i]
for j in xrange(0, i):
b = ingredients[j]
if a not in self._intersections:
self._intersections[a] = {b: 1}
elif b not in self._intersections[a]:
self._intersections[a][b] = 1
else:
self._intersections[a][b] += 1
def precompute_tanimoto(self):
counts = {}
self._intersections = {}
cursor = self._connection.cursor()
cursor.execute('''select recipe_id, ingredient_id
from recipe_ingredient
order by recipe_id, ingredient_id''')
rows = cursor.fetchall()
print len(rows)
last_recipe = None
for recipe, ingredient in rows:
if recipe != last_recipe:
if last_recipe != None:
self.inc_intersections(ingredients)
last_recipe = recipe
ingredients = [ingredient]
else:
ingredients.append(ingredient)
if ingredient not in counts:
counts[ingredient] = 1
else:
counts[ingredient] += 1
self.inc_intersections(ingredients)
self._counts = counts
def tanimoto(self, ingredient_a, ingredient_b):
if self._counts == None:
self.precompute_tanimoto()
if ingredient_b > ingredient_a:
ingredient_b, ingredient_a = ingredient_a, ingredient_b
n_a, n_b = self._counts[ingredient_a], self._counts[ingredient_b]
n_ab = self._intersections[ingredient_a][ingredient_b]
print n_a, n_b, n_ab
return float(n_ab) / (n_a + n_b - n_ab)
4
你为什么不直接把所有的食谱都加载到内存里,然后在内存中计算Tanimoto系数呢?
这样做更简单,而且速度快得多。