|
1 | 1 | import os |
2 | | -from typing import Any, Dict, List, Optional, Type |
| 2 | +from typing import Any, Dict, List, Optional, Type, TypeVar, Union, overload |
| 3 | +from urllib.parse import urlparse |
3 | 4 | from warnings import warn |
4 | 5 |
|
5 | 6 | from redis import Redis, RedisCluster |
|
11 | 12 | from redis.asyncio.connection import SSLConnection as AsyncSSLConnection |
12 | 13 | from redis.connection import SSLConnection |
13 | 14 | from redis.exceptions import ResponseError |
| 15 | +from redis.sentinel import Sentinel |
14 | 16 |
|
15 | 17 | from redisvl import __version__ |
16 | 18 | from redisvl.redis.constants import REDIS_URL_ENV_VAR |
@@ -192,6 +194,9 @@ def parse_attrs(attrs): |
192 | 194 | } |
193 | 195 |
|
194 | 196 |
|
| 197 | +T = TypeVar("T", Redis, AsyncRedis) |
| 198 | + |
| 199 | + |
195 | 200 | class RedisConnectionFactory: |
196 | 201 | """Builds connections to a Redis database, supporting both synchronous and |
197 | 202 | asynchronous clients. |
@@ -253,7 +258,9 @@ def get_redis_connection( |
253 | 258 | variable is not set. |
254 | 259 | """ |
255 | 260 | url = redis_url or get_address_from_env() |
256 | | - if is_cluster_url(url, **kwargs): |
| 261 | + if url.startswith("redis+sentinel"): |
| 262 | + client = RedisConnectionFactory._redis_sentinel_client(url, Redis, **kwargs) |
| 263 | + elif is_cluster_url(url, **kwargs): |
257 | 264 | client = RedisCluster.from_url(url, **kwargs) |
258 | 265 | else: |
259 | 266 | client = Redis.from_url(url, **kwargs) |
@@ -293,7 +300,11 @@ async def _get_aredis_connection( |
293 | 300 | """ |
294 | 301 | url = url or get_address_from_env() |
295 | 302 |
|
296 | | - if is_cluster_url(url, **kwargs): |
| 303 | + if url.startswith("redis+sentinel"): |
| 304 | + client = RedisConnectionFactory._redis_sentinel_client( |
| 305 | + url, AsyncRedis, **kwargs |
| 306 | + ) |
| 307 | + elif is_cluster_url(url, **kwargs): |
297 | 308 | client = AsyncRedisCluster.from_url(url, **kwargs) |
298 | 309 | else: |
299 | 310 | client = AsyncRedis.from_url(url, **kwargs) |
@@ -334,6 +345,10 @@ def get_async_redis_connection( |
334 | 345 | DeprecationWarning, |
335 | 346 | ) |
336 | 347 | url = url or get_address_from_env() |
| 348 | + if url.startswith("redis+sentinel"): |
| 349 | + return RedisConnectionFactory._redis_sentinel_client( |
| 350 | + url, AsyncRedis, **kwargs |
| 351 | + ) |
337 | 352 | return AsyncRedis.from_url(url, **kwargs) |
338 | 353 |
|
339 | 354 | @staticmethod |
@@ -440,3 +455,60 @@ async def validate_async_redis( |
440 | 455 | await redis_client.echo(_lib_name) |
441 | 456 |
|
442 | 457 | # Module validation removed - operations will fail naturally if modules are missing |
| 458 | + |
| 459 | + @staticmethod |
| 460 | + @overload |
| 461 | + def _redis_sentinel_client( |
| 462 | + redis_url: str, redis_class: type[Redis], **kwargs: Any |
| 463 | + ) -> Redis: ... |
| 464 | + |
| 465 | + @staticmethod |
| 466 | + @overload |
| 467 | + def _redis_sentinel_client( |
| 468 | + redis_url: str, redis_class: type[AsyncRedis], **kwargs: Any |
| 469 | + ) -> AsyncRedis: ... |
| 470 | + |
| 471 | + @staticmethod |
| 472 | + def _redis_sentinel_client( |
| 473 | + redis_url: str, redis_class: Union[type[Redis], type[AsyncRedis]], **kwargs: Any |
| 474 | + ) -> Union[Redis, AsyncRedis]: |
| 475 | + sentinel_list, service_name, db, username, password = ( |
| 476 | + RedisConnectionFactory._parse_sentinel_url(redis_url) |
| 477 | + ) |
| 478 | + |
| 479 | + sentinel_kwargs = {} |
| 480 | + if username: |
| 481 | + sentinel_kwargs["username"] = username |
| 482 | + kwargs["username"] = username |
| 483 | + if password: |
| 484 | + sentinel_kwargs["password"] = password |
| 485 | + kwargs["password"] = password |
| 486 | + if db: |
| 487 | + kwargs["db"] = db |
| 488 | + |
| 489 | + sentinel = Sentinel(sentinel_list, sentinel_kwargs=sentinel_kwargs, **kwargs) |
| 490 | + return sentinel.master_for(service_name, redis_class=redis_class, **kwargs) |
| 491 | + |
| 492 | + @staticmethod |
| 493 | + def _parse_sentinel_url(url: str) -> tuple: |
| 494 | + parsed_url = urlparse(url) |
| 495 | + hosts_part = parsed_url.netloc.split("@")[-1] |
| 496 | + sentinel_hosts = hosts_part.split(",") |
| 497 | + |
| 498 | + sentinel_list = [] |
| 499 | + for host in sentinel_hosts: |
| 500 | + host_parts = host.split(":") |
| 501 | + if len(host_parts) == 2: |
| 502 | + sentinel_list.append((host_parts[0], int(host_parts[1]))) |
| 503 | + else: |
| 504 | + sentinel_list.append((host_parts[0], 26379)) |
| 505 | + |
| 506 | + service_name = "mymaster" |
| 507 | + db = None |
| 508 | + if parsed_url.path: |
| 509 | + path_parts = parsed_url.path.split("/") |
| 510 | + service_name = path_parts[1] or "mymaster" |
| 511 | + if len(path_parts) > 2: |
| 512 | + db = path_parts[2] |
| 513 | + |
| 514 | + return sentinel_list, service_name, db, parsed_url.username, parsed_url.password |
0 commit comments