Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release: OpenAPI swagger fix + project grid guide #891

Merged
merged 4 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion api/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ def from_internal(
id=cohort_template_id_format(internal.id),
name=internal.name,
description=internal.description,
criteria=internal.criteria.to_external(project_names=project_names).dict(),
criteria=internal.criteria.to_external(
project_names=project_names
).model_dump(),
project_id=internal.project,
)

Expand Down
8 changes: 2 additions & 6 deletions api/routes/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import asyncio
import csv
import io
from datetime import date
from typing import Any, Generator

from fastapi import APIRouter
Expand Down Expand Up @@ -139,14 +138,11 @@ async def export_project_participants(
for row in prepare_participants_for_export(participants, fields=fields):
writer.writerow(row)

basefn = f'{connection.project}-project-summary-{connection.author}-{date.today().isoformat()}'

return StreamingResponse(
iter([output.getvalue()]),
media_type=export_type.get_mime_type(),
headers={
'Content-Disposition': f'filename={basefn}{export_type.get_extension()}'
},
# content-disposition doesn't work here :(
headers={},
)


Expand Down
2 changes: 1 addition & 1 deletion api/utils/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def openapi():
version=version,
routes=app.routes,
# update when FastAPI + swagger supports 3.1.0
# openapi_version='3.1.0'
openapi_version='3.0.2',
)

openapi_schema['servers'] = [{'url': url} for url in URLS]
Expand Down
11 changes: 8 additions & 3 deletions db/python/connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,17 @@ def _get_config():
return SMConnections._credentials

