diff --git a/README.md b/README.md index dd1734a0..7e820b58 100644 --- a/README.md +++ b/README.md @@ -634,3 +634,5 @@ The system uses a precise naming convention: - Ensures fair resource allocation among all clients - **Examples**: Implementation examples are available in the [examples/rate_limit](examples/rate_limit) directory. + +dummy diff --git a/stac_fastapi/core/setup.py b/stac_fastapi/core/setup.py index 92442997..b055eecd 100644 --- a/stac_fastapi/core/setup.py +++ b/stac_fastapi/core/setup.py @@ -19,6 +19,7 @@ "pygeofilter~=0.3.1", "jsonschema~=4.0.0", "slowapi~=0.1.9", + "redis==6.4.0", ] setup( diff --git a/stac_fastapi/core/stac_fastapi/core/core.py b/stac_fastapi/core/stac_fastapi/core/core.py index e36d71d7..1532494a 100644 --- a/stac_fastapi/core/stac_fastapi/core/core.py +++ b/stac_fastapi/core/stac_fastapi/core/core.py @@ -24,6 +24,12 @@ from stac_fastapi.core.base_settings import ApiBaseSettings from stac_fastapi.core.datetime_utils import format_datetime_range from stac_fastapi.core.models.links import PagingLinks +from stac_fastapi.core.redis_utils import ( + add_previous_link, + cache_current_url, + cache_previous_url, + connect_redis, +) from stac_fastapi.core.serializers import CollectionSerializer, ItemSerializer from stac_fastapi.core.session import Session from stac_fastapi.core.utilities import filter_fields @@ -237,6 +243,12 @@ async def all_collections(self, **kwargs) -> stac_types.Collections: base_url = str(request.base_url) limit = int(request.query_params.get("limit", os.getenv("STAC_ITEM_LIMIT", 10))) token = request.query_params.get("token") + current_url = str(request.url) + redis = None + try: + redis = await connect_redis() + except Exception: + redis = None collections, next_token = await self.database.get_all_collections( token=token, limit=limit, request=request @@ -252,6 +264,10 @@ async def all_collections(self, **kwargs) -> stac_types.Collections: }, ] + await add_previous_link(redis, links, "collections", current_url, token) + if redis: + await cache_previous_url(redis, current_url, "collections") + if next_token: next_link = PagingLinks(next=next_token, request=request).link_next() links.append(next_link) @@ -284,7 +300,6 @@ async def get_collection( async def item_collection( self, collection_id: str, - request: Request, bbox: Optional[BBox] = None, datetime: Optional[str] = None, limit: Optional[int] = None, @@ -323,6 +338,31 @@ async def item_collection( Raises: HTTPException: 404 if the collection does not exist. """ + request: Request = kwargs["request"] + token = request.query_params.get("token") + base_url = str(request.base_url) + + current_url = str(request.url) + + try: + redis = await connect_redis() + except Exception: + redis = None + + if redis: + await cache_current_url(redis, current_url, collection_id) + + collection = await self.get_collection( + collection_id=collection_id, request=request + ) + collection_id = collection.get("id") + if collection_id is None: + raise HTTPException(status_code=404, detail="Collection not found") + + search = self.database.make_search() + search = self.database.apply_collections_filter( + search=search, collection_ids=[collection_id] + ) try: await self.get_collection(collection_id=collection_id, request=request) except Exception: @@ -336,6 +376,45 @@ async def item_collection( datetime=datetime, limit=limit, token=token, + collection_ids=[collection_id], + datetime_search=datetime_search, + ) + + items = [ + self.item_serializer.db_to_stac(item, base_url=base_url) for item in items + ] + + collection_links = [ + { + "rel": "collection", + "type": "application/json", + "href": urljoin(str(request.base_url), f"collections/{collection_id}"), + }, + { + "rel": "parent", + "type": "application/json", + "href": urljoin(str(request.base_url), f"collections/{collection_id}"), + }, + ] + + paging_links = await PagingLinks(request=request, next=next_token).get_links() + + if redis: + await add_previous_link( + redis, paging_links, collection_id, current_url, token + ) + + if redis: + await cache_previous_url(redis, current_url, collection_id) + + links = collection_links + paging_links + + return stac_types.ItemCollection( + type="FeatureCollection", + features=items, + links=links, + numReturned=len(items), + numMatched=maybe_count, sortby=sortby, query=query, filter_expr=filter_expr, @@ -482,7 +561,14 @@ async def post_search( HTTPException: If there is an error with the cql2_json filter. """ base_url = str(request.base_url) + current_url = str(request.url) + try: + redis = await connect_redis() + except Exception: + redis = None + if redis: + await cache_current_url(redis, current_url, "search_result") search = self.database.make_search() if search_request.ids: @@ -592,6 +678,14 @@ async def post_search( ] links = await PagingLinks(request=request, next=next_token).get_links() + if redis: + await add_previous_link( + redis, links, "search_result", current_url, search_request.token + ) + + if redis: + await cache_previous_url(redis, current_url, "search_result") + return stac_types.ItemCollection( type="FeatureCollection", features=items, diff --git a/stac_fastapi/core/stac_fastapi/core/redis_utils.py b/stac_fastapi/core/stac_fastapi/core/redis_utils.py new file mode 100644 index 00000000..4013fd0f --- /dev/null +++ b/stac_fastapi/core/stac_fastapi/core/redis_utils.py @@ -0,0 +1,196 @@ +"""Utilities for connecting to and managing Redis connections.""" + +import logging +import os +from typing import Dict, List, Optional + +from pydantic_settings import BaseSettings +from redis import asyncio as aioredis +from stac_pydantic.shared import MimeTypes + +from stac_fastapi.core.utilities import get_bool_env + +redis_pool = None + +logger = logging.getLogger(__name__) + + +class RedisSentinelSettings(BaseSettings): + """Configuration settings for connecting to a Redis Sentinel server.""" + + sentinel_hosts: List[str] = os.getenv("REDIS_SENTINEL_HOSTS", "").split(",") + sentinel_ports: List[int] = [ + int(port) + for port in os.getenv("REDIS_SENTINEL_PORTS", "").split(",") + if port.strip() + ] + sentinel_master_name: str = os.getenv("REDIS_SENTINEL_MASTER_NAME", "") + redis_db: int = int(os.getenv("REDIS_DB", "0")) + + max_connections: int = int(os.getenv("REDIS_MAX_CONNECTIONS", "5")) + retry_on_timeout: bool = get_bool_env("REDIS_RETRY_TIMEOUT", True) + decode_responses: bool = get_bool_env("REDIS_DECODE_RESPONSES", True) + client_name: str = os.getenv("REDIS_CLIENT_NAME", "stac-fastapi-app") + health_check_interval: int = int(os.getenv("REDIS_HEALTH_CHECK_INTERVAL", "30")) + + +class RedisSettings(BaseSettings): + """Configuration settings for connecting to a Redis server.""" + + redis_host: str = os.getenv("REDIS_HOST", "localhost") + redis_port: int = int(os.getenv("REDIS_PORT", "6379")) + redis_db: int = int(os.getenv("REDIS_DB", "0")) + + max_connections: int = int(os.getenv("REDIS_MAX_CONNECTIONS", "5")) + retry_on_timeout: bool = get_bool_env("REDIS_RETRY_TIMEOUT", True) + decode_responses: bool = get_bool_env("REDIS_DECODE_RESPONSES", True) + client_name: str = os.getenv("REDIS_CLIENT_NAME", "stac-fastapi-app") + health_check_interval: int = int(os.getenv("REDIS_HEALTH_CHECK_INTERVAL", "30")) + + +# select which configuration to be used RedisSettings or RedisSentinelSettings +redis_settings = RedisSettings() + + +async def connect_redis_sentinel( + settings: Optional[RedisSentinelSettings] = None, +) -> Optional[aioredis.Redis]: + """Return a Redis Sentinel connection.""" + global redis_pool + settings = redis_settings + + if ( + not settings.sentinel_hosts + or not settings.sentinel_hosts[0] + or not settings.sentinel_master_name + ): + return None + + if redis_pool is None: + try: + sentinel = aioredis.Sentinel( + [ + (host, port) + for host, port in zip( + settings.sentinel_hosts, settings.sentinel_ports + ) + ], + decode_responses=settings.decode_responses, + retry_on_timeout=settings.retry_on_timeout, + client_name=f"{settings.client_name}-sentinel", + ) + + master = sentinel.master_for( + settings.sentinel_master_name, + db=settings.redis_db, + decode_responses=settings.decode_responses, + retry_on_timeout=settings.retry_on_timeout, + client_name=settings.client_name, + max_connections=settings.max_connections, + ) + + redis_pool = master + + except Exception: + return None + + return redis_pool + + +async def connect_redis( + settings: Optional[RedisSettings] = None, +) -> Optional[aioredis.Redis]: + """Return a Redis connection for regular Redis server.""" + global redis_pool + settings = redis_settings + + if not settings.redis_host: + return None + + if redis_pool is None: + try: + redis_pool = aioredis.Redis( + host=settings.redis_host, + port=settings.redis_port, + db=settings.redis_db, + decode_responses=settings.decode_responses, + retry_on_timeout=settings.retry_on_timeout, + client_name=settings.client_name, + health_check_interval=settings.health_check_interval, + max_connections=settings.max_connections, + ) + except Exception as e: + logger.error(f"Redis connection failed: {e}") + return None + + return redis_pool + + +async def close_redis() -> None: + """Close the Redis connection pool if it exists.""" + global redis_pool + if redis_pool: + await redis_pool.close() + redis_pool = None + + +async def cache_current_url(redis, current_url: str, key: str) -> None: + """Add to Redis cache the current URL for navigation.""" + if not redis: + return + + try: + current_key = f"current:{key}" + await redis.setex(current_key, 600, current_url) + except Exception as e: + logger.error(f"Redis cache error for {key}: {e}") + + +async def get_previous_url(redis, key: str) -> Optional[str]: + """Get previous URL from Redis cache if it exists.""" + if redis is None: + return None + + try: + prev_key = f"prev:{key}" + previous_url = await redis.get(prev_key) + if previous_url: + return previous_url + except Exception as e: + logger.error(f"Redis get previous error for {key}: {e}") + + return None + + +async def cache_previous_url(redis, current_url: str, key: str) -> None: + """Cache the current URL as previous for previous links in next page.""" + if not redis: + return + + try: + prev_key = f"prev:{key}" + await redis.setex(prev_key, 600, current_url) + except Exception as e: + logger.error(f"Redis cache previous error for {key}: {e}") + + +async def add_previous_link( + redis, + links: List[Dict], + key: str, + current_url: str, + token: Optional[str] = None, +) -> None: + """Add previous link into navigation.""" + if not redis or not token: + return + + previous_url = await get_previous_url(redis, key) + if previous_url: + links.append( + { + "rel": "previous", + "type": MimeTypes.json, + "href": previous_url, + } + ) diff --git a/stac_fastapi/tests/conftest.py b/stac_fastapi/tests/conftest.py index 08e3277d..94bdea91 100644 --- a/stac_fastapi/tests/conftest.py +++ b/stac_fastapi/tests/conftest.py @@ -6,6 +6,7 @@ import pytest import pytest_asyncio +import redis # noqa: F401 from fastapi import Depends, HTTPException, security, status from httpx import ASGITransport, AsyncClient from pydantic import ConfigDict