Skip to content

Commit d2b846c

Browse files
authored
JDBC Url param to enable or disable token federation (databricks#1105)
## Description Add EnableTokenFederation JDBC URL parameter to control whether token federation is applied to credentials providers (enabled by default). - Introduce new `DatabricksJdbcUrlParams.ENABLE_TOKEN_FEDERATION (default "1")` and IDatabricksConnectionContext.isTokenFederationEnabled(). - DatabricksConnectionContext exposes i`sTokenFederationEnabled()` to read the setting from URL or Properties. ## Testing - Unit tests Fixes databricks#1013
1 parent a0756ec commit d2b846c

File tree

7 files changed

+164
-24
lines changed

7 files changed

+164
-24
lines changed

NEXT_CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## [Unreleased]
44

55
### Added
6+
- Added the EnableTokenFederation url param to enable or disable Token federation feature. By default it is set to 1
67

78
### Updated
89

src/main/java/com/databricks/jdbc/api/impl/DatabricksConnectionContext.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,4 +1145,9 @@ public boolean isSeaSyncMetadataEnabled() {
11451145
public boolean getDisableOauthRefreshToken() {
11461146
return getParameter(DatabricksJdbcUrlParams.DISABLE_OAUTH_REFRESH_TOKEN, "1").equals("1");
11471147
}
1148+
1149+
@Override
1150+
public boolean isTokenFederationEnabled() {
1151+
return getParameter(DatabricksJdbcUrlParams.ENABLE_TOKEN_FEDERATION, "1").equals("1");
1152+
}
11481153
}

src/main/java/com/databricks/jdbc/api/internal/IDatabricksConnectionContext.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,4 +413,7 @@ public interface IDatabricksConnectionContext {
413413

414414
/** Returns whether OAuth refresh tokens should be disabled (omit offline_access by default). */
415415
boolean getDisableOauthRefreshToken();
416+
417+
/** Returns whether token federation is enabled for authentication. */
418+
boolean isTokenFederationEnabled();
416419
}

src/main/java/com/databricks/jdbc/common/DatabricksJdbcUrlParams.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,9 @@ public enum DatabricksJdbcUrlParams {
185185
DISABLE_OAUTH_REFRESH_TOKEN(
186186
"DisableOauthRefreshToken",
187187
"Disable requesting OAuth refresh tokens (omit offline_access unless explicitly provided)",
188-
"1");
188+
"1"),
189+
ENABLE_TOKEN_FEDERATION(
190+
"EnableTokenFederation", "Enable token federation for authentication", "1");
189191

190192
private final String paramName;
191193
private final String defaultValue;

src/main/java/com/databricks/jdbc/dbclient/impl/common/ClientConfigurator.java

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,7 @@ public void setupU2MConfig() throws DatabricksParsingException {
209209
if (databricksConfig.isAzure()) {
210210
LOGGER.debug("Using Azure U2M Auth");
211211
databricksConfig.setCredentialsProvider(
212-
new DatabricksTokenFederationProvider(
213-
connectionContext,
212+
wrapWithTokenFederationIfEnabled(
214213
new AzureExternalBrowserProvider(connectionContext, redirectPort)));
215214
return;
216215
}
@@ -229,8 +228,7 @@ public void setupU2MConfig() throws DatabricksParsingException {
229228
}
230229

231230
databricksConfig.setCredentialsProvider(
232-
new DatabricksTokenFederationProvider(
233-
connectionContext, new ExternalBrowserCredentialsProvider(tokenCache)));
231+
wrapWithTokenFederationIfEnabled(new ExternalBrowserCredentialsProvider(tokenCache)));
234232
}
235233

236234
/**
@@ -305,9 +303,9 @@ public void setupAccessTokenConfig() throws DatabricksParsingException {
305303
public void setupOAuthAccessTokenConfig() throws DatabricksParsingException {
306304
// Token Federation is only supported for JWT tokens
307305
if (DatabricksAuthUtil.isTokenJWT(connectionContext.getPassThroughAccessToken())) {
308-
DatabricksTokenFederationProvider databricksTokenFederationProvider =
309-
new DatabricksTokenFederationProvider(connectionContext, new PatCredentialsProvider());
310-
databricksConfig.setCredentialsProvider(databricksTokenFederationProvider);
306+
CredentialsProvider credentialsProvider =
307+
wrapWithTokenFederationIfEnabled(new PatCredentialsProvider());
308+
databricksConfig.setCredentialsProvider(credentialsProvider);
311309
}
312310

313311
databricksConfig
@@ -329,12 +327,11 @@ public void setupU2MRefreshConfig() throws DatabricksParsingException {
329327
.setClientSecret(connectionContext.getClientSecret());
330328
CredentialsProvider provider =
331329
new OAuthRefreshCredentialsProvider(connectionContext, databricksConfig);
332-
CredentialsProvider databricksTokenFederationProvider =
333-
new DatabricksTokenFederationProvider(connectionContext, provider);
330+
CredentialsProvider wrappedProvider = wrapWithTokenFederationIfEnabled(provider);
334331

335332
databricksConfig
336-
.setAuthType(databricksTokenFederationProvider.authType()) // oauth-refresh
337-
.setCredentialsProvider(databricksTokenFederationProvider);
333+
.setAuthType(wrappedProvider.authType()) // oauth-refresh
334+
.setCredentialsProvider(wrappedProvider);
338335
}
339336

340337
/** Setup the OAuth M2M authentication settings in the databricks config. */
@@ -353,13 +350,11 @@ public void setupM2MConfig() throws DatabricksParsingException {
353350
if (authType.equals(GCP_GOOGLE_CREDENTIALS_AUTH_TYPE)) {
354351
databricksConfig.setGoogleCredentials(connectionContext.getGoogleCredentials());
355352
databricksConfig.setCredentialsProvider(
356-
new DatabricksTokenFederationProvider(
357-
connectionContext, new GoogleCredentialsCredentialsProvider()));
353+
wrapWithTokenFederationIfEnabled(new GoogleCredentialsCredentialsProvider()));
358354
} else {
359355
databricksConfig.setGoogleServiceAccount(connectionContext.getGoogleServiceAccount());
360356
databricksConfig.setCredentialsProvider(
361-
new DatabricksTokenFederationProvider(
362-
connectionContext, new GoogleIdCredentialsProvider()));
357+
wrapWithTokenFederationIfEnabled(new GoogleIdCredentialsProvider()));
363358
}
364359

