@@ -73,6 +73,10 @@ def __init__(
7373 seed : Optional [str ] = None ,
7474 ** kwargs ,
7575 ) -> None :
76+ if dialect is not None and dialect .startswith ("mysql+" ):
77+ msg = "Please remove 'mysql+' prefix from dialect parameter"
78+ raise ValueError (msg )
79+
7680 raise_for_deprecated_parameter (kwargs , "MYSQL_USER" , "username" )
7781 raise_for_deprecated_parameter (kwargs , "MYSQL_ROOT_PASSWORD" , "root_password" )
7882 raise_for_deprecated_parameter (kwargs , "MYSQL_PASSWORD" , "password" )
@@ -85,7 +89,9 @@ def __init__(
8589 self .root_password = root_password or environ .get ("MYSQL_ROOT_PASSWORD" , "test" )
8690 self .password = password or environ .get ("MYSQL_PASSWORD" , "test" )
8791 self .dbname = dbname or environ .get ("MYSQL_DATABASE" , "test" )
92+
8893 self .dialect = dialect or environ .get ("MYSQL_DIALECT" , None )
94+ self ._db_url_dialect_part = "mysql" if self .dialect is None else f"mysql+{ self .dialect } "
8995
9096 if self .username == "root" :
9197 self .root_password = self .password
@@ -106,9 +112,12 @@ def _connect(self) -> None:
106112 )
107113
108114 def get_connection_url (self ) -> str :
109- dialect = "mysql" if self .dialect is None else f"mysql+{ self .dialect } "
110115 return super ()._create_connection_url (
111- dialect = dialect , username = self .username , password = self .password , dbname = self .dbname , port = self .port
116+ dialect = self ._db_url_dialect_part ,
117+ username = self .username ,
118+ password = self .password ,
119+ dbname = self .dbname ,
120+ port = self .port ,
112121 )
113122
114123 def _transfer_seed (self ) -> None :
0 commit comments