Skip to content

Commit

Permalink
Refactor REST Api using routers and increase test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
giffels committed Oct 8, 2021
1 parent 90267b6 commit 3f76538
Show file tree
Hide file tree
Showing 22 changed files with 279 additions and 144 deletions.
4 changes: 2 additions & 2 deletions docs/source/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
.. Created by changelog.py at 2021-10-06, command
'/Users/giffler/.cache/pre-commit/repor6pnmwlm/py_env-python3.9/bin/changelog docs/source/changes compile --output=docs/source/changelog.rst'
.. Created by changelog.py at 2021-10-08, command
'/Users/giffler/.cache/pre-commit/repor6pnmwlm/py_env-default/bin/changelog docs/source/changes compile --output=docs/source/changelog.rst'
based on the format of 'https://keepachangelog.com/'
#########
Expand Down
22 changes: 0 additions & 22 deletions tardis/rest/app.py

This file was deleted.

Empty file added tardis/rest/app/__init__.py
Empty file.
18 changes: 18 additions & 0 deletions tardis/rest/app/crud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
async def get_resource_state(sql_registry, drone_uuid: str):
sql_query = """
SELECT R.drone_uuid, RS.state
FROM Resources R
JOIN ResourceStates RS ON R.state_id = RS.state_id
WHERE R.drone_uuid = :drone_uuid"""
return await sql_registry.async_execute(sql_query, dict(drone_uuid=drone_uuid))


async def get_resources(sql_registry):
sql_query = """
SELECT R.remote_resource_uuid , RS.state, R.drone_uuid, S.site_name,
MT.machine_type, R.created, R.updated
FROM Resources R
JOIN ResourceStates RS ON R.state_id = RS.state_id
JOIN Sites S ON R.site_id = S.site_id
JOIN MachineTypes MT ON R.machine_type_id = MT.machine_type_id"""
return await sql_registry.async_execute(sql_query, {})
2 changes: 1 addition & 1 deletion tardis/rest/database.py → tardis/rest/app/database.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ..plugins.sqliteregistry import SqliteRegistry
from ...plugins.sqliteregistry import SqliteRegistry


def get_sql_registry():
Expand Down
6 changes: 6 additions & 0 deletions tardis/rest/app/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .routers import resources
from fastapi import FastAPI

app = FastAPI()

app.include_router(resources.router)
Empty file.
29 changes: 29 additions & 0 deletions tardis/rest/app/routers/resources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from .. import security, crud, database
from ....plugins.sqliteregistry import SqliteRegistry
from fastapi import APIRouter, Depends, HTTPException, Path, Security


router = APIRouter(prefix="/resources")


@router.get("/state/{drone_uuid}")
async def get_resource_state(
drone_uuid: str = Path(..., regex=r"^\S+-[A-Fa-f0-9]{10}$"),
sql_registry: SqliteRegistry = Depends(database.get_sql_registry()),
_: str = Security(security.check_authorization, scopes=["user:read"]),
):
query_result = await crud.get_resource_state(sql_registry, drone_uuid)
try:
query_result = query_result[0]
except IndexError:
raise HTTPException(status_code=404, detail="Drone not found") from None
return query_result


@router.get("/")
async def get_resources(
sql_registry: SqliteRegistry = Depends(database.get_sql_registry()),
_: str = Security(security.check_authorization, scopes=["user:read"]),
):
query_result = await crud.get_resources(sql_registry)
return query_result
4 changes: 2 additions & 2 deletions tardis/rest/security.py → tardis/rest/app/security.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ..configuration.configuration import Configuration
from ..exceptions.tardisexceptions import TardisError
from ...configuration.configuration import Configuration
from ...exceptions.tardisexceptions import TardisError

from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, SecurityScopes
Expand Down
7 changes: 0 additions & 7 deletions tardis/rest/crud.py

This file was deleted.

2 changes: 1 addition & 1 deletion tardis/rest/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, secret_key: str, algorithm: str = "HS256", **fast_api_args):
# necessary to avoid that the TARDIS' logger configuration is overwritten!
if "log_config" not in fast_api_args:
fast_api_args["log_config"] = None
self._config = Config("tardis.rest.app:app", **fast_api_args)
self._config = Config("tardis.rest.app.main:app", **fast_api_args)

@property
@lru_cache(maxsize=16)
Expand Down
2 changes: 1 addition & 1 deletion tardis/rest/token_generator/generate_token.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ..security import create_access_token
from ..app.security import create_access_token
from ...utilities.utils import disable_logging
from cobald.daemon.core.config import load

