Skip to content

Commit 0de9885

Browse files
authored
Merge pull request #1574 from jceipek/fix-annotated-in-pathex
Fix use of `PathEx` with `Annotated` types
2 parents cd273d0 + 7abcae0 commit 0de9885

File tree

4 files changed

+199
-9
lines changed

4 files changed

+199
-9
lines changed

ninja/signature/details.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import inspect
22
import warnings
33
from collections import defaultdict, namedtuple
4+
from sys import version_info
45
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
56

67
import pydantic
@@ -222,9 +223,24 @@ def _get_param_type(self, name: str, arg: inspect.Parameter) -> FuncParam:
222223

223224
if get_origin(annotation) is Annotated:
224225
args = get_args(annotation)
225-
if isinstance(args[1], Param):
226+
if isinstance(args[-1], Param):
226227
prev_default = default
227-
annotation, default = args
228+
if len(args) == 2:
229+
annotation, default = args
230+
else:
231+
# TODO: Remove version check once support for <=3.8 is dropped.
232+
# Annotated[] is only available at runtime in 3.9+ per
233+
# https://docs.python.org/3/library/typing.html#typing.Annotated
234+
if version_info >= (3, 9):
235+
# NOTE: Annotated[args[:-1]] seems to have the same runtime
236+
# behavior as Annotated[*args[:-1]], but the latter is
237+
# invalid in Python < 3.11 because star expressions
238+
# were not allowed in index expressions.
239+
annotation, default = Annotated[args[:-1]], args[-1]
240+
else: # pragma: no cover -- requires specific Python versions
241+
raise NotImplementedError(
242+
"This definition requires Python version 3.9+"
243+
)
228244
if prev_default != self.signature.empty:
229245
default.default = prev_default
230246

tests/main.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
from sys import version_info
12
from typing import List, Optional
23
from uuid import UUID
34

5+
import pydantic
6+
import pytest
47
from django.urls import register_converter
8+
from typing_extensions import Annotated
59

6-
from ninja import Field, Path, Query, Router, Schema
10+
from ninja import Field, P, Path, PathEx, Query, Router, Schema
711

812
router = Router()
913

@@ -38,6 +42,46 @@ def get_bool_id(request, item_id: bool):
3842
return item_id
3943

4044

45+
def custom_validator(value: int) -> int:
46+
if value != 42:
47+
raise ValueError("Input should pass this custom validator")
48+
return value
49+
50+
51+
CustomValidatedInt = Annotated[
52+
int,
53+
pydantic.AfterValidator(custom_validator),
54+
pydantic.WithJsonSchema({
55+
"type": "int",
56+
"example": "42",
57+
}),
58+
]
59+
60+
# TODO: Remove this condition once support for <= 3.8 is dropped
61+
if version_info >= (3, 9):
62+
63+
@router.get("/path/param_ex/{item_id}")
64+
def get_path_param_ex_id(
65+
request,
66+
item_id: PathEx[CustomValidatedInt, P(description="path_ex description")],
67+
):
68+
return item_id
69+
70+
else:
71+
72+
def test_annotated_path_ex_unsupported():
73+
with pytest.raises(NotImplementedError, match="3.9+"):
74+
75+
@router.get("/path/param_ex/{item_id}")
76+
def get_path_param_ex_id(
77+
request,
78+
item_id: PathEx[
79+
CustomValidatedInt, P(description="path_ex description")
80+
],
81+
):
82+
return item_id
83+
84+
4185
@router.get("/path/param/{item_id}")
4286
def get_path_param_id(request, item_id: str = Path(None)):
4387
return item_id

tests/test_openapi_schema.py

Lines changed: 104 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,26 @@
11
import sys
2+
from sys import version_info
23
from typing import Any, List, Union
34
from unittest.mock import Mock
45

6+
import pydantic
57
import pytest
68
from django.contrib.admin.views.decorators import staff_member_required
79
from django.test import Client, override_settings
8-
from pydantic import ConfigDict
9-
10-
from ninja import Body, Field, File, Form, NinjaAPI, Query, Schema, UploadedFile
10+
from typing_extensions import Annotated
11+
12+
from ninja import (
13+
Body,
14+
Field,
15+
File,
16+
Form,
17+
NinjaAPI,
18+
P,
19+
PathEx,
20+
Query,
21+
Schema,
22+
UploadedFile,
23+
)
1124
from ninja.openapi.urls import get_openapi_urls
1225
from ninja.pagination import PaginationBase, paginate
1326
from ninja.renderers import JSONRenderer
@@ -28,13 +41,23 @@ class TypeB(Schema):
2841
b: str
2942

3043

44+
AnnotatedStr = Annotated[
45+
str,
46+
pydantic.WithJsonSchema({
47+
"type": "string",
48+
"format": "custom-format",
49+
"example": "example_string",
50+
}),
51+
]
52+
53+
3154
def to_camel(string: str) -> str:
3255
words = string.split("_")
3356
return words[0].lower() + "".join(word.capitalize() for word in words[1:])
3457

