From 68e782b17426507f42d1bb3a7c818bc8761b91a6 Mon Sep 17 00:00:00 2001 From: Logan Connolly Date: Tue, 5 Jul 2022 22:52:26 +0200 Subject: [PATCH] test(api): do not share state and rollback db --- api/tests/api/test_info.py | 15 ++++----- api/tests/api/test_posts.py | 61 ++++++++++-------------------------- api/tests/api/test_reddit.py | 13 ++++---- api/tests/conftest.py | 55 ++++++++++++++++++++++++++++---- 4 files changed, 80 insertions(+), 64 deletions(-) diff --git a/api/tests/api/test_info.py b/api/tests/api/test_info.py index c6622cd..f68b1c0 100644 --- a/api/tests/api/test_info.py +++ b/api/tests/api/test_info.py @@ -1,15 +1,16 @@ -from starlette.status import HTTP_200_OK +from httpx import AsyncClient +from starlette import status from app.core.config import settings -def test_get_reddit_info(client): +async def test_get_reddit_info(async_client: AsyncClient): """Test that reddit user info is retrieved""" - resp = client.get(f"{settings.api_version}/info/account/") - assert resp.status_code == HTTP_200_OK + resp = await async_client.get(f"{settings.api_version}/info/account/") + assert resp.status_code == status.HTTP_200_OK -def test_get_reddit_labels(client): +async def test_get_reddit_labels(async_client: AsyncClient): """Test that label counts are returned for the reddit posts""" - resp = client.get(f"{settings.api_version}/info/label/") - assert resp.status_code == HTTP_200_OK + resp = await async_client.get(f"{settings.api_version}/info/label/") + assert resp.status_code == status.HTTP_200_OK diff --git a/api/tests/api/test_posts.py b/api/tests/api/test_posts.py index 87e1bb9..e19512a 100644 --- a/api/tests/api/test_posts.py +++ b/api/tests/api/test_posts.py @@ -1,51 +1,22 @@ -from json import dumps - -from starlette.status import HTTP_200_OK, HTTP_201_CREATED +from httpx import AsyncClient +from starlette import status from app.core.config import settings - -POST = dict(reddit_id="xkl123", title="AITA?", label="NTA", text="Once upon a time") -POST_ID = None - -# TODO: don't make fixtures mutability - should work asyncronously as well -# TODO: look again for wrappers/solutions for creating DB sandbox - - -def test_add_post(client): - """Test that post is added to DB""" - global POST_ID - resp = client.post(f"{settings.api_version}/posts/", data=dumps(POST)) - POST_ID = resp.json()["id"] - POST.update({"id": POST_ID}) - assert resp.status_code == HTTP_201_CREATED - assert resp.json() == POST - - -def test_get_post(client): - """Test that post can be retrieved from DB by id""" - resp = client.get(f"{settings.api_version}/posts/{POST_ID}/") - assert resp.status_code == HTTP_200_OK - assert resp.json() == POST - - -def test_get_posts(client): - """Test that a list of posts can be retrieved from DB""" - resp = client.get(f"{settings.api_version}/posts/") - assert resp.status_code == HTTP_200_OK - assert resp.json()["total"] >= 1 +from app.db.tables import Post -def test_update_post(client): - """Test that dummy post is properly updated""" - POST["label"] = "YTA" - payload = dumps(POST) - resp = client.put(f"{settings.api_version}/posts/{POST_ID}/", data=payload) - assert resp.status_code == HTTP_200_OK - assert resp.json() == POST +async def test_post_create(async_client: AsyncClient): + """Test that ingredient can be created""" + url = f"{settings.api_version}/posts/" + id_ = str(Post.generate_post_id("xkl123")) + payload = {"id": id_, "title": "AITA?", "label": "NTA", "text": "test"} + response = await async_client.post(url, json=payload) -def test_remove_post(client): - """Test that dummy post is deleted from DB""" - resp = client.delete(f"{settings.api_version}/posts/{POST_ID}/") - assert resp.status_code == HTTP_200_OK - assert resp.json() == POST + assert response.status_code == status.HTTP_201_CREATED + assert response.json() == { + "id": response.json()["id"], + "title": payload["title"], + "label": payload["label"], + "text": payload["text"], + } diff --git a/api/tests/api/test_reddit.py b/api/tests/api/test_reddit.py index f354add..8dc2c09 100644 --- a/api/tests/api/test_reddit.py +++ b/api/tests/api/test_reddit.py @@ -1,15 +1,16 @@ -from starlette.status import HTTP_200_OK, HTTP_201_CREATED +from httpx import AsyncClient +from starlette.status import HTTP_201_CREATED, HTTP_404_NOT_FOUND from app.core.config import settings -def test_trigger_reddit_sync(client): +async def test_trigger_reddit_sync(async_client: AsyncClient): """Test that post sync is instantiated""" - resp = client.post(f"{settings.api_version}/sync/?filter=top&limit=50") + resp = await async_client.post(f"{settings.api_version}/sync/?filter=top&limit=10") assert resp.status_code == HTTP_201_CREATED -def test_get_reddit_sync_info(client): +async def test_get_reddit_sync_info(async_client: AsyncClient): """Test that info regarding last sync is returned""" - resp = client.get(f"{settings.api_version}/sync/") - assert resp.status_code == HTTP_200_OK + resp = await async_client.get(f"{settings.api_version}/sync/") + assert resp.status_code == HTTP_404_NOT_FOUND diff --git a/api/tests/conftest.py b/api/tests/conftest.py index 6e0798b..99ead1a 100644 --- a/api/tests/conftest.py +++ b/api/tests/conftest.py @@ -1,10 +1,53 @@ +import asyncio +from typing import AsyncGenerator, Callable, Generator + import pytest -from fastapi.testclient import TestClient +from asgi_lifespan import LifespanManager +from httpx import AsyncClient +from sqlalchemy.ext.asyncio.session import AsyncSession + +from app.db.session import async_session, engine +from app.db.tables.base import Base + + +@pytest.fixture(scope="session") +def event_loop() -> Generator: + """Create an instance of the default event loop for each test case.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest.fixture +async def db_session() -> AsyncGenerator[AsyncSession, None]: + """Create test database session that will then be reverted after test run""" + async with engine.begin() as connection: + await connection.run_sync(Base.metadata.drop_all) + await connection.run_sync(Base.metadata.create_all) + async with async_session(bind=connection) as session: + yield session + await session.flush() + await session.rollback() + + +@pytest.fixture +def override_get_db(db_session: AsyncSession) -> Callable: + """Make database session an async callable to pass to get_db dependency""" + + async def _override_get_db() -> AsyncGenerator[AsyncSession, None]: + yield db_session + + return _override_get_db + -from app.main import app +@pytest.fixture() +async def async_client(override_get_db: Callable) -> AsyncGenerator: + """Create test client to be used to test api endpoints""" + from app.db.session import get_db + from app.main import app + app.dependency_overrides[get_db] = override_get_db -@pytest.fixture(scope="module") -def client(): - with TestClient(app) as client: - yield client + async with LifespanManager(app): + async with AsyncClient(app=app, base_url="http://testserver") as client: + yield client