11
11
# License for the specific language governing permissions and limitations
12
12
# under the License.
13
13
import os
14
+ from time import sleep
14
15
from typing import Optional
15
- from testcontainers .core .generic import DbContainer
16
+
17
+ from testcontainers .core .config import MAX_TRIES , SLEEP_TIME
18
+ from testcontainers .core .generic import DependencyFreeDbContainer
16
19
from testcontainers .core .utils import raise_for_deprecated_parameter
20
+ from testcontainers .core .waiting_utils import (wait_container_is_ready ,
21
+ wait_for_logs )
17
22
18
23
19
- class PostgresContainer (DbContainer ):
24
+ class PostgresContainer (DependencyFreeDbContainer ):
20
25
"""
21
26
Postgres database container.
22
27
@@ -41,14 +46,14 @@ class PostgresContainer(DbContainer):
41
46
"""
42
47
def __init__ (self , image : str = "postgres:latest" , port : int = 5432 ,
43
48
username : Optional [str ] = None , password : Optional [str ] = None ,
44
- dbname : Optional [str ] = None , driver : str = "psycopg2" , ** kwargs ) -> None :
49
+ dbname : Optional [str ] = None , driver : str | None = "psycopg2" , ** kwargs ) -> None :
45
50
raise_for_deprecated_parameter (kwargs , "user" , "username" )
46
51
super (PostgresContainer , self ).__init__ (image = image , ** kwargs )
47
- self .username = username or os .environ .get ("POSTGRES_USER" , "test" )
48
- self .password = password or os .environ .get ("POSTGRES_PASSWORD" , "test" )
49
- self .dbname = dbname or os .environ .get ("POSTGRES_DB" , "test" )
52
+ self .username : str = username or os .environ .get ("POSTGRES_USER" , "test" )
53
+ self .password : str = password or os .environ .get ("POSTGRES_PASSWORD" , "test" )
54
+ self .dbname : str = dbname or os .environ .get ("POSTGRES_DB" , "test" )
50
55
self .port = port
51
- self .driver = driver
56
+ self .driver = f"+ { driver } " if driver else ""
52
57
53
58
self .with_exposed_ports (self .port )
54
59
@@ -59,7 +64,22 @@ def _configure(self) -> None:
59
64
60
65
def get_connection_url (self , host = None ) -> str :
61
66
return super ()._create_connection_url (
62
- dialect = f"postgresql+ { self .driver } " , username = self .username ,
67
+ dialect = f"postgresql{ self .driver } " , username = self .username ,
63
68
password = self .password , dbname = self .dbname , host = host ,
64
69
port = self .port ,
65
70
)
71
+
72
+ @wait_container_is_ready ()
73
+ def _verify_status (self ) -> None :
74
+ wait_for_logs (self , ".*database system is ready to accept connections.*" , MAX_TRIES , SLEEP_TIME )
75
+
76
+ count = 0
77
+ while count < MAX_TRIES :
78
+ status , _ = self .exec (f"pg_isready -hlocalhost -p{ self .port } -U{ self .username } " )
79
+ if status == 0 :
80
+ return
81
+
82
+ sleep (SLEEP_TIME )
83
+ count += 1
84
+
85
+ raise RuntimeError ("Postgres could not get into a ready state" )
0 commit comments