365360
} else if (connectionContext.getAzureTenantId() != null) {
@@ -376,8 +371,7 @@ public void setupM2MConfig() throws DatabricksParsingException {
376371
.setAzureClientSecret(connectionContext.getClientSecret())
377372
.setAzureTenantId(connectionContext.getAzureTenantId())
378373
.setCredentialsProvider(
379-
new DatabricksTokenFederationProvider(
380-
connectionContext, new AzureServicePrincipalCredentialsProvider()));
374+
wrapWithTokenFederationIfEnabled(new AzureServicePrincipalCredentialsProvider()));
381375
} else {
382376
databricksConfig
383377
.setClientId(connectionContext.getClientId())
@@ -387,14 +381,12 @@ public void setupM2MConfig() throws DatabricksParsingException {
387381
new PrivateKeyClientCredentialProvider(connectionContext, databricksConfig);
388382
databricksConfig
389383
.setAuthType(jwtProvider.authType())
390-
.setCredentialsProvider(
391-
new DatabricksTokenFederationProvider(connectionContext, jwtProvider));
384+
.setCredentialsProvider(wrapWithTokenFederationIfEnabled(jwtProvider));
392385
} else {
393386
CredentialsProvider m2mProvider = new OAuthM2MServicePrincipalCredentialsProvider();
394387
databricksConfig
395388
.setAuthType(DatabricksJdbcConstants.M2M_AUTH_TYPE)
396-
.setCredentialsProvider(
397-
new DatabricksTokenFederationProvider(connectionContext, m2mProvider));
389+
.setCredentialsProvider(wrapWithTokenFederationIfEnabled(m2mProvider));
398390
}
399391
}
400392
}
@@ -403,8 +395,7 @@ private void setupAzureMI() {
403395
databricksConfig.setHost(connectionContext.getHostForOAuth());
404396
databricksConfig.setAuthType(DatabricksJdbcConstants.AZURE_MSI_AUTH_TYPE);
405397
databricksConfig.setCredentialsProvider(
406-
new DatabricksTokenFederationProvider(
407-
connectionContext, new AzureMSICredentialProvider(connectionContext)));
398+
wrapWithTokenFederationIfEnabled(new AzureMSICredentialProvider(connectionContext)));
408399
}
409400

