计算多个向量的Tanimoto系数的SQL

4 投票
4 回答
1610 浏览
提问于 2025-04-15 17:33

我觉得用一个例子来解释我的问题会更简单。

我有一个表格,里面列出了食谱的各种原料。我已经写了一个函数,用来计算原料之间的Tanimoto系数。这个计算两个原料之间的系数的速度还不错(需要3个SQL查询),但是当要计算所有可能的原料组合时,就不太好用了。要计算所有原料组合的系数,需要进行N + (N*(N-1))/2个查询,对于1000种原料来说,这就需要500500个查询。有没有更快的方法呢?这是我目前的进展:

class Filtering():
  def __init__(self):
    self._connection=sqlite.connect('database.db')

  def n_recipes(self, ingredient_id):
    cursor = self._connection.cursor()
    cursor.execute('''select count(recipe_id) from recipe_ingredient
        where ingredient_id = ? ''', (ingredient_id, ))
    return cursor.fetchone()[0]

  def n_recipes_intersection(self, ingredient_a, ingredient_b):
    cursor = self._connection.cursor()
    cursor.execute('''select count(drink_id) from recipe_ingredient where
        ingredient_id = ? and recipe_id in (
        select recipe_id from recipe_ingredient
        where ingredient_id = ?) ''', (ingredient_a, ingredient_b))
    return cursor.fetchone()[0]

  def tanimoto(self, ingredient_a, ingredient_b):
    n_a, n_b = map(self.n_recipes, (ingredient_a, ingredient_b))
    n_ab = self.n_recipes_intersection(ingredient_a, ingredient_b)
    return float(n_ab) / (n_a + n_b - n_ab)

4 个回答

1

如果你有1000种食材,那么用1000个查询就可以把每种食材和一组食谱在内存中对应起来。比如说,一种食材通常会出现在大约100个食谱中,那么每组食谱大概只需要几KB的空间,所以整个字典也就只需要几MB的内存——这完全没问题,整个内容都能放在内存里。如果每种食材对应的食谱数量增加十倍,内存也不会出现严重问题。

result = dict()
for ing_id in all_ingredient_ids:
    cursor.execute('''select recipe_id from recipe_ingredient
        where ingredient_id = ?''', (ing_id,))
    result[ing_id] = set(r[0] for r in cursor.fetchall())
return result

在完成这1000个查询后,接下来需要的50万个成对的Tanimoto系数计算显然都是在内存中完成的。你可以提前计算出各种集合长度的平方,这样可以加快速度(并把结果放在另一个字典里),而每对的关键部分“A点乘B”当然就是这两个集合交集的长度。

3

如果有人感兴趣的话,这是我根据Alex和S.Lotts的建议写出来的代码。谢谢你们。

def __init__(self):
    self._connection=sqlite.connect('database.db')
    self._counts = None
    self._intersections = {}

def inc_intersections(self, ingredients):
    ingredients.sort()
    lenght = len(ingredients)
    for i in xrange(1, lenght):
        a = ingredients[i]
        for j in xrange(0, i):
            b = ingredients[j]
            if a not in self._intersections:
                self._intersections[a] = {b: 1}
            elif b not in self._intersections[a]:
                self._intersections[a][b] = 1
            else:
                self._intersections[a][b] += 1


def precompute_tanimoto(self):
    counts = {}
    self._intersections = {}

    cursor = self._connection.cursor()
    cursor.execute('''select recipe_id, ingredient_id
        from recipe_ingredient
        order by recipe_id, ingredient_id''')
    rows = cursor.fetchall()            

    print len(rows)

    last_recipe = None
    for recipe, ingredient in rows:
        if recipe != last_recipe:
            if last_recipe != None:
                self.inc_intersections(ingredients)
            last_recipe = recipe
            ingredients = [ingredient]
        else:
            ingredients.append(ingredient)

        if ingredient not in counts:
            counts[ingredient] = 1
        else:
            counts[ingredient] += 1

    self.inc_intersections(ingredients)

    self._counts = counts

def tanimoto(self, ingredient_a, ingredient_b):
    if self._counts == None:
        self.precompute_tanimoto()

    if ingredient_b > ingredient_a:
        ingredient_b, ingredient_a = ingredient_a, ingredient_b

    n_a, n_b = self._counts[ingredient_a], self._counts[ingredient_b]
    n_ab = self._intersections[ingredient_a][ingredient_b]

    print n_a, n_b, n_ab

    return float(n_ab) / (n_a + n_b - n_ab)
4

你为什么不直接把所有的食谱都加载到内存里,然后在内存中计算Tanimoto系数呢?

这样做更简单,而且速度快得多。

撰写回答