SQLAlchemy 与单表继承的关系

6 投票
1 回答
4144 浏览
提问于 2025-04-18 03:21

我正在使用SQLAlchemy的单表继承功能来处理TransactionStudentTransactionCompanyTransaction这几个类:

class Transaction(Base):
    __tablename__ = 'transaction'

    id = Column(Integer, primary_key=True)

    # Who paid? This depends on whether the Transaction is a
    # CompanyTransaction or a StudentTransaction. We use
    # SQLAlchemy's Single Table Inheritance to make this work.
    discriminator = Column('origin', String(50))
    __mapper_args__ = {'polymorphic_on': discriminator}

    # When?
    time = Column(DateTime, default=datetime.utcnow)

    # Who administered it?
    staff_id = Column(Integer, ForeignKey('staff.id'))
    staff = relationship(
        'Staff',
        primaryjoin='and_(Transaction.staff_id==Staff.id)'
    )

    # How much?
    amount = Column(Integer)  # Negative for refunds, includes the decimal part

    # Type of transaction
    type = Column(Enum(
        'cash',
        'card',
        'transfer'
    ))


class CompanyTransaction(Transaction):
    __mapper_args__ = {'polymorphic_identity': 'company'}

    company_id = Column(Integer, ForeignKey('company.id'))
    company = relationship(
        'Company',
        primaryjoin='and_(CompanyTransaction.company_id==Company.id)'
    )


class StudentTransaction(Transaction):
    __mapper_args__ = {'polymorphic_identity': 'student'}

    student_id = Column(Integer, ForeignKey('student.id'))
    student = relationship(
        'Student',
        primaryjoin='and_(StudentTransaction.student_id==Student.id)'
    )

然后,我有一个学生类,它和学生交易类之间定义了一对多的关系:

class Student(Base):
    __tablename__ = 'student'

    id = Column(Integer, primary_key=True)

    transactions = relationship(
        'StudentTransaction',
        primaryjoin='and_(Student.id==StudentTransaction.student_id)',
        back_populates='student'
    )


    @hybrid_property
    def balance(self):
        return sum([transaction.amount for transaction in self.transactions])

问题是,当我调用学生类时,出现了错误:NotImplementedError: <built-in function getitem>,这个错误是在Student.balance()函数的返回行中出现的。

我哪里做错了呢?

谢谢。

1 个回答

12

一个混合属性是一个构造,它可以让我们在Python中创建一个描述符,这个描述符在实例级别和类级别的表现是不一样的。在类级别,我们希望它能够生成一个SQL表达式。直接使用像sum()这样的普通Python函数或者列表推导式来生成SQL表达式是不合法的。

举个例子,如果我想从“学生”表中查询,并且想要计算“交易”表中“金额”列的总和,我可能会想使用一个相关的子查询和一个SQL聚合函数。我们在这里想要的SQL看起来会像这样:

SELECT * FROM student WHERE (
      SELECT SUM(amount) FROM transaction WHERE student_id=student.id) > 500

我们的混合属性需要控制并生成这个表达式:

from sqlalchemy import *
from sqlalchemy.orm import *
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property

Base = declarative_base()

class Transaction(Base):
    __tablename__ = 'transaction'

    id = Column(Integer, primary_key=True)
    discriminator = Column('origin', String(50))
    __mapper_args__ = {'polymorphic_on': discriminator}
    amount = Column(Integer)

class StudentTransaction(Transaction):
    __mapper_args__ = {'polymorphic_identity': 'student'}

    student_id = Column(Integer, ForeignKey('student.id'))
    student = relationship(
        'Student',
        primaryjoin='and_(StudentTransaction.student_id==Student.id)'
    )

class Student(Base):
    __tablename__ = 'student'

    id = Column(Integer, primary_key=True)

    transactions = relationship(
        'StudentTransaction',
        primaryjoin='and_(Student.id==StudentTransaction.student_id)',
        back_populates='student'
    )

    @hybrid_property
    def balance(self):
        return sum([transaction.amount for transaction in self.transactions])

    @balance.expression
    def balance(cls):
        return select([
                    func.sum(StudentTransaction.amount)
                ]).where(StudentTransaction.student_id==cls.id).as_scalar()

e = create_engine("sqlite://", echo=True)
Base.metadata.create_all(e)
s = Session(e)

s.add_all([
    Student(transactions=[StudentTransaction(amount=50), StudentTransaction(amount=180)]),
    Student(transactions=[StudentTransaction(amount=600), StudentTransaction(amount=180)]),
    Student(transactions=[StudentTransaction(amount=25), StudentTransaction(amount=400)]),
])

print s.query(Student).filter(Student.balance > 400).all()

最后的输出结果是:

SELECT student.id AS student_id 
FROM student 
WHERE (SELECT sum("transaction".amount) AS sum_1 
FROM "transaction" 
WHERE "transaction".student_id = student.id) > ?
2014-04-19 19:38:10,866 INFO sqlalchemy.engine.base.Engine (400,)
[<__main__.Student object at 0x101f2e4d0>, <__main__.Student object at 0x101f2e6d0>]

撰写回答