SQLAlchemy、array_agg和匹配输入lis

2024-04-19 18:55:34 发布

您现在位置:Python中文网/ 问答频道 /正文

我正试图更充分地使用SQLAlchemy,而不是一有困难就回到纯SQL。在本例中,我在Postgres数据库(9.5)中有一个表,它通过将单个项atom_id与组标识符group_id相关联,将一组整数存储为一个组。在

给出一个atom_ids的列表,我想知道这组atom_ids属于哪一个{},如果有的话。只需使用group_idatom_id列就可以解决这个问题。在

现在我试图将“group”概括为不仅由atom_ids列表组成,而且还由其他上下文组成。在下面的示例中,列表是通过包含sequence列来排序的,但是在概念上也可以使用其他列,例如weight列,它给每个atom_id一个表示原子在组中的“份额”的[0,1]浮点值。在

下面是一个单元测试的大部分,它演示了我的问题。在

首先,一些设置:

def test_multi_column_grouping(self):
    class MultiColumnGroups(base.Base):
        __tablename__ = 'multi_groups'

        group_id = Column(Integer)
        atom_id = Column(Integer)
        sequence = Column(Integer)  # arbitrary 'other' column.  In this case, an integer, but it could be a float (e.g. weighting factor)

    base.Base.metadata.create_all(self.engine)

    # Insert 6 rows representing 2 different 'groups' of values
    vals = [
        # Group 1
        {'group_id': 1, 'atom_id': 1, 'sequence': 1},
        {'group_id': 1, 'atom_id': 2, 'sequence': 2},
        {'group_id': 1, 'atom_id': 3, 'sequence': 3},
        # Group 2
        {'group_id': 2, 'atom_id': 1, 'sequence': 3},
        {'group_id': 2, 'atom_id': 2, 'sequence': 2},
        {'group_id': 2, 'atom_id': 3, 'sequence': 1},
    ]

    self.session.bulk_save_objects(
        [MultiColumnGroups(**x) for x in vals])
    self.session.flush()

    self.assertEqual(6, len(self.session.query(MultiColumnGroups).all()))

现在,我想查询上表以查找特定输入集属于哪个组。我使用一组(命名的)元组来表示查询参数。在

^{pr2}$

原始SQL解决方案。我宁愿不要依赖这个,因为这个练习的一部分是学习更多的sql炼金术。在

    r = self.session.execute('''
        select group_id
        from multi_groups
        group by group_id
        having array_agg((atom_id, sequence)) = :query_tuples
        ''', {'query_tuples': values_to_match}).fetchone()
    print(r)  # > (2,)
    self.assertEqual(2, r[0])

下面是上面的原始SQL解决方案相当直接地转换为 断开的SQLAlchemy查询。运行此操作将产生一个psycopg2错误:(psycopg2.ProgrammingError) operator does not exist: record[] = integer[]。我认为我需要将array_agg转换成int[]?只要分组列都是整数(如果需要的话,这是可以接受的限制),但理想情况下,这可以用于混合类型的输入元组/表列。在

    from sqlalchemy import tuple_
    from sqlalchemy.dialects.postgresql import array_agg

    existing_group = self.session.query(MultiColumnGroups).\
        with_entities(MultiColumnGroups.group_id).\
        group_by(MultiColumnGroups.group_id).\
        having(array_agg(tuple_(MultiColumnGroups.atom_id, MultiColumnGroups.sequence)) == values_to_match).\
        one_or_none()

    self.assertIsNotNone(existing_group)
    print('|{}|'.format(existing_group))

上面的session.query()关闭了吗?我是不是在这里瞎了眼,错过了一些可以用其他方式解决这个问题的非常明显的东西?在


Tags: selfidids列表sqlsessiongroupquery
2条回答

我发现你的回答也很有帮助。由于我没有足够的声誉来评论你的解决方案,我将发布我根据你的帮助所做的更改。在

我发现了双负sql来生成一些不太理想的sql,所以我从sql中反向寻找一些更干净的东西。在

