diff --git a/src/snowflake/connector/__init__.py b/src/snowflake/connector/__init__.py index 41b5288ac7..ce570d3072 100644 --- a/src/snowflake/connector/__init__.py +++ b/src/snowflake/connector/__init__.py @@ -45,13 +45,26 @@ from .log_configuration import EasyLoggingConfigPython from .version import VERSION +from typing import TypeVar, ParamSpec, Unpack + +P = ParamSpec("P") +T = TypeVar("T", bound=SnowflakeConnection) + logging.getLogger(__name__).addHandler(NullHandler()) setup_external_libraries() - @wraps(SnowflakeConnection.__init__) -def Connect(**kwargs) -> SnowflakeConnection: - return SnowflakeConnection(**kwargs) +def connect( + __cls: type[T] = SnowflakeConnection, + /, + *args: P.args, + **kwargs: Unpack[P.kwargs] +) -> T: + return __cls(*args, **kwargs) + +# @wraps(SnowflakeConnection.__init__) +# def Connect(**kwargs) -> SnowflakeConnection: +# return SnowflakeConnection(**kwargs) connect = Connect diff --git a/test/unit/test_type_check.py b/test/unit/test_type_check.py new file mode 100644 index 0000000000..84f7976505 --- /dev/null +++ b/test/unit/test_type_check.py @@ -0,0 +1,13 @@ +import snowflake.connector as conn + +c = conn.connect( + user="user", + password="pass", + account="account" +) + +invalid = conn.connect( + user="user", + password=123, + account="account" +)