Skip to content

Commit a974a75

Browse files
committed
add error on prefix with mysql part and test cases
1 parent 52aa815 commit a974a75

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

modules/mysql/testcontainers/mysql/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ def __init__(
7373
seed: Optional[str] = None,
7474
**kwargs,
7575
) -> None:
76+
if dialect is not None and dialect.startswith("mysql+"):
77+
msg = "Please remove 'mysql+' prefix from dialect parameter"
78+
raise ValueError(msg)
79+
7680
raise_for_deprecated_parameter(kwargs, "MYSQL_USER", "username")
7781
raise_for_deprecated_parameter(kwargs, "MYSQL_ROOT_PASSWORD", "root_password")
7882
raise_for_deprecated_parameter(kwargs, "MYSQL_PASSWORD", "password")
@@ -85,7 +89,9 @@ def __init__(
8589
self.root_password = root_password or environ.get("MYSQL_ROOT_PASSWORD", "test")
8690
self.password = password or environ.get("MYSQL_PASSWORD", "test")
8791
self.dbname = dbname or environ.get("MYSQL_DATABASE", "test")
92+
8893
self.dialect = dialect or environ.get("MYSQL_DIALECT", None)
94+
self._db_url_dialect_part = "mysql" if self.dialect is None else f"mysql+{self.dialect}"
8995

9096
if self.username == "root":
9197
self.root_password = self.password
@@ -106,9 +112,12 @@ def _connect(self) -> None:
106112
)
107113

108114
def get_connection_url(self) -> str:
109-
dialect = "mysql" if self.dialect is None else f"mysql+{self.dialect}"
110115
return super()._create_connection_url(
111-
dialect=dialect, username=self.username, password=self.password, dbname=self.dbname, port=self.port
116+
dialect=self._db_url_dialect_part,
117+
username=self.username,
118+
password=self.password,
119+
dbname=self.dbname,
120+
port=self.port,
112121
)
113122

114123
def _transfer_seed(self) -> None:

modules/mysql/tests/test_mysql.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,21 @@ def test_docker_env_variables():
6868
assert re.match(pattern, url)
6969

7070

71+
@pytest.mark.parametrize(
72+
"dialect",
73+
[
74+
"mysql+pymysql",
75+
"mysql+mariadb",
76+
"mysql+mysqldb",
77+
],
78+
)
79+
def test_mysql_dialect_expecting_error_on_mysql_prefix(dialect: str):
80+
match = f"Please remove *.* prefix from dialect parameter"
81+
82+
with pytest.raises(ValueError, match=match):
83+
_ = MySqlContainer("mariadb:10.6.5", dialect=dialect)
84+
85+
7186
# This is a feature in the generic DbContainer class
7287
# but it can't be tested on its own
7388
# so is tested in various database modules:

0 commit comments

Comments
 (0)