这里有一些简单的数据。该示例已稍作修改,以使用文本字段角色而不是序列字段。这也可以推广到其他类型:

drop table if exists multi_groups;
create table multi_groups (group_id, atom_id, role) as
values
  (1, 1, 'referrer'),
  (1, 2, 'rendering'),
  (1, 3, 'attending'),
  (2, 1, 'attending'),
  (2, 2, 'rendering'),
  (2, 3, 'referrer');

原始解决方案生成的sql类似于:

^{pr2}$

我使用了这个方法,并对sql进行了一些处理,以获得:

with vtm as (
  select
    unnest(array[1, 2, 3]) as atom_id,
    unnest(array['attending', 'rendering', 'referrer']) as role
),
matched as (
  select
    dim_staging.multi_groups.group_id as group_id,
    vtm.atom_id as atom_id,
    3 as cnt
  from dim_staging.multi_groups
  full outer join vtm
    on (vtm.atom_id, vtm.role) = (dim_staging.multi_groups.atom_id, dim_staging.multi_groups.role)
)
select matched.group_id
from matched
where not (
  exists (
    select *
    from matched
    where matched.group_id is null
  )
)
group by matched.group_id
having count(1) filter (where matched.atom_id is null) = 0
  and count(1) = matched.cnt;

下面是一个完整的测试脚本来演示如何创建上面的sql

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import declarative_base
import os
from sqlalchemy import (
    Column,
    Integer,
    Text
)
from sqlalchemy.sql.expression import func, select, tuple_, exists, join, literal, label
from sqlalchemy.dialects import postgresql
from collections import namedtuple


db_url = os.getenv('DB_URL', 'postgresql://localhost:5432/dw')
engine = create_engine(db_url, echo=False)
Session = sessionmaker(bind=engine)
session = Session()


Base = declarative_base()


class MultiColumnGroups(Base):
    __tablename__ = 'multi_groups'
    id = Column(Integer, primary_key=True)
    group_id = Column(Integer)
    atom_id = Column(Integer)
    role = Column(Text)


Base.metadata.drop_all(engine, [MultiColumnGroups.__table__])
Base.metadata.create_all(engine, [MultiColumnGroups.__table__])

vals = [
    # Group 1
    {'group_id': 1, 'atom_id': 1, 'role': 'referrer'},
    {'group_id': 1, 'atom_id': 2, 'role': 'rendering'},
    {'group_id': 1, 'atom_id': 3, 'role': 'attending'},
    # Group 2
    {'group_id': 2, 'atom_id': 1, 'role': 'attending'},
    {'group_id': 2, 'atom_id': 2, 'role': 'rendering'},
    {'group_id': 2, 'atom_id': 3, 'role': 'referrer'},
]

session.bulk_save_objects(
    [MultiColumnGroups(**x) for x in vals]
)
session.commit()

Entity = namedtuple('Entity', ['atom_id', 'role'])
values_to_match = [
    # (atom_id, role)
    # Entity(1, 'referrer'),
    # Entity(2, 'rendering'),
    # Entity(3, 'attending'),
    Entity(1, 'attending'),
    Entity(2, 'rendering'),
    Entity(3, 'referrer'),
]

vtm = select(
    [
        func.unnest(
            postgresql.array([
                getattr(e, f) for e in values_to_match
                ]
            )
        ).label(f)
        for f in Entity._fields
    ]
).cte(name='vtm')

j = join(
    MultiColumnGroups, vtm,
    tuple_(vtm.c.atom_id, vtm.c.role) == tuple_(MultiColumnGroups.atom_id, MultiColumnGroups.role),
    full=True
)
matched = select([
  MultiColumnGroups.group_id,
  vtm.c.atom_id,
  label(
    'cnt',
    literal(len(values_to_match),type_=Integer
   )
)]).select_from(j).cte(name='matched')

group_id = session.query(matched.c.group_id).\
    filter(
        ~exists().
        select_from(matched).
        where(matched.c.group_id == None)
    ).\
    group_by(matched.c.group_id).\
    having(func.count(1).filter(matched.c.atom_id == None) == 0).\
    having(func.count(1) == matched.c.cnt).one().group_id