410401
/**
@@ -446,4 +437,18 @@ private void setupDiscoveryEndpoint() {
446437
databricksConfig.setDiscoveryUrl(connectionContext.getOAuthDiscoveryURL());
447438
}
448439
}
440+
441+
/**
442+
* Conditionally wraps a credentials provider with DatabricksTokenFederationProvider based on the
443+
* EnableTokenFederation connection parameter.
444+
*
445+
* @param provider The credentials provider to potentially wrap
446+
* @return The original provider if token federation is disabled, or wrapped provider if enabled
447+
*/
448+
private CredentialsProvider wrapWithTokenFederationIfEnabled(CredentialsProvider provider) {
449+
if (connectionContext.isTokenFederationEnabled()) {
450+
return new DatabricksTokenFederationProvider(connectionContext, provider);
451+
}
452+
return provider;
453+
}
449454
}

src/test/java/com/databricks/jdbc/api/impl/DatabricksConnectionContextTest.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1153,4 +1153,47 @@ public void testDisableOauthRefreshTokenParam() throws DatabricksSQLException {
11531153
DatabricksConnectionContext.parse(TestConstants.VALID_URL_1, props);
11541154
assertFalse(ctx.getDisableOauthRefreshToken());
11551155
}
1156+
1157+
@Test
1158+
public void testEnableTokenFederation() throws DatabricksSQLException {
1159+
// Test default value (should be enabled by default)
1160+
DatabricksConnectionContext ctx =
1161+
(DatabricksConnectionContext)
1162+
DatabricksConnectionContext.parse(TestConstants.VALID_URL_1, properties);
1163+
assertTrue(ctx.isTokenFederationEnabled()); // Default should be true
1164+
1165+
// Test via URL parameter - enabled
1166+
String urlWithTokenFederationEnabled =
1167+
"jdbc:databricks://sample-host.18.azuredatabricks.net:9999/default;httpPath=/sql/1.0/warehouses/999999999;EnableTokenFederation=1";
1168+
ctx =
1169+
(DatabricksConnectionContext)
1170+
DatabricksConnectionContext.parse(urlWithTokenFederationEnabled, properties);
1171+
assertTrue(ctx.isTokenFederationEnabled());
1172+
1173+
// Test via URL parameter - disabled
1174+
String urlWithTokenFederationDisabled =
1175+
"jdbc:databricks://sample-host.18.azuredatabricks.net:9999/default;httpPath=/sql/1.0/warehouses/999999999;EnableTokenFederation=0";
1176+
ctx =
1177+
(DatabricksConnectionContext)
1178+
DatabricksConnectionContext.parse(urlWithTokenFederationDisabled, properties);
1179+
assertFalse(ctx.isTokenFederationEnabled());
1180+
1181+
// Test via Properties - enabled
1182+
Properties props = new Properties();
1183+
props.setProperty("password", "passwd");
1184+
props.setProperty("EnableTokenFederation", "1");
1185+
ctx =
1186+
(DatabricksConnectionContext)
1187+
DatabricksConnectionContext.parse(TestConstants.VALID_URL_1, props);
1188+
assertTrue(ctx.isTokenFederationEnabled());
1189+
1190+
// Test via Properties - disabled
1191+
props = new Properties();
1192+
props.setProperty("password", "passwd");
1193+
props.setProperty("EnableTokenFederation", "0");
1194+
ctx =
1195+
(DatabricksConnectionContext)
1196+
DatabricksConnectionContext.parse(TestConstants.VALID_URL_1, props);
1197+
assertFalse(ctx.isTokenFederationEnabled());
1198+
}
11561199
}

