Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions examples/fastapi/_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,23 @@ async def test_user_list(self, client: AsyncClient) -> None: # nosec
await self.user_list(client)


@pytest.mark.anyio
async def test_404(client: AsyncClient) -> None:
response = await client.get("/404")
assert response.status_code == 404, response.text
data = response.json()
assert isinstance(data["detail"], str)


@pytest.mark.anyio
async def test_422(client: AsyncClient) -> None:
response = await client.get("/422")
assert response.status_code == 422, response.text
data = response.json()
assert isinstance(data["detail"], list)
assert isinstance(data["detail"][0], dict)


class TestUserEast(UserTester):
timezone = "Asia/Shanghai"
delta_hours = 8
Expand Down Expand Up @@ -123,6 +140,23 @@ async def test_user_list(self, client_east: AsyncClient) -> None: # nosec
assert item.model_dump()["created_at"].hour == created_at.hour


@pytest.mark.anyio
async def test_404_east(client_east: AsyncClient) -> None:
response = await client_east.get("/404")
assert response.status_code == 404, response.text
data = response.json()
assert isinstance(data["detail"], str)


@pytest.mark.anyio
async def test_422_east(client_east: AsyncClient) -> None:
response = await client_east.get("/422")
assert response.status_code == 422, response.text
data = response.json()
assert isinstance(data["detail"], list)
assert isinstance(data["detail"][0], dict)


def query_without_app(pk: int) -> int:
async def runner() -> bool:
async with register_orm():
Expand Down
1 change: 0 additions & 1 deletion examples/fastapi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@
db_url=os.getenv("DB_URL", "sqlite://db.sqlite3"),
modules={"models": ["models"]},
generate_schemas=True,
add_exception_handlers=True,
)
9 changes: 6 additions & 3 deletions examples/fastapi/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from examples.fastapi.config import register_orm
from tortoise import Tortoise, generate_config
from tortoise.contrib.fastapi import RegisterTortoise
from tortoise.contrib.fastapi import RegisterTortoise, tortoise_exception_handlers


@asynccontextmanager
Expand All @@ -23,7 +23,6 @@ async def lifespan_test(app: FastAPI) -> AsyncGenerator[None, None]:
app=app,
config=config,
generate_schemas=True,
add_exception_handlers=True,
_create_db=True,
):
# db connected
Expand All @@ -47,5 +46,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# db connections closed


app = FastAPI(title="Tortoise ORM FastAPI example", lifespan=lifespan)
app = FastAPI(
title="Tortoise ORM FastAPI example",
lifespan=lifespan,
exception_handlers=tortoise_exception_handlers(),
)
app.include_router(users_router, prefix="")
1 change: 1 addition & 0 deletions examples/fastapi/main_custom_timezone.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
app,
use_tz=False,
timezone="Asia/Shanghai",
add_exception_handlers=True,
):
# db connected
yield
Expand Down
11 changes: 11 additions & 0 deletions examples/fastapi/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,14 @@ async def delete_user(user_id: int):
if not deleted_count:
raise HTTPException(status_code=404, detail=f"User {user_id} not found")
return Status(message=f"Deleted user {user_id}")


@router.get("/404")
async def get_404():
await Users.get(id=0)


@router.get("/422")
async def get_422():
obj = await Users.create(username="foo")
await Users.create(username=obj.username)
114 changes: 58 additions & 56 deletions tortoise/contrib/fastapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,36 @@
from types import ModuleType
from typing import TYPE_CHECKING

from fastapi.responses import JSONResponse
from pydantic import BaseModel # pylint: disable=E0611
from starlette.routing import _DefaultLifespan

from tortoise import Tortoise, connections
from tortoise.exceptions import DoesNotExist, IntegrityError
from tortoise.log import logger

if TYPE_CHECKING:
from fastapi import FastAPI, Request


if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


class HTTPNotFoundError(BaseModel):
detail: str
def tortoise_exception_handlers() -> dict:
from fastapi.responses import JSONResponse

async def doesnotexist_exception_handler(request: "Request", exc: DoesNotExist):
return JSONResponse(status_code=404, content={"detail": str(exc)})

async def integrityerror_exception_handler(request: "Request", exc: IntegrityError):
return JSONResponse(
status_code=422,
content={"detail": [{"loc": [], "msg": str(exc), "type": "IntegrityError"}]},
)

