Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .coveragerc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[report]
exclude_lines =
exclude_lines =
pragma: not covered
@overload
[run]
Expand Down
170 changes: 170 additions & 0 deletions src/dishka/activator_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from __future__ import annotations

from dataclasses import dataclass
from enum import Enum, auto

from dishka.dependency_source.activator import Activator
from dishka.dependency_source.factory import Factory
from dishka.entities.factory_type import FactoryType
from dishka.entities.key import DependencyKey
from dishka.entities.marker import Marker
from dishka.entities.scope import BaseScope
from dishka.factory_index import FactoryIndex


class ActivatorType(Enum):
STATIC = auto()
DYNAMIC = auto()


@dataclass(frozen=True, slots=True)
class ClassifiedActivator:
key: DependencyKey
activator: Activator
type: ActivatorType
dependencies: frozenset[DependencyKey]


class ActivatorClassifier:
def __init__(
self,
factory_index: FactoryIndex,
activators: dict[DependencyKey, Activator],
root_scope: BaseScope,
) -> None:
self._factory_index = factory_index
self._activators = activators
self._root_scope = root_scope

def _is_async_factory(self, factory: Factory) -> bool:
return factory.type in (
FactoryType.ASYNC_FACTORY,
FactoryType.ASYNC_GENERATOR,
)

def _is_marker_dependency(
self,
activator: Activator,
dep: DependencyKey,
) -> bool:
return (
dep.type_hint is activator.marker_type
or dep.type_hint is Marker
)

def _get_factory_deps(self, factory: Factory) -> list[DependencyKey]:
return list(factory.dependencies) + list(
factory.kw_dependencies.values(),
)

def _get_activator_dependencies(
self,
activator: Activator,
) -> frozenset[DependencyKey]:
all_deps = self._get_factory_deps(activator.factory)
return frozenset(
dep for dep in all_deps
if dep in self._activators
and not self._is_marker_dependency(activator, dep)
)

def _get_all_dependencies(
self,
activator: Activator,
) -> list[DependencyKey]:
all_deps = self._get_factory_deps(activator.factory)
return [
dep for dep in all_deps
if not self._is_marker_dependency(activator, dep)
]

def _is_root_context_dep(self, dep: DependencyKey) -> bool:
return dep in self._factory_index.context_keys_at_root

def _is_registered(self, dep: DependencyKey) -> bool:
return dep in self._factory_index or dep in self._activators

def _topological_sort(
self,
activator_deps: dict[DependencyKey, frozenset[DependencyKey]],
) -> list[DependencyKey]:
result: list[DependencyKey] = []
visited: set[DependencyKey] = set()

def visit(key: DependencyKey) -> None:
if key in visited:
return
for dep in activator_deps.get(key, frozenset()):
if dep in activator_deps:
visit(dep)
visited.add(key)
result.append(key)

for key in activator_deps:
visit(key)

return result

def classify(self) -> dict[DependencyKey, ClassifiedActivator]:
"""Classify activators as STATIC or DYNAMIC.

Returns dict ordered by dependency topology (dependencies before
dependents). Callers may rely on this ordering invariant.
"""
activator_deps: dict[DependencyKey, frozenset[DependencyKey]] = {}
for key, activator in self._activators.items():
activator_deps[key] = self._get_activator_dependencies(activator)

eval_order = self._topological_sort(activator_deps)

classification: dict[DependencyKey, ClassifiedActivator] = {}

for key in eval_order:
activator = self._activators[key]
activator_type = self._classify_single(
activator,
activator_deps[key],
classification,
)
classification[key] = ClassifiedActivator(
key=key,
activator=activator,
type=activator_type,
dependencies=activator_deps[key],
)

return classification

def _classify_single(
self,
activator: Activator,
activator_dependencies: frozenset[DependencyKey],
already_classified: dict[DependencyKey, ClassifiedActivator],
) -> ActivatorType:
factory = activator.factory

if self._is_async_factory(factory):
return ActivatorType.DYNAMIC

all_deps = self._get_all_dependencies(activator)

if not all_deps:
return ActivatorType.STATIC

for dep in activator_dependencies:
classified = already_classified.get(dep)
if classified and classified.type == ActivatorType.DYNAMIC:
return ActivatorType.DYNAMIC

non_activator_deps = [
dep for dep in all_deps if dep not in self._activators
]

for dep in non_activator_deps:
if self._is_root_context_dep(dep):
continue
if not self._is_registered(dep):
continue
return ActivatorType.DYNAMIC

return ActivatorType.STATIC
6 changes: 4 additions & 2 deletions src/dishka/async_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,14 +307,16 @@ def make_async_container(
) -> AsyncContainer:
context_provider = make_root_context_provider(providers, context, scopes)
has_provider = HasProvider()
registries = RegistryBuilder(
builder = RegistryBuilder(
scopes=scopes,
container_key=CONTAINER_KEY,
multicomponent_providers=[has_provider],
providers=(*providers, context_provider),
skip_validation=skip_validation,
validation_settings=validation_settings,
).build()
)
registries = builder.build(context=context, start_scope=start_scope)

