Skip to content

Commit d7adde2

Browse files
committed
Updated async test infra
1 parent eca1cd3 commit d7adde2

File tree

1 file changed

+60
-36
lines changed

1 file changed

+60
-36
lines changed

tests/test_asyncio/conftest.py

Lines changed: 60 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,23 @@
1717
from redis.asyncio.retry import Retry
1818
from redis.auth.idp import IdentityProviderInterface
1919
from redis.auth.token import JWToken
20+
from redis.auth.token_manager import RetryPolicy, TokenManagerConfig
2021
from redis.backoff import NoBackoff
2122
from redis.credentials import CredentialProvider
22-
from redis_entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig
23+
from redis_entraid.cred_provider import (
24+
DEFAULT_DELAY_IN_MS,
25+
DEFAULT_EXPIRATION_REFRESH_RATIO,
26+
DEFAULT_LOWER_REFRESH_BOUND_MILLIS,
27+
DEFAULT_MAX_ATTEMPTS,
28+
EntraIdCredentialsProvider,
29+
)
2330
from redis_entraid.identity_provider import (
2431
ManagedIdentityIdType,
32+
ManagedIdentityProviderConfig,
2533
ManagedIdentityType,
26-
create_provider_from_managed_identity,
27-
create_provider_from_service_principal,
34+
ServicePrincipalIdentityProviderConfig,
35+
_create_provider_from_managed_identity,
36+
_create_provider_from_service_principal,
2837
)
2938
from tests.conftest import REDIS_INFO
3039

@@ -255,41 +264,58 @@ def identity_provider(request) -> IdentityProviderInterface:
255264
return mock_identity_provider()
256265

257266
auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
267+
config = get_identity_provider_config(request=request)
258268

259269
if auth_type == "MANAGED_IDENTITY":
260-
return _get_managed_identity_provider(request)
270+
return _create_provider_from_managed_identity(config)
271+
272+
return _create_provider_from_service_principal(config)
273+
274+
275+
def get_identity_provider_config(
276+
request,
277+
) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]:
278+
if hasattr(request, "param"):
279+
kwargs = request.param.get("idp_kwargs", {})
280+
else:
281+
kwargs = {}
261282

262-
return _get_service_principal_provider(request)
283+
auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
284+
285+
if auth_type == AuthType.MANAGED_IDENTITY:
286+
return _get_managed_identity_provider_config(request)
263287

288+
return _get_service_principal_provider_config(request)
264289

265-
def _get_managed_identity_provider(request):
266-
authority = os.getenv("AZURE_AUTHORITY")
290+
291+
def _get_managed_identity_provider_config(request) -> ManagedIdentityProviderConfig:
267292
resource = os.getenv("AZURE_RESOURCE")
268-
id_value = os.getenv("AZURE_ID_VALUE", None)
293+
id_value = os.getenv("AZURE_USER_ASSIGNED_MANAGED_ID", None)
269294

270295
if hasattr(request, "param"):
271296
kwargs = request.param.get("idp_kwargs", {})
272297
else:
273298
kwargs = {}
274299

275300
identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED)
276-
id_type = kwargs.pop("id_type", ManagedIdentityIdType.CLIENT_ID)
301+
id_type = kwargs.pop("id_type", ManagedIdentityIdType.OBJECT_ID)
277302

278-
return create_provider_from_managed_identity(
303+
return ManagedIdentityProviderConfig(
279304
identity_type=identity_type,
280305
resource=resource,
281306
id_type=id_type,
282307
id_value=id_value,
283-
authority=authority,
284-
**kwargs,
308+
kwargs=kwargs,
285309
)
286310

287311

288-
def _get_service_principal_provider(request):
312+
def _get_service_principal_provider_config(
313+
request,
314+
) -> ServicePrincipalIdentityProviderConfig:
289315
client_id = os.getenv("AZURE_CLIENT_ID")
290316
client_credential = os.getenv("AZURE_CLIENT_SECRET")
291-
authority = os.getenv("AZURE_AUTHORITY")
292-
scopes = os.getenv("AZURE_REDIS_SCOPES", [])
317+
tenant_id = os.getenv("AZURE_TENANT_ID")
318+
scopes = os.getenv("AZURE_REDIS_SCOPES", None)
293319

294320
if hasattr(request, "param"):
295321
kwargs = request.param.get("idp_kwargs", {})
@@ -303,14 +329,14 @@ def _get_service_principal_provider(request):
303329
if isinstance(scopes, str):
304330
scopes = scopes.split(",")
305331

306-
return create_provider_from_service_principal(
332+
return ServicePrincipalIdentityProviderConfig(
307333
client_id=client_id,
308334
client_credential=client_credential,
309335
scopes=scopes,
310336
timeout=timeout,
311337
token_kwargs=token_kwargs,
312-
authority=authority,
313-
**kwargs,
338+
tenant_id=tenant_id,
339+
app_kwargs=kwargs,
314340
)
315341

316342

@@ -322,31 +348,29 @@ def get_credential_provider(request) -> CredentialProvider:
322348
return cred_provider_class(**cred_provider_kwargs)
323349

324350
idp = identity_provider(request)
325-
initial_delay_in_ms = cred_provider_kwargs.get("initial_delay_in_ms", 0)
326-
block_for_initial = cred_provider_kwargs.get("block_for_initial", False)
327351
expiration_refresh_ratio = cred_provider_kwargs.get(
328-
"expiration_refresh_ratio", TokenAuthConfig.DEFAULT_EXPIRATION_REFRESH_RATIO
352+
"expiration_refresh_ratio", DEFAULT_EXPIRATION_REFRESH_RATIO
329353
)
330354
lower_refresh_bound_millis = cred_provider_kwargs.get(
331-
"lower_refresh_bound_millis", TokenAuthConfig.DEFAULT_LOWER_REFRESH_BOUND_MILLIS
332-
)
333-
max_attempts = cred_provider_kwargs.get(
334-
"max_attempts", TokenAuthConfig.DEFAULT_MAX_ATTEMPTS
355+
"lower_refresh_bound_millis", DEFAULT_LOWER_REFRESH_BOUND_MILLIS
335356
)
336-
delay_in_ms = cred_provider_kwargs.get(
337-
"delay_in_ms", TokenAuthConfig.DEFAULT_DELAY_IN_MS
357+
max_attempts = cred_provider_kwargs.get("max_attempts", DEFAULT_MAX_ATTEMPTS)
358+
delay_in_ms = cred_provider_kwargs.get("delay_in_ms", DEFAULT_DELAY_IN_MS)
359+
360+
token_mgr_config = TokenManagerConfig(
361+
expiration_refresh_ratio=expiration_refresh_ratio,
362+
lower_refresh_bound_millis=lower_refresh_bound_millis,
363+
token_request_execution_timeout_in_ms=DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS, # noqa
364+
retry_policy=RetryPolicy(
365+
max_attempts=max_attempts,
366+
delay_in_ms=delay_in_ms,
367+
),
338368
)
339369

340-
auth_config = TokenAuthConfig(idp)
341-
auth_config.expiration_refresh_ratio = expiration_refresh_ratio
342-
auth_config.lower_refresh_bound_millis = lower_refresh_bound_millis
343-
auth_config.max_attempts = max_attempts
344-
auth_config.delay_in_ms = delay_in_ms
345-
346370
return EntraIdCredentialsProvider(
347-
config=auth_config,
348-
initial_delay_in_ms=initial_delay_in_ms,
349-
block_for_initial=block_for_initial,
371+
identity_provider=idp,
372+
token_manager_config=token_mgr_config,
373+
initial_delay_in_ms=delay_in_ms,
350374
)
351375

352376

0 commit comments

Comments
 (0)