Skip to content

Commit 840f5c3

Browse files
committed
support schema union in query
1 parent 9e1e5f6 commit 840f5c3

File tree

2 files changed

+124
-38
lines changed

2 files changed

+124
-38
lines changed

ninja/signature/details.py

Lines changed: 60 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,13 @@ def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]]
187187
arg_names: Any = {}
188188
for arg in args:
189189
if is_pydantic_model(arg.annotation):
190-
for name, path in self._model_flatten_map(arg.annotation, arg.alias):
191-
if name in flatten_map:
190+
for name, path, is_union_descendant in self._model_flatten_map(
191+
arg.annotation, arg.alias
192+
):
193+
model = arg.annotation
194+
if get_origin(model) is Annotated:
195+
model = get_args(model)[0]
196+
if not is_union_descendant and name in flatten_map:
192197
raise ConfigError(
193198
f"Duplicated name: '{name}' in params: '{arg_names[name]}' & '{arg.name}'"
194199
)
@@ -205,15 +210,26 @@ def _args_flatten_map(self, args: List[FuncParam]) -> Dict[str, Tuple[str, ...]]
205210

206211
return flatten_map
207212

208-
def _model_flatten_map(self, model: TModel, prefix: str) -> Generator:
213+
def _model_flatten_map(
214+
self, model: TModel, prefix: str, is_union_descendant: bool = False
215+
) -> Generator[Tuple[str, str, bool], None, None]:
209216
field: FieldInfo
210-
for attr, field in model.model_fields.items():
211-
field_name = field.alias or attr
212-
name = f"{prefix}{self.FLATTEN_PATH_SEP}{field_name}"
213-
if is_pydantic_model(field.annotation):
214-
yield from self._model_flatten_map(field.annotation, name) # type: ignore
215-
else:
216-
yield field_name, name
217+
if get_origin(model) is Annotated:
218+
model = get_args(model)[0]
219+
if get_origin(model) in UNION_TYPES:
220+
# If the model is a union type, process each type in the union
221+
for arg in get_args(model):
222+
yield from self._model_flatten_map(arg, prefix, True)
223+
else:
224+
for attr, field in model.model_fields.items():
225+
field_name = field.alias or attr
226+
name = f"{prefix}{self.FLATTEN_PATH_SEP}{field_name}"
227+
if is_pydantic_model(field.annotation):
228+
yield from self._model_flatten_map(
229+
field.annotation, name, is_union_descendant
230+
) # type: ignore
231+
else:
232+
yield field_name, name, is_union_descendant
217233

