Skip to content

Commit 66a53ea

Browse files
committed
Added testing
1 parent 974ad4f commit 66a53ea

File tree

6 files changed

+559
-5
lines changed

6 files changed

+559
-5
lines changed

.github/actions/run-tests/action.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ runs:
103103

104104
if (( $REDIS_MAJOR_VERSION < 7 )) && [ "$protocol" == "3" ]; then
105105
echo "Skipping module tests: Modules doesn't support RESP3 for Redis versions < 7"
106-
invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}" --extra-markers="not redismod"
106+
invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}" --extra-markers="not redismod and not cp_integration"
107107
else
108108
invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}"
109109
fi

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
async-timeout>=4.0.3
2-
redispy-entraid-credentials @ git+https://github.com/redis-developer/redispy-entra-credentials.git/@main
2+
redispy-entraid-credentials @ git+https://github.com/redis-developer/redispy-entra-credentials.git/@main
3+
PyJWT~=2.9.0

tests/conftest.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
11
import argparse
2+
import os
23
import random
34
import time
5+
from datetime import datetime, timezone
46
from typing import Callable, TypeVar
57
from unittest import mock
68
from unittest.mock import Mock
79
from urllib.parse import urlparse
810

11+
import jwt
912
import pytest
13+
from entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig
14+
from entraid.identity_provider import ManagedIdentityType, create_provider_from_managed_identity, \
15+
create_provider_from_service_principal
16+
17+
from redis.auth.token import JWToken
18+
from redis.auth.token_manager import TokenManager
19+
from redis.credentials import CredentialProvider, StreamingCredentialProvider
20+
1021
import redis
1122
from packaging.version import Version
1223
from redis import Sentinel
24+
from redis.auth.idp import IdentityProviderInterface
1325
from redis.backoff import NoBackoff
1426
from redis.cache import (
1527
CacheConfig,
@@ -575,6 +587,113 @@ def cache_key(request) -> CacheKey:
575587
return CacheKey(command, keys)
576588

577589

590+
def mock_identity_provider() -> IdentityProviderInterface:
591+
mock_provider = Mock(spec=IdentityProviderInterface)
592+
token = {
593+
"exp": datetime.now(timezone.utc).timestamp() + 3600,
594+
"oid": "username"
595+
}
596+
encoded = jwt.encode(token, "secret", algorithm='HS256')
597+
jwt_token = JWToken(encoded)
598+
mock_provider.request_token.return_value = jwt_token
599+
return mock_provider
600+
601+
602+
def identity_provider(request) -> IdentityProviderInterface:
603+
auth_type = os.getenv("IDP_AUTH_TYPE")
604+
605+
if request.param.get("mock_idp", None) is not None:
606+
return mock_identity_provider()
607+
608+
if auth_type == "MANAGED_IDENTITY":
609+
return _get_managed_identity_provider(request)
610+
611+
return _get_service_principal_provider(request)
612+
613+
614+
def _get_managed_identity_provider(request):
615+
authority = os.getenv("IDP_AUTHORITY")
616+
identity_type = ManagedIdentityType(os.getenv("IDP_IDENTITY_TYPE"))
617+
resource = os.getenv("IDP_RESOURCE")
618+
id_type = os.getenv("IDP_ID_TYPE", None)
619+
id_value = os.getenv("IDP_ID_VALUE", None)
620+
kwargs = request.param.get("idp_kwargs", {})
621+
622+
return create_provider_from_managed_identity(
623+
identity_type=identity_type,
624+
resource=resource,
625+
id_type=id_type,
626+
id_value=id_value,
627+
authority=authority,
628+
**kwargs
629+
)
630+
631+
632+
def _get_service_principal_provider(request):
633+
client_id = os.getenv("IDP_CLIENT_ID")
634+
client_credential = os.getenv("IDP_CLIENT_CREDENTIAL")
635+
authority = os.getenv("IDP_AUTHORITY")
636+
scopes = os.getenv("IDP_SCOPES", [])
637+
kwargs = request.param.get("idp_kwargs", {})
638+
639+
if isinstance(scopes, str):
640+
scopes = scopes.split(',')
641+
642+
token_kwargs = request.param.get("token_kwargs", {})
643+
timeout = request.param.get("timeout", None)
644+
645+
return create_provider_from_service_principal(
646+
client_id=client_id,
647+
client_credential=client_credential,
648+
scopes=scopes,
649+
timeout=timeout,
650+
token_kwargs=token_kwargs,
651+
authority=authority,
652+
**kwargs
653+
)
654+
655+
656+
def get_credential_provider(request) -> CredentialProvider:
657+
cred_provider_class = request.param.get("cred_provider_class")
658+
cred_provider_kwargs = request.param.get("cred_provider_kwargs", {})
659+
660+
if cred_provider_class != EntraIdCredentialsProvider:
661+
return cred_provider_class(**cred_provider_kwargs)
662+
663+
idp = identity_provider(request)
664+
initial_delay_in_ms = cred_provider_kwargs.get("initial_delay_in_ms", 0)
665+
block_for_initial = cred_provider_kwargs.get("block_for_initial", False)
666+
expiration_refresh_ratio = cred_provider_kwargs.get(
667+
"expiration_refresh_ratio", TokenAuthConfig.DEFAULT_EXPIRATION_REFRESH_RATIO
668+
)
669+
lower_refresh_bound_millis = cred_provider_kwargs.get(
670+
"lower_refresh_bound_millis", TokenAuthConfig.DEFAULT_LOWER_REFRESH_BOUND_MILLIS
671+
)
672+
max_attempts = cred_provider_kwargs.get(
673+
"max_attempts", TokenAuthConfig.DEFAULT_MAX_ATTEMPTS
674+
)
675+
delay_in_ms = cred_provider_kwargs.get(
676+
"delay_in_ms", TokenAuthConfig.DEFAULT_DELAY_IN_MS
677+
)
678+
679+
auth_config = TokenAuthConfig(idp)
680+
auth_config.expiration_refresh_ratio(expiration_refresh_ratio)
681+
auth_config.lower_refresh_bound_millis(lower_refresh_bound_millis)
682+
auth_config.max_attempts(max_attempts)
683+
auth_config.delay_in_ms(delay_in_ms)
684+
685+
return EntraIdCredentialsProvider(
686+
config=auth_config,
687+
initial_delay_in_ms=initial_delay_in_ms,
688+
block_for_initial=block_for_initial,
689+
)
690+
691+
692+
@pytest.fixture()
693+
def credential_provider(request) -> CredentialProvider:
694+
return get_credential_provider(request)
695+
696+
578697
def wait_for_command(client, monitor, command, key=None):
579698
# issue a command with a key name that's local to this process.
580699
# if we find a command with our key before the command we're waiting

tests/test_asyncio/conftest.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
1+
import os
12
import random
23
from contextlib import asynccontextmanager as _asynccontextmanager
4+
from datetime import datetime, timezone
35
from typing import Union
46

7+
import jwt
58
import pytest
69
import pytest_asyncio
10+
from entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig
11+
from entraid.identity_provider import ManagedIdentityType, create_provider_from_managed_identity, \
12+
create_provider_from_service_principal
13+
from mock.mock import Mock
14+
from redis.credentials import CredentialProvider
15+
716
import redis.asyncio as redis
817
from packaging.version import Version
918
from redis.asyncio import Sentinel
1019
from redis.asyncio.client import Monitor
1120
from redis.asyncio.connection import Connection, parse_url
1221
from redis.asyncio.retry import Retry
22+
from redis.auth.idp import IdentityProviderInterface
23+
from redis.auth.token import JWToken
1324
from redis.backoff import NoBackoff
1425
from tests.conftest import REDIS_INFO
1526

@@ -216,6 +227,113 @@ async def mock_cluster_resp_slaves(create_redis, **kwargs):
216227
yield mocked
217228

218229

230+
def mock_identity_provider() -> IdentityProviderInterface:
231+
mock_provider = Mock(spec=IdentityProviderInterface)
232+
token = {
233+
"exp": datetime.now(timezone.utc).timestamp() + 3600,
234+
"oid": "username"
235+
}
236+
encoded = jwt.encode(token, "secret", algorithm='HS256')
237+
jwt_token = JWToken(encoded)
238+
mock_provider.request_token.return_value = jwt_token
239+
return mock_provider
240+
241+
242+
def identity_provider(request) -> IdentityProviderInterface:
243+
auth_type = os.getenv("IDP_AUTH_TYPE")
244+
245+
if request.param.get("mock_idp", None) is not None:
246+
return mock_identity_provider()
247+
248+
if auth_type == "MANAGED_IDENTITY":
249+
return _get_managed_identity_provider(request)
250+
251+
return _get_service_principal_provider(request)
252+
253+
254+
def _get_managed_identity_provider(request):
255+
authority = os.getenv("IDP_AUTHORITY")
256+
identity_type = ManagedIdentityType(os.getenv("IDP_IDENTITY_TYPE"))
257+
resource = os.getenv("IDP_RESOURCE")
258+
id_type = os.getenv("IDP_ID_TYPE", None)
259+
id_value = os.getenv("IDP_ID_VALUE", None)
260+
kwargs = request.param.get("idp_kwargs", {})
261+
262+
return create_provider_from_managed_identity(
263+
identity_type=identity_type,
264+
resource=resource,
265+
id_type=id_type,
266+
id_value=id_value,
267+
authority=authority,
268+
**kwargs
269+
)
270+
271+
272+
def _get_service_principal_provider(request):
273+
client_id = os.getenv("IDP_CLIENT_ID")
274+
client_credential = os.getenv("IDP_CLIENT_CREDENTIAL")
275+
authority = os.getenv("IDP_AUTHORITY")
276+
scopes = os.getenv("IDP_SCOPES", [])
277+
kwargs = request.param.get("idp_kwargs", {})
278+
279+
if isinstance(scopes, str):
280+
scopes = scopes.split(',')
281+
282+
token_kwargs = request.param.get("token_kwargs", {})
283+
timeout = request.param.get("timeout", None)
284+
285+
return create_provider_from_service_principal(
286+
client_id=client_id,
287+
client_credential=client_credential,
288+
scopes=scopes,
289+
timeout=timeout,
290+
token_kwargs=token_kwargs,
291+
authority=authority,
292+
**kwargs
293+
)
294+
295+
296+
def get_credential_provider(request) -> CredentialProvider:
297+
cred_provider_class = request.param.get("cred_provider_class")
298+
cred_provider_kwargs = request.param.get("cred_provider_kwargs", {})
299+
300+
if cred_provider_class != EntraIdCredentialsProvider:
301+
return cred_provider_class(**cred_provider_kwargs)
302+
303+
idp = identity_provider(request)
304+
initial_delay_in_ms = cred_provider_kwargs.get("initial_delay_in_ms", 0)
305+
block_for_initial = cred_provider_kwargs.get("block_for_initial", False)
306+
expiration_refresh_ratio = cred_provider_kwargs.get(
307+
"expiration_refresh_ratio", TokenAuthConfig.DEFAULT_EXPIRATION_REFRESH_RATIO
308+
)
309+
lower_refresh_bound_millis = cred_provider_kwargs.get(
310+
"lower_refresh_bound_millis", TokenAuthConfig.DEFAULT_LOWER_REFRESH_BOUND_MILLIS
311+
)
312+
max_attempts = cred_provider_kwargs.get(
313+
"max_attempts", TokenAuthConfig.DEFAULT_MAX_ATTEMPTS
314+
)
315+
delay_in_ms = cred_provider_kwargs.get(
316+
"delay_in_ms", TokenAuthConfig.DEFAULT_DELAY_IN_MS
317+
)
318+
319+
auth_config = TokenAuthConfig(idp)
320+
auth_config.expiration_refresh_ratio(expiration_refresh_ratio)
321+
auth_config.lower_refresh_bound_millis(lower_refresh_bound_millis)
322+
auth_config.max_attempts(max_attempts)
323+
auth_config.delay_in_ms(delay_in_ms)
324+
325+
return EntraIdCredentialsProvider(
326+
config=auth_config,
327+
initial_delay_in_ms=initial_delay_in_ms,
328+
block_for_initial=block_for_initial,
329+
)
330+
331+
332+
@pytest_asyncio.fixture()
333+
async def credential_provider(request) -> CredentialProvider:
334+
return get_credential_provider(request)
335+
336+
219337
async def wait_for_command(
220338
client: redis.Redis, monitor: Monitor, command: str, key: Union[str, None] = None
221339
):

0 commit comments

Comments
 (0)