Skip to content

Commit 430f305

Browse files
committed
feat: add static evaluation of activators
Classify activators as STATIC/DYNAMIC and evaluate static ones at container creation time, stripping inactive providers from the dependency graph.
1 parent 9dc9bea commit 430f305

File tree

6 files changed

+857
-4
lines changed

6 files changed

+857
-4
lines changed

src/dishka/activator_classifier.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from enum import Enum, auto
5+
from typing import TYPE_CHECKING
6+
7+
from dishka.entities.factory_type import FactoryType
8+
from dishka.entities.key import DependencyKey
9+
from dishka.entities.marker import Marker
10+
11+
if TYPE_CHECKING:
12+
from dishka.dependency_source.activator import Activator
13+
from dishka.dependency_source.factory import Factory
14+
from dishka.entities.scope import BaseScope
15+
from dishka.registry import Registry
16+
17+
18+
class ActivatorType(Enum):
19+
STATIC = auto()
20+
DYNAMIC = auto()
21+
22+
23+
@dataclass(frozen=True, slots=True)
24+
class ClassifiedActivator:
25+
key: DependencyKey
26+
activator: Activator
27+
type: ActivatorType
28+
dependencies: frozenset[DependencyKey]
29+
30+
31+
class ActivatorClassifier:
32+
def __init__(
33+
self,
34+
registries: tuple[Registry, ...],
35+
activators: dict[DependencyKey, Activator],
36+
root_scope: BaseScope,
37+
) -> None:
38+
self._registries = registries
39+
self._activators = activators
40+
self._root_scope = root_scope
41+
self._all_factories: dict[DependencyKey, Factory] = {}
42+
self._context_keys_at_root: set[DependencyKey] = set()
43+
self._build_factory_index()
44+
45+
def _build_factory_index(self) -> None:
46+
for registry in self._registries:
47+
for key, factory in registry.factories.items():
48+
self._all_factories[key] = factory
49+
if (
50+
factory.type == FactoryType.CONTEXT
51+
and factory.scope == self._root_scope
52+
):
53+
self._context_keys_at_root.add(key)
54+
55+
def _is_async_factory(self, factory: Factory) -> bool:
56+
return factory.type in (
57+
FactoryType.ASYNC_FACTORY,
58+
FactoryType.ASYNC_GENERATOR,
59+
)
60+
61+
def _is_marker_dependency(
62+
self,
63+
activator: Activator,
64+
dep: DependencyKey,
65+
) -> bool:
66+
"""Check if dependency is the activator's marker (auto-injected)."""
67+
return (
68+
dep.type_hint is activator.marker_type
69+
or dep.type_hint is Marker
70+
)
71+
72+
def _get_activator_dependencies(
73+
self,
74+
activator: Activator,
75+
) -> frozenset[DependencyKey]:
76+
factory = activator.factory
77+
all_deps = list(factory.dependencies) + list(
78+
factory.kw_dependencies.values(),
79+
)
80+
return frozenset(
81+
dep for dep in all_deps
82+
if dep in self._activators
83+
and not self._is_marker_dependency(activator, dep)
84+
)
85+
86+
def _get_all_dependencies(
87+
self,
88+
activator: Activator,
89+
) -> list[DependencyKey]:
90+
factory = activator.factory
91+
all_deps = list(factory.dependencies) + list(
92+
factory.kw_dependencies.values(),
93+
)
94+
return [
95+
dep for dep in all_deps
96+
if not self._is_marker_dependency(activator, dep)
97+
]
98+
99+
def _is_root_context_dep(self, dep: DependencyKey) -> bool:
100+
return dep in self._context_keys_at_root
101+
102+
def _is_registered(self, dep: DependencyKey) -> bool:
103+
return dep in self._all_factories or dep in self._activators
104+
105+
def _topological_sort(
106+
self,
107+
activator_deps: dict[DependencyKey, frozenset[DependencyKey]],
108+
) -> list[DependencyKey]:
109+
result: list[DependencyKey] = []
110+
visited: set[DependencyKey] = set()
111+
112+
def visit(key: DependencyKey) -> None:
113+
if key in visited:
114+
return
115+
for dep in activator_deps.get(key, frozenset()):
116+
if dep in activator_deps:
117+
visit(dep)
118+
visited.add(key)
119+
result.append(key)
120+
121+
for key in activator_deps:
122+
visit(key)
123+
124+
return result
125+
126+
def classify(self) -> dict[DependencyKey, ClassifiedActivator]:
127+
activator_deps: dict[DependencyKey, frozenset[DependencyKey]] = {}
128+
for key, activator in self._activators.items():
129+
activator_deps[key] = self._get_activator_dependencies(activator)
130+
131+
eval_order = self._topological_sort(activator_deps)
132+
133+
classification: dict[DependencyKey, ClassifiedActivator] = {}
134+
135+
for key in eval_order:
136+
activator = self._activators[key]
137+
activator_type = self._classify_single(
138+
activator,
139+
activator_deps[key],
140+
classification,
141+
)
142+
classification[key] = ClassifiedActivator(
143+
key=key,
144+
activator=activator,
145+
type=activator_type,
146+
dependencies=activator_deps[key],
147+
)
148+
149+
return classification
150+
151+
def _classify_single(
152+
self,
153+
activator: Activator,
154+
activator_dependencies: frozenset[DependencyKey],
155+
already_classified: dict[DependencyKey, ClassifiedActivator],
156+
) -> ActivatorType:
157+
factory = activator.factory
158+
159+
if self._is_async_factory(factory):
160+
return ActivatorType.DYNAMIC
161+
162+
all_deps = self._get_all_dependencies(activator)
163+
164+
if not all_deps:
165+
return ActivatorType.STATIC
166+
167+
for dep in activator_dependencies:
168+
classified = already_classified.get(dep)
169+
if classified and classified.type == ActivatorType.DYNAMIC:
170+
return ActivatorType.DYNAMIC
171+
172+
non_activator_deps = [
173+
dep for dep in all_deps if dep not in self._activators
174+
]
175+
176+
for dep in non_activator_deps:
177+
if self._is_root_context_dep(dep):
178+
continue
179+
if not self._is_registered(dep):
180+
continue
181+
return ActivatorType.DYNAMIC
182+
183+
return ActivatorType.STATIC

