Skip to content

Commit bf86954

Browse files
author
Savannah Norem
authored
Merge branch 'main' into testFindQuery
2 parents 94702ac + f245488 commit bf86954

File tree

3 files changed

+76
-16
lines changed

3 files changed

+76
-16
lines changed

aredis_om/model/model.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
ClassVar,
1515
Dict,
1616
List,
17+
Literal,
1718
Mapping,
1819
Optional,
1920
Sequence,
@@ -141,10 +142,10 @@ def embedded(cls):
141142

142143
def is_supported_container_type(typ: Optional[type]) -> bool:
143144
# TODO: Wait, why don't we support indexing sets?
144-
if typ == list or typ == tuple:
145+
if typ == list or typ == tuple or typ == Literal:
145146
return True
146147
unwrapped = get_origin(typ)
147-
return unwrapped == list or unwrapped == tuple
148+
return unwrapped == list or unwrapped == tuple or unwrapped == Literal
148149

149150

150151
def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any]):
@@ -1423,6 +1424,8 @@ def outer_type_or_annotation(field):
14231424
if not isinstance(field.annotation, type):
14241425
raise AttributeError(f"could not extract outer type from field {field}")
14251426
return field.annotation
1427+
elif get_origin(field.annotation) == Literal:
1428+
return str
14261429
else:
14271430
return field.annotation.__args__[0]
14281431

@@ -2066,21 +2069,33 @@ def schema_for_type(
20662069
# find any values marked as indexed.
20672070
if is_container_type and not is_vector:
20682071
field_type = get_origin(typ)
2069-
embedded_cls = get_args(typ)
2070-
if not embedded_cls:
2071-
log.warning(
2072-
"Model %s defined an empty list or tuple field: %s", cls, name
2072+
if field_type == Literal:
2073+
path = f"{json_path}.{name}"
2074+
return cls.schema_for_type(
2075+
path,
2076+
name,
2077+
name_prefix,
2078+
str,
2079+
field_info,
2080+
parent_type=field_type,
2081+
)
2082+
else:
2083+
embedded_cls = get_args(typ)
2084+
if not embedded_cls:
2085+
log.warning(
2086+
"Model %s defined an empty list or tuple field: %s", cls, name
2087+
)
2088+
return ""
2089+
path = f"{json_path}.{name}[*]"
2090+
embedded_cls = embedded_cls[0]
2091+
return cls.schema_for_type(
2092+
path,
2093+
name,
2094+
name_prefix,
2095+
embedded_cls,
2096+
field_info,
2097+
parent_type=field_type,
20732098
)
2074-
return ""
2075-
embedded_cls = embedded_cls[0]
2076-
return cls.schema_for_type(
2077-
f"{json_path}.{name}[*]",
2078-
name,
2079-
name_prefix,
2080-
embedded_cls,
2081-
field_info,
2082-
parent_type=field_type,
2083-
)
20842099
elif field_is_model:
20852100
name_prefix = f"{name_prefix}_{name}" if name_prefix else name
20862101
sub_fields = []

tests/test_hash_model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,3 +929,25 @@ class TestUpdate(HashModel):
929929

930930
rematerialized = await TestUpdate.find(TestUpdate.pk == t.pk).first()
931931
assert rematerialized.age == 34
932+
933+
934+
@py_test_mark_asyncio
935+
async def test_literals():
936+
from typing import Literal
937+
938+
class TestLiterals(HashModel):
939+
flavor: Literal["apple", "pumpkin"] = Field(index=True, default="apple")
940+
941+
schema = TestLiterals.redisearch_schema()
942+
943+
key_prefix = TestLiterals.make_key(
944+
TestLiterals._meta.primary_key_pattern.format(pk="")
945+
)
946+
assert schema == (
947+
f"ON HASH PREFIX 1 {key_prefix} SCHEMA pk TAG SEPARATOR | flavor TAG SEPARATOR |"
948+
)
949+
await Migrator().run()
950+
item = TestLiterals(flavor="pumpkin")
951+
await item.save()
952+
rematerialized = await TestLiterals.find(TestLiterals.flavor == "pumpkin").first()
953+
assert rematerialized.pk == item.pk

tests/test_json_model.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,3 +1134,26 @@ async def get_page(cls, offset, limit):
11341134
res = await Test.get_page(10, 30)
11351135
assert len(res) == 30
11361136
assert res[0].num == 10
1137+
1138+
1139+
@py_test_mark_asyncio
1140+
async def test_literals():
1141+
from typing import Literal
1142+
1143+
class TestLiterals(JsonModel):
1144+
flavor: Literal["apple", "pumpkin"] = Field(index=True, default="apple")
1145+
1146+
schema = TestLiterals.redisearch_schema()
1147+
1148+
key_prefix = TestLiterals.make_key(
1149+
TestLiterals._meta.primary_key_pattern.format(pk="")
1150+
)
1151+
assert schema == (
1152+
f"ON JSON PREFIX 1 {key_prefix} SCHEMA $.pk AS pk TAG SEPARATOR | "
1153+
"$.flavor AS flavor TAG SEPARATOR |"
1154+
)
1155+
await Migrator().run()
1156+
item = TestLiterals(flavor="pumpkin")
1157+
await item.save()
1158+
rematerialized = await TestLiterals.find(TestLiterals.flavor == "pumpkin").first()
1159+
assert rematerialized.pk == item.pk

0 commit comments

Comments
 (0)