From 2ea3f6b9f02684357a7cca2d96eb005c7d2cb61a Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 6 Mar 2025 12:12:34 +0200 Subject: [PATCH 1/3] (tests): Added testing for auth via DefaultAzureCredential --- dev_requirements.txt | 2 +- tests/entraid_utils.py | 41 +++++++++++++++++++++++++++++++++------ tests/test_credentials.py | 13 +++++++++++-- 3 files changed, 47 insertions(+), 9 deletions(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index 2a0938bec3..ad7330598d 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -13,4 +13,4 @@ ujson>=4.2.0 uvloop vulture>=2.3.0 numpy>=1.24.0 -redis-entraid==0.3.0b1 +redis-entraid==0.4.0b2 diff --git a/tests/entraid_utils.py b/tests/entraid_utils.py index daefbd3956..22cfae0af2 100644 --- a/tests/entraid_utils.py +++ b/tests/entraid_utils.py @@ -18,7 +18,8 @@ ManagedIdentityType, ServicePrincipalIdentityProviderConfig, _create_provider_from_managed_identity, - _create_provider_from_service_principal, + _create_provider_from_service_principal, DefaultAzureCredentialIdentityProviderConfig, + _create_provider_from_default_azure_credential, ) from tests.conftest import mock_identity_provider @@ -26,6 +27,7 @@ class AuthType(Enum): MANAGED_IDENTITY = "managed_identity" SERVICE_PRINCIPAL = "service_principal" + DEFAULT_AZURE_CREDENTIAL = "default_azure_credential" def identity_provider(request) -> IdentityProviderInterface: @@ -37,18 +39,23 @@ def identity_provider(request) -> IdentityProviderInterface: if request.param.get("mock_idp", None) is not None: return mock_identity_provider() - auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + auth_type = kwargs.get("auth_type", AuthType.SERVICE_PRINCIPAL) config = get_identity_provider_config(request=request) - if auth_type == "MANAGED_IDENTITY": + if auth_type == AuthType.MANAGED_IDENTITY: return _create_provider_from_managed_identity(config) + if auth_type == AuthType.DEFAULT_AZURE_CREDENTIAL: + return _create_provider_from_default_azure_credential(config) + return _create_provider_from_service_principal(config) -def get_identity_provider_config( - request, -) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: +def get_identity_provider_config(request) -> Union[ + ManagedIdentityProviderConfig, + ServicePrincipalIdentityProviderConfig, + DefaultAzureCredentialIdentityProviderConfig +]: if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) else: @@ -59,6 +66,9 @@ def get_identity_provider_config( if auth_type == AuthType.MANAGED_IDENTITY: return _get_managed_identity_provider_config(request) + if auth_type == AuthType.DEFAULT_AZURE_CREDENTIAL: + return _get_default_azure_credential_provider_config(request) + return _get_service_principal_provider_config(request) @@ -113,6 +123,25 @@ def _get_service_principal_provider_config( app_kwargs=kwargs, ) +def _get_default_azure_credential_provider_config(request) -> DefaultAzureCredentialIdentityProviderConfig: + scopes = os.getenv("AZURE_REDIS_SCOPES", ()) + + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + token_kwargs = request.param.get("token_kwargs", {}) + else: + kwargs = {} + token_kwargs = {} + + if isinstance(scopes, str): + scopes = scopes.split(',') + + return DefaultAzureCredentialIdentityProviderConfig( + scopes=scopes, + app_kwargs=kwargs, + token_kwargs=token_kwargs + ) + def get_entra_id_credentials_provider(request, cred_provider_kwargs): idp = identity_provider(request) diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 1f98c5208d..58bbd01f28 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -22,6 +22,7 @@ get_endpoint, skip_if_redis_enterprise, ) +from tests.entraid_utils import AuthType try: from redis_entraid.cred_provider import EntraIdCredentialsProvider @@ -585,8 +586,12 @@ class TestEntraIdCredentialsProvider: "cred_provider_class": EntraIdCredentialsProvider, "single_connection_client": True, }, + { + "cred_provider_class": EntraIdCredentialsProvider, + "idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL}, + }, ], - ids=["pool", "single"], + ids=["pool", "single", "DefaultAzureCredential"], indirect=True, ) @pytest.mark.onlynoncluster @@ -656,8 +661,12 @@ class TestClusterEntraIdCredentialsProvider: "cred_provider_class": EntraIdCredentialsProvider, "single_connection_client": True, }, + { + "cred_provider_class": EntraIdCredentialsProvider, + "idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL}, + }, ], - ids=["pool", "single"], + ids=["pool", "single", "DefaultAzureCredential"], indirect=True, ) @pytest.mark.onlycluster From 64b4adaa05aa0d724049ba750b146d28ff44349c Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 6 Mar 2025 12:22:34 +0200 Subject: [PATCH 2/3] Added testing for async --- tests/entraid_utils.py | 20 ++++++++++++-------- tests/test_asyncio/conftest.py | 5 ----- tests/test_asyncio/test_credentials.py | 13 +++++++++++-- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/tests/entraid_utils.py b/tests/entraid_utils.py index 22cfae0af2..529c3ccdee 100644 --- a/tests/entraid_utils.py +++ b/tests/entraid_utils.py @@ -18,7 +18,8 @@ ManagedIdentityType, ServicePrincipalIdentityProviderConfig, _create_provider_from_managed_identity, - _create_provider_from_service_principal, DefaultAzureCredentialIdentityProviderConfig, + _create_provider_from_service_principal, + DefaultAzureCredentialIdentityProviderConfig, _create_provider_from_default_azure_credential, ) from tests.conftest import mock_identity_provider @@ -51,10 +52,12 @@ def identity_provider(request) -> IdentityProviderInterface: return _create_provider_from_service_principal(config) -def get_identity_provider_config(request) -> Union[ +def get_identity_provider_config( + request, +) -> Union[ ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig, - DefaultAzureCredentialIdentityProviderConfig + DefaultAzureCredentialIdentityProviderConfig, ]: if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) @@ -123,7 +126,10 @@ def _get_service_principal_provider_config( app_kwargs=kwargs, ) -def _get_default_azure_credential_provider_config(request) -> DefaultAzureCredentialIdentityProviderConfig: + +def _get_default_azure_credential_provider_config( + request, +) -> DefaultAzureCredentialIdentityProviderConfig: scopes = os.getenv("AZURE_REDIS_SCOPES", ()) if hasattr(request, "param"): @@ -134,12 +140,10 @@ def _get_default_azure_credential_provider_config(request) -> DefaultAzureCreden token_kwargs = {} if isinstance(scopes, str): - scopes = scopes.split(',') + scopes = scopes.split(",") return DefaultAzureCredentialIdentityProviderConfig( - scopes=scopes, - app_kwargs=kwargs, - token_kwargs=token_kwargs + scopes=scopes, app_kwargs=kwargs, token_kwargs=token_kwargs ) diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 60e447e6fd..aedf4e97c7 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -18,11 +18,6 @@ from .compat import mock -class AuthType(Enum): - MANAGED_IDENTITY = "managed_identity" - SERVICE_PRINCIPAL = "service_principal" - - async def _get_info(redis_url): client = redis.Redis.from_url(redis_url) info = await client.info() diff --git a/tests/test_asyncio/test_credentials.py b/tests/test_asyncio/test_credentials.py index ce8d76ea45..b4824be469 100644 --- a/tests/test_asyncio/test_credentials.py +++ b/tests/test_asyncio/test_credentials.py @@ -18,6 +18,7 @@ from redis.exceptions import ConnectionError from redis.utils import str_if_bytes from tests.conftest import get_endpoint, skip_if_redis_enterprise +from tests.entraid_utils import AuthType from tests.test_asyncio.conftest import get_credential_provider try: @@ -616,8 +617,12 @@ class TestEntraIdCredentialsProvider: "cred_provider_class": EntraIdCredentialsProvider, "cred_provider_kwargs": {"block_for_initial": True}, }, + { + "cred_provider_class": EntraIdCredentialsProvider, + "idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL}, + }, ], - ids=["blocked", "non-blocked"], + ids=["blocked", "non-blocked", "DefaultAzureCredential"], indirect=True, ) @pytest.mark.asyncio @@ -692,8 +697,12 @@ class TestClusterEntraIdCredentialsProvider: "cred_provider_class": EntraIdCredentialsProvider, "cred_provider_kwargs": {"block_for_initial": True}, }, + { + "cred_provider_class": EntraIdCredentialsProvider, + "idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL}, + }, ], - ids=["blocked", "non-blocked"], + ids=["blocked", "non-blocked", "DefaultAzureCredential"], indirect=True, ) @pytest.mark.asyncio From 3d7107e052886411342c89205a360211765430c6 Mon Sep 17 00:00:00 2001 From: vladvildanov Date: Thu, 6 Mar 2025 12:25:26 +0200 Subject: [PATCH 3/3] Remove unused import --- tests/test_asyncio/conftest.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index aedf4e97c7..340d146ea3 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -1,6 +1,5 @@ import random from contextlib import asynccontextmanager as _asynccontextmanager -from enum import Enum from typing import Union import pytest