return {
DoesNotExist: doesnotexist_exception_handler,
IntegrityError: integrityerror_exception_handler,
}


class RegisterTortoise(AbstractAsyncContextManager):
Expand Down Expand Up @@ -122,17 +133,22 @@ def __init__(
self._create_db = _create_db

if add_exception_handlers and app is not None:
from starlette.middleware.exceptions import ExceptionMiddleware

warnings.warn(
"Setting `add_exception_handlers` to be true is deprecated, "
"use `FastAPI(exception_handlers=tortoise_exception_handlers())` instead."
"See more about it on https://tortoise.github.io/examples/fastapi",
DeprecationWarning,
)
original_call_func = ExceptionMiddleware.__call__

@app.exception_handler(DoesNotExist)
async def doesnotexist_exception_handler(request: "Request", exc: DoesNotExist):
return JSONResponse(status_code=404, content={"detail": str(exc)})
async def wrap_middleware_call(self, *args, **kw) -> None:
if DoesNotExist not in self._exception_handlers:
self._exception_handlers.update(tortoise_exception_handlers())
await original_call_func(self, *args, **kw)

@app.exception_handler(IntegrityError)
async def integrityerror_exception_handler(request: "Request", exc: IntegrityError):
return JSONResponse(
status_code=422,
content={"detail": [{"loc": [], "msg": str(exc), "type": "IntegrityError"}]},
)
ExceptionMiddleware.__call__ = wrap_middleware_call # type:ignore

async def init_orm(self) -> None: # pylint: disable=W0612
await Tortoise.init(
Expand Down Expand Up @@ -166,8 +182,7 @@ async def __aexit__(self, *args, **kw) -> None:

def __await__(self) -> Generator[None, None, Self]:
async def _self() -> Self:
await self.init_orm()
return self
return await self.__aenter__()

return _self().__await__()

Expand All @@ -182,8 +197,9 @@ def register_tortoise(
add_exception_handlers: bool = False,
) -> None:
"""
Registers ``startup`` and ``shutdown`` events to set-up and tear-down Tortoise-ORM
inside a FastAPI application.
Registers Tortoise-ORM with set-up at the beginning of FastAPI application's lifespan
(which allow user to read/write data from/to db inside the lifespan function),
and tear-down at the end of that lifespan.

You can configure using only one of ``config``, ``config_file``
and ``(db_url, modules)``.
Expand Down Expand Up @@ -245,40 +261,26 @@ def register_tortoise(
ConfigurationError
For any configuration error
"""
orm = RegisterTortoise(
app,
config,
config_file,
db_url,
modules,
generate_schemas,
add_exception_handlers,
)
if isinstance(lifespan := app.router.lifespan_context, _DefaultLifespan):
# Leave on_event here to compare with old versions
# So people can upgrade tortoise-orm in running project without changing any code

@app.on_event("startup")
async def init_orm() -> None: # pylint: disable=W0612
await orm.init_orm()

@app.on_event("shutdown")
async def close_orm() -> None: # pylint: disable=W0612
await orm.close_orm()

else:
# If custom lifespan was passed to app, register tortoise in it
warnings.warn(
"`register_tortoise` function is deprecated, "
"use the `RegisterTortoise` class instead."
"See more about it on https://tortoise.github.io/examples/fastapi",
DeprecationWarning,
)

@asynccontextmanager
async def orm_lifespan(app_instance: "FastAPI"):
async with orm:
async with lifespan(app_instance):
yield

app.router.lifespan_context = orm_lifespan
from fastapi.routing import _merge_lifespan_context

# Leave this function here to compare with old versions
# So people can upgrade tortoise-orm in running project without changing any code

@asynccontextmanager
async def orm_lifespan(app_instance: "FastAPI"):
async with RegisterTortoise(
app_instance,
config,
config_file,
db_url,
modules,
generate_schemas,
):
yield

original_lifespan = app.router.lifespan_context
app.router.lifespan_context = _merge_lifespan_context(orm_lifespan, original_lifespan)

if add_exception_handlers:
for exp_type, endpoint in tortoise_exception_handlers().items():
app.exception_handler(exp_type)(endpoint)