diff --git a/soda/spark/setup.py b/soda/spark/setup.py index 87d0b2b02..5f698bb82 100644 --- a/soda/spark/setup.py +++ b/soda/spark/setup.py @@ -15,7 +15,10 @@ "odbc": [ "pyodbc", ], - "databricks": ["databricks-sql-connector"], + "databricks": [ + "databricks-sql-connector", + "databricks-sdk", + ], } # TODO Fix the params setup( diff --git a/soda/spark/soda/data_sources/spark_data_source.py b/soda/spark/soda/data_sources/spark_data_source.py index 1e3a97abd..b2a0a713c 100644 --- a/soda/spark/soda/data_sources/spark_data_source.py +++ b/soda/spark/soda/data_sources/spark_data_source.py @@ -134,16 +134,66 @@ def odbc_connection_function( def databricks_connection_function(host: str, http_path: str, token: str, database: str, schema: str, **kwargs): + """ + Connection to databricks with databricks sql connector. + + Supplying a token will enforce connection via personal access token. + + host, client_id and client_secret keys can be supplied to the configuration parameter for m2m oauth. + + Setting oauth_method to "databricks-oauth" will enforce a u2m oauth connection. + + Read the python-sql-connector documentation for more information. + + Parameters + ---------- + host : str + The databricks server host name. + http_path: str + The http_path to your databricks sql warehouse or cluster + token: str + Databricks personal access token + database: str + The databricks catalog + schema : str + The databricks schema + + Returns + ------- + out : databricks.sql.Connection + The databricks connection object + """ from databricks import sql user_agent_entry = f"soda-core-spark/{SODA_CORE_VERSION} (Databricks)" logging.getLogger("databricks.sql").setLevel(logging.INFO) + + auth_method = kwargs.get("auth_method") + + if not token and not auth_method: + from databricks.sdk.core import Config, oauth_service_principal + + config = Config(**kwargs.get("configuration", {})) + + if not host: + host = config.hostname + + def credential_provider(): + + return oauth_service_principal(config) + + credentials_provider = credential_provider + else: + credentials_provider = None + connection = sql.connect( server_hostname=host, catalog=database, schema=schema, http_path=http_path, access_token=token, + credentials_provider=credentials_provider, + auth_type=kwargs.get("auth_method"), _user_agent_entry=user_agent_entry, ) return connection