@@ -55,6 +55,16 @@ def verify_aws_token(token: str, region: str):
55
55
assert headers ["X-Snowflake-Audience" ] == "snowflakecomputing.com"
56
56
57
57
58
+ def test_mro ():
59
+ """Ensure that methods from AuthByPluginAsync override those from AuthByPlugin."""
60
+ from snowflake .connector .aio .auth import AuthByPlugin as AuthByPluginAsync
61
+ from snowflake .connector .auth import AuthByPlugin as AuthByPluginSync
62
+
63
+ assert AuthByWorkloadIdentity .mro ().index (
64
+ AuthByPluginAsync
65
+ ) < AuthByWorkloadIdentity .mro ().index (AuthByPluginSync )
66
+
67
+
58
68
# -- OIDC Tests --
59
69
60
70
@@ -319,95 +329,3 @@ async def test_explicit_azure_uses_explicit_entra_resource(fake_azure_metadata_s
319
329
token = fake_azure_metadata_service .token
320
330
parsed = jwt .decode (token , options = {"verify_signature" : False })
321
331
assert parsed ["aud" ] == "api://non-standard"
322
-
323
-
324
- # -- Auto-detect Tests --
325
-
326
-
327
- async def test_autodetect_aws_present (
328
- no_metadata_service , fake_aws_environment : FakeAwsEnvironmentAsync
329
- ):
330
- auth_class = AuthByWorkloadIdentity (provider = None )
331
- await auth_class .prepare ()
332
-
333
- data = await extract_api_data (auth_class )
334
- assert data ["AUTHENTICATOR" ] == "WORKLOAD_IDENTITY"
335
- assert data ["PROVIDER" ] == "AWS"
336
- verify_aws_token (data ["TOKEN" ], fake_aws_environment .region )
337
-
338
-
339
- @mock .patch ("snowflake.connector.aio._wif_util.AioInstanceMetadataRegionFetcher" )
340
- async def test_autodetect_gcp_present (
341
- mock_fetcher ,
342
- fake_gce_metadata_service : FakeGceMetadataServiceAsync ,
343
- ):
344
- # Mock AioInstanceMetadataRegionFetcher to return None properly as an async function
345
- async def mock_retrieve_region ():
346
- return None
347
-
348
- mock_fetcher .return_value .retrieve_region .side_effect = mock_retrieve_region
349
-
350
- auth_class = AuthByWorkloadIdentity (provider = None )
351
- await auth_class .prepare ()
352
-
353
- assert await extract_api_data (auth_class ) == {
354
- "AUTHENTICATOR" : "WORKLOAD_IDENTITY" ,
355
- "PROVIDER" : "GCP" ,
356
- "TOKEN" : fake_gce_metadata_service .token ,
357
- }
358
-
359
-
360
- @mock .patch ("snowflake.connector.aio._wif_util.AioInstanceMetadataRegionFetcher" )
361
- async def test_autodetect_azure_present (mock_fetcher , fake_azure_metadata_service ):
362
- # Mock AioInstanceMetadataRegionFetcher to return None properly as an async function
363
- async def mock_retrieve_region ():
364
- return None
365
-
366
- mock_fetcher .return_value .retrieve_region .side_effect = mock_retrieve_region
367
-
368
- auth_class = AuthByWorkloadIdentity (provider = None )
369
- await auth_class .prepare ()
370
-
371
- assert await extract_api_data (auth_class ) == {
372
- "AUTHENTICATOR" : "WORKLOAD_IDENTITY" ,
373
- "PROVIDER" : "AZURE" ,
374
- "TOKEN" : fake_azure_metadata_service .token ,
375
- }
376
-
377
-
378
- async def test_autodetect_oidc_present (no_metadata_service ):
379
- dummy_token = gen_dummy_id_token (sub = "service-1" , iss = "issuer-1" )
380
- auth_class = AuthByWorkloadIdentity (provider = None , token = dummy_token )
381
- await auth_class .prepare ()
382
-
383
- assert await extract_api_data (auth_class ) == {
384
- "AUTHENTICATOR" : "WORKLOAD_IDENTITY" ,
385
- "PROVIDER" : "OIDC" ,
386
- "TOKEN" : dummy_token ,
387
- }
388
-
389
-
390
- @mock .patch ("snowflake.connector.aio._wif_util.AioInstanceMetadataRegionFetcher" )
391
- async def test_autodetect_no_provider_raises_error (mock_fetcher , no_metadata_service ):
392
- # Mock AioInstanceMetadataRegionFetcher to return None properly as an async function
393
- async def mock_retrieve_region ():
394
- return None
395
-
396
- mock_fetcher .return_value .retrieve_region .side_effect = mock_retrieve_region
397
-
398
- auth_class = AuthByWorkloadIdentity (provider = None , token = None )
399
- with pytest .raises (ProgrammingError ) as excinfo :
400
- await auth_class .prepare ()
401
- assert "No workload identity credential was found for 'auto-detect" in str (
402
- excinfo .value
403
- )
404
-
405
-
406
- def test_mro ():
407
- """Ensure that methods from AuthByPluginAsync override those from AuthByPlugin."""
408
- from snowflake .connector .aio .auth import AuthByPlugin as AuthByPluginAsync
409
- from snowflake .connector .auth import AuthByPlugin as AuthByPluginSync
410
-
411
- assert AuthByWorkloadIdentity .mro ().index (
412
- AuthByPluginAsync
413
- ) < AuthByWorkloadIdentity .mro ().index (AuthByPluginSync )
0 commit comments