diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index 83b2986dc8..6acdab3a3c 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -499,10 +499,9 @@ def __init__( is_kwargs_empty = not kwargs if "application" not in kwargs: - if ENV_VAR_PARTNER in os.environ.keys(): - kwargs["application"] = os.environ[ENV_VAR_PARTNER] - elif "streamlit" in sys.modules: - kwargs["application"] = "streamlit" + app = self._detect_application() + if app: + kwargs["application"] = app if "insecure_mode" in kwargs: warn_message = "The 'insecure_mode' connection property is deprecated. Please use 'disable_ocsp_checks' instead" @@ -2283,3 +2282,17 @@ def _check_oauth_parameters(self) -> None: "errno": ER_INVALID_VALUE, }, ) + + @staticmethod + def _detect_application() -> None | str: + if ENV_VAR_PARTNER in os.environ.keys(): + return os.environ[ENV_VAR_PARTNER] + if "streamlit" in sys.modules: + return "streamlit" + if all( + (jpmod in sys.modules) + for jpmod in ("ipykernel", "jupyter_core", "jupyter_client") + ): + return "jupyter_notebook" + if "snowbooks" in sys.modules: + return "snowflake_notebook" diff --git a/test/unit/test_connection.py b/test/unit/test_connection.py index 2da9e83754..1aa145252a 100644 --- a/test/unit/test_connection.py +++ b/test/unit/test_connection.py @@ -194,12 +194,23 @@ def test_partner_env_var(mock_post_requests): @pytest.mark.skipolddriver -def test_imported_module(mock_post_requests): - with patch.dict(sys.modules, {"streamlit": "foo"}): - assert fake_connector().application == "streamlit" +@pytest.mark.parametrize( + "sys_modules,application", + [ + ({"streamlit": None}, "streamlit"), + ( + {"ipykernel": None, "jupyter_core": None, "jupyter_client": None}, + "jupyter_notebook", + ), + ({"snowbooks": None}, "snowflake_notebook"), + ], +) +def test_imported_module(mock_post_requests, sys_modules, application): + with patch.dict(sys.modules, sys_modules): + assert fake_connector().application == application assert ( - mock_post_requests["data"]["CLIENT_ENVIRONMENT"]["APPLICATION"] == "streamlit" + mock_post_requests["data"]["CLIENT_ENVIRONMENT"]["APPLICATION"] == application )