Sqlalchemy. 按多个参数深度过滤
我有一个数据库:
class Category(Base):
__tablename__ = 'category'
id = Column(Integer, primary_key=True, nullable=False)
title = Column(Text, nullable=False)
class Filter(Base):
__tablename__ = 'filter'
id = Column(Integer, primary_key=True, nullable=False)
category_id = Column(ForeignKey("category.id"), nullable=False)
title = Column(Text, nullable=False)
transcription = Column(Text, nullable=False)
category = relationship('Category', backref='filters')
class FilterParameter(Base):
__tablename__ = 'fparam'
id = Column(Integer, primary_key=True, nullable=False)
value = Column(Text, nullable=False)
feature_id = Column(ForeignKey("filter.id"), nullable=False)
transcription = Column(Text, nullable=False)
filter = relationship('Filter', backref="fparams")
class Product(Base):
__tablename__ = 'product'
id = Column(Integer, primary_key=True, nullable=False)
title = Column(Text, nullable=False)
pic = Column(Text)
price = Column(Float, nullable=False)
availability = Column(Boolean, nullable=False)
keywords = Column(Text, nullable=False)
class ProductParameter(Base):
__tablename__ = 'pparam'
id = Column(Integer, primary_key=True, nullable=False)
product_id = Column(ForeignKey("product.id"), nullable=False)
param_id = Column(ForeignKey("fparam.id"), nullable=False)
product = relationship('Product', backref="pparams")
我的目标是根据多个条件来筛选产品列表(用“与”来连接条件)。每个产品都有一组不同的条件(一个产品对应多个条件),这些条件存放在一个叫做FilterParameter的表里。
我觉得我已经快达到目标了,因为根据其中一种条件来筛选是可以成功的:
res = cursor.session.query(Product).join(ProductParameter, Product.id == ProductParameter.product_id)\
.join(FilterParameter, FilterParameter.id == ProductParameter.param_id)\
.join(Filter, Filter.id == FilterParameter.feature_id)
res = res.filter(FilterParameter.transcription.in_("acer,asus".split(',')),
Filter.transcription == "brand").all()
这个是可以工作的。
但是如果需要根据多种条件来筛选的话:
res = res.filter(and_(FilterParameter.transcription.in_("acer,asus".split(',')),
Filter.transcription == "brand"))\
.filter(and_(FilterParameter.transcription.in_("1920x1080".split(',')),
Filter.transcription == "pixel")).all()
这个就不行了。
2 个回答
0
这个问题比看起来要复杂一些。我觉得你可以使用连接查询、子查询或者数组列(如果你在用PostgreSQL的话)。
一些准备代码
设置模型并填充数据库
import os
from sqlalchemy import (
create_engine,
Column,
Integer,
ForeignKey,
Text,
)
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.sql import (
and_,
)
from sqlalchemy.orm import (
declarative_base,
Session,
relationship,
aliased,
)
def get_engine(env):
return create_engine(f"postgresql+psycopg2://{env['DB_USER']}:{env['DB_PASSWORD']}@{env['DB_HOST']}:{env['DB_PORT']}/{env['DB_NAME']}", echo=True)
Base = declarative_base()
class Filter(Base):
__tablename__ = 'filter'
id = Column(Integer, primary_key=True, nullable=False)
transcription = Column(Text, nullable=False)
class FilterParameter(Base):
__tablename__ = 'fparam'
id = Column(Integer, primary_key=True, nullable=False)
feature_id = Column(ForeignKey("filter.id"), nullable=False)
transcription = Column(Text, nullable=False)
filter = relationship('Filter', backref="fparams")
class Product(Base):
__tablename__ = 'product'
id = Column(Integer, primary_key=True, nullable=False)
title = Column(Text, nullable=False)
tags = Column(ARRAY(Text))
class ProductParameter(Base):
__tablename__ = 'pparam'
id = Column(Integer, primary_key=True, nullable=False)
product_id = Column(ForeignKey("product.id"), nullable=False)
param_id = Column(ForeignKey("fparam.id"), nullable=False)
product = relationship('Product', backref="pparams")
def populate(engine):
Base.metadata.create_all(engine)
with Session(engine) as session:
brand = Filter(transcription='brand')
session.add(brand)
brand_params = [FilterParameter(filter=brand, transcription=name) for name in ['acer', 'asus']]
session.add_all(brand_params)
pixel = Filter(transcription='pixel')
session.add(pixel)
pixel_params = [FilterParameter(filter=pixel, transcription=name) for name in ['1920x1080']]
session.add_all(pixel_params)
session.flush()
p1 = Product(title="TV", tags=['brand--acer', 'pixel--1920x1080'])
session.add(p1)
session.add_all([ProductParameter(product=p1, param_id=brand_params[0].id), ProductParameter(product=p1, param_id=pixel_params[0].id)])
session.commit()
连接查询示例
这个选项会对每个过滤器在ProductParameter、FilterParameter和Filter之间进行连接。我认为你可以通过预先获取过滤器,比如"brand"
,来优化这个过程,从而避免连接到Filter
。
def filter_products_with_joins(session, filters):
res = session.query(Product)
for filter_name, filter_params in filters:
pp = aliased(ProductParameter)
fp = aliased(FilterParameter)
f = aliased(Filter)
res = res.join(pp, Product.id == pp.product_id)
res = res.join(fp, and_(fp.id == pp.param_id, fp.transcription.in_(filter_params.split(','))))
res = res.join(f, and_(f.id == fp.feature_id, f.transcription == filter_name))
return res.all()
子查询示例
这个方法和连接查询类似,但可能更快,而且更灵活/更容易阅读。我们不是一个个连接,而是把每个过滤器的子查询用AND语句组合在一起。也就是说,我想要那些产品,它们的ID在这些品牌的产品ID列表中,并且在这些像素的产品ID列表中,等等。
def filter_products_with_subqueries(session, filters):
subs = []
for filter_name, filter_params in filters:
res = session.query(Product.id)
pp = ProductParameter
fp = FilterParameter
f = Filter
res = res.join(pp, Product.id == pp.product_id)
res = res.join(fp, and_(fp.id == pp.param_id, fp.transcription.in_(filter_params.split(','))))
res = res.join(f, and_(f.id == fp.feature_id, f.transcription == filter_name))
subs.append(res.subquery())
return session.query(Product).filter(*[Product.id.in_(sub) for sub in subs]).all()
使用数组列的示例(我想只有Postgres支持)
这个查询起来简单多了,但维护起来比较麻烦。在这种情况下,我们把过滤器和过滤参数压缩成标签(过滤器和过滤参数的组合)。我们用标签设置产品,然后用标签列表来查询这些标签。也就是说,我想要那些标签包含这些品牌标签,并且标签还包含这些像素标签的产品,等等。
def filter_products_with_tags(session, filters):
tag_groups = []
for filter_name, filter_params in filters:
tag_groups.append([f'{filter_name}--{filter_param}' for filter_param in filter_params.split(',')])
return session.query(Product).filter(*[Product.tags.overlap(tag_group) for tag_group in tag_groups]).all()
其余的代码
测试和主要调用。
def query(engine):
with Session(engine) as session:
for func in (filter_products_with_joins, filter_products_with_subqueries, filter_products_with_tags):
assert len(func(session, [('brand', 'acer,asus')])) == 1
assert len(func(session, [('brand', 'acer,asus'), ('pixel', '1920x1080')])) == 1
assert len(func(session, [('brand', 'acer,asus,sony'), ('pixel', '1920x1080')])) == 1
assert len(func(session, [('brand', 'acer,asus,sony'), ('pixel', '480x360')])) == 0
assert len(func(session, [('brand', 'acer,asus,sony'), ('pixel', '1920x1080,480x360')])) == 1
def main():
engine = get_engine(os.environ)
populate(engine)
query(engine)
if __name__ == '__main__':
main()
0
我这样解决了这个问题,在我的情况下效果很好:
def dot_filters(query, args: dict):
""" Get filter args as dict, where key is Filter.transcription, and val is FilterParameter.transcription """
for key, val in args.items(): # Go to all filter args
""" Create subquery where we back a list of Filter.id with Filter.transcription == key """
sub_query = cursor.session.query(Filter.id).filter(Filter.transcription == key).scalar_subquery()
""" Create subquery where we back a list of Product.id (from ProductParameter.product_id)
with FilterParameter.transcription in val_list and necessary Filter.id subquery"""
sub_query = cursor.session.query(ProductParameter.product_id).join(FilterParameter)\
.filter(FilterParameter.feature_id.in_(sub_query),
FilterParameter.transcription.in_(val.split(',')))\
.scalar_subquery()
""" Get list of Products from filtered list of ID`s """
query = query.filter(Product.id.in_(sub_query))
return query
最后是这样使用的:
query = dot_filters(query, args)