Pytest:内存数据在夹具间不持久化

0 投票
1 回答
19 浏览
提问于 2025-04-12 15:05

这是我的项目设置

project/
  app.py
  test_app.py

app.py

from pydantic import BaseModel
from sqlalchemy.orm import Session
from fastapi import Depends, FastAPI, HTTPException


class UserCreateModel(BaseModel):
    username: str
    password: str


def get_database_session():
    yield


class UserRepo:
    def create_user(self, user_create_model: UserCreateModel, session: Session):
        pass

    def get_user_by_id(self, id: int, session: Session):
        pass


def create_app():
    app = FastAPI()

    @app.post("/api/v1/users", status_code=201)
    async def create_user(
        user_create_model: UserCreateModel, user_repo: UserRepo = Depends(), session=Depends(get_database_session)
    ):
        user = user_repo.create_user(user_create_model=user_create_model, session=session)
        return user

    @app.get("/api/v1/users/{id}", status_code=200)
    async def get_user_by_id(id: int, user_repo: UserRepo = Depends(), session=Depends(get_database_session)):
        user = user_repo.get_user_by_id(id=id, session=session)

        if not user:
            raise HTTPException(status_code=404)
        return user

    return app


app = create_app()

test_app.py

from app import UserCreateModel, UserRepo, create_app, get_database_session
from dataclasses import dataclass
from fastapi import FastAPI
from fastapi.testclient import TestClient
from pytest import fixture
from sqlalchemy.orm import Session


@fixture(scope="session")
def app():
    def _get_database_session():
        return True
    
    app = create_app()
    app.dependency_overrides[get_database_session] = _get_database_session
    yield app


@fixture(scope="session")
def client(app: FastAPI):
    client = TestClient(app=app)
    yield client


@fixture(scope="function")
def user_repo(app: FastAPI):
    print("created")

    @dataclass
    class User:
        id: int
        username: str
        password: str

    class MockUserRepo:
        def __init__(self):
            self.database = []

        def create_user(self, user_create_model: UserCreateModel, session: Session) -> User | None:
            if session:
                user = User(
                    id=len(self.database) + 1,
                    username=user_create_model.username,
                    password=user_create_model.password,
                )
                self.database.append(user)
                return user

        def get_user_by_id(self, id: int, session: Session) -> User | None:
            print(self.database)
            if session:
                for user in self.database:
                    if user.id == id:
                        return user

                return None


    app.dependency_overrides[UserRepo] = MockUserRepo
    yield


def test_create_user(client: TestClient, user_repo):
    data = {"username": "mike", "password": "123"}
    res = client.post("/api/v1/users", json=data)

    assert res.status_code == 201
    assert res.json()["username"] == "mike"
    assert res.json()["password"] == "123"
    assert "id" in res.json()

def test_get_user(client: TestClient, user_repo):
    data = {"username": "nick", "password": "123"}
    res = client.post("/api/v1/users", json=data)
    assert res.status_code == 201
    user = res.json()
    user_id = user["id"]

    res = client.get(f"/api/v1/users/{user_id}")

    assert res.status_code == 200

对于test_get_user()这个测试,我首先发送了一个post请求,并把{"username": "nick", "password": "123"}这个数据添加到self.database里。然后我又发送了一个get请求,但奇怪的是,尽管我设置的测试环境是函数级别的,self.database还是空的。有没有人知道这是怎么回事?

1 个回答

1

FastAPI会为每个请求运行依赖项,这意味着每次请求都会得到一个新的MockUserRepo实例。如果你想让它在每个测试中保持不变,可以尝试以下方法:

@fixture(scope="function")
def user_repo(app: FastAPI):
    print("created")

    @dataclass
    class User:
        id: int
        username: str
        password: str

    database = []

    class MockUserRepo:
        def create_user(self, user_create_model: UserCreateModel, session: Session) -> User | None:
            if session:
                user = User(
                    id=len(database) + 1,
                    username=user_create_model.username,
                    password=user_create_model.password,
                )
                database.append(user)
                return user

        def get_user_by_id(self, id: int, session: Session) -> User | None:
            print(database)
            if session:
                for user in database:
                    if user.id == id:
                        return user

                return None


    app.dependency_overrides[UserRepo] = MockUserRepo
    yield

撰写回答