Skip to content
This repository is currently being migrated. It's locked while the migration is in progress.

Commit 32ba221

Browse files
authored
[Fix] Decouple OAuth functionality from Config (databricks#784)
## Changes ### OAuth Refactoring Currently, OAuthClient uses Config internally to resolve the OIDC endpoints by passing the client ID and host to an internal Config instance and calling its `oidc_endpoints` method. This has a few drawbacks: 1. There is nearly a cyclical dependency: `Config` depends on methods in `oauth.py`, and `OAuthClient` depends on `Config`. This currently doesn't break because the `Config` import is done at runtime in the `OAuthClient` constructor. 2. Databricks supports both in-house OAuth and Azure Entra ID OAuth. Currently, the choice between these options depends on whether a user specifies the azure_client_id or client_id parameter in the Config. Because Config is used within OAuthClient, this means that OAuthClient needs to expose a parameter to configure either client_id or azure_client_id. Rather than having these classes deeply coupled to one another, we can allow users to fetch the OIDC endpoints for a given account/workspace as a top-level functionality and provide this to `OAuthClient`. This breaks the cyclic dependency and doesn't require `OAuthClient` to expose any unnecessary parameters. Further, I've also tried to remove the coupling of the other classes in `oauth.py` to `OAuthClient`. Currently, `OAuthClient` serves both as the mechanism to initialize OAuth and as a kind of configuration object, capturing OAuth endpoint URLs, client ID/secret, redirect URL, and scopes. Now, the parameters for each of these classes are explicit, removing all unnecessarily coupling between them. One nice advantage is that the Consent can be serialized/deserialized without any reference to the `OAuthClient` anymore. There is definitely more work to be done to simplify and clean up the OAuth implementation, but this should at least unblock users who need to use Azure Entra ID U2M OAuth in the SDK. ## Tests The new OIDC endpoint methods are tested, and those tests also verify that those endpoints are retried in case of rate limiting. I ran the flask app example against an AWS workspace, and I ran the external-browser demo example against AWS, Azure and GCP workspaces with the default client ID and with a newly created OAuth app with and without credentials. - [ ] `make test` run locally - [ ] `make fmt` applied - [ ] relevant integration tests applied
1 parent 15257eb commit 32ba221

File tree

7 files changed

+459
-157
lines changed

7 files changed

+459
-157
lines changed

databricks/sdk/_base_client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import urllib.parse
23
from datetime import timedelta
34
from types import TracebackType
45
from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List,
@@ -17,6 +18,25 @@
1718
logger = logging.getLogger('databricks.sdk')
1819

1920

21+
def _fix_host_if_needed(host: Optional[str]) -> Optional[str]:
22+
if not host:
23+
return host
24+
25+
# Add a default scheme if it's missing
26+
if '://' not in host:
27+
host = 'https://' + host
28+
29+
o = urllib.parse.urlparse(host)
30+
# remove trailing slash
31+
path = o.path.rstrip('/')
32+
# remove port if 443
33+
netloc = o.netloc
34+
if o.port == 443:
35+
netloc = netloc.split(':')[0]
36+
37+
return urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment))
38+
39+
2040
class _BaseClient:
2141

