Skip to content

Commit ceb8cad

Browse files
committed
fix: resolve __init__ introspection when Protocol is first in bases
Protocol's `_no_init_or_replace_init` stub shadows real `__init__` from base classes. Walk MRO to find the actual `__init__` with type hints. Closes #628
1 parent f0dffc8 commit ceb8cad

File tree

3 files changed

+137
-2
lines changed

3 files changed

+137
-2
lines changed

src/dishka/provider/make_factory.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,22 @@ def _is_bound_method(obj: Any) -> bool:
9292
return ismethod(obj) and bool(obj.__self__)
9393

9494

95+
def _resolve_init(tp: type) -> Any:
96+
init = tp.__init__ # type: ignore[misc]
97+
if init is not _protocol_init:
98+
return init
99+
for cls in tp.__mro__:
100+
if cls is object:
101+
continue
102+
init_candidate = cls.__dict__.get("__init__")
103+
if init_candidate is not None and init_candidate is not _protocol_init:
104+
return init_candidate
105+
return init
106+
107+
95108
def _get_init_members(tp: type) -> MembersStorage[str, None]:
96-
type_hints = get_all_type_hints(tp.__init__) # type: ignore[misc, no-untyped-call]
109+
real_init = _resolve_init(tp)
110+
type_hints = get_all_type_hints(real_init) # type: ignore[no-untyped-call]
97111
if "__init__" in tp.__dict__:
98112
overridden = frozenset(type_hints)
99113
else:
@@ -256,7 +270,7 @@ def _make_factory_by_class(
256270
if not provides:
257271
provides = source
258272

259-
init = strip_alias(source).__init__
273+
init = _resolve_init(strip_alias(source))
260274
if missing_hints := _params_without_hints(init, skip_self=True):
261275
raise MissingHintsError(source, missing_hints, append_init=True)
262276
# we need to fix concrete generics and normal classes as well

tests/unit/container/test_with_parents.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,32 @@ class MyProvider(Provider):
303303
container = make_container(MyProvider())
304304

305305
assert isinstance(container.get(IntRepo), ConcreteRepo)
306+
307+
308+
class _Dep:
309+
pass
310+
311+
312+
def test_protocol_first_with_parents() -> None:
313+
class Proto(Protocol):
314+
pass
315+
316+
class Base:
317+
def __init__(self, dep: _Dep) -> None:
318+
self.dep = dep
319+
320+
class Impl(Proto, Base):
321+
pass
322+
323+
provider = Provider(scope=Scope.APP)
324+
provider.provide(_Dep)
325+
provider.provide(WithParents[Impl])
326+
327+
container = make_container(provider)
328+
result = container.get(Impl)
329+
assert isinstance(result, Impl)
330+
assert isinstance(result.dep, _Dep)
331+
assert container.get(Proto) is result
332+
assert container.get(Base) is result
333+
334+

tests/unit/test_provider.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,98 @@ class MyImpl(MyProto):
413413
assert factory.dependencies == []
414414

415415

416+
class _Dep:
417+
pass
418+
419+
420+
def test_provide_multiple_protocols_before_base():
421+
class Proto1(Protocol):
422+
pass
423+
424+
class Proto2(Protocol):
425+
pass
426+
427+
class Base:
428+
def __init__(self, dep: _Dep) -> None:
429+
self.dep = dep
430+
431+
class Multi(Proto1, Proto2, Base):
432+
pass
433+
434+
provider = Provider(scope=Scope.APP)
435+
provider.provide(_Dep)
436+
provider.provide(Multi)
437+
438+
container = make_container(provider)
439+
result = container.get(Multi)
440+
assert isinstance(result, Multi)
441+
assert isinstance(result.dep, _Dep)
442+
443+
444+
def test_provide_own_init_overrides_protocol_stub():
445+
class Proto(Protocol):
446+
pass
447+
448+
class Base:
449+
def __init__(self, dep: _Dep) -> None:
450+
self.dep = dep
451+
452+
class Impl(Proto, Base):
453+
def __init__(self, dep: _Dep) -> None:
454+
self.own_dep = dep
455+
456+
provider = Provider(scope=Scope.APP)
457+
provider.provide(_Dep)
458+
provider.provide(Impl)
459+
460+
container = make_container(provider)
461+
result = container.get(Impl)
462+
assert isinstance(result, Impl)
463+
assert isinstance(result.own_dep, _Dep)
464+
465+
466+
def test_provide_protocol_with_explicit_init():
467+
class ProtoWithInit(Protocol):
468+
def __init__(self, dep: _Dep) -> None: ...
469+
470+
class Impl(ProtoWithInit):
471+
def __init__(self, dep: _Dep) -> None:
472+
self.dep = dep
473+
474+
provider = Provider(scope=Scope.APP)
475+
provider.provide(_Dep)
476+
provider.provide(Impl)
477+
478+
container = make_container(provider)
479+
result = container.get(Impl)
480+
assert isinstance(result, Impl)
481+
assert isinstance(result.dep, _Dep)
482+
483+
484+
def test_provide_deep_hierarchy_with_protocol():
485+
class Proto(Protocol):
486+
pass
487+
488+
class GrandBase:
489+
def __init__(self, dep: _Dep) -> None:
490+
self.dep = dep
491+
492+
class Base(GrandBase):
493+
pass
494+
495+
class Impl(Proto, Base):
496+
pass
497+
498+
provider = Provider(scope=Scope.APP)
499+
provider.provide(_Dep)
500+
provider.provide(Impl)
501+
502+
container = make_container(provider)
503+
result = container.get(Impl)
504+
assert isinstance(result, Impl)
505+
assert isinstance(result.dep, _Dep)
506+
507+
416508
class A:
417509
pass
418510

0 commit comments

Comments
 (0)