@staticmethod
def make_connection(config: DatabaseConfiguration):
def make_connection(
config: DatabaseConfiguration, log_database_queries: bool | None = None
):
"""Create connection from dbname"""
# the connection string will prepare pooling automatically
return databases.Database(
config.get_connection_string(), echo=LOG_DATABASE_QUERIES
_should_log = (
log_database_queries
if log_database_queries is not None
else LOG_DATABASE_QUERIES
)
return databases.Database(config.get_connection_string(), echo=_should_log)

@staticmethod
async def connect():
Expand Down
2 changes: 1 addition & 1 deletion models/models/billing.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ class BillingTotalCostQueryModel(SMBase):

def __hash__(self):
"""Create hash for this object to use in caching"""
return hash(self.json())
return hash(self.model_dump_json())

def to_filter(self) -> BillingFilter:
"""
Expand Down
20 changes: 19 additions & 1 deletion models/models/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ class ProjectParticipantGridFilterType(Enum):
neq = 'neq'
startswith = 'startswith'
icontains = 'icontains'
gt = 'gt'
gte = 'gte'
lt = 'lt'
lte = 'lte'


class ProjectParticipantGridField(SMBase):
Expand Down Expand Up @@ -249,9 +253,23 @@ def update_d_from_meta(d: dict[str, bool], meta: dict[str, Any]):
)
participant_fields = [
Field(
key='external_ids',
key='id',
label='Participant ID',
is_visible=True,
filter_key='id',
filter_types=[
ProjectParticipantGridFilterType.eq,
ProjectParticipantGridFilterType.neq,
ProjectParticipantGridFilterType.gt,
ProjectParticipantGridFilterType.gte,
ProjectParticipantGridFilterType.lt,
ProjectParticipantGridFilterType.lte,
],
),
Field(
key='external_ids',
label='External Participant ID',
is_visible=True,
filter_key='external_id',
),
Field(
Expand Down
6 changes: 3 additions & 3 deletions test/test_cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,11 @@ async def test_internal_external(self):
cc_external = CohortCriteria(**cc_external_dict)
cc_internal = cc_external.to_internal(projects_internal=[self.project_id])
self.assertIsInstance(cc_internal, CohortCriteriaInternal)
self.assertDictEqual(cc_internal.dict(), cc_internal_dict)
self.assertDictEqual(cc_internal.model_dump(), cc_internal_dict)

cc_ext_trip = cc_internal.to_external(project_names=[self.project_name])
self.assertIsInstance(cc_ext_trip, CohortCriteria)
self.assertDictEqual(cc_ext_trip.dict(), cc_external_dict)
self.assertDictEqual(cc_ext_trip.model_dump(), cc_external_dict)

ctpl_internal_dict = {
'id': 496,
Expand All @@ -289,7 +289,7 @@ async def test_internal_external(self):
criteria_projects=[self.project_id], template_project=self.project_id
)
self.assertIsInstance(ctpl_internal, CohortTemplateInternal)
self.assertDictEqual(ctpl_internal.dict(), ctpl_internal_dict)
self.assertDictEqual(ctpl_internal.model_dump(), ctpl_internal_dict)

@run_as_sync
async def test_create_cohort_by_sgs(self):
Expand Down
16 changes: 15 additions & 1 deletion test/test_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,23 @@

DEFAULT_PARTICIPANT_FIELDS = [
ProjectParticipantGridField(
key='external_ids',
key='id',
label='Participant ID',
is_visible=True,
filter_key='id',
filter_types=[
ProjectParticipantGridFilterType.eq,
ProjectParticipantGridFilterType.neq,
ProjectParticipantGridFilterType.gt,
ProjectParticipantGridFilterType.gte,
ProjectParticipantGridFilterType.lt,
ProjectParticipantGridFilterType.lte,
],
),
ProjectParticipantGridField(
key='external_ids',
label='External Participant ID',
is_visible=True,
filter_key='external_id',
),
ProjectParticipantGridField(
Expand Down
87 changes: 64 additions & 23 deletions test/testbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import subprocess
import unittest
from functools import wraps
from typing import Dict

import databases.core
import nest_asyncio
Expand Down Expand Up @@ -47,14 +46,16 @@
logging.getLogger(lname).setLevel(logging.WARNING)


loop = asyncio.new_event_loop()


def find_free_port():
"""Find free port to run tests on"""
s = socket.socket()
s.bind(('', 0)) # Bind to a free port provided by the host.
return s.getsockname()[1] # Return the port number assigned.


loop = asyncio.new_event_loop()
free_port_number = s.getsockname()[1] # Return the port number assigned.
s.close() # free the port so we can immediately use
return free_port_number


def run_as_sync(f):
Expand All @@ -75,8 +76,9 @@ class DbTest(unittest.TestCase):

# store connections here, so they can be created PER-CLASS
# and don't get recreated per test.
dbs: Dict[str, MySqlContainer] = {}
connections: Dict[str, databases.Database] = {}
dbs: MySqlContainer | None = None
connections: dict[str, databases.Database] = {}

author: str
project_id: ProjectId
project_name: str
Expand All @@ -99,20 +101,51 @@ async def setup():
os.environ['SM_ENVIRONMENT'] = 'test'
logger = logging.getLogger()
try:
db = MySqlContainer('mariadb:11.2.2', password='test')
port_to_expose = find_free_port()
# override the default port to map the container to
db.with_bind_ports(db.port, port_to_expose)
logger.disabled = True
db.start()
logger.disabled = False
cls.dbs[cls.__name__] = db
db = cls.dbs
if not cls.dbs:
db = MySqlContainer('mariadb:11.2.2', password='test')

port_to_expose = find_free_port()

# override the default port to map the container to
db.with_bind_ports(db.port, port_to_expose)
logger.disabled = True
db.start()
logger.disabled = False
cls.dbs = db

db_prefix = 'db'
if am_i_in_test_environment:
db_prefix = '../db'

lcon_string = f'jdbc:mariadb://{db.get_container_host_ip()}:{port_to_expose}/{db.dbname}'
if not db:
raise ValueError('No database container found')

# create the database
db_name = str(cls.__name__) + 'Db'

_root_connection = SMConnections.make_connection(
CredentialedDatabaseConfiguration(
host=db.get_container_host_ip(),
port=str(port_to_expose),
username='root',
password=db.password,
dbname=db.dbname,
),
)

# create the database for each test class, and give permissions
await _root_connection.connect()
await _root_connection.execute(f'CREATE DATABASE {db_name};')
await _root_connection.execute(
f"GRANT ALL PRIVILEGES ON `{db_name}`.* TO {db.username}@'%';"
)
await _root_connection.execute('FLUSH PRIVILEGES;')
await _root_connection.disconnect()

# mfranklin -> future dancoates: if you work out how to copy the
# database instead of running liquibase, that would be great
lcon_string = f'jdbc:mariadb://{db.get_container_host_ip()}:{port_to_expose}/{db_name}'
# apply the liquibase schema
command = [
'liquibase',
Expand All @@ -127,17 +160,18 @@ async def setup():
]
subprocess.check_output(command, stderr=subprocess.STDOUT)

cls.author = 'testuser'

sm_db = SMConnections.make_connection(
CredentialedDatabaseConfiguration(
host=db.get_container_host_ip(),
port=port_to_expose,
port=str(port_to_expose),
username='root',
password=db.password,
dbname=db.dbname,
dbname=db_name,
)
)
await sm_db.connect()
cls.author = 'testuser'

cls.connections[cls.__name__] = sm_db
formed_connection = Connection(
Expand Down Expand Up @@ -213,10 +247,17 @@ async def setup():

@classmethod
def tearDownClass(cls) -> None:
db = cls.dbs.get(cls.__name__)
if db:
db.exec(f'DROP DATABASE {db.dbname};')
db.stop()
# remove from active_tests
@run_as_sync
async def tearDown():
connection = cls.connections.pop(cls.__name__, None)
if connection:
await connection.disconnect()

if len(cls.connections) == 0 and cls.dbs:
cls.dbs.stop()

tearDown()

def setUp(self) -> None:
self._connection = self.connections[self.__class__.__name__]
Expand Down
13 changes: 13 additions & 0 deletions web/src/index.css
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,19 @@ blockquote {
border: 1px solid var(--color-border-default);
}

.ui .modal > .content {
background: var(--color-bg-card);
}
.ui .modal > .header {
background: var(--color-bg-card);
color: var(--color-text-primary);
}

.ui .modal > .actions {
background: var(--color-bg-card);
color: var(--color-text-primary);
}

textarea {
background: var(--color-bg);
color: var(--color-text-primary);
Expand Down
25 changes: 18 additions & 7 deletions web/src/pages/project/DictEditor.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ export type DictEditorInput = { [key: string]: InputValue } | string

interface DictEditorProps {
input: DictEditorInput
height?: string
readonly?: boolean
onChange: (json: object) => void
}

Expand Down Expand Up @@ -38,7 +40,12 @@ const parseString = (str: string) => {
}
}

export const DictEditor: React.FunctionComponent<DictEditorProps> = ({ input, onChange }) => {
export const DictEditor: React.FunctionComponent<DictEditorProps> = ({
input,
onChange,
height,
readonly,
}) => {
const [textValue, setInnerTextValue] = React.useState<string>(getStringFromValue(input))
const theme = React.useContext(ThemeContext)

Expand Down Expand Up @@ -75,25 +82,29 @@ export const DictEditor: React.FunctionComponent<DictEditorProps> = ({ input, on
>
<Editor
value={textValue}
height="200px"
height={height || '200px'}
theme={theme.theme === 'dark-mode' ? 'vs-dark' : 'vs-light'}
language="yaml"
onChange={(value) => handleChange(value || '')}
options={{
minimap: { enabled: false },
automaticLayout: true,
readOnly: readonly,
scrollBeyondLastLine: false,
}}
/>
{error && (
<p>
<em style={{ color: 'var(--color-text-red)' }}>{error}</em>
</p>
)}
<p>
<Button onClick={submit} disabled={!!error}>
Apply
</Button>
</p>
{!readonly && (
<p>
<Button onClick={submit} disabled={!!error}>
Apply
</Button>
</p>
)}
</div>
)
}
Loading