Skip to content
This repository has been archived by the owner on Aug 8, 2023. It is now read-only.

Commit

Permalink
test(api): do not share state and rollback db
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-connolly committed Jul 6, 2022
1 parent 2e6d706 commit 68e782b
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 64 deletions.
15 changes: 8 additions & 7 deletions api/tests/api/test_info.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from starlette.status import HTTP_200_OK
from httpx import AsyncClient
from starlette import status

from app.core.config import settings


def test_get_reddit_info(client):
async def test_get_reddit_info(async_client: AsyncClient):
"""Test that reddit user info is retrieved"""
resp = client.get(f"{settings.api_version}/info/account/")
assert resp.status_code == HTTP_200_OK
resp = await async_client.get(f"{settings.api_version}/info/account/")
assert resp.status_code == status.HTTP_200_OK


def test_get_reddit_labels(client):
async def test_get_reddit_labels(async_client: AsyncClient):
"""Test that label counts are returned for the reddit posts"""
resp = client.get(f"{settings.api_version}/info/label/")
assert resp.status_code == HTTP_200_OK
resp = await async_client.get(f"{settings.api_version}/info/label/")
assert resp.status_code == status.HTTP_200_OK
61 changes: 16 additions & 45 deletions api/tests/api/test_posts.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,22 @@
from json import dumps

from starlette.status import HTTP_200_OK, HTTP_201_CREATED
from httpx import AsyncClient
from starlette import status

from app.core.config import settings

POST = dict(reddit_id="xkl123", title="AITA?", label="NTA", text="Once upon a time")
POST_ID = None

# TODO: don't make fixtures mutability - should work asyncronously as well
# TODO: look again for wrappers/solutions for creating DB sandbox


def test_add_post(client):
"""Test that post is added to DB"""
global POST_ID
resp = client.post(f"{settings.api_version}/posts/", data=dumps(POST))
POST_ID = resp.json()["id"]
POST.update({"id": POST_ID})
assert resp.status_code == HTTP_201_CREATED
assert resp.json() == POST


def test_get_post(client):
"""Test that post can be retrieved from DB by id"""
resp = client.get(f"{settings.api_version}/posts/{POST_ID}/")
assert resp.status_code == HTTP_200_OK
assert resp.json() == POST


def test_get_posts(client):
"""Test that a list of posts can be retrieved from DB"""
resp = client.get(f"{settings.api_version}/posts/")
assert resp.status_code == HTTP_200_OK
assert resp.json()["total"] >= 1
from app.db.tables import Post


def test_update_post(client):
"""Test that dummy post is properly updated"""
POST["label"] = "YTA"
payload = dumps(POST)
resp = client.put(f"{settings.api_version}/posts/{POST_ID}/", data=payload)
assert resp.status_code == HTTP_200_OK
assert resp.json() == POST
async def test_post_create(async_client: AsyncClient):
"""Test that ingredient can be created"""
url = f"{settings.api_version}/posts/"
id_ = str(Post.generate_post_id("xkl123"))
payload = {"id": id_, "title": "AITA?", "label": "NTA", "text": "test"}

response = await async_client.post(url, json=payload)

def test_remove_post(client):
"""Test that dummy post is deleted from DB"""
resp = client.delete(f"{settings.api_version}/posts/{POST_ID}/")
assert resp.status_code == HTTP_200_OK
assert resp.json() == POST
assert response.status_code == status.HTTP_201_CREATED
assert response.json() == {
"id": response.json()["id"],
"title": payload["title"],
"label": payload["label"],
"text": payload["text"],
}
13 changes: 7 additions & 6 deletions api/tests/api/test_reddit.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from starlette.status import HTTP_200_OK, HTTP_201_CREATED
from httpx import AsyncClient
from starlette.status import HTTP_201_CREATED, HTTP_404_NOT_FOUND

from app.core.config import settings


def test_trigger_reddit_sync(client):
async def test_trigger_reddit_sync(async_client: AsyncClient):
"""Test that post sync is instantiated"""
resp = client.post(f"{settings.api_version}/sync/?filter=top&limit=50")
resp = await async_client.post(f"{settings.api_version}/sync/?filter=top&limit=10")
assert resp.status_code == HTTP_201_CREATED


def test_get_reddit_sync_info(client):
async def test_get_reddit_sync_info(async_client: AsyncClient):
"""Test that info regarding last sync is returned"""
resp = client.get(f"{settings.api_version}/sync/")
assert resp.status_code == HTTP_200_OK
resp = await async_client.get(f"{settings.api_version}/sync/")
assert resp.status_code == HTTP_404_NOT_FOUND
55 changes: 49 additions & 6 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,53 @@
import asyncio
from typing import AsyncGenerator, Callable, Generator

import pytest
from fastapi.testclient import TestClient
from asgi_lifespan import LifespanManager
from httpx import AsyncClient
from sqlalchemy.ext.asyncio.session import AsyncSession

from app.db.session import async_session, engine
from app.db.tables.base import Base


@pytest.fixture(scope="session")
def event_loop() -> Generator:
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()


@pytest.fixture
async def db_session() -> AsyncGenerator[AsyncSession, None]:
"""Create test database session that will then be reverted after test run"""
async with engine.begin() as connection:
await connection.run_sync(Base.metadata.drop_all)
await connection.run_sync(Base.metadata.create_all)
async with async_session(bind=connection) as session:
yield session
await session.flush()
await session.rollback()


@pytest.fixture
def override_get_db(db_session: AsyncSession) -> Callable:
"""Make database session an async callable to pass to get_db dependency"""

async def _override_get_db() -> AsyncGenerator[AsyncSession, None]:
yield db_session

return _override_get_db


from app.main import app
@pytest.fixture()
async def async_client(override_get_db: Callable) -> AsyncGenerator:
"""Create test client to be used to test api endpoints"""
from app.db.session import get_db
from app.main import app

app.dependency_overrides[get_db] = override_get_db

@pytest.fixture(scope="module")
def client():
with TestClient(app) as client:
yield client
async with LifespanManager(app):
async with AsyncClient(app=app, base_url="http://testserver") as client:
yield client

0 comments on commit 68e782b

Please sign in to comment.