|
| 1 | +import os |
1 | 2 | import random
|
2 | 3 | from contextlib import asynccontextmanager as _asynccontextmanager
|
| 4 | +from datetime import datetime, timezone |
3 | 5 | from typing import Union
|
4 | 6 |
|
| 7 | +import jwt |
5 | 8 | import pytest
|
6 | 9 | import pytest_asyncio
|
| 10 | +from entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig |
| 11 | +from entraid.identity_provider import ManagedIdentityType, create_provider_from_managed_identity, \ |
| 12 | + create_provider_from_service_principal |
| 13 | +from mock.mock import Mock |
| 14 | +from redis.credentials import CredentialProvider |
| 15 | + |
7 | 16 | import redis.asyncio as redis
|
8 | 17 | from packaging.version import Version
|
9 | 18 | from redis.asyncio import Sentinel
|
10 | 19 | from redis.asyncio.client import Monitor
|
11 | 20 | from redis.asyncio.connection import Connection, parse_url
|
12 | 21 | from redis.asyncio.retry import Retry
|
| 22 | +from redis.auth.idp import IdentityProviderInterface |
| 23 | +from redis.auth.token import JWToken |
13 | 24 | from redis.backoff import NoBackoff
|
14 | 25 | from tests.conftest import REDIS_INFO
|
15 | 26 |
|
@@ -216,6 +227,113 @@ async def mock_cluster_resp_slaves(create_redis, **kwargs):
|
216 | 227 | yield mocked
|
217 | 228 |
|
218 | 229 |
|
| 230 | +def mock_identity_provider() -> IdentityProviderInterface: |
| 231 | + mock_provider = Mock(spec=IdentityProviderInterface) |
| 232 | + token = { |
| 233 | + "exp": datetime.now(timezone.utc).timestamp() + 3600, |
| 234 | + "oid": "username" |
| 235 | + } |
| 236 | + encoded = jwt.encode(token, "secret", algorithm='HS256') |
| 237 | + jwt_token = JWToken(encoded) |
| 238 | + mock_provider.request_token.return_value = jwt_token |
| 239 | + return mock_provider |
| 240 | + |
| 241 | + |
| 242 | +def identity_provider(request) -> IdentityProviderInterface: |
| 243 | + auth_type = os.getenv("IDP_AUTH_TYPE") |
| 244 | + |
| 245 | + if request.param.get("mock_idp", None) is not None: |
| 246 | + return mock_identity_provider() |
| 247 | + |
| 248 | + if auth_type == "MANAGED_IDENTITY": |
| 249 | + return _get_managed_identity_provider(request) |
| 250 | + |
| 251 | + return _get_service_principal_provider(request) |
| 252 | + |
| 253 | + |
| 254 | +def _get_managed_identity_provider(request): |
| 255 | + authority = os.getenv("IDP_AUTHORITY") |
| 256 | + identity_type = ManagedIdentityType(os.getenv("IDP_IDENTITY_TYPE")) |
| 257 | + resource = os.getenv("IDP_RESOURCE") |
| 258 | + id_type = os.getenv("IDP_ID_TYPE", None) |
| 259 | + id_value = os.getenv("IDP_ID_VALUE", None) |
| 260 | + kwargs = request.param.get("idp_kwargs", {}) |
| 261 | + |
| 262 | + return create_provider_from_managed_identity( |
| 263 | + identity_type=identity_type, |
| 264 | + resource=resource, |
| 265 | + id_type=id_type, |
| 266 | + id_value=id_value, |
| 267 | + authority=authority, |
| 268 | + **kwargs |
| 269 | + ) |
| 270 | + |
| 271 | + |
| 272 | +def _get_service_principal_provider(request): |
| 273 | + client_id = os.getenv("IDP_CLIENT_ID") |
| 274 | + client_credential = os.getenv("IDP_CLIENT_CREDENTIAL") |
| 275 | + authority = os.getenv("IDP_AUTHORITY") |
| 276 | + scopes = os.getenv("IDP_SCOPES", []) |
| 277 | + kwargs = request.param.get("idp_kwargs", {}) |
| 278 | + |
| 279 | + if isinstance(scopes, str): |
| 280 | + scopes = scopes.split(',') |
| 281 | + |
| 282 | + token_kwargs = request.param.get("token_kwargs", {}) |
| 283 | + timeout = request.param.get("timeout", None) |
| 284 | + |
| 285 | + return create_provider_from_service_principal( |
| 286 | + client_id=client_id, |
| 287 | + client_credential=client_credential, |
| 288 | + scopes=scopes, |
| 289 | + timeout=timeout, |
| 290 | + token_kwargs=token_kwargs, |
| 291 | + authority=authority, |
| 292 | + **kwargs |
| 293 | + ) |
| 294 | + |
| 295 | + |
| 296 | +def get_credential_provider(request) -> CredentialProvider: |
| 297 | + cred_provider_class = request.param.get("cred_provider_class") |
| 298 | + cred_provider_kwargs = request.param.get("cred_provider_kwargs", {}) |
| 299 | + |
| 300 | + if cred_provider_class != EntraIdCredentialsProvider: |
| 301 | + return cred_provider_class(**cred_provider_kwargs) |
| 302 | + |
| 303 | + idp = identity_provider(request) |
| 304 | + initial_delay_in_ms = cred_provider_kwargs.get("initial_delay_in_ms", 0) |
| 305 | + block_for_initial = cred_provider_kwargs.get("block_for_initial", False) |
| 306 | + expiration_refresh_ratio = cred_provider_kwargs.get( |
| 307 | + "expiration_refresh_ratio", TokenAuthConfig.DEFAULT_EXPIRATION_REFRESH_RATIO |
| 308 | + ) |
| 309 | + lower_refresh_bound_millis = cred_provider_kwargs.get( |
| 310 | + "lower_refresh_bound_millis", TokenAuthConfig.DEFAULT_LOWER_REFRESH_BOUND_MILLIS |
| 311 | + ) |
| 312 | + max_attempts = cred_provider_kwargs.get( |
| 313 | + "max_attempts", TokenAuthConfig.DEFAULT_MAX_ATTEMPTS |
| 314 | + ) |
| 315 | + delay_in_ms = cred_provider_kwargs.get( |
| 316 | + "delay_in_ms", TokenAuthConfig.DEFAULT_DELAY_IN_MS |
| 317 | + ) |
| 318 | + |
| 319 | + auth_config = TokenAuthConfig(idp) |
| 320 | + auth_config.expiration_refresh_ratio(expiration_refresh_ratio) |
| 321 | + auth_config.lower_refresh_bound_millis(lower_refresh_bound_millis) |
| 322 | + auth_config.max_attempts(max_attempts) |
| 323 | + auth_config.delay_in_ms(delay_in_ms) |
| 324 | + |
| 325 | + return EntraIdCredentialsProvider( |
| 326 | + config=auth_config, |
| 327 | + initial_delay_in_ms=initial_delay_in_ms, |
| 328 | + block_for_initial=block_for_initial, |
| 329 | + ) |
| 330 | + |
| 331 | + |
| 332 | +@pytest_asyncio.fixture() |
| 333 | +async def credential_provider(request) -> CredentialProvider: |
| 334 | + return get_credential_provider(request) |
| 335 | + |
| 336 | + |
219 | 337 | async def wait_for_command(
|
220 | 338 | client: redis.Redis, monitor: Monitor, command: str, key: Union[str, None] = None
|
221 | 339 | ):
|
|
0 commit comments