src/test/java/com/databricks/jdbc/dbclient/impl/common/ClientConfiguratorTest.java

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ void getWorkspaceClient_OAuthWithClientCredentials_AuthenticatesCorrectlyWithJWT
135135
when(mockContext.getClientSecret()).thenReturn("client-secret");
136136
when(mockContext.useJWTAssertion()).thenReturn(true);
137137
when(mockContext.getTokenEndpoint()).thenReturn("token-endpoint");
138+
when(mockContext.isTokenFederationEnabled()).thenReturn(true);
138139
when(mockContext.getHttpConnectionPoolSize()).thenReturn(100);
139140
when(mockContext.getHttpMaxConnectionsPerRoute()).thenReturn(100);
140141
configurator = new ClientConfigurator(mockContext);
@@ -516,6 +517,7 @@ void testSetupU2MConfig_WithTokenCache()
516517
when(mockContext.getTokenCachePassPhrase()).thenReturn("testPassphrase");
517518
when(mockContext.getHttpMaxConnectionsPerRoute()).thenReturn(100);
518519
when(mockContext.getDisableOauthRefreshToken()).thenReturn(true);
520+
when(mockContext.isTokenFederationEnabled()).thenReturn(true);
519521

520522
configurator = new ClientConfigurator(mockContext);
521523
WorkspaceClient client = configurator.getWorkspaceClient();
@@ -567,6 +569,7 @@ void testSetupU2MConfig_WithoutTokenCache()
567569
when(mockContext.isTokenCacheEnabled()).thenReturn(false);
568570
when(mockContext.getHttpMaxConnectionsPerRoute()).thenReturn(100);
569571
when(mockContext.getDisableOauthRefreshToken()).thenReturn(true);
572+
when(mockContext.isTokenFederationEnabled()).thenReturn(true);
570573