2242
def __init__(self,

databricks/sdk/config.py

Lines changed: 10 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
import requests
1111

1212
from . import useragent
13+
from ._base_client import _fix_host_if_needed
1314
from .clock import Clock, RealClock
1415
from .credentials_provider import CredentialsStrategy, DefaultCredentials
1516
from .environments import (ALL_ENVS, AzureEnvironment, Cloud,
1617
DatabricksEnvironment, get_environment_for_hostname)
17-
from .oauth import OidcEndpoints, Token
18+
from .oauth import (OidcEndpoints, Token, get_account_endpoints,
19+
get_azure_entra_id_workspace_endpoints,
20+
get_workspace_endpoints)
1821

1922
logger = logging.getLogger('databricks.sdk')
2023

@@ -254,24 +257,10 @@ def oidc_endpoints(self) -> Optional[OidcEndpoints]:
254257
if not self.host:
255258
return None
256259
if self.is_azure and self.azure_client_id:
257-
# Retrieve authorize endpoint to retrieve token endpoint after
258-
res = requests.get(f'{self.host}/oidc/oauth2/v2.0/authorize', allow_redirects=False)
259-
real_auth_url = res.headers.get('location')
260-
if not real_auth_url:
261-
return None
262-
return OidcEndpoints(authorization_endpoint=real_auth_url,
263-
token_endpoint=real_auth_url.replace('/authorize', '/token'))
260+
return get_azure_entra_id_workspace_endpoints(self.host)
264261
if self.is_account_client and self.account_id:
265-
prefix = f'{self.host}/oidc/accounts/{self.account_id}'
266-
return OidcEndpoints(authorization_endpoint=f'{prefix}/v1/authorize',
267-
token_endpoint=f'{prefix}/v1/token')
268-
oidc = f'{self.host}/oidc/.well-known/oauth-authorization-server'
269-
res = requests.get(oidc)
270-
if res.status_code != 200:
271-
return None
272-
auth_metadata = res.json()
273-
return OidcEndpoints(authorization_endpoint=auth_metadata.get('authorization_endpoint'),
274-
token_endpoint=auth_metadata.get('token_endpoint'))
262+
return get_account_endpoints(self.host, self.account_id)
263+
return get_workspace_endpoints(self.host)
275264

276265
def debug_string(self) -> str:
277266
""" Returns log-friendly representation of configured attributes """
@@ -346,22 +335,9 @@ def attributes(cls) -> Iterable[ConfigAttribute]:
346335
return cls._attributes
347336

348337
def _fix_host_if_needed(self):
349-
if not self.host:
350-
return
351-
352-
# Add a default scheme if it's missing
353-
if '://' not in self.host:
354-
self.host = 'https://' + self.host
355-
356-
o = urllib.parse.urlparse(self.host)
357-
# remove trailing slash
358-
path = o.path.rstrip('/')
359-
# remove port if 443
360-
netloc = o.netloc
361-
if o.port == 443:
362-
netloc = netloc.split(':')[0]
363-
364-
self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment))
338+
updated_host = _fix_host_if_needed(self.host)
339+
if updated_host:
340+
self.host = updated_host
365341

366342
def load_azure_tenant_id(self):
367343
"""[Internal] Load the Azure tenant ID from the Azure Databricks login page.

databricks/sdk/credentials_provider.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -187,30 +187,35 @@ def token() -> Token:
187187
def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]:
188188
if cfg.auth_type != 'external-browser':
189189
return None
190+
client_id, client_secret = None, None
190191
if cfg.client_id:
191192
client_id = cfg.client_id
192-
elif cfg.is_aws:
193+
client_secret = cfg.client_secret
194+
elif cfg.azure_client_id:
195+
client_id = cfg.azure_client
196+
client_secret = cfg.azure_client_secret
197+
198+
if not client_id:
193199
client_id = 'databricks-cli'
194-
elif cfg.is_azure:
195-
# Use Azure AD app for cases when Azure CLI is not available on the machine.
196-
# App has to be registered as Single-page multi-tenant to support PKCE
197-
# TODO: temporary app ID, change it later.
198-
client_id = '6128a518-99a9-425b-8333-4cc94f04cacd'
199-
else:
200-
raise ValueError(f'local browser SSO is not supported')
201-
oauth_client = OAuthClient(host=cfg.host,
202-
client_id=client_id,
203-
redirect_url='http://localhost:8020',
204-
client_secret=cfg.client_secret)
205200

206201
# Load cached credentials from disk if they exist.
207202
# Note that these are local to the Python SDK and not reused by other SDKs.
208-
token_cache = TokenCache(oauth_client)
203+
oidc_endpoints = cfg.oidc_endpoints
204+
redirect_url = 'http://localhost:8020'
205+
token_cache = TokenCache(host=cfg.host,
206+
oidc_endpoints=oidc_endpoints,
207+
client_id=client_id,
208+
client_secret=client_secret,
209+
redirect_url=redirect_url)
209210
credentials = token_cache.load()
210211
if credentials:
211212
# Force a refresh in case the loaded credentials are expired.
212213
credentials.token()
213214
else:
215+
oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints,
216+
client_id=client_id,
217+
redirect_url=redirect_url,
218+
client_secret=client_secret)
214219
consent = oauth_client.initiate_consent()
215220
if not consent:
216221
return None

0 commit comments

Comments
 (0)