diff --git a/postgres/setup.py b/postgres/setup.py index 1d9abd351..f86fa5614 100644 --- a/postgres/setup.py +++ b/postgres/setup.py @@ -12,8 +12,9 @@ url="https://github.com/testcontainers/testcontainers-python", install_requires=[ "testcontainers-core", - "sqlalchemy", + "sqlalchemy[asyncio]", "psycopg2-binary", + "asyncpg" ], python_requires=">=3.7", ) diff --git a/postgres/testcontainers/postgres/__init__.py b/postgres/testcontainers/postgres/__init__.py index 85e0bac80..3e756f3c3 100644 --- a/postgres/testcontainers/postgres/__init__.py +++ b/postgres/testcontainers/postgres/__init__.py @@ -12,8 +12,9 @@ # under the License. import os from typing import Optional -from testcontainers.core.generic import DbContainer +from testcontainers.core.generic import DbContainer, ADDITIONAL_TRANSIENT_ERRORS from testcontainers.core.utils import raise_for_deprecated_parameter +from testcontainers.core.waiting_utils import wait_container_is_ready class PostgresContainer(DbContainer): @@ -63,3 +64,13 @@ def get_connection_url(self, host=None) -> str: password=self.password, dbname=self.dbname, host=host, port=self.port, ) + + @wait_container_is_ready(*ADDITIONAL_TRANSIENT_ERRORS) + def _connect(self) -> None: + if self.driver == "asyncpg": + from sqlalchemy.ext.asyncio import create_async_engine as create_engine + else: + from sqlalchemy import create_engine + engine = create_engine(self.get_connection_url()) + conn = engine.connect() + conn.close() diff --git a/postgres/tests/test_postgres.py b/postgres/tests/test_postgres.py index c00c1b3fe..262b61abf 100644 --- a/postgres/tests/test_postgres.py +++ b/postgres/tests/test_postgres.py @@ -1,4 +1,8 @@ +from asyncio import sleep + +import pytest import sqlalchemy +from sqlalchemy.ext.asyncio import create_async_engine from testcontainers.postgres import PostgresContainer @@ -18,3 +22,15 @@ def test_docker_run_postgres_with_driver_pg8000(): engine = sqlalchemy.create_engine(postgres.get_connection_url()) with engine.begin() as connection: connection.execute(sqlalchemy.text("select 1=1")) + + +@pytest.mark.asyncio +async def test_docker_run_postgres_with_driver_asyncio(): + postgres_container = PostgresContainer("postgres:9.5", driver="asyncpg") + with postgres_container as postgres: + # in local test need to wait while pg container is ready + # else raised `ConnectionError: unexpected connection_lost() call` + await sleep(5) + engine = create_async_engine(postgres.get_connection_url()) + async with engine.begin() as connection: + await connection.execute(sqlalchemy.text("select 1=1")) diff --git a/requirements.in b/requirements.in index e9e122610..1d269f885 100644 --- a/requirements.in +++ b/requirements.in @@ -26,6 +26,7 @@ flake8<3.8.0 # 3.8.0 adds a dependency on importlib-metadata which conflicts wi pg8000 pytest pytest-cov +pytest-asyncio sphinx twine wheel