-
-
Notifications
You must be signed in to change notification settings - Fork 146
Improve optimizer for polymorphism and connections: post-fetch prefetching, InheritanceManager support, and fewer N+1s #808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
79a1f03
9f2921a
a2773ae
89f94f5
1f2a210
e413ce1
464554e
0c80160
c0ef0fb
7e11552
49f61d3
903317b
13d8191
a540b09
fc5feac
19a91e3
7357a46
b93d5f6
fe72d8f
12ff37f
8d23a67
fb153af
d526a74
6ea1ed2
4f2a031
c125606
f334ffe
5f7aa18
67ef528
cce87ca
57f65eb
908f0de
d044790
041633d
0f6bc35
4d48ce5
44ecb19
96e67da
766b424
6197da3
f77ec4f
3ed6732
2de8d59
4396c67
841b926
71708ff
0d9ce84
a251726
d7e2b28
c83dfeb
f87eede
bb4a65a
4f4981c
a524124
cc7c87d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -277,6 +277,21 @@ async def async_resolver(): | |
| if "info" not in kwargs: | ||
| kwargs["info"] = info | ||
|
|
||
| # If we have a prefetched cache for this reverse accessor on the source instance, | ||
| # pass a hint so the queryset hook can short-circuit to the cached list without | ||
| # re-optimizing/requerying. | ||
| attname = self.django_name or self.python_name | ||
| if source is not None: | ||
| cache = getattr(source, "_prefetched_objects_cache", None) | ||
| if isinstance(cache, dict) and attname in cache: | ||
| kwargs = dict(kwargs) | ||
| kwargs["__use_prefetched_cache__"] = attname | ||
|
|
||
| # Provide source instance in kwargs so the queryset hook can use prefetched cache | ||
| if source is not None: | ||
| kwargs = dict(kwargs) | ||
| kwargs["__source__"] = source | ||
|
|
||
| result = django_resolver( | ||
| self.get_queryset_hook(**kwargs), | ||
| qs_hook=lambda qs: qs, | ||
|
|
@@ -286,13 +301,54 @@ async def async_resolver(): | |
|
|
||
| def get_queryset_hook(self, info: Info, **kwargs): | ||
| if self.is_connection or self.is_paginated: | ||
| # We don't want to fetch results yet, those will be done by the connection/pagination | ||
| # For connections/paginated fields, avoid DB hits when we already have | ||
| # a prefetched cache for this reverse accessor on the source instance. | ||
| # Otherwise, just return the queryset and let the connection handle it. | ||
| def qs_hook(qs: models.QuerySet): # type: ignore | ||
| return self.get_queryset(qs, info, **kwargs) | ||
| # If the resolver passed a hint to use the source's prefetched cache, | ||
| # return the cached list so the connection can operate on a Python list | ||
| # (preventing per-node LIMIT queries on reverse relations). | ||
| use_cache_key = kwargs.pop("__use_prefetched_cache__", None) | ||
| source_obj = kwargs.pop("__source__", None) | ||
| if use_cache_key and source_obj is not None: | ||
| cache = getattr(source_obj, "_prefetched_objects_cache", None) | ||
| if isinstance(cache, dict) and use_cache_key in cache: | ||
| return cache[use_cache_key] | ||
| qs2 = self.get_queryset(qs, info, **kwargs) | ||
| # If the connection queryset carries parent-level postfetch branches, | ||
| # we need to trigger evaluation now so the optimizer's postfetch hook | ||
| # can batch nested reverse relations across the page. This will not | ||
| # bypass pagination because the queryset already carries LIMIT/OFFSET. | ||
| from strawberry_django.queryset import ( | ||
| get_queryset_config as _get_qs_cfg, | ||
| ) | ||
|
|
||
| cfg = _get_qs_cfg(qs2) | ||
| if getattr(cfg, "parent_postfetch_branches", None): | ||
| from strawberry_django.resolvers import default_qs_hook as _dqsh | ||
|
|
||
| qs2 = _dqsh(qs2) | ||
| return qs2 | ||
|
|
||
| elif self.is_list: | ||
|
|
||
| def qs_hook(qs: models.QuerySet): # type: ignore | ||
| # If the source instance has a prefetched cache for this accessor, short-circuit | ||
| use_cache_key = kwargs.pop("__use_prefetched_cache__", None) | ||
| source_obj = kwargs.pop("__source__", None) | ||
| if use_cache_key and source_obj is not None: | ||
| cache = getattr(source_obj, "_prefetched_objects_cache", None) | ||
| if isinstance(cache, dict) and use_cache_key in cache: | ||
| # Only short-circuit to the cache if the queryset does NOT carry | ||
| # postfetch hints. If it does, we must run the default_qs_hook so | ||
| # nested postfetch can execute on this queryset. | ||
| from strawberry_django.queryset import ( | ||
| get_queryset_config as _get_qs_cfg, | ||
| ) | ||
|
|
||
| cfg = _get_qs_cfg(qs) | ||
| if not getattr(cfg, "postfetch_prefetch", None): | ||
| return cache[use_cache_key] | ||
|
Comment on lines
+345
to
+351
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick: Ditto Also, can't we just import those at module-level? (don't remember if they would cause import loops) |
||
| qs = self.get_queryset(qs, info, **kwargs) | ||
| if not self.disable_fetch_list_results: | ||
| qs = default_qs_hook(qs) | ||
|
|
@@ -464,13 +520,75 @@ def resolve( | |
| assert self.connection_type is not None | ||
| nodes = cast("Iterable[relay.Node]", next_(source, info, **kwargs)) | ||
|
|
||
| # Helper to apply early SQL pagination for simple forward root connections | ||
| def _apply_early_pagination(qs: models.QuerySet): | ||
| import contextlib | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick: this should be imported at module level |
||
|
|
||
| # Only for root connections (source is None), forward pagination with `first`, | ||
| # and no cursors | ||
| if source is not None: | ||
| return qs | ||
| if ( | ||
| first in {None, 0} | ||
| or before is not None | ||
| or after is not None | ||
| or last is not None | ||
| ): | ||
| return qs | ||
|
|
||
| # Cap by max_results when provided | ||
| page_limit = first | ||
| if isinstance(self.max_results, int) and page_limit is not None: | ||
| page_limit = min(page_limit, self.max_results) | ||
|
|
||
| # Ensure deterministic ordering | ||
| if not qs.ordered: | ||
| qs = qs.order_by("pk") | ||
|
|
||
| # Build ORDER BY expressions as Django does for window pagination | ||
| from django.db import DEFAULT_DB_ALIAS | ||
|
Comment on lines
+548
to
+549
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick: ditto |
||
|
|
||
| # Use the public `db` property instead of the private `_db` attribute | ||
| with contextlib.suppress(Exception): | ||
| compiler = qs.query.get_compiler(using=(qs.db or DEFAULT_DB_ALIAS)) | ||
| order_by_exprs = [expr for expr, _ in compiler.get_order_by()] | ||
|
|
||
| # Annotate row number and total count (global partition) | ||
| from django.db.models import Count, Window | ||
| from django.db.models.functions import RowNumber | ||
|
Comment on lines
+556
to
+558
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick: ditto |
||
|
|
||
| qs = qs.annotate( | ||
| _strawberry_row_number=Window( | ||
| RowNumber(), | ||
| partition_by=None, | ||
| order_by=order_by_exprs, | ||
| ), | ||
| _strawberry_total_count=Window( | ||
| Count(1), | ||
| partition_by=None, | ||
| ), | ||
| ) | ||
|
Comment on lines
+560
to
+570
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thought: this is duplicating some code that we have on pagination module. Also, maybe this whole function should be there, as it isnot related to the base field itself |
||
|
|
||
| # Fetch one extra row so super().resolve_connection can compute hasNextPage | ||
| return ( | ||
| qs.filter(_strawberry_row_number__lte=(page_limit + 1)) | ||
| if page_limit is not None | ||
| else qs | ||
| ) | ||
|
|
||
| # Best-effort: on any failure above, return original queryset | ||
| return qs | ||
|
|
||
| # We have a single resolver for both sync and async, so we need to check if | ||
| # nodes is awaitable or not and resolve it accordingly | ||
| if inspect.isawaitable(nodes): | ||
|
|
||
| async def async_resolver(): | ||
| awaited = await nodes | ||
| if isinstance(awaited, models.QuerySet): | ||
| awaited = _apply_early_pagination(awaited) | ||
| resolved = self.connection_type.resolve_connection( | ||
| await nodes, | ||
| awaited, | ||
| info=info, | ||
| before=before, | ||
| after=after, | ||
|
|
@@ -486,6 +604,9 @@ async def async_resolver(): | |
|
|
||
| return async_resolver() | ||
|
|
||
| if isinstance(nodes, models.QuerySet): | ||
| nodes = _apply_early_pagination(nodes) | ||
|
|
||
| return self.connection_type.resolve_connection( | ||
| nodes, | ||
| info=info, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question: why the aliases here? I think you can just import and use their names directly