Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/dishka/code_tools/factory_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,9 @@ def _selector_factory_body(
builder: FactoryBuilder, source_call: str, factory: Factory,
) -> None:
first = True
has_default = False
for variant in factory.when_dependencies:
condition = builder.when(variant.when_override, factory.when_component)
solved_value = builder.getter(variant.provides)
has_default = False
if first and not condition:
builder.assign_solved(solved_value)
elif first:
Expand All @@ -189,12 +187,12 @@ def _selector_factory_body(
elif not condition:
with builder.else_():
builder.assign_solved(solved_value)
has_default = True
first = True
else:
with builder.elif_(condition):
builder.assign_solved(solved_value)
if not has_default:
# if-chain not closed with else or not generated at all
if not first or not factory.when_dependencies:
error_call = builder.call(
builder.global_(NoActiveFactoryError),
builder.global_(factory.provides),
Expand Down
1 change: 0 additions & 1 deletion src/dishka/dependency_source/activator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def _replace_dep(
) -> DependencyKey:
if (
dependency.type_hint is self.marker_type or
dependency.type_hint is Marker or
dependency.type_hint is Marker
):
return const_dependency_key(marker)
Expand Down
2 changes: 1 addition & 1 deletion src/dishka/dependency_source/context_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def as_factory(
return aliased.as_factory(scope=self.scope, component=component)

def __get__(self, instance: Any, owner: Any) -> ContextVariable:
scope = self.scope or instance.scope
scope = self.scope or getattr(instance, "scope", None)
return ContextVariable(
scope=scope,
provides=self.provides,
Expand Down
9 changes: 5 additions & 4 deletions src/dishka/dependency_source/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dishka.entities.component import Component
from dishka.entities.key import DependencyKey
from dishka.entities.marker import BaseMarker, BoolMarker, combine_when
from dishka.entities.marker import BaseMarker, combine_when
from dishka.entities.scope import BaseScope
from .factory import Factory
from .type_match import get_typevar_replacement, is_broader_or_same_type
Expand Down Expand Up @@ -42,6 +42,8 @@ def as_factory(
new_dependency: DependencyKey,
cache: bool,
component: Component,
when_override: BaseMarker | None,
when_active: BaseMarker | None,
) -> Factory:
typevar_replacement = get_typevar_replacement(
self.provides.type_hint,
Expand All @@ -50,7 +52,6 @@ def as_factory(
if self.scope is not None:
scope = self.scope

when = self.when or BoolMarker(False)
return Factory(
scope=scope,
source=self.factory.source,
Expand All @@ -70,8 +71,8 @@ def as_factory(
},
type_=self.factory.type,
cache=cache,
when_override=when,
when_active=when,
when_override=combine_when(self.when, when_override),
when_active=combine_when(self.when, when_active),
when_component=self.factory.when_component or component,
when_dependencies=[],
)
Expand Down
5 changes: 1 addition & 4 deletions src/dishka/dependency_source/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class Factory(FactoryData):
"when_active",
"when_component",
"when_dependencies",
"when_override",
)

def __init__(
Expand Down Expand Up @@ -73,9 +72,7 @@ def __init__(
self.when_dependencies = when_dependencies

def __get__(self, instance: Any, owner: Any) -> Factory:
scope = self.scope or instance.scope
if instance is None:
return self
scope = self.scope or getattr(instance, "scope", None)
provider_when = getattr(instance, "when", None)
when_active = combine_when(provider_when, self.when_active)
when_override = combine_when(provider_when, self.when_override)
Expand Down
22 changes: 21 additions & 1 deletion src/dishka/entities/marker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Any, ClassVar

from dishka.exception_base import InvalidMarkerError


class BaseMarker:
"""
Expand Down Expand Up @@ -153,4 +156,21 @@ class HasContext(Marker):
Special marker for checking if a type is available in current context.
"""
def __repr__(self) -> str:
return f"HasContext({self.value.__name__})"
return f"HasContext({self.value})"


def unpack_marker(marker: BaseMarker | None) -> Iterator[Marker]:
match marker:
case Marker():
yield marker
case NotMarker():
yield from unpack_marker(marker.marker)
case BinOpMarker():
yield from unpack_marker(marker.left)
yield from unpack_marker(marker.right)
case BoolMarker():
return
case None:
return
case _:
raise InvalidMarkerError(marker)
11 changes: 11 additions & 0 deletions src/dishka/exception_base.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
from typing import Any


class DishkaError(Exception):
pass


class InvalidMarkerError(DishkaError):
def __init__(self, marker: Any) -> None:
self.marker = marker

def __str__(self) -> str:
return f"Cannot use {self.marker!r} as marker."
9 changes: 0 additions & 9 deletions src/dishka/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections.abc import Sequence
from typing import Any

from dishka.entities.factory_type import FactoryData
from dishka.entities.marker import Marker
Expand Down Expand Up @@ -48,14 +47,6 @@ class InvalidGraphError(DishkaError):
pass


class InvalidMarkerError(DishkaError):
def __init__(self, marker: Any) -> None:
self.marker = marker

def __str__(self) -> str:
return f"Cannot use {self.marker!r} as marker."


class NoActivatorError(DishkaError):
def __init__(self, marker_key: DependencyKey) -> None:
self.marker_key = marker_key
Expand Down
9 changes: 8 additions & 1 deletion src/dishka/plotter/d2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@ def _render_node(self, node: Node) -> str:
if node.type is NodeType.ALIAS:
return ""
name = self._node_type(node) + self._escape(node.name)

if node.source_name:
source = f' "{self._escape(node.source_name)}()": ""\n'
else:
source = ""

res = f'{node.id}: "{name}"' + "{\n"
res += " shape: class\n"

if node.type is NodeType.SELECTOR:
res += "}\n"
return res

res += source
for dep in node.dependencies:
dep_name = self._escape(self.nodes[dep].name)
Expand Down Expand Up @@ -54,6 +59,8 @@ def _node_type(self, node: Node) -> str:
prefix = ""
if node.type is NodeType.DECORATOR:
return "🎭 " + prefix
elif node.type is NodeType.SELECTOR:
return "🤔 " + prefix
elif node.type is NodeType.CONTEXT:
return "📥 " + prefix
elif node.type is NodeType.ALIAS:
Expand Down
14 changes: 12 additions & 2 deletions src/dishka/plotter/mermaid.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def _render_node(self, node: Node) -> str:
if node.type is NodeType.ALIAS:
return ""
name = self._node_type(node) + self._escape(node.name)
if node.type is NodeType.SELECTOR:
return f'class {node.id}["{name}"]'
source_name = self._escape(node.source_name)
return "\n".join([
f'class {node.id}["{name}"]{{',
Expand Down Expand Up @@ -77,6 +79,8 @@ def _node_type(self, node: Node) -> str:
prefix = ""
if node.type is NodeType.DECORATOR:
return "🎭 " + prefix
elif node.type is NodeType.SELECTOR:
return "🤔 " + prefix
elif node.type is NodeType.CONTEXT:
return "📥 " + prefix
elif node.type is NodeType.ALIAS:
Expand Down Expand Up @@ -116,8 +120,14 @@ def _fill_nodes(self, groups: list[Group]) -> None:

def render(self, groups: list[Group]) -> str:
self._fill_nodes(groups)

res = "classDiagram\n"
res = (
"---\n"
" config:\n"
" class:\n"
" hideEmptyMembersBox: true\n"
"---\n"
)
res += "classDiagram\n"
res += "direction LR\n"
for group in groups:
res += self._render_group(group)
Expand Down
1 change: 1 addition & 0 deletions src/dishka/plotter/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class NodeType(Enum):
CONTEXT = "Context"
FACTORY = "Factory"
DECORATOR = "Decorator"
SELECTOR = "Selector"
ALIAS = "Alias"


Expand Down
45 changes: 36 additions & 9 deletions src/dishka/plotter/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
from dishka._adaptix.type_tools.basic_utils import is_protocol
from dishka.dependency_source import Factory
from dishka.entities.factory_type import FactoryType
from dishka.entities.marker import unpack_marker
from dishka.registry import Registry
from dishka.registry_builder import DECORATED_COMPONENT_PREFIX
from dishka.registry_builder import (
DECORATED_COMPONENT_PREFIX,
SELECTOR_COMPONENT_PREFIX,
)
from dishka.text_rendering import get_name
from .model import Group, GroupType, Node, NodeType

Expand Down Expand Up @@ -38,6 +42,8 @@ def _node_type(self, factory: Factory) -> NodeType:
for dep in factory.kw_dependencies.values():
if self._is_decorated(dep):
return NodeType.DECORATOR
if factory.when_dependencies:
return NodeType.SELECTOR

if factory.type is FactoryType.ALIAS:
return NodeType.ALIAS
Expand All @@ -51,11 +57,23 @@ def _is_decorated_component(self, group: Group) -> bool:
return False
return group.name.startswith(DECORATED_COMPONENT_PREFIX)

def _trace_decorator(self, node: Node, target_group: Group) -> None:
if node.type is not NodeType.DECORATOR:
def _is_selector_component(self, group: Group) -> bool:
if group.type is not GroupType.COMPONENT:
return False
return group.name.startswith(SELECTOR_COMPONENT_PREFIX)

def _trace_internal_components(
self,
node: Node,
target_group: Group,
) -> None:
if node.type not in (NodeType.DECORATOR, NodeType.SELECTOR):
return
for group in self.groups.values():
if not self._is_decorated_component(group):
if not (
self._is_decorated_component(group) or
self._is_selector_component(group)
):
continue
nodes_to_move = [
n
Expand All @@ -65,7 +83,7 @@ def _trace_decorator(self, node: Node, target_group: Group) -> None:
for moved in nodes_to_move:
group.nodes.remove(moved)
target_group.nodes.append(moved)
self._trace_decorator(moved, target_group)
self._trace_internal_components(moved, target_group)

def _make_factories(
self, scope: BaseScope, group: Group, registry: Registry,
Expand All @@ -84,7 +102,11 @@ def _make_factories(
)
group.children.append(component_group)
node_name = get_name(key.type_hint, include_module=False)
if factory.type in (FactoryType.CONTEXT, FactoryType.ALIAS):
if factory.type in (
FactoryType.CONTEXT,
FactoryType.ALIAS,
FactoryType.SELECTOR,
):
source_name = ""
else:
source_name = get_name(factory.source, include_module=False)
Expand All @@ -107,8 +129,13 @@ def _fill_dependencies(
for key, factory in registry.factories.items():
node = self.nodes[key, registry.scope]
all_deps = (
list(factory.dependencies)
+ list(factory.kw_dependencies.values())
list(factory.dependencies)
+ list(factory.kw_dependencies.values())
+ [
DependencyKey(m, factory.when_component)
for m in unpack_marker(factory.when_override)
]
+ [sub.provides for sub in factory.when_dependencies]
)
for dep in all_deps:
for dep_registry in parent_registries:
Expand Down Expand Up @@ -151,6 +178,6 @@ def transform(self, container: Container | AsyncContainer) -> list[Group]:

for group in self.groups.values():
for node in group.nodes:
self._trace_decorator(node, group)
self._trace_internal_components(node, group)

return self.clean_groups(result)
Loading
Loading