218234
def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
219235
# _EMPTY = self.signature.empty
@@ -336,24 +352,40 @@ def detect_collection_fields(
336352
for path in (p for p in flatten_map.values() if len(p) > 1):
337353
annotation_or_field: Any = args_d[path[0]].annotation
338354
for attr in path[1:]:
339-
if hasattr(annotation_or_field, "annotation"):
340-
annotation_or_field = annotation_or_field.annotation
341-
annotation_or_field = next(
342-
(
343-
a
344-
for a in annotation_or_field.model_fields.values()
345-
if a.alias == attr
346-
),
347-
annotation_or_field.model_fields.get(attr),
348-
) # pragma: no cover
355+
if get_origin(annotation_or_field) is Annotated:
356+
annotation_or_field = get_args(annotation_or_field)[0]
357+
358+
# check union types
359+
if get_origin(annotation_or_field) in UNION_TYPES:
360+
for arg in get_args(annotation_or_field):
361+
annotation_or_field = _detect_collection_fields(
362+
arg, attr, path, result
363+
)
364+
else:
365+
annotation_or_field = _detect_collection_fields(
366+
annotation_or_field, attr, path, result
367+
)
368+
return result
349369

350-
annotation_or_field = getattr(
351-
annotation_or_field, "outer_type_", annotation_or_field
352-
)
353370

354-
# if hasattr(annotation_or_field, "annotation"):
355-
annotation_or_field = annotation_or_field.annotation
371+
def _detect_collection_fields(
372+
annotation_or_field: Any,
373+
attr: str,
374+
path: Tuple[str, ...],
375+
result: List[Any],
376+
) -> Any:
377+
annotation_or_field = next(
378+
(a for a in annotation_or_field.model_fields.values() if a.alias == attr),
379+
annotation_or_field.model_fields.get(attr),
380+
) # pragma: no cover
356381

357-
if is_collection_type(annotation_or_field):
358-
result.append(path[-1])
359-
return result
382+
annotation_or_field = getattr(
383+
annotation_or_field, "outer_type_", annotation_or_field
384+
)
385+
386+
# if hasattr(annotation_or_field, "annotation"):
387+
annotation_or_field = annotation_or_field.annotation
388+
389+
if is_collection_type(annotation_or_field):
390+
return result.append(path[-1])
391+
return annotation_or_field

tests/test_discriminator.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from pydantic import Field
44
from typing_extensions import Annotated, Literal
55

6-
from ninja import NinjaAPI, Schema
6+
from ninja import NinjaAPI, Query, Schema
77
from ninja.testing import TestClient
88

99

@@ -37,21 +37,43 @@ def create_example_regular(request, payload: RegularUnion):
3737
return {"data": payload.model_dump(), "type": payload.__class__.__name__}
3838

3939

40+
@api.get("/descr-union")
41+
def get_example(request, payload: UnionDiscriminator = Query(...)):
42+
return {}
43+
44+
45+
@api.get("/regular-union")
46+
def get_example_regular(request, payload: RegularUnion = Query(...)):
47+
return {}
48+
49+
4050
client = TestClient(api)
4151

4252

4353
def test_schema():
4454
schema = api.get_openapi_schema()
45-
detail1 = schema["paths"]["/api/descr-union"]["post"]["requestBody"]["content"][
46-
"application/json"
47-
]["schema"]
48-
detail2 = schema["paths"]["/api/regular-union"]["post"]["requestBody"]["content"][
49-
"application/json"
50-
]["schema"]
55+
post_detail1 = schema["paths"]["/api/descr-union"]["post"]["requestBody"][
56+
"content"
57+
]["application/json"]["schema"]
58+
post_detail2 = schema["paths"]["/api/regular-union"]["post"]["requestBody"][
59+
"content"
60+
]["application/json"]["schema"]
61+
get_detail1 = schema["paths"]["/api/descr-union"]["get"]["parameters"][0]["schema"]
62+
get_detail2 = schema["paths"]["/api/regular-union"]["get"]["parameters"][0][
63+
"schema"
64+
]
5165

5266
# First method should have 'discriminator' in OpenAPI api
53-
assert "discriminator" in detail1
54-
assert detail1["discriminator"] == {
67+
assert "discriminator" in post_detail1
68+
assert "discriminator" in get_detail1
69+
assert post_detail1["discriminator"] == {
70+
"mapping": {
71+
"ONE": "#/components/schemas/Example1",
72+
"TWO": "#/components/schemas/Example2",
73+
},
74+
"propertyName": "label",
75+
}
76+
assert get_detail1["discriminator"] == {
5577
"mapping": {
5678
"ONE": "#/components/schemas/Example1",
5779
"TWO": "#/components/schemas/Example2",
@@ -60,7 +82,8 @@ def test_schema():
6082
}
6183

6284
# Second method should NOT have 'discriminator'
63-
assert "discriminator" not in detail2
85+
assert "discriminator" not in post_detail2
86+
assert "discriminator" not in get_detail2
6487

6588

6689
def test_annotated_union_with_discriminator():
@@ -108,3 +131,34 @@ def test_regular_union():
108131
"data": {"label": "TWO", "value": 123},
109132
"type": "Example2",
110133
}
134+
135+
136+
def test_annotated_union_with_discriminator_get():
137+
# Test Example1
138+
response = client.get(
139+
"/descr-union",
140+
query_params={"label": "ONE", "value": "42"},
141+
)
142+
assert response.status_code == 200
143+
144+
# Test Example2
145+
response = client.get(
146+
"/descr-union",
147+
query_params={"label": "TWO", "value": "42"},
148+
)
149+
assert response.status_code == 200
150+
151+
152+
def test_regular_union_get():
153+
# Test that regular unions still work
154+
response = client.get(
155+
"/regular-union",
156+
query_params={"label": "ONE", "value": "2025"},
157+
)
158+
assert response.status_code == 200
159+
160+
response = client.get(
161+
"/regular-union",
162+
query_params={"label": "TWO", "value": 123},
163+
)
164+
assert response.status_code == 200

0 commit comments

Comments
 (0)