diff --git a/core/testcontainers/core/generic.py b/core/testcontainers/core/generic.py index 7faac273a..3b9dcfe28 100644 --- a/core/testcontainers/core/generic.py +++ b/core/testcontainers/core/generic.py @@ -29,11 +29,13 @@ class DbContainer(DockerContainer): """ Generic database container. """ + @wait_container_is_ready(*ADDITIONAL_TRANSIENT_ERRORS) def _connect(self) -> None: import sqlalchemy engine = sqlalchemy.create_engine(self.get_connection_url()) - engine.connect() + conn = engine.connect() + conn.close() def get_connection_url(self) -> str: raise NotImplementedError diff --git a/postgres/setup.py b/postgres/setup.py index 1d9abd351..4257fc33e 100644 --- a/postgres/setup.py +++ b/postgres/setup.py @@ -14,6 +14,7 @@ "testcontainers-core", "sqlalchemy", "psycopg2-binary", + "asyncpg", ], python_requires=">=3.7", ) diff --git a/postgres/testcontainers/postgres/__init__.py b/postgres/testcontainers/postgres/__init__.py index 85e0bac80..45e12d9cf 100644 --- a/postgres/testcontainers/postgres/__init__.py +++ b/postgres/testcontainers/postgres/__init__.py @@ -14,6 +14,14 @@ from typing import Optional from testcontainers.core.generic import DbContainer from testcontainers.core.utils import raise_for_deprecated_parameter +from testcontainers.core.waiting_utils import wait_container_is_ready + +ADDITIONAL_TRANSIENT_ERRORS = [] +try: + from sqlalchemy.exc import DBAPIError + ADDITIONAL_TRANSIENT_ERRORS.append(DBAPIError) +except ImportError: + pass class PostgresContainer(DbContainer): @@ -39,10 +47,15 @@ class PostgresContainer(DbContainer): >>> version 'PostgreSQL 9.5...' """ + + DEFAULT_DRIVER = "psycopg2" + def __init__(self, image: str = "postgres:latest", port: int = 5432, username: Optional[str] = None, password: Optional[str] = None, - dbname: Optional[str] = None, driver: str = "psycopg2", **kwargs) -> None: + dbname: Optional[str] = None, driver: Optional[str] = None, **kwargs) -> None: raise_for_deprecated_parameter(kwargs, "user", "username") + if driver is None: + driver = self.DEFAULT_DRIVER super(PostgresContainer, self).__init__(image=image, **kwargs) self.username = username or os.environ.get("POSTGRES_USER", "test") self.password = password or os.environ.get("POSTGRES_PASSWORD", "test") @@ -52,14 +65,24 @@ def __init__(self, image: str = "postgres:latest", port: int = 5432, self.with_exposed_ports(self.port) + @wait_container_is_ready(*ADDITIONAL_TRANSIENT_ERRORS) + def _connect(self) -> None: + import sqlalchemy + engine = sqlalchemy.create_engine(self.get_connection_url(driver=self.DEFAULT_DRIVER)) + conn = engine.connect() + conn.close() + def _configure(self) -> None: self.with_env("POSTGRES_USER", self.username) self.with_env("POSTGRES_PASSWORD", self.password) self.with_env("POSTGRES_DB", self.dbname) - def get_connection_url(self, host=None) -> str: + def get_connection_url(self, host: Optional[str] = None, driver: Optional[str] = None) -> str: + if driver is None: + driver = self.driver + return super()._create_connection_url( - dialect=f"postgresql+{self.driver}", username=self.username, + dialect=f"postgresql+{driver}", username=self.username, password=self.password, dbname=self.dbname, host=host, port=self.port, ) diff --git a/postgres/tests/test_postgres.py b/postgres/tests/test_postgres.py index c00c1b3fe..9042f0b8e 100644 --- a/postgres/tests/test_postgres.py +++ b/postgres/tests/test_postgres.py @@ -1,4 +1,6 @@ +import pytest import sqlalchemy +from sqlalchemy.ext.asyncio import create_async_engine from testcontainers.postgres import PostgresContainer @@ -18,3 +20,13 @@ def test_docker_run_postgres_with_driver_pg8000(): engine = sqlalchemy.create_engine(postgres.get_connection_url()) with engine.begin() as connection: connection.execute(sqlalchemy.text("select 1=1")) + + +@pytest.mark.asyncio +async def test_docker_run_async_postgres(): + with PostgresContainer("postgres:9.5", driver="asyncpg") as postgres: + engine = create_async_engine(postgres.get_connection_url()) + async with engine.begin() as connection: + result = await connection.execute(sqlalchemy.text("select version()")) + for row in result: + assert row[0].lower().startswith("postgresql 9.5") diff --git a/requirements.in b/requirements.in index e9e122610..1d269f885 100644 --- a/requirements.in +++ b/requirements.in @@ -26,6 +26,7 @@ flake8<3.8.0 # 3.8.0 adds a dependency on importlib-metadata which conflicts wi pg8000 pytest pytest-cov +pytest-asyncio sphinx twine wheel