Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix _get_type - check for NoneType #254

Merged
merged 5 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions piccolo_api/fastapi/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@
ANNOTATIONS: t.DefaultDict = defaultdict(dict)


try:
# Python 3.10 and above
from types import UnionType # type: ignore
except ImportError:

class UnionType: # type: ignore
...
Dismissed Show dismissed Hide dismissed


class HTTPMethod(str, Enum):
get = "GET"
delete = "DELETE"
Expand Down Expand Up @@ -83,25 +92,30 @@

For example::

>>> get_type(Optional[int])
>>> _get_type(Optional[int])
int

>>> get_type(int)
>>> _get_type(int | None)
int

>>> get_type(list[str])
>>> _get_type(int)
int

>>> _get_type(list[str])
list[str]

"""
origin = t.get_origin(type_)

# Note: even if `t.Optional` is passed in, the origin is still a
# `t.Union`.
if origin is t.Union:
args = t.get_args(type_)
# `t.Union` or `UnionType` depending on the Python version.
if any(origin is i for i in (t.Union, UnionType)):
union_args = t.get_args(type_)

NoneType = type(None)

if len(args) == 2 and None in args:
return [i for i in args if i is not None][0]
if len(union_args) == 2 and NoneType in union_args:
return [i for i in union_args if i is not NoneType][0]

return type_

Expand Down
29 changes: 28 additions & 1 deletion tests/fastapi/test_fastapi_endpoints.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import sys
import typing as t
from unittest import TestCase

import pytest
from fastapi import FastAPI
from piccolo.columns import ForeignKey, Integer, Varchar
from piccolo.columns.readable import Readable
from piccolo.table import Table
from starlette.testclient import TestClient

from piccolo_api.crud.endpoints import PiccoloCRUD
from piccolo_api.fastapi.endpoints import FastAPIWrapper
from piccolo_api.fastapi.endpoints import FastAPIWrapper, _get_type


class Movie(Table):
Expand Down Expand Up @@ -246,3 +249,27 @@ def test_patch(self):
self.assertEqual(
response.json(), {"id": 1, "name": "Star Wars", "rating": 90}
)


class TestGetType(TestCase):
def test_get_type(self):
"""
If we pass in an optional type, it should return the non-optional type.
"""
# Should return the underlying type, as they're all optional:
self.assertIs(_get_type(t.Optional[str]), str)
self.assertIs(_get_type(t.Optional[t.List[str]]), t.List[str])
self.assertIs(_get_type(t.Union[str, None]), str)

# Should be returned as is, because it's not optional:
self.assertIs(_get_type(t.List[str]), t.List[str])

@pytest.mark.skipif(
sys.version_info < (3, 10), reason="Union syntax not available"
)
def test_new_union_syntax(self):
"""
Make sure it works with the new syntax added in Python 3.10.
"""
self.assertIs(_get_type(str | None), str) # type: ignore
self.assertIs(_get_type(None | str), str) # type: ignore
Loading