Skip to content

Commit 24dc3e9

Browse files
authored
Merge pull request #658 from fadedDexofan/fix/protocol-inspection
fix: resolve `__init__` introspection when `Protocol` is first in bases
2 parents f0dffc8 + 51af83b commit 24dc3e9

File tree

4 files changed

+181
-44
lines changed

4 files changed

+181
-44
lines changed

src/dishka/provider/make_factory.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,22 @@ def _is_bound_method(obj: Any) -> bool:
9292
return ismethod(obj) and bool(obj.__self__)
9393

9494

95+
def _resolve_init(tp: type) -> Any:
96+
init = tp.__init__ # type: ignore[misc]
97+
if init is not _protocol_init:
98+
return init
99+
for cls in tp.__mro__:
100+
if cls is object:
101+
continue
102+
init_candidate = cls.__dict__.get("__init__")
103+
if init_candidate is not None and init_candidate is not _protocol_init:
104+
return init_candidate
105+
return init
106+
107+
95108
def _get_init_members(tp: type) -> MembersStorage[str, None]:
96-
type_hints = get_all_type_hints(tp.__init__) # type: ignore[misc, no-untyped-call]
109+
real_init = _resolve_init(tp)
110+
type_hints = get_all_type_hints(real_init) # type: ignore[no-untyped-call]
97111
if "__init__" in tp.__dict__:
98112
overridden = frozenset(type_hints)
99113
else:
@@ -256,7 +270,7 @@ def _make_factory_by_class(
256270
if not provides:
257271
provides = source
258272

259-
init = strip_alias(source).__init__
273+
init = _resolve_init(strip_alias(source))
260274
if missing_hints := _params_without_hints(init, skip_self=True):
261275
raise MissingHintsError(source, missing_hints, append_init=True)
262276
# we need to fix concrete generics and normal classes as well

tests/unit/container/test_resolve.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import math
2+
from typing import Protocol
3+
14
import pytest
25

36
from dishka import (
@@ -7,6 +10,7 @@
710
make_container,
811
provide,
912
)
13+
from dishka.provider import provide_all
1014
from ..sample_providers import (
1115
A_VALUE,
1216
ClassA,
@@ -144,3 +148,135 @@ def test_kwargs():
144148

145149
container = make_container(provider)
146150
assert container.get(str) == "ok"
151+
152+
153+
class _Dep:
154+
pass
155+
156+
157+
def test_provide_multiple_protocols_before_base():
158+
class Proto1(Protocol):
159+
pass
160+
161+
class Proto2(Protocol):
162+
pass
163+
164+
class Base:
165+
def __init__(self, dep: _Dep) -> None:
166+
self.dep = dep
167+
168+
class Multi(Proto1, Proto2, Base):
169+
pass
170+
171+
provider = Provider(scope=Scope.APP)
172+
provider.provide(_Dep)
173+
provider.provide(Multi)
174+
175+
container = make_container(provider)
176+
result = container.get(Multi)
177+
assert isinstance(result, Multi)
178+
assert isinstance(result.dep, _Dep)
179+
180+
181+
def test_provide_own_init_overrides_protocol_stub():
182+
class Proto(Protocol):
183+
pass
184+
185+
class Base:
186+
def __init__(self, dep: _Dep) -> None:
187+
self.dep = dep
188+
189+
class Impl(Proto, Base):
190+
def __init__(self, dep: _Dep) -> None:
191+
self.own_dep = dep
192+
193+
provider = Provider(scope=Scope.APP)
194+
provider.provide(_Dep)
195+
provider.provide(Impl)
196+
197+
container = make_container(provider)
198+
result = container.get(Impl)
199+
assert isinstance(result, Impl)
200+
assert isinstance(result.own_dep, _Dep)
201+
202+
203+
def test_provide_protocol_with_explicit_init():
204+
class ProtoWithInit(Protocol):
205+
def __init__(self, dep: _Dep) -> None: ...
206+
207+
class Impl(ProtoWithInit):
208+
def __init__(self, dep: _Dep) -> None:
209+
self.dep = dep
210+
211+
provider = Provider(scope=Scope.APP)
212+
provider.provide(_Dep)
213+
provider.provide(Impl)
214+
215+
container = make_container(provider)
216+
result = container.get(Impl)
217+
assert isinstance(result, Impl)
218+
assert isinstance(result.dep, _Dep)
219+
220+
221+
def test_provide_deep_hierarchy_with_protocol():
222+
class Proto(Protocol):
223+
pass
224+
225+
class GrandBase:
226+
def __init__(self, dep: _Dep) -> None:
227+
self.dep = dep
228+
229+
class Base(GrandBase):
230+
pass
231+
232+
class Impl(Proto, Base):
233+
pass
234+
235+
provider = Provider(scope=Scope.APP)
236+
provider.provide(_Dep)
237+
provider.provide(Impl)
238+
239+
container = make_container(provider)
240+
result = container.get(Impl)
241+
assert isinstance(result, Impl)
242+
assert isinstance(result.dep, _Dep)
243+
244+
245+
def test_provide_all_as_provider_method():
246+
def a() -> int:
247+
return 100
248+
249+
def b(num: int) -> float:
250+
return num / 2
251+
252+
provider = Provider(scope=Scope.APP)
253+
provider.provide_all(a, b)
254+
255+
container = make_container(provider)
256+
257+
hundred = container.get(int)
258+
assert hundred == 100
259+
260+
fifty = container.get(float)
261+
assert math.isclose(fifty, 50.0, abs_tol=1e-9)
262+
263+
264+
def test_provide_all_in_class():
265+
class MyProvider(Provider):
266+
scope = Scope.APP
267+
268+
def a(self) -> int:
269+
return 100
270+
271+
def b(self, num: int) -> float:
272+
return num / 2
273+
274+
abcd = provide_all(a, b)
275+
276+
container = make_container(MyProvider())
277+
278+
hundred = container.get(int)
279+
assert hundred == 100
280+
281+
fifty = container.get(float)
282+
assert math.isclose(fifty, 50.0, abs_tol=1e-9)

tests/unit/container/test_with_parents.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,32 @@ class MyProvider(Provider):
303303
container = make_container(MyProvider())
304304

305305
assert isinstance(container.get(IntRepo), ConcreteRepo)
306+
307+
308+
class _Dep:
309+
pass
310+
311+
312+
def test_protocol_first_with_parents() -> None:
313+
class Proto(Protocol):
314+
pass
315+
316+
class Base:
317+
def __init__(self, dep: _Dep) -> None:
318+
self.dep = dep
319+
320+
class Impl(Proto, Base):
321+
pass
322+
323+
provider = Provider(scope=Scope.APP)
324+
provider.provide(_Dep)
325+
provider.provide(WithParents[Impl])
326+
327+
container = make_container(provider)
328+
result = container.get(Impl)
329+
assert isinstance(result, Impl)
330+
assert isinstance(result.dep, _Dep)
331+
assert container.get(Proto) is result
332+
assert container.get(Base) is result
333+
334+

tests/unit/test_provider.py

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import math
21
from collections.abc import (
32
AsyncGenerator,
43
AsyncIterable,
@@ -28,7 +27,6 @@
2827
Scope,
2928
alias,
3029
decorate,
31-
make_container,
3230
provide,
3331
)
3432
from dishka._adaptix.feature_requirement import HAS_TV_DEFAULT
@@ -479,46 +477,6 @@ def decorator(self, param: int) -> str:
479477
decorate(decorator)
480478

481479

482-
def test_provide_all_as_provider_method():
483-
def a() -> int:
484-
return 100
485-
486-
def b(num: int) -> float:
487-
return num / 2
488-
489-
provider = Provider(scope=Scope.APP)
490-
provider.provide_all(a, b)
491-
492-
container = make_container(provider)
493-
494-
hundred = container.get(int)
495-
assert hundred == 100
496-
497-
fifty = container.get(float)
498-
assert math.isclose(fifty, 50.0, abs_tol=1e-9)
499-
500-
501-
def test_provide_all_in_class():
502-
class MyProvider(Provider):
503-
scope = Scope.APP
504-
505-
def a(self) -> int:
506-
return 100
507-
508-
def b(self, num: int) -> float:
509-
return num / 2
510-
511-
abcd = provide_all(a, b)
512-
513-
container = make_container(MyProvider())
514-
515-
hundred = container.get(int)
516-
assert hundred == 100
517-
518-
fifty = container.get(float)
519-
assert math.isclose(fifty, 50.0, abs_tol=1e-9)
520-
521-
522480
make_factory_by_source = partial(
523481
make_factory,
524482
provides=None,

0 commit comments

Comments
 (0)