container = AsyncContainer(
*registries,
context=context,
Expand Down
6 changes: 4 additions & 2 deletions src/dishka/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,14 +304,16 @@ def make_container(
) -> Container:
context_provider = make_root_context_provider(providers, context, scopes)
has_provider = HasProvider()
registries = RegistryBuilder(
builder = RegistryBuilder(
scopes=scopes,
container_key=CONTAINER_KEY,
multicomponent_providers=[has_provider],
providers=(*providers, context_provider),
skip_validation=skip_validation,
validation_settings=validation_settings,
).build()
)
registries = builder.build(context=context, start_scope=start_scope)

container = Container(
*registries,
context=context,
Expand Down
4 changes: 2 additions & 2 deletions src/dishka/dependency_source/activator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dishka.entities.component import Component
from dishka.entities.key import DependencyKey, const_dependency_key
from dishka.entities.marker import Marker
from dishka.entities.marker import BaseMarker, Marker
from dishka.entities.scope import BaseScope
from .factory import Factory

Expand Down Expand Up @@ -35,7 +35,7 @@ def _replace_dep(
if (
dependency.type_hint is self.marker_type or
dependency.type_hint is Marker or
dependency.type_hint is Marker
dependency.type_hint is BaseMarker
):
return const_dependency_key(marker)
return dependency
Expand Down
44 changes: 44 additions & 0 deletions src/dishka/factory_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from dataclasses import dataclass

from dishka.dependency_source.factory import Factory
from dishka.entities.factory_type import FactoryType
from dishka.entities.key import DependencyKey
from dishka.entities.scope import BaseScope


@dataclass(frozen=True, slots=True)
class FactoryIndex:
factories_by_key: dict[DependencyKey, Factory]
context_keys_at_root: frozenset[DependencyKey]

@classmethod
def from_processed_factories(
cls,
processed_factories: dict[DependencyKey, list[Factory]],
root_scope: BaseScope,
) -> FactoryIndex:
factories_by_key: dict[DependencyKey, Factory] = {}
context_keys: set[DependencyKey] = set()

for key, factory_list in processed_factories.items():
if factory_list:
factory = factory_list[-1] # Last wins (override order)
factories_by_key[key] = factory
if (
factory.type == FactoryType.CONTEXT
and factory.scope == root_scope
):
context_keys.add(key)

return cls(
factories_by_key=factories_by_key,
context_keys_at_root=frozenset(context_keys),
)

def __contains__(self, key: DependencyKey) -> bool:
return key in self.factories_by_key

def get(self, key: DependencyKey) -> Factory | None:
return self.factories_by_key.get(key)
118 changes: 118 additions & 0 deletions src/dishka/processed_factory_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import annotations

from dishka.dependency_source.factory import Factory
from dishka.entities.component import DEFAULT_COMPONENT
from dishka.entities.key import DependencyKey
from dishka.entities.marker import (
AndMarker,
BaseMarker,
BoolMarker,
Marker,
NotMarker,
OrMarker,
)


class ProcessedFactoryFilter:
def __init__(
self,
activation_results: dict[DependencyKey, bool],
) -> None:
self._activation_results = activation_results

def _eval_simple_marker(
self,
marker: Marker,
component: str,
) -> bool | None:
key = DependencyKey(marker, component)
return self._activation_results.get(key)

def _eval_not_marker(
self,
marker: NotMarker,
component: str,
) -> bool | None:
inner = self._is_marker_active(marker.marker, component)
return None if inner is None else not inner

def _eval_or_marker(
self,
marker: OrMarker,
component: str,
) -> bool | None:
left = self._is_marker_active(marker.left, component)
right = self._is_marker_active(marker.right, component)
if left is True or right is True:
return True
if left is None or right is None:
return None
return False

def _eval_and_marker(
self,
marker: AndMarker,
component: str,
) -> bool | None:
left = self._is_marker_active(marker.left, component)
right = self._is_marker_active(marker.right, component)
if left is False or right is False:
return False
if left is None or right is None:
return None
return True

def _is_marker_active( # noqa: PLR0911
self,
marker: BaseMarker | None,
component: str,
) -> bool | None:
match marker:
case None:
return True
case BoolMarker():
return None if not marker.value else True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return marker.value?

case NotMarker():
return self._eval_not_marker(marker, component)
case OrMarker():
return self._eval_or_marker(marker, component)
case AndMarker():
return self._eval_and_marker(marker, component)
case Marker():
return self._eval_simple_marker(marker, component)
case _:
return None

def _should_include_factory(
self,
factory: Factory,
component: str,
) -> bool:
if factory.when_active is None:
return True

result = self._is_marker_active(factory.when_active, component)

if result is None:
return True # Dynamic - keep
return result # True=keep, False=remove

def filter(
self,
processed_factories: dict[DependencyKey, list[Factory]],
) -> dict[DependencyKey, list[Factory]]:
if not self._activation_results:
return processed_factories

filtered: dict[DependencyKey, list[Factory]] = {}

for key, factory_list in processed_factories.items():
component = key.component or DEFAULT_COMPONENT
kept = [
f for f in factory_list
if self._should_include_factory(f, component)
]
if kept:
filtered[key] = kept

return filtered
Loading
Loading