Skip to content

Is generic provider/container possible to implement? #660

@dartt0n

Description

@dartt0n

Is it possible to create generic container or provider?

Basically, the idea is to have multiple generic interfaces, and require a specific implementation to specify the type. Then, I want to have a dependency on these interfaces in the usecase class (in the example below: to require all interface implementations to use exactly the same session type).

class MyUsecase[UserSession, AnotherSession]:
    def __init__(
        self,
        adapter1: UserRepoInterface[UserSession],
        adapter2: TokenRepoInterface[UserSession], # these two are restricted to use the same session
        adapter3: LogRepoInterface[AnotherSession], # this one is free
    ):
        ...

This can be solved by having the interface accept/return the session interface rather than the bounded generic type, but then I have no way to make sure that all adapters use the same session.

Some draft of the idea:

import random
import string
from dataclasses import dataclass
from typing import Protocol, final, override

from dishka import Provider, Scope, make_container, provide


class SessionInterface(Protocol):
    def begin(self) -> None: ...
    def savepoint(self) -> str: ...
    def commit(self) -> None: ...
    def rollback(self, savepoint: str | None = None) -> None: ...


class RepoInteface[SessionT: SessionInterface](Protocol):
    def get_session(self) -> SessionT: ...
    def save_user(self, session: SessionT, name: str) -> None: ...


@dataclass
class Usecase[T: SessionInterface]:
    repo: RepoInteface[T]

    def do_business_logic(self):
        session = self.repo.get_session()

        self.repo.save_user(session, "Ivan")


@dataclass
class MySessionImpl:
    name: str

    def begin(self) -> None:
        print(f"begin; -- {self.name}")

    def savepoint(self) -> str:
        random_string = "".join(random.choices(string.ascii_lowercase, k=10))
        print(f"savepoint {random_string}; -- {self.name}")
        return random_string

    def commit(self) -> None:
        print(f"begin; -- {self.name}")

    def rollback(self, savepoint: str | None = None) -> None:
        print(f"rollback{' ' + savepoint if savepoint else ''}; -- {self.name}")


@dataclass
class RepoImpl(RepoInteface[MySessionImpl]):
    @override
    def get_session(self) -> MySessionImpl:
        return MySessionImpl("session_my_repo_impl")

    @override
    def save_user(self, session: MySessionImpl, name: str) -> None:
        session = self.get_session()
        session.begin()

        print(f"INSERT INTO users (name) VALUES ({name:!r});")  # dont do this in production

        session.commit()


@final
class MyProvider[T: SessionInterface](Provider):
    # generic provider
    scope = Scope.APP

    usecase = provide(source=Usecase[T])


# in prod
container = make_container(MyProvider[MySessionImpl]())
# in tests: container = make_container(MyProvider[FakeSession])

create_user_usecase = container.get(Usecase)
create_user_usecase.do_business_logic()

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    Status

    No status

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions