Skip to content

Commit 8d39568

Browse files
committed
test: Updated CredentialProvider test infrastructure
1 parent f76afb2 commit 8d39568

File tree

1 file changed

+62
-37
lines changed

1 file changed

+62
-37
lines changed

tests/conftest.py

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
from datetime import datetime, timezone
77
from enum import Enum
8-
from typing import Callable, TypeVar
8+
from typing import Callable, TypeVar, Union
99
from unittest import mock
1010
from unittest.mock import Mock
1111
from urllib.parse import urlparse
@@ -17,6 +17,7 @@
1717
from redis import Sentinel
1818
from redis.auth.idp import IdentityProviderInterface
1919
from redis.auth.token import JWToken
20+
from redis.auth.token_manager import RetryPolicy, TokenManagerConfig
2021
from redis.backoff import NoBackoff
2122
from redis.cache import (
2223
CacheConfig,
@@ -29,12 +30,21 @@
2930
from redis.credentials import CredentialProvider
3031
from redis.exceptions import RedisClusterException
3132
from redis.retry import Retry
32-
from redis_entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig
33+
from redis_entraid.cred_provider import (
34+
DEFAULT_DELAY_IN_MS,
35+
DEFAULT_EXPIRATION_REFRESH_RATIO,
36+
DEFAULT_LOWER_REFRESH_BOUND_MILLIS,
37+
DEFAULT_MAX_ATTEMPTS,
38+
DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS,
39+
EntraIdCredentialsProvider,
40+
)
3341
from redis_entraid.identity_provider import (
3442
ManagedIdentityIdType,
43+
ManagedIdentityProviderConfig,
3544
ManagedIdentityType,
36-
create_provider_from_managed_identity,
37-
create_provider_from_service_principal,
45+
ServicePrincipalIdentityProviderConfig,
46+
_create_provider_from_managed_identity,
47+
_create_provider_from_service_principal,
3848
)
3949
from tests.ssl_utils import get_tls_certificates
4050

@@ -623,41 +633,58 @@ def identity_provider(request) -> IdentityProviderInterface:
623633
return mock_identity_provider()
624634

625635
auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
636+
config = get_identity_provider_config(request=request)
626637

627638
if auth_type == "MANAGED_IDENTITY":
628-
return _get_managed_identity_provider(request)
639+
return _create_provider_from_managed_identity(config)
640+
641+
return _create_provider_from_service_principal(config)
629642

630-
return _get_service_principal_provider(request)
631643

644+
def get_identity_provider_config(
645+
request,
646+
) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]:
647+
if hasattr(request, "param"):
648+
kwargs = request.param.get("idp_kwargs", {})
649+
else:
650+
kwargs = {}
651+
652+
auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
632653

633-
def _get_managed_identity_provider(request):
634-
authority = os.getenv("AZURE_AUTHORITY")
654+
if auth_type == AuthType.MANAGED_IDENTITY:
655+
return _get_managed_identity_provider_config(request)
656+
657+
return _get_service_principal_provider_config(request)
658+
659+
660+
def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig:
635661
resource = os.getenv("AZURE_RESOURCE")
636-
id_value = os.getenv("AZURE_ID_VALUE", None)
662+
id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None)
637663

638664
if hasattr(request, "param"):
639665
kwargs = request.param.get("idp_kwargs", {})
640666
else:
641667
kwargs = {}
642668

643669
identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED)
644-
id_type = kwargs.pop("id_type", ManagedIdentityIdType.CLIENT_ID)
670+
id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID)
645671

646-
return create_provider_from_managed_identity(
672+
return ManagedIdentityProviderConfig(
647673
identity_type=identity_type,
648674
resource=resource,
649675
id_type=id_type,
650676
id_value=id_value,
651-
authority=authority,
652-
**kwargs,
677+
kwargs=kwargs,
653678
)
654679

655680

656-
def _get_service_principal_provider(request):
681+
def _get_service_principal_provider_config(
682+
request,
683+
) -> ServicePrincipalIdentityProviderConfig:
657684
client_id = os.getenv("AZURE_CLIENT_ID")
658685
client_credential = os.getenv("AZURE_CLIENT_SECRET")
659-
authority = os.getenv("AZURE_AUTHORITY")
660-
scopes = os.getenv("AZURE_REDIS_SCOPES", [])
686+
tenant_id = os.getenv("AZURE_TENANT_ID")
687+
scopes = os.getenv("AZURE_REDIS_SCOPES", None)
661688

662689
if hasattr(request, "param"):
663690
kwargs = request.param.get("idp_kwargs", {})
@@ -671,14 +698,14 @@ def _get_service_principal_provider(request):
671698
if isinstance(scopes, str):
672699
scopes = scopes.split(",")
673700

674-
return create_provider_from_service_principal(
701+
return ServicePrincipalIdentityProviderConfig(
675702
client_id=client_id,
676703
client_credential=client_credential,
677704
scopes=scopes,
678705
timeout=timeout,
679706
token_kwargs=token_kwargs,
680-
authority=authority,
681-
**kwargs,
707+
tenant_id=tenant_id,
708+
app_kwargs=kwargs,
682709
)
683710

684711

@@ -690,31 +717,29 @@ def get_credential_provider(request) -> CredentialProvider:
690717
return cred_provider_class(**cred_provider_kwargs)
691718

692719
idp = identity_provider(request)
693-
initial_delay_in_ms = cred_provider_kwargs.get("initial_delay_in_ms", 0)
694-
block_for_initial = cred_provider_kwargs.get("block_for_initial", False)
695720
expiration_refresh_ratio = cred_provider_kwargs.get(
696-
"expiration_refresh_ratio", TokenAuthConfig.DEFAULT_EXPIRATION_REFRESH_RATIO
721+
"expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO
697722
)
698723
lower_refresh_bound_millis = cred_provider_kwargs.get(
699-
"lower_refresh_bound_millis", TokenAuthConfig.DEFAULT_LOWER_REFRESH_BOUND_MILLIS
700-
)
701-
max_attempts = cred_provider_kwargs.get(
702-
"max_attempts", TokenAuthConfig.DEFAULT_MAX_ATTEMPTS
724+
"lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS
703725
)
704-
delay_in_ms = cred_provider_kwargs.get(
705-
"delay_in_ms", TokenAuthConfig.DEFAULT_DELAY_IN_MS
726+
max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS)
727+
delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS)
728+
729+
token_mgr_config = TokenManagerConfig(
730+
expiration_refresh_ratio=expiration_refresh_ratio,
731+
lower_refresh_bound_millis=lower_refresh_bound_millis,
732+
token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS,
733+
retry_policy=RetryPolicy(
734+
max_attempts=max_attempts,
735+
delay_in_ms=delay_in_ms,
736+
),
706737
)
707738

708-
auth_config = TokenAuthConfig(idp)
709-
auth_config.expiration_refresh_ratio = expiration_refresh_ratio
710-
auth_config.lower_refresh_bound_millis = lower_refresh_bound_millis
711-
auth_config.max_attempts = max_attempts
712-
auth_config.delay_in_ms = delay_in_ms
713-
714739
return EntraIdCredentialsProvider(
715-
config=auth_config,
716-
initial_delay_in_ms=initial_delay_in_ms,
717-
block_for_initial=block_for_initial,
740+
identity_provider=idp,
741+
token_manager_config=token_mgr_config,
742+
initial_delay_in_ms=delay_in_ms,
718743
)
719744

720745

0 commit comments

Comments
 (0)