Expand Down
Empty file added tests/rest_t/app_t/__init__.py
Empty file.
50 changes: 48 additions & 2 deletions tests/rest_t/test_crutd.py → tests/rest_t/app_t/test_crutd.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from tardis.rest import crud
from tardis.rest.app import crud
from tardis.plugins.sqliteregistry import SqliteRegistry
from ..utilities.utilities import run_async
from tests.utilities.utilities import run_async

from unittest import TestCase
from unittest.mock import MagicMock
Expand Down Expand Up @@ -60,3 +60,49 @@ async def mocked_async_execute(sql_query: str, bind_parameters: dict):
WHERE R.drone_uuid = :drone_uuid""",
{"drone_uuid": "test-available-01234567ab"},
)

def test_get_resources(self):
full_expected_resources = [
{
"remote_resource_uuid": "14fa5640a7c146e482e8be41ec5dffea",
"state": "AvailableState",
"drone_uuid": "test-0125bc9fd8",
"site_name": "Test",
"machine_type": "m1.test",
"created": "2021-10-08T12:42:16.354400",
"updated": "2021-10-08T12:42:28.382025",
},
{
"remote_resource_uuid": "b3efcc5bc8b741af9222987e0434ca61",
"state": "AvailableState",
"drone_uuid": "test-6af3cfef14",
"site_name": "Test",
"machine_type": "m1.test",
"created": "2021-10-08T12:42:16.373454",
"updated": "2021-10-08T12:42:30.648325",
},
]

async def mocked_async_execute(sql_query: str, bind_parameters: dict):
return full_expected_resources

self.sql_registry_mock.async_execute.side_effect = mocked_async_execute

self.assertEqual(
full_expected_resources,
run_async(
crud.get_resources,
sql_registry=self.sql_registry_mock,
),
)

self.sql_registry_mock.async_execute.assert_called_with(
"""
SELECT R.remote_resource_uuid , RS.state, R.drone_uuid, S.site_name,
MT.machine_type, R.created, R.updated
FROM Resources R
JOIN ResourceStates RS ON R.state_id = RS.state_id
JOIN Sites S ON R.site_id = S.site_id
JOIN MachineTypes MT ON R.machine_type_id = MT.machine_type_id""",
{},
)
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from tardis.rest.database import get_sql_registry
from tardis.rest.app.database import get_sql_registry

from unittest import TestCase
from unittest.mock import patch


class TestDatabase(TestCase):
@patch("tardis.rest.database.SqliteRegistry")
@patch("tardis.rest.app.database.SqliteRegistry")
def test_get_sql_registry(self, mocked_sqlite_registry):
self.assertEqual(get_sql_registry()(), mocked_sqlite_registry())
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from tardis.exceptions.tardisexceptions import TardisError
from tardis.rest.security import (
from tardis.rest.app.security import (
create_access_token,
check_authorization,
get_algorithm,
Expand All @@ -21,7 +21,7 @@ class TestSecurity(TestCase):

@classmethod
def setUpClass(cls) -> None:
cls.mock_config_patcher = patch("tardis.rest.security.Configuration")
cls.mock_config_patcher = patch("tardis.rest.app.security.Configuration")
cls.mock_config = cls.mock_config_patcher.start()

@classmethod
Expand All @@ -46,7 +46,7 @@ def clear_lru_cache():
get_algorithm.cache_clear()
get_secret_key.cache_clear()

@patch("tardis.rest.security.datetime")
@patch("tardis.rest.app.security.datetime")
def test_create_access_token(self, mocked_datetime):
self.clear_lru_cache()

Expand Down Expand Up @@ -111,7 +111,7 @@ def test_check_authorization(self):
self.assertEqual(he.exception.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(he.exception.detail, "Could not validate credentials")

@patch("tardis.rest.security.jwt")
@patch("tardis.rest.app.security.jwt")
def test_check_authorization_jwt_error(self, mocked_jwt):
mocked_jwt.decode.side_effect = JWTError

Expand Down
Empty file.
59 changes: 59 additions & 0 deletions tests/rest_t/routers_t/base_test_case_routers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from tardis.rest.app.security import get_algorithm, get_secret_key
from tests.utilities.utilities import run_async

from httpx import AsyncClient

from unittest import TestCase
from unittest.mock import patch


class TestCaseRouters(TestCase):
mock_sqlite_registry_patcher = None
mock_crud_patcher = None
mock_config_patcher = None

@classmethod
def setUpClass(cls) -> None:
cls.mock_sqlite_registry_patcher = patch(
"tardis.rest.app.database.SqliteRegistry"
)
cls.mock_crud_patcher = patch("tardis.rest.app.routers.resources.crud")
cls.mock_config_patcher = patch("tardis.rest.app.security.Configuration")
cls.mock_sqlite_registry = cls.mock_sqlite_registry_patcher.start()
cls.mock_crud = cls.mock_crud_patcher.start()
cls.mock_config = cls.mock_config_patcher.start()

@classmethod
def tearDownClass(cls) -> None:
cls.mock_sqlite_registry_patcher.stop()
cls.mock_crud_patcher.stop()
cls.mock_config_patcher.stop()

def setUp(self) -> None:
secret_key = "63328dc6b8524bf08b0ba151e287edb498852b77b97f837088de4d17247d032c"
algorithm = "HS256"

config = self.mock_config.return_value
config.Services.restapi.secret_key = secret_key
config.Services.restapi.algorithm = algorithm

from tardis.rest.app.main import (
app,
) # has to be imported after SqliteRegistry patch

self.client = AsyncClient(app=app, base_url="http://test")

def tearDown(self) -> None:
run_async(self.client.aclose)

@staticmethod
def clear_lru_cache():
get_algorithm.cache_clear()
get_secret_key.cache_clear()

@property
def headers(
self,
token="Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ0YXJkaXMiLCJzY29wZXMiOlsidXNlcjpyZWFkIl19.l2xDqxEQOLYQq6cDX7RGDcT1XvyupRcBUpvvW1l4yeM", # noqa B950
):
return {"accept": "application/json", "Authorization": token}
83 changes: 83 additions & 0 deletions tests/rest_t/routers_t/test_resources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from tests.rest_t.routers_t.base_test_case_routers import TestCaseRouters
from tests.utilities.utilities import async_return, run_async


class TestResources(TestCaseRouters):
# Reminder: When defining `setUp`, `setUpClass`, `tearDown` and `tearDownClass`
# in router tests the corresponding super().function() needs to be called as well.
def test_get_resource_state(self):
self.clear_lru_cache()
self.mock_crud.get_resource_state.return_value = async_return(
return_value=[{"drone_uuid": "test-0123456789", "state": "AvailableState"}]
)

response = run_async(
self.client.get, "/resources/state/test-0123456789", headers=self.headers
)
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.json(),
{"drone_uuid": "test-0123456789", "state": "AvailableState"},
)

self.mock_crud.get_resource_state.return_value = async_return(return_value=[])
response = run_async(
self.client.get, "/resources/state/test-1234567890", headers=self.headers
)
self.assertEqual(response.status_code, 404)
self.assertEqual(response.json(), {"detail": "Drone not found"})

response = run_async(
self.client.get, "/resources/state/test-invalid", headers=self.headers
)
self.assertEqual(response.status_code, 422)
self.assertEqual(
response.json(),
{
"detail": [
{
"ctx": {"pattern": "^\\S+-[A-Fa-f0-9]{10}$"},
"loc": ["path", "drone_uuid"],
"msg": 'string does not match regex "^\\S+-[A-Fa-f0-9]{10}$"',
"type": "value_error.str.regex",
}
]
},
)

response = run_async(self.client.get, "/resources/state", headers=self.headers)
self.assertEqual(response.status_code, 404)
self.assertEqual(response.json(), {"detail": "Not Found"})

def test_get_resources(self):
self.clear_lru_cache()
full_expected_resources = [
{
"remote_resource_uuid": "14fa5640a7c146e482e8be41ec5dffea",
"state": "AvailableState",
"drone_uuid": "test-0125bc9fd8",
"site_name": "Test",
"machine_type": "m1.test",
"created": "2021-10-08T12:42:16.354400",
"updated": "2021-10-08T12:42:28.382025",
},
{
"remote_resource_uuid": "b3efcc5bc8b741af9222987e0434ca61",
"state": "AvailableState",
"drone_uuid": "test-6af3cfef14",
"site_name": "Test",
"machine_type": "m1.test",
"created": "2021-10-08T12:42:16.373454",
"updated": "2021-10-08T12:42:30.648325",
},
]
self.mock_crud.get_resources.return_value = async_return(
return_value=full_expected_resources
)

response = run_async(self.client.get, "/resources", headers=self.headers)
self.assertEqual(response.status_code, 200)
self.assertEqual(
response.json(),
full_expected_resources,
)
Loading

0 comments on commit 3f76538

Please sign in to comment.