571574
configurator = new ClientConfigurator(mockContext);
572575
WorkspaceClient client = configurator.getWorkspaceClient();
@@ -585,4 +588,82 @@ void testSetupU2MConfig_WithoutTokenCache()
585588
ExternalBrowserCredentialsProvider.class,
586589
databricksTokenFederationProvider.getCredentialsProvider());
587590
}
591+
592+
@Test
593+
void testTokenFederationEnabled_WrapsCredentialsProvider()
594+
throws DatabricksParsingException, DatabricksSSLException {
595+
// Setup OAuth M2M with token federation enabled
596+
when(mockContext.getAuthMech()).thenReturn(AuthMech.OAUTH);
597+
when(mockContext.getAuthFlow()).thenReturn(AuthFlow.CLIENT_CREDENTIALS);
598+
when(mockContext.getHostForOAuth()).thenReturn("https://oauth-m2m.databricks.com");
599+
when(mockContext.getClientId()).thenReturn("m2m-client-id");
600+
when(mockContext.getClientSecret()).thenReturn("m2m-client-secret");
601+
when(mockContext.getHttpConnectionPoolSize()).thenReturn(100);
602+
when(mockContext.getHttpMaxConnectionsPerRoute()).thenReturn(100);
603+
when(mockContext.getDisableOauthRefreshToken()).thenReturn(true);
604+
when(mockContext.isTokenFederationEnabled()).thenReturn(true); // Token federation enabled
605+
when(mockContext.useJWTAssertion()).thenReturn(false);
606+
when(mockContext.getAzureTenantId()).thenReturn(null);
607+
when(mockContext.getCloud()).thenReturn(Cloud.AWS);
608+
609+
configurator = new ClientConfigurator(mockContext);
610+
WorkspaceClient client = configurator.getWorkspaceClient();
611+
assertNotNull(client);
612+
DatabricksConfig config = client.config();
613+
614+
// Verify that the credentials provider is wrapped with DatabricksTokenFederationProvider
615+
assertInstanceOf(DatabricksTokenFederationProvider.class, config.getCredentialsProvider());
616+
assertEquals(DatabricksJdbcConstants.M2M_AUTH_TYPE, config.getAuthType());
617+
}
618+
619+
@Test
620+
void testTokenFederationDisabled_DoesNotWrapCredentialsProvider()
621+
throws DatabricksParsingException, DatabricksSSLException {
622+
// Setup OAuth M2M with token federation disabled
623+
when(mockContext.getAuthMech()).thenReturn(AuthMech.OAUTH);
624+
when(mockContext.getAuthFlow()).thenReturn(AuthFlow.CLIENT_CREDENTIALS);
625+
when(mockContext.getHostForOAuth()).thenReturn("https://oauth-m2m.databricks.com");
626+
when(mockContext.getClientId()).thenReturn("m2m-client-id");
627+
when(mockContext.getClientSecret()).thenReturn("m2m-client-secret");
628+
when(mockContext.getHttpConnectionPoolSize()).thenReturn(100);
629+
when(mockContext.getHttpMaxConnectionsPerRoute()).thenReturn(100);
630+
when(mockContext.getDisableOauthRefreshToken()).thenReturn(true);
631+
when(mockContext.isTokenFederationEnabled()).thenReturn(false); // Token federation disabled
632+
when(mockContext.useJWTAssertion()).thenReturn(false);
633+
when(mockContext.getAzureTenantId()).thenReturn(null);
634+
when(mockContext.getCloud()).thenReturn(Cloud.AWS);
635+
636+
configurator = new ClientConfigurator(mockContext);
637+
WorkspaceClient client = configurator.getWorkspaceClient();
638+
assertNotNull(client);
639+
DatabricksConfig config = client.config();
640+
641+
// Verify that the credentials provider is NOT wrapped with DatabricksTokenFederationProvider
642+
assertNotNull(config.getCredentialsProvider());
643+
// Should be the original OAuthM2MServicePrincipalCredentialsProvider, not wrapped
644+
assertFalse(config.getCredentialsProvider() instanceof DatabricksTokenFederationProvider);
645+
}
646+
647+
@Test
648+
void testTokenFederationWithPATAuth_DoesNotAffectPATAuth()
649+
throws DatabricksParsingException, DatabricksSSLException {
650+
// Setup PAT auth with token federation disabled - should not affect PAT auth
651+
when(mockContext.getAuthMech()).thenReturn(AuthMech.PAT);
652+
when(mockContext.getHostUrl()).thenReturn("https://pat.databricks.com");
653+
when(mockContext.getToken()).thenReturn("pat-token");
654+
when(mockContext.getHttpConnectionPoolSize()).thenReturn(100);
655+
when(mockContext.getHttpMaxConnectionsPerRoute()).thenReturn(100);
656+
657+
configurator = new ClientConfigurator(mockContext);
658+
WorkspaceClient client = configurator.getWorkspaceClient();
659+
assertNotNull(client);
660+
DatabricksConfig config = client.config();
661+
662+
// PAT auth should work normally regardless of token federation setting
663+
assertEquals("https://pat.databricks.com", config.getHost());
664+
assertEquals("pat-token", config.getToken());
665+
assertEquals(DatabricksJdbcConstants.ACCESS_TOKEN_AUTH_TYPE, config.getAuthType());
666+
// PAT auth doesn't use Token federation provider, so it should be SDK default provider
667+
assertFalse(config.getCredentialsProvider() instanceof DatabricksTokenFederationProvider);
668+
}
588669
}

0 commit comments

Comments
 (0)