Skip to content

Commit c488a39

Browse files
author
Andrzej Pijanowski
committed
feat: add queryables cache and optional validation for search parameters
1 parent c1a7bc1 commit c488a39

File tree

5 files changed

+396
-0
lines changed

5 files changed

+396
-0
lines changed

stac_fastapi/core/stac_fastapi/core/core.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@
3939
BulkTransactionMethod,
4040
Items,
4141
)
42+
from stac_fastapi.sfeos_helpers.queryables import (
43+
get_properties_from_cql2_filter,
44+
initialize_queryables_cache,
45+
validate_queryables,
46+
)
4247
from stac_fastapi.types import stac as stac_types
4348
from stac_fastapi.types.conformance import BASE_CONFORMANCE_CLASSES
4449
from stac_fastapi.types.core import AsyncBaseCoreClient
@@ -88,6 +93,10 @@ class CoreClient(AsyncBaseCoreClient):
8893
title: str = attr.ib(default="stac-fastapi")
8994
description: str = attr.ib(default="stac-fastapi")
9095

96+
def __attrs_post_init__(self):
97+
"""Initialize the queryables cache."""
98+
initialize_queryables_cache(self.database)
99+
91100
def _landing_page(
92101
self,
93102
base_url: str,
@@ -815,6 +824,8 @@ async def post_search(
815824
)
816825

817826
if hasattr(search_request, "query") and getattr(search_request, "query"):
827+
query_fields = set(getattr(search_request, "query").keys())
828+
await validate_queryables(query_fields)
818829
for field_name, expr in getattr(search_request, "query").items():
819830
field = "properties__" + field_name
820831
for op, value in expr.items():
@@ -833,7 +844,11 @@ async def post_search(
833844

834845
if cql2_filter is not None:
835846
try:
847+
query_fields = get_properties_from_cql2_filter(cql2_filter)
848+
await validate_queryables(query_fields)
836849
search = await self.database.apply_cql2_filter(search, cql2_filter)
850+
except HTTPException:
851+
raise
837852
except Exception as e:
838853
raise HTTPException(
839854
status_code=400, detail=f"Error with cql2 filter: {e}"
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""A module for managing queryable attributes."""
2+
3+
import asyncio
4+
import os
5+
import time
6+
from typing import Any, Dict, List, Optional, Set
7+
8+
from fastapi import HTTPException
9+
10+
from stac_fastapi.core.base_database_logic import BaseDatabaseLogic
11+
12+
13+
class QueryablesCache:
14+
"""A thread-safe, time-based cache for queryable properties."""
15+
16+
def __init__(self, database_logic: Any):
17+
"""
18+
Initialize the QueryablesCache.
19+
20+
Args:
21+
database_logic: An instance of a class with a `get_queryables_mapping` method.
22+
"""
23+
self._db_logic = database_logic
24+
self._cache: Dict[str, List[str]] = {}
25+
self._all_queryables: Set[str] = set()
26+
self._last_updated: float = 0
27+
self._lock = asyncio.Lock()
28+
self.validation_enabled: bool = False
29+
self.cache_ttl: int = 3600 # How often to refresh cache (in seconds)
30+
self.reload_settings()
31+
32+
def reload_settings(self):
33+
"""Reload settings from environment variables."""
34+
self.validation_enabled = (
35+
os.getenv("VALIDATE_QUERYABLES", "false").lower() == "true"
36+
)
37+
self.cache_ttl = int(os.getenv("QUERYABLES_CACHE_TTL", "3600"))
38+
39+
async def _update_cache(self):
40+
"""Update the cache with the latest queryables from the database."""
41+
if not self.validation_enabled:
42+
return
43+
44+
async with self._lock:
45+
if (time.time() - self._last_updated < self.cache_ttl) and self._cache:
46+
return
47+
48+
queryables_mapping = await self._db_logic.get_queryables_mapping()
49+
all_queryables_set = set(queryables_mapping.keys())
50+
51+
self._all_queryables = all_queryables_set
52+
53+
self._cache = {"*": list(all_queryables_set)}
54+
self._last_updated = time.time()
55+
56+
async def get_all_queryables(self) -> Set[str]:
57+
"""
58+
Return a set of all queryable attributes across all collections.
59+
60+
This method will update the cache if it's stale or has been cleared.
61+
"""
62+
if not self.validation_enabled:
63+
return set()
64+
65+
if (time.time() - self._last_updated >= self.cache_ttl) or not self._cache:
66+
await self._update_cache()
67+
return self._all_queryables
68+
69+
async def validate(self, fields: Set[str]) -> None:
70+
"""
71+
Validate if the provided fields are queryable.
72+
73+
Raises HTTPException if invalid fields are found.
74+
"""
75+
if not self.validation_enabled:
76+
return
77+
78+
allowed_fields = await self.get_all_queryables()
79+
invalid_fields = fields - allowed_fields
80+
if invalid_fields:
81+
raise HTTPException(
82+
status_code=400,
83+
detail=f"Invalid query fields: {', '.join(invalid_fields)}. Allowed fields are: {', '.join(allowed_fields)}",
84+
)
85+
86+
87+
_queryables_cache_instance: Optional[QueryablesCache] = None
88+
89+
90+
def initialize_queryables_cache(database_logic: BaseDatabaseLogic):
91+
"""
92+
Initialize the global queryables cache.
93+
94+
:param database_logic: An instance of DatabaseLogic.
95+
"""
96+
global _queryables_cache_instance
97+
if _queryables_cache_instance is None:
98+
_queryables_cache_instance = QueryablesCache(database_logic)
99+
100+
101+
async def all_queryables() -> Set[str]:
102+
"""Get all queryable properties from the cache."""
103+
if _queryables_cache_instance is None:
104+
raise Exception("Queryables cache not initialized.")
105+
return await _queryables_cache_instance.get_all_queryables()
106+
107+
108+
async def validate_queryables(fields: Set[str]) -> None:
109+
"""Validate if the provided fields are queryable."""
110+
if _queryables_cache_instance is None:
111+
return
112+
await _queryables_cache_instance.validate(fields)
113+
114+
115+
def reload_queryables_settings():
116+
"""Reload queryables settings from environment variables."""
117+
if _queryables_cache_instance:
118+
_queryables_cache_instance.reload_settings()
119+
120+
121+
def get_properties_from_cql2_filter(cql2_filter: Dict[str, Any]) -> Set[str]:
122+
"""Recursively extract property names from a CQL2 filter."""
123+
props: Set[str] = set()
124+
if "op" in cql2_filter and "args" in cql2_filter:
125+
for arg in cql2_filter["args"]:
126+
if isinstance(arg, dict):
127+
if "op" in arg:
128+
props.update(get_properties_from_cql2_filter(arg))
129+
elif "property" in arg:
130+
props.add(arg["property"])
131+
return props
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import json
2+
import os
3+
from unittest import mock
4+
5+
import pytest
6+
7+
from stac_fastapi.sfeos_helpers.queryables import reload_queryables_settings
8+
9+
10+
@pytest.fixture(autouse=True)
11+
def enable_validation():
12+
with mock.patch.dict(os.environ, {"VALIDATE_QUERYABLES": "true"}):
13+
reload_queryables_settings()
14+
yield
15+
reload_queryables_settings()
16+
17+
18+
@pytest.mark.asyncio
19+
async def test_search_post_query_valid_param(app_client, ctx):
20+
"""Test POST /search with a valid query parameter"""
21+
query = {"query": {"eo:cloud_cover": {"lt": 10}}}
22+
resp = await app_client.post("/search", json=query)
23+
assert resp.status_code == 200
24+
25+
26+
@pytest.mark.asyncio
27+
async def test_search_post_query_invalid_param(app_client, ctx):
28+
"""Test POST /search with an invalid query parameter"""
29+
query = {"query": {"invalid_param": {"eq": "test"}}}
30+
resp = await app_client.post("/search", json=query)
31+
assert resp.status_code == 400
32+
resp_json = resp.json()
33+
assert "Invalid query fields: invalid_param" in resp_json["detail"]
34+
35+
36+
@pytest.mark.asyncio
37+
async def test_item_collection_get_filter_valid_param(app_client, ctx):
38+
"""Test GET /collections/{collection_id}/items with a valid filter parameter"""
39+
collection_id = ctx.item["collection"]
40+
filter_body = {
41+
"op": "<",
42+
"args": [{"property": "eo:cloud_cover"}, 10],
43+
}
44+
params = {
45+
"filter-lang": "cql2-json",
46+
"filter": json.dumps(filter_body),
47+
}
48+
resp = await app_client.get(f"/collections/{collection_id}/items", params=params)
49+
assert resp.status_code == 200
50+
51+
52+
@pytest.mark.asyncio
53+
async def test_item_collection_get_filter_invalid_param(app_client, ctx):
54+
"""Test GET /collections/{collection_id}/items with an invalid filter parameter"""
55+
collection_id = ctx.item["collection"]
56+
filter_body = {
57+
"op": "=",
58+
"args": [{"property": "invalid_param"}, "test"],
59+
}
60+
params = {
61+
"filter-lang": "cql2-json",
62+
"filter": json.dumps(filter_body),
63+
}
64+
resp = await app_client.get(f"/collections/{collection_id}/items", params=params)
65+
assert resp.status_code == 400
66+
resp_json = resp.json()
67+
assert "Invalid query fields: invalid_param" in resp_json["detail"]

stac_fastapi/tests/data/test_collection.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
"type": "Collection",
77
"description": "Landat 8 imagery radiometrically calibrated and orthorectified using gound points and Digital Elevation Model (DEM) data to correct relief displacement.",
88
"stac_version": "1.0.0",
9+
"queryables": {
10+
"eo:cloud_cover": {
11+
"$ref": "https://stac-extensions.github.io/eo/v1.0.0/schema.json#/definitions/fields/properties/eo:cloud_cover"
12+
}
13+
},
914
"license": "PDDL-1.0",
1015
"summaries": {
1116
"platform": [

0 commit comments

Comments
 (0)