src/dishka/async_container.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .provider import BaseProvider, make_root_context_provider
3131
from .registry import Registry
3232
from .registry_builder import RegistryBuilder
33+
from .static_evaluator import apply_static_evaluation
3334

3435
T = TypeVar("T")
3536

@@ -307,14 +308,19 @@ def make_async_container(
307308
) -> AsyncContainer:
308309
context_provider = make_root_context_provider(providers, context, scopes)
309310
has_provider = HasProvider()
310-
registries = RegistryBuilder(
311+
builder = RegistryBuilder(
311312
scopes=scopes,
312313
container_key=CONTAINER_KEY,
313314
multicomponent_providers=[has_provider],
314315
providers=(*providers, context_provider),
315316
skip_validation=skip_validation,
316317
validation_settings=validation_settings,
317-
).build()
318+
)
319+
registries = builder.build()
320+
registries = apply_static_evaluation(
321+
registries, builder.activators, context, start_scope,
322+
)
323+
318324
container = AsyncContainer(
319325
*registries,
320326
context=context,

src/dishka/container.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .provider import BaseProvider, make_root_context_provider
3131
from .registry import Registry
3232
from .registry_builder import RegistryBuilder
33+
from .static_evaluator import apply_static_evaluation
3334

3435
T = TypeVar("T")
3536

@@ -304,14 +305,19 @@ def make_container(
304305
) -> Container:
305306
context_provider = make_root_context_provider(providers, context, scopes)
306307
has_provider = HasProvider()
307-
registries = RegistryBuilder(
308+
builder = RegistryBuilder(
308309
scopes=scopes,
309310
container_key=CONTAINER_KEY,
310311
multicomponent_providers=[has_provider],
311312
providers=(*providers, context_provider),
312313
skip_validation=skip_validation,
313314
validation_settings=validation_settings,
314-
).build()
315+
)
316+
registries = builder.build()
317+
registries = apply_static_evaluation(
318+
registries, builder.activators, context, start_scope,
319+
)
320+
315321
container = Container(
316322
*registries,
317323
context=context,

src/dishka/registry_filter.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
from dishka.entities.component import DEFAULT_COMPONENT
6+
from dishka.entities.key import DependencyKey
7+
from dishka.entities.marker import (
8+
AndMarker,
9+
BaseMarker,
10+
BoolMarker,
11+
Marker,
12+
NotMarker,
13+
OrMarker,
14+
)
15+
from dishka.registry import Registry
16+
17+
if TYPE_CHECKING:
18+
from dishka.dependency_source.factory import Factory
19+
20+
21+
class RegistryFilter:
22+
def __init__(
23+
self,
24+
activation_results: dict[DependencyKey, bool],
25+
) -> None:
26+
self._activation_results = activation_results
27+
28+
def _eval_simple_marker(
29+
self,
30+
marker: Marker,
31+
provides_key: DependencyKey | None,
32+
) -> bool | None:
33+
key = DependencyKey(
34+
marker,
35+
provides_key.component if provides_key else DEFAULT_COMPONENT,
36+
)
37+
return self._activation_results.get(key)
38+
39+
def _eval_not_marker(
40+
self,
41+
marker: NotMarker,
42+
provides_key: DependencyKey | None,
43+
) -> bool | None:
44+
inner = self._is_marker_active(marker.marker, provides_key)
45+
return None if inner is None else not inner
46+
47+
def _eval_or_marker(
48+
self,
49+
marker: OrMarker,
50+
provides_key: DependencyKey | None,
51+
) -> bool | None:
52+
left = self._is_marker_active(marker.left, provides_key)
53+
right = self._is_marker_active(marker.right, provides_key)
54+
if left is True or right is True:
55+
return True
56+
if left is None or right is None:
57+
return None
58+
return False
59+
60+
def _eval_and_marker(
61+
self,
62+
marker: AndMarker,
63+
provides_key: DependencyKey | None,
64+
) -> bool | None:
65+
left = self._is_marker_active(marker.left, provides_key)
66+
right = self._is_marker_active(marker.right, provides_key)
67+
if left is False or right is False:
68+
return False
69+
if left is None or right is None:
70+
return None
71+
return True
72+
73+
def _is_marker_active(
74+
self,
75+
marker: BaseMarker | None,
76+
provides_key: DependencyKey | None,
77+
) -> bool | None:
78+
result: bool | None
79+
match marker:
80+
case None:
81+
result = True
82+
case BoolMarker():
83+
result = None if not marker.value else True
84+
case NotMarker():
85+
result = self._eval_not_marker(marker, provides_key)
86+
case OrMarker():
87+
result = self._eval_or_marker(marker, provides_key)
88+
case AndMarker():
89+
result = self._eval_and_marker(marker, provides_key)
90+
case Marker():
91+
result = self._eval_simple_marker(marker, provides_key)
92+
case _:
93+
result = None
94+
return result
95+
96+
def _should_include_factory(self, factory: Factory) -> bool:
97+
if factory.when_active is None:
98+
return True
99+
100+
result = self._is_marker_active(
101+
factory.when_active,
102+
factory.provides,
103+
)
104+
105+
if result is None:
106+
return True
107+
return result
108+
109+
def filter(
110+
self,
111+
registries: tuple[Registry, ...],
112+
) -> tuple[Registry, ...]:
113+
if not self._activation_results:
114+
return registries
115+
116+
filtered: list[Registry] = []
117+
118+
for registry in registries:
119+
new_registry = Registry(
120+
registry.scope,
121+
has_fallback=registry.has_fallback,
122+
)
123+
124+
for key, factory in registry.factories.items():
125+
if self._should_include_factory(factory):
126+
new_registry.add_factory(factory, key)
127+
128+
filtered.append(new_registry)
129+
130+
return tuple(filtered)

0 commit comments

Comments
 (0)