11"""Item crud client."""
22
3+ import json
34import re
45from typing import Any , Dict , List , Optional , Set , Union
56from urllib .parse import unquote_plus , urljoin
1415from pygeofilter .parsers .cql2_text import parse as parse_cql2_text
1516from pypgstac .hydration import hydrate
1617from stac_fastapi .api .models import JSONResponse
17- from stac_fastapi .types .core import AsyncBaseCoreClient
18+ from stac_fastapi .types .core import AsyncBaseCoreClient , Relations
1819from stac_fastapi .types .errors import InvalidQueryParameter , NotFoundError
1920from stac_fastapi .types .requests import get_base_url
2021from stac_fastapi .types .rfc3339 import DateTimeType
2122from stac_fastapi .types .stac import Collection , Collections , Item , ItemCollection
22- from stac_pydantic .links import Relations
2323from stac_pydantic .shared import BBox , MimeTypes
2424
2525from stac_fastapi .pgstac .config import Settings
3939class CoreCrudClient (AsyncBaseCoreClient ):
4040 """Client for core endpoints defined by stac."""
4141
42- async def all_collections (self , request : Request , ** kwargs ) -> Collections :
43- """Read all collections from the database."""
42+ async def all_collections ( # noqa: C901
43+ self ,
44+ request : Request ,
45+ # Extensions
46+ bbox : Optional [BBox ] = None ,
47+ datetime : Optional [DateTimeType ] = None ,
48+ limit : Optional [int ] = None ,
49+ query : Optional [str ] = None ,
50+ token : Optional [str ] = None ,
51+ fields : Optional [List [str ]] = None ,
52+ sortby : Optional [str ] = None ,
53+ filter : Optional [str ] = None ,
54+ filter_lang : Optional [str ] = None ,
55+ ** kwargs ,
56+ ) -> Collections :
57+ """Cross catalog search (GET).
58+
59+ Called with `GET /collections`.
60+
61+ Returns:
62+ Collections which match the search criteria, returns all
63+ collections by default.
64+ """
4465 base_url = get_base_url (request )
4566
67+ # Parse request parameters
68+ base_args = {
69+ "bbox" : bbox ,
70+ "limit" : limit ,
71+ "token" : token ,
72+ "query" : orjson .loads (unquote_plus (query )) if query else query ,
73+ }
74+
75+ clean_args = clean_search_args (
76+ base_args = base_args ,
77+ datetime = datetime ,
78+ fields = fields ,
79+ sortby = sortby ,
80+ filter_query = filter ,
81+ filter_lang = filter_lang ,
82+ )
83+
4684 async with request .app .state .get_connection (request , "r" ) as conn :
47- collections = await conn .fetchval (
48- """
49- SELECT * FROM all_collections();
85+ q , p = render (
5086 """
87+ SELECT * FROM collection_search(:req::text::jsonb);
88+ """ ,
89+ req = json .dumps (clean_args ),
5190 )
91+ collections_result : Collections = await conn .fetchval (q , * p )
92+
93+ next : Optional [str ] = None
94+ prev : Optional [str ] = None
95+
96+ if links := collections_result .get ("links" ):
97+ next = collections_result ["links" ].pop ("next" )
98+ prev = collections_result ["links" ].pop ("prev" )
99+
52100 linked_collections : List [Collection ] = []
101+ collections = collections_result ["collections" ]
53102 if collections is not None and len (collections ) > 0 :
54103 for c in collections :
55104 coll = Collection (** c )
@@ -71,25 +120,16 @@ async def all_collections(self, request: Request, **kwargs) -> Collections:
71120
72121 linked_collections .append (coll )
73122
74- links = [
75- {
76- "rel" : Relations .root .value ,
77- "type" : MimeTypes .json ,
78- "href" : base_url ,
79- },
80- {
81- "rel" : Relations .parent .value ,
82- "type" : MimeTypes .json ,
83- "href" : base_url ,
84- },
85- {
86- "rel" : Relations .self .value ,
87- "type" : MimeTypes .json ,
88- "href" : urljoin (base_url , "collections" ),
89- },
90- ]
91- collection_list = Collections (collections = linked_collections or [], links = links )
92- return collection_list
123+ links = await PagingLinks (
124+ request = request ,
125+ next = next ,
126+ prev = prev ,
127+ ).get_links ()
128+
129+ return Collections (
130+ collections = linked_collections or [],
131+ links = links ,
132+ )
93133
94134 async def get_collection (
95135 self , collection_id : str , request : Request , ** kwargs
@@ -386,7 +426,7 @@ async def post_search(
386426
387427 return ItemCollection (** item_collection )
388428
389- async def get_search ( # noqa: C901
429+ async def get_search (
390430 self ,
391431 request : Request ,
392432 collections : Optional [List [str ]] = None ,
@@ -421,51 +461,15 @@ async def get_search( # noqa: C901
421461 "query" : orjson .loads (unquote_plus (query )) if query else query ,
422462 }
423463
424- if filter :
425- if filter_lang == "cql2-text" :
426- filter = to_cql2 (parse_cql2_text (filter ))
427- filter_lang = "cql2-json"
428-
429- base_args ["filter" ] = orjson .loads (filter )
430- base_args ["filter-lang" ] = filter_lang
431-
432- if datetime :
433- base_args ["datetime" ] = format_datetime_range (datetime )
434-
435- if intersects :
436- base_args ["intersects" ] = orjson .loads (unquote_plus (intersects ))
437-
438- if sortby :
439- # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
440- sort_param = []
441- for sort in sortby :
442- sortparts = re .match (r"^([+-]?)(.*)$" , sort )
443- if sortparts :
444- sort_param .append (
445- {
446- "field" : sortparts .group (2 ).strip (),
447- "direction" : "desc" if sortparts .group (1 ) == "-" else "asc" ,
448- }
449- )
450- base_args ["sortby" ] = sort_param
451-
452- if fields :
453- includes = set ()
454- excludes = set ()
455- for field in fields :
456- if field [0 ] == "-" :
457- excludes .add (field [1 :])
458- elif field [0 ] == "+" :
459- includes .add (field [1 :])
460- else :
461- includes .add (field )
462- base_args ["fields" ] = {"include" : includes , "exclude" : excludes }
463-
464- # Remove None values from dict
465- clean = {}
466- for k , v in base_args .items ():
467- if v is not None and v != []:
468- clean [k ] = v
464+ clean = clean_search_args (
465+ base_args = base_args ,
466+ intersects = intersects ,
467+ datetime = datetime ,
468+ fields = fields ,
469+ sortby = sortby ,
470+ filter_query = filter ,
471+ filter_lang = filter_lang ,
472+ )
469473
470474 # Do the request
471475 try :
@@ -476,3 +480,62 @@ async def get_search( # noqa: C901
476480 ) from e
477481
478482 return await self .post_search (search_request , request = request )
483+
484+
485+ def clean_search_args ( # noqa: C901
486+ base_args : Dict [str , Any ],
487+ intersects : Optional [str ] = None ,
488+ datetime : Optional [DateTimeType ] = None ,
489+ fields : Optional [List [str ]] = None ,
490+ sortby : Optional [str ] = None ,
491+ filter_query : Optional [str ] = None ,
492+ filter_lang : Optional [str ] = None ,
493+ ) -> Dict [str , Any ]:
494+ """Clean up search arguments to match format expected by pgstac"""
495+ if filter_query :
496+ if filter_lang == "cql2-text" :
497+ filter_query = to_cql2 (parse_cql2_text (filter_query ))
498+ filter_lang = "cql2-json"
499+
500+ base_args ["filter" ] = orjson .loads (filter_query )
501+ base_args ["filter_lang" ] = filter_lang
502+
503+ if datetime :
504+ base_args ["datetime" ] = format_datetime_range (datetime )
505+
506+ if intersects :
507+ base_args ["intersects" ] = orjson .loads (unquote_plus (intersects ))
508+
509+ if sortby :
510+ # https://github.com/radiantearth/stac-spec/tree/master/api-spec/extensions/sort#http-get-or-post-form
511+ sort_param = []
512+ for sort in sortby :
513+ sortparts = re .match (r"^([+-]?)(.*)$" , sort )
514+ if sortparts :
515+ sort_param .append (
516+ {
517+ "field" : sortparts .group (2 ).strip (),
518+ "direction" : "desc" if sortparts .group (1 ) == "-" else "asc" ,
519+ }
520+ )
521+ base_args ["sortby" ] = sort_param
522+
523+ if fields :
524+ includes = set ()
525+ excludes = set ()
526+ for field in fields :
527+ if field [0 ] == "-" :
528+ excludes .add (field [1 :])
529+ elif field [0 ] == "+" :
530+ includes .add (field [1 :])
531+ else :
532+ includes .add (field )
533+ base_args ["fields" ] = {"include" : includes , "exclude" : excludes }
534+
535+ # Remove None values from dict
536+ clean = {}
537+ for k , v in base_args .items ():
538+ if v is not None and v != []:
539+ clean [k ] = v
540+
541+ return clean
0 commit comments