print(group_id)

编辑:子组是否存在导致多个匹配的情况可以通过在查询中包含被比较的值的数量作为一个计数来解决,并检查匹配的分组计数是否等于值的数目。抱歉疏忽了。在

我认为您的解决方案会产生不确定的结果,因为组中的行是以未指定的顺序排列的,因此数组聚合与给定数组之间的比较可能会产生true或false,基于此: 在

[local]:5432 u@sopython*=> select group_id
[local] u@sopython- > from multi_groups 
[local] u@sopython- > group by group_id
[local] u@sopython- > having array_agg((atom_id, sequence)) = ARRAY[(1,3),(2,2),(3,1)];
 group_id 
     
        2
(1 row)

[local]:5432 u@sopython*=> update multi_groups set atom_id = atom_id where atom_id = 2;
UPDATE 2
[local]:5432 u@sopython*=> select group_id                                             
from multi_groups 
group by group_id
having array_agg((atom_id, sequence)) = ARRAY[(1,3),(2,2),(3,1)];
 group_id 
     
(0 rows)

您可以对两者都应用排序,或者尝试完全不同的方法:可以使用relational division代替数组比较。在

为了分割,您必须从Entity记录列表中形成一个临时关系。同样,有很多方法可以做到这一点。下面是一个使用非嵌套数组的方法:

^{pr2}$

另一个使用联合:

In [114]: vtm = union_all(*[
     ...:     select([literal(e.atom_id).label('atom_id'),
     ...:             literal(e.sequence).label('sequence')])
     ...:     for e in values_to_match
     ...: ]).alias()

一张临时桌子也可以。在

有了新的关系,您需要找到“查找那些不在组中的实体的multi_groups的答案”。这是一个可怕的句子,但有道理:

In [117]: mg = aliased(MultiColumnGroups)

In [119]: session.query(MultiColumnGroups.group_id).\
     ...:     filter(~exists().
     ...:         select_from(vtm).
     ...:         where(~exists().
     ...:             where(MultiColumnGroups.group_id == mg.group_id).
     ...:             where(tuple_(vtm.c.atom_id, vtm.c.sequence) ==
     ...:                   tuple_(mg.atom_id, mg.sequence)).
     ...:             correlate_except(mg))).\
     ...:     distinct().\
     ...:     all()
     ...: 
Out[119]: [(2)]

另一方面,您也可以选择组与给定实体的交集:

In [19]: gs = intersect(*[
    ...:     session.query(MultiColumnGroups.group_id).
    ...:         filter(MultiColumnGroups.atom_id == vtm.atom_id,
    ...:                MultiColumnGroups.sequence == vtm.sequence)
    ...:     for vtm in values_to_match
    ...: ])

In [20]: session.execute(gs).fetchall()
Out[20]: [(2,)]

错误

ProgrammingError: (psycopg2.ProgrammingError) operator does not exist: record[] = integer[]
LINE 3: ...gg((multi_groups.atom_id, multi_groups.sequence)) = ARRAY[AR...
                                                             ^
HINT:  No operator matches the given name and argument type(s). You might need to add explicit type casts.
 [SQL: 'SELECT multi_groups.group_id AS multi_groups_group_id \nFROM multi_groups GROUP BY multi_groups.group_id \nHAVING array_agg((multi_groups.atom_id, multi_groups.sequence)) = %(array_agg_1)s'] [parameters: {'array_agg_1': [[1, 3], [2, 2], [3, 1]]}] (Background on this error at: http://sqlalche.me/e/f405)

是由于您的values_to_match首先转换为列表列表列表(原因未知),然后是converted to an array by your DB-API driver。结果是一个整数数组,而不是一个记录数组(int,int)。使用raw DB-API connection和游标,传递元组列表的工作与您预期的一样。在

在SQLAlchemy中,如果将列表values_to_match^{}包装在一起,那么它将按照您的预期工作,不过请记住,结果是不确定的。在

相关问题 更多 >