-
Notifications
You must be signed in to change notification settings - Fork 0
/
conftest.py
90 lines (73 loc) · 2.78 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from typing import AsyncIterable
import pytest
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import (
AsyncConnection,
AsyncEngine,
AsyncSession,
create_async_engine,
)
from app.config import settings
from app.database import Base, get_db_session
from app.main import app
from app.tests.utils import authentication_token_from_email, get_superuser_token_headers
# SQLite in-memory database
@pytest.fixture(scope="session", autouse=True)
async def engine() -> AsyncIterable[AsyncEngine]:
from sqlalchemy.pool import StaticPool
_engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with _engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield _engine
await _engine.dispose()
# Postgres temporary database
# @pytest.fixture(scope="session", autouse=True)
# async def engine() -> AsyncIterable[AsyncEngine]:
# import testing.postgresql
# from sqlalchemy.pool import NullPool
# with testing.postgresql.Postgresql() as postgresql:
# _engine = create_async_engine(
# postgresql.url().replace("postgresql://", "postgresql+asyncpg://"),
# poolclass=NullPool,
# )
# async with _engine.begin() as conn:
# await conn.run_sync(Base.metadata.create_all)
# yield _engine
# await _engine.dispose()
@pytest.fixture(scope="function")
async def session(engine: AsyncEngine) -> AsyncIterable[AsyncSession]:
"""Create a new database session with a rollback at the end of the test.
Inspired by https://dev.to/jbrocher/fastapi-testing-a-database-5ao5
"""
connection: AsyncConnection = await engine.connect()
transaction = await connection.begin()
session = AsyncSession(bind=connection)
yield session
await transaction.rollback()
await connection.close()
@pytest.fixture(scope="function")
async def client(session: AsyncSession) -> AsyncIterable[AsyncClient]:
"""Create a test client that uses the override_get_db fixture to return
a session.
"""
app.dependency_overrides[get_db_session] = lambda: session
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as test_client:
yield test_client
@pytest.fixture(scope="function")
async def superuser_token_headers(
client: AsyncClient, session: AsyncSession
) -> dict[str, str]:
return await get_superuser_token_headers(client, session)
@pytest.fixture(scope="function")
async def normal_user_token_headers(
client: AsyncClient, session: AsyncSession
) -> dict[str, str]:
return await authentication_token_from_email(
client=client, email=settings.EMAIL_TEST_USER, session=session
)