Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/dishka/provider/make_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,22 @@ def _is_bound_method(obj: Any) -> bool:
return ismethod(obj) and bool(obj.__self__)


def _resolve_init(tp: type) -> Any:
init = tp.__init__ # type: ignore[misc]
if init is not _protocol_init:
return init
for cls in tp.__mro__:
if cls is object:
continue
init_candidate = cls.__dict__.get("__init__")
if init_candidate is not None and init_candidate is not _protocol_init:
return init_candidate
return init


def _get_init_members(tp: type) -> MembersStorage[str, None]:
type_hints = get_all_type_hints(tp.__init__) # type: ignore[misc, no-untyped-call]
real_init = _resolve_init(tp)
type_hints = get_all_type_hints(real_init) # type: ignore[no-untyped-call]
if "__init__" in tp.__dict__:
overridden = frozenset(type_hints)
else:
Expand Down Expand Up @@ -256,7 +270,7 @@ def _make_factory_by_class(
if not provides:
provides = source

init = strip_alias(source).__init__
init = _resolve_init(strip_alias(source))
if missing_hints := _params_without_hints(init, skip_self=True):
raise MissingHintsError(source, missing_hints, append_init=True)
# we need to fix concrete generics and normal classes as well
Expand Down
136 changes: 136 additions & 0 deletions tests/unit/container/test_resolve.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import math
from typing import Protocol

import pytest

from dishka import (
Expand All @@ -7,6 +10,7 @@
make_container,
provide,
)
from dishka.provider import provide_all
from ..sample_providers import (
A_VALUE,
ClassA,
Expand Down Expand Up @@ -144,3 +148,135 @@ def test_kwargs():

container = make_container(provider)
assert container.get(str) == "ok"


class _Dep:
pass


def test_provide_multiple_protocols_before_base():
class Proto1(Protocol):
pass

class Proto2(Protocol):
pass

class Base:
def __init__(self, dep: _Dep) -> None:
self.dep = dep

class Multi(Proto1, Proto2, Base):
pass

provider = Provider(scope=Scope.APP)
provider.provide(_Dep)
provider.provide(Multi)

container = make_container(provider)
result = container.get(Multi)
assert isinstance(result, Multi)
assert isinstance(result.dep, _Dep)


def test_provide_own_init_overrides_protocol_stub():
class Proto(Protocol):
pass

class Base:
def __init__(self, dep: _Dep) -> None:
self.dep = dep

class Impl(Proto, Base):
def __init__(self, dep: _Dep) -> None:
self.own_dep = dep

provider = Provider(scope=Scope.APP)
provider.provide(_Dep)
provider.provide(Impl)

container = make_container(provider)
result = container.get(Impl)
assert isinstance(result, Impl)
assert isinstance(result.own_dep, _Dep)


def test_provide_protocol_with_explicit_init():
class ProtoWithInit(Protocol):
def __init__(self, dep: _Dep) -> None: ...

class Impl(ProtoWithInit):
def __init__(self, dep: _Dep) -> None:
self.dep = dep

provider = Provider(scope=Scope.APP)
provider.provide(_Dep)
provider.provide(Impl)

container = make_container(provider)
result = container.get(Impl)
assert isinstance(result, Impl)
assert isinstance(result.dep, _Dep)


def test_provide_deep_hierarchy_with_protocol():
class Proto(Protocol):
pass

class GrandBase:
def __init__(self, dep: _Dep) -> None:
self.dep = dep

class Base(GrandBase):
pass

class Impl(Proto, Base):
pass

provider = Provider(scope=Scope.APP)
provider.provide(_Dep)
provider.provide(Impl)

container = make_container(provider)
result = container.get(Impl)
assert isinstance(result, Impl)
assert isinstance(result.dep, _Dep)


def test_provide_all_as_provider_method():
def a() -> int:
return 100

def b(num: int) -> float:
return num / 2

provider = Provider(scope=Scope.APP)
provider.provide_all(a, b)

container = make_container(provider)

hundred = container.get(int)
assert hundred == 100

fifty = container.get(float)
assert math.isclose(fifty, 50.0, abs_tol=1e-9)


def test_provide_all_in_class():
class MyProvider(Provider):
scope = Scope.APP

def a(self) -> int:
return 100

def b(self, num: int) -> float:
return num / 2

abcd = provide_all(a, b)

container = make_container(MyProvider())

hundred = container.get(int)
assert hundred == 100

fifty = container.get(float)
assert math.isclose(fifty, 50.0, abs_tol=1e-9)
29 changes: 29 additions & 0 deletions tests/unit/container/test_with_parents.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,32 @@ class MyProvider(Provider):
container = make_container(MyProvider())

assert isinstance(container.get(IntRepo), ConcreteRepo)


class _Dep:
pass


def test_protocol_first_with_parents() -> None:
class Proto(Protocol):
pass

class Base:
def __init__(self, dep: _Dep) -> None:
self.dep = dep

class Impl(Proto, Base):
pass

provider = Provider(scope=Scope.APP)
provider.provide(_Dep)
provider.provide(WithParents[Impl])

container = make_container(provider)
result = container.get(Impl)
assert isinstance(result, Impl)
assert isinstance(result.dep, _Dep)
assert container.get(Proto) is result
assert container.get(Base) is result


42 changes: 0 additions & 42 deletions tests/unit/test_provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
from collections.abc import (
AsyncGenerator,
AsyncIterable,
Expand Down Expand Up @@ -28,7 +27,6 @@
Scope,
alias,
decorate,
make_container,
provide,
)
from dishka._adaptix.feature_requirement import HAS_TV_DEFAULT
Expand Down Expand Up @@ -479,46 +477,6 @@ def decorator(self, param: int) -> str:
decorate(decorator)


def test_provide_all_as_provider_method():
def a() -> int:
return 100

def b(num: int) -> float:
return num / 2

provider = Provider(scope=Scope.APP)
provider.provide_all(a, b)

container = make_container(provider)

hundred = container.get(int)
assert hundred == 100

fifty = container.get(float)
assert math.isclose(fifty, 50.0, abs_tol=1e-9)


def test_provide_all_in_class():
class MyProvider(Provider):
scope = Scope.APP

def a(self) -> int:
return 100

def b(self, num: int) -> float:
return num / 2

abcd = provide_all(a, b)

container = make_container(MyProvider())

hundred = container.get(int)
assert hundred == 100

fifty = container.get(float)
assert math.isclose(fifty, 50.0, abs_tol=1e-9)


make_factory_by_source = partial(
make_factory,
provides=None,
Expand Down
Loading