3558

3659
class Response(Schema):
37-
model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True)
60+
model_config = pydantic.ConfigDict(alias_generator=to_camel, populate_by_name=True)
3861
i: int
3962
f: float = Field(..., title="f title", description="f desc")
4063

@@ -65,10 +88,42 @@ def method_body_schema(request, data: Payload):
6588

6689

6790
@api.get("/test-path/{int:i}/{f}", response=Response)
68-
def method_path(request, i: int, f: float):
91+
def method_path(
92+
request,
93+
i: int,
94+
f: float,
95+
):
6996
return dict(i=i, f=f)
7097

7198

99+
# This definition is only possible in Python 3.9+
100+
# TODO: Drop this condition once support for <= 3.8 is dropped
101+
if version_info >= (3, 9):
102+
103+
@api.get("/test-pathex/{path_ex}", response=AnnotatedStr)
104+
def method_pathex(
105+
request,
106+
path_ex: PathEx[
107+
AnnotatedStr,
108+
P(description="path_ex description"),
109+
],
110+
):
111+
return path_ex
112+
113+
else:
114+
with pytest.raises(NotImplementedError, match="3.9+"):
115+
116+
@api.get("/test-pathex/{path_ex}", response=AnnotatedStr)
117+
def method_pathex(
118+
request,
119+
path_ex: PathEx[
120+
AnnotatedStr,
121+
P(description="path_ex description"),
122+
],
123+
):
124+
return path_ex
125+
126+
72127
@api.post("/test-form", response=Response)
73128
def method_form(request, data: Payload = Form(...)):
74129
return dict(i=data.i, f=data.f)
@@ -434,6 +489,49 @@ def test_schema_path(schema):
434489
}
435490

436491

492+
@pytest.mark.skipif(
493+
version_info < (3, 9),
494+
reason="requires py3.9+ for Annotated[] at the route definition site",
495+
)
496+
def test_schema_pathex(schema):
497+
method_list = schema["paths"]["/api/test-pathex/{path_ex}"]["get"]
498+
499+
assert "requestBody" not in method_list
500+
501+
assert method_list["parameters"] == [
502+
{
503+
"in": "path",
504+
"name": "path_ex",
505+
"schema": {
506+
"title": "Path Ex",
507+
"type": "string",
508+
"format": "custom-format",
509+
"description": "path_ex description",
510+
"example": "example_string",
511+
},
512+
"required": True,
513+
"example": "example_string",
514+
"description": "path_ex description",
515+
},
516+
]
517+
518+
assert method_list["responses"] == {
519+
200: {
520+
"content": {
521+
"application/json": {
522+
"schema": {
523+
"example": "example_string",
524+
"format": "custom-format",
525+
"title": "Response",
526+
"type": "string",
527+
},
528+
},
529+
},
530+
"description": "OK",
531+
}
532+
}
533+
534+
437535
def test_schema_form(schema):
438536
method_list = schema["paths"]["/api/test-form"]["post"]
439537

@@ -606,7 +704,7 @@ def test_schema_title_description(schema):
606704
"schema": {
607705
"properties": {
608706
"file": {
609-
"description": "file " "param " "desc",
707+
"description": "file param desc",
610708
"format": "binary",
611709
"title": "File",
612710
"type": "string",

tests/test_path.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from sys import version_info
2+
13
import pytest
24
from main import router
35

@@ -33,6 +35,17 @@ def test_text_get():
3335
]
3436
}
3537

38+
response_not_valid_custom = {
39+
"detail": [
40+
{
41+
"ctx": {"error": "Input should pass this custom validator"},
42+
"loc": ["path", "item_id"],
43+
"msg": "Value error, Input should pass this custom validator",
44+
"type": "value_error",
45+
}
46+
]
47+
}
48+
3649
response_not_valid_int_float = {
3750
"detail": [
3851
{
@@ -274,6 +287,25 @@ def test_get_path(path, expected_status, expected_response):
274287
assert response.json() == expected_response
275288

276289

290+
@pytest.mark.skipif(
291+
version_info < (3, 9),
292+
reason="requires py3.9+ for Annotated[] at the route definition site",
293+
)
294+
@pytest.mark.parametrize(
295+
"path,expected_status,expected_response",
296+
[
297+
("/path/param_ex/True", 422, response_not_valid_int),
298+
("/path/param_ex/0", 422, response_not_valid_custom),
299+
("/path/param_ex/42", 200, 42),
300+
],
301+
)
302+
def test_get_pathex(path, expected_status, expected_response):
303+
response = client.get(path)
304+
print(path, response.json())
305+
assert response.status_code == expected_status
306+
assert response.json() == expected_response
307+
308+
277309
@pytest.mark.parametrize(
278310
"path,expected_status,expected_response",
279311
[

0 commit comments

Comments
 (0)