Skip to content

Commit

Permalink
supporting literals as tag type (#635)
Browse files Browse the repository at this point in the history
* supporting literals as tag type

* fixing key-prefix issue
  • Loading branch information
slorello89 committed Aug 2, 2024
1 parent b20e887 commit f245488
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 16 deletions.
47 changes: 31 additions & 16 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ClassVar,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
Expand Down Expand Up @@ -141,10 +142,10 @@ def embedded(cls):

def is_supported_container_type(typ: Optional[type]) -> bool:
# TODO: Wait, why don't we support indexing sets?
if typ == list or typ == tuple:
if typ == list or typ == tuple or typ == Literal:
return True
unwrapped = get_origin(typ)
return unwrapped == list or unwrapped == tuple
return unwrapped == list or unwrapped == tuple or unwrapped == Literal


def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any]):
Expand Down Expand Up @@ -1414,6 +1415,8 @@ def outer_type_or_annotation(field):
if not isinstance(field.annotation, type):
raise AttributeError(f"could not extract outer type from field {field}")
return field.annotation
elif get_origin(field.annotation) == Literal:
return str
else:
return field.annotation.__args__[0]

Expand Down Expand Up @@ -2057,21 +2060,33 @@ def schema_for_type(
# find any values marked as indexed.
if is_container_type and not is_vector:
field_type = get_origin(typ)
embedded_cls = get_args(typ)
if not embedded_cls:
log.warning(
"Model %s defined an empty list or tuple field: %s", cls, name
if field_type == Literal:
path = f"{json_path}.{name}"
return cls.schema_for_type(
path,
name,
name_prefix,
str,
field_info,
parent_type=field_type,
)
else:
embedded_cls = get_args(typ)
if not embedded_cls:
log.warning(
"Model %s defined an empty list or tuple field: %s", cls, name
)
return ""
path = f"{json_path}.{name}[*]"
embedded_cls = embedded_cls[0]
return cls.schema_for_type(
path,
name,
name_prefix,
embedded_cls,
field_info,
parent_type=field_type,
)
return ""
embedded_cls = embedded_cls[0]
return cls.schema_for_type(
f"{json_path}.{name}[*]",
name,
name_prefix,
embedded_cls,
field_info,
parent_type=field_type,
)
elif field_is_model:
name_prefix = f"{name_prefix}_{name}" if name_prefix else name
sub_fields = []
Expand Down
22 changes: 22 additions & 0 deletions tests/test_hash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,3 +917,25 @@ class TestUpdate(HashModel):

rematerialized = await TestUpdate.find(TestUpdate.pk == t.pk).first()
assert rematerialized.age == 34


@py_test_mark_asyncio
async def test_literals():
from typing import Literal

class TestLiterals(HashModel):
flavor: Literal["apple", "pumpkin"] = Field(index=True, default="apple")

schema = TestLiterals.redisearch_schema()

key_prefix = TestLiterals.make_key(
TestLiterals._meta.primary_key_pattern.format(pk="")
)
assert schema == (
f"ON HASH PREFIX 1 {key_prefix} SCHEMA pk TAG SEPARATOR | flavor TAG SEPARATOR |"
)
await Migrator().run()
item = TestLiterals(flavor="pumpkin")
await item.save()
rematerialized = await TestLiterals.find(TestLiterals.flavor == "pumpkin").first()
assert rematerialized.pk == item.pk
24 changes: 24 additions & 0 deletions tests/test_json_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,7 @@ class ModelWithIntPk(JsonModel):
m = await ModelWithIntPk.find(ModelWithIntPk.my_id == 42).first()
assert m.my_id == 42


@py_test_mark_asyncio
async def test_pagination():
class Test(JsonModel):
Expand All @@ -1121,3 +1122,26 @@ async def get_page(cls, offset, limit):
res = await Test.get_page(10, 30)
assert len(res) == 30
assert res[0].num == 10


@py_test_mark_asyncio
async def test_literals():
from typing import Literal

class TestLiterals(JsonModel):
flavor: Literal["apple", "pumpkin"] = Field(index=True, default="apple")

schema = TestLiterals.redisearch_schema()

key_prefix = TestLiterals.make_key(
TestLiterals._meta.primary_key_pattern.format(pk="")
)
assert schema == (
f"ON JSON PREFIX 1 {key_prefix} SCHEMA $.pk AS pk TAG SEPARATOR | "
"$.flavor AS flavor TAG SEPARATOR |"
)
await Migrator().run()
item = TestLiterals(flavor="pumpkin")
await item.save()
rematerialized = await TestLiterals.find(TestLiterals.flavor == "pumpkin").first()
assert rematerialized.pk == item.pk

0 comments on commit f245488

Please sign in to comment.