66import sys
77from asyncio import as_completed
88from pathlib import Path
9- from typing import Any , Awaitable , Dict , List , Optional , Union , Tuple
10- from aiohttp . client_reqrep import ClientResponse
11- from jsonpath_ng import parse as jparse
9+ from typing import Any , Awaitable , Dict , List , Optional , Tuple , Union
10+ from warnings import warn
11+
1212import pandas as pd
1313from aiohttp import ClientSession
14+ from aiohttp .client_reqrep import ClientResponse
1415from jinja2 import Environment , StrictUndefined , Template , UndefinedError
16+ from jsonpath_ng import parse as jparse
1517
1618from .config_manager import config_directory , ensure_config
1719from .errors import InvalidParameterError , RequestError , UniversalParameterOverridden
2123 ConfigDef ,
2224 FieldDefUnion ,
2325 OffsetPaginationDef ,
24- SeekPaginationDef ,
2526 PagePaginationDef ,
26- TokenPaginationDef ,
27+ SeekPaginationDef ,
2728 TokenLocation ,
29+ TokenPaginationDef ,
2830)
2931from .throttler import OrderedThrottler , ThrottleSession
3032
@@ -108,6 +110,7 @@ async def query( # pylint: disable=too-many-locals
108110 self ,
109111 table : str ,
110112 * ,
113+ _q : Optional [str ] = None ,
111114 _auth : Optional [Dict [str , Any ]] = None ,
112115 _count : Optional [int ] = None ,
113116 ** where : Any ,
@@ -119,6 +122,8 @@ async def query( # pylint: disable=too-many-locals
119122 ----------
120123 table
121124 The table name.
125+ _q: Optional[str] = None
126+ Search string to be matched in the response.
122127 _auth: Optional[Dict[str, Any]] = None
123128 The parameters for authentication. Usually the authentication parameters
124129 should be defined when instantiating the Connector. In case some tables have different
@@ -134,12 +139,13 @@ async def query( # pylint: disable=too-many-locals
134139 if key not in allowed_params :
135140 raise InvalidParameterError (key )
136141
137- return await self ._query_imp (table , where , _auth = _auth , _count = _count )
142+ return await self ._query_imp (table , where , _auth = _auth , _q = _q , _count = _count )
138143
139144 @property
140145 def table_names (self ) -> List [str ]:
141146 """
142147 Return all the names of the available tables in a list.
148+
143149 Note
144150 ----
145151 We abstract each website as a database containing several tables.
@@ -148,9 +154,8 @@ def table_names(self) -> List[str]:
148154 return list (self ._impdb .tables .keys ())
149155
150156 def info (self ) -> None :
151- """
152- Show the basic information and provide guidance for users to issue queries.
153- """
157+ """Show the basic information and provide guidance for users
158+ to issue queries."""
154159
155160 # get info
156161 tbs : Dict [str , Any ] = {}
@@ -216,6 +221,7 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches,too-m
216221 * ,
217222 _auth : Optional [Dict [str , Any ]] = None ,
218223 _count : Optional [int ] = None ,
224+ _q : Optional [str ] = None ,
219225 ) -> pd .DataFrame :
220226 if table not in self ._impdb .tables :
221227 raise ValueError (f"No such table { table } in { self ._impdb .name } " )
@@ -238,7 +244,12 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches,too-m
238244
239245 if reqconf .pagination is None or _count is None :
240246 df = await self ._fetch (
241- itable , kwargs , _client = client , _throttler = throttler , _auth = _auth ,
247+ itable ,
248+ kwargs ,
249+ _client = client ,
250+ _throttler = throttler ,
251+ _auth = _auth ,
252+ _q = _q ,
242253 )
243254 return df
244255
@@ -263,6 +274,7 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches,too-m
263274 _throttler = throttler ,
264275 _page = i ,
265276 _auth = _auth ,
277+ _q = _q ,
266278 _limit = count ,
267279 _anchor = last_id - 1 ,
268280 )
@@ -274,7 +286,7 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches,too-m
274286 # The API returns empty for this page, maybe we've reached the end
275287 break
276288
277- cid = df .columns .get_loc (pagdef .seek_id )
289+ cid = df .columns .get_loc (pagdef .seek_id ) # type: ignore
278290 last_id = int (df .iloc [- 1 , cid ]) - 1 # type: ignore
279291
280292 dfs .append (df )
@@ -291,6 +303,7 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches,too-m
291303 _throttler = throttler ,
292304 _page = i ,
293305 _auth = _auth ,
306+ _q = _q ,
294307 _limit = count ,
295308 _anchor = next_token ,
296309 _raw = True ,
@@ -326,6 +339,7 @@ async def _query_imp( # pylint: disable=too-many-locals,too-many-branches,too-m
326339 _page = i ,
327340 _allowed_page = allowed_page ,
328341 _auth = _auth ,
342+ _q = _q ,
329343 _limit = count ,
330344 _anchor = anchor ,
331345 )
@@ -355,6 +369,7 @@ async def _fetch( # pylint: disable=too-many-locals,too-many-branches,too-many-
355369 _limit : Optional [int ] = None ,
356370 _anchor : Optional [Any ] = None ,
357371 _auth : Optional [Dict [str , Any ]] = None ,
372+ _q : Optional [str ] = None ,
358373 _raw : bool = False ,
359374 ) -> Union [Optional [pd .DataFrame ], Tuple [Optional [pd .DataFrame ], ClientResponse ]]:
360375
@@ -371,12 +386,6 @@ async def _fetch( # pylint: disable=too-many-locals,too-many-branches,too-many-
371386 if reqdef .authorization is not None :
372387 reqdef .authorization .build (req_data , _auth or self ._auth , self ._storage )
373388
374- for key in ["headers" , "params" , "cookies" ]:
375- field_def = getattr (reqdef , key , None )
376- if field_def is not None :
377- instantiated_fields = populate_field (field_def , self ._jenv , merged_vars )
378- req_data [key ].update (** instantiated_fields )
379-
380389 if reqdef .body is not None :
381390 # TODO: do we support binary body?
382391 instantiated_fields = populate_field (
@@ -414,6 +423,39 @@ async def _fetch( # pylint: disable=too-many-locals,too-many-branches,too-many-
414423 if _anchor is not None :
415424 req_data ["params" ][anchor ] = _anchor
416425
426+ if _q is not None :
427+ if reqdef .search is None :
428+ raise ValueError (
429+ "_q specified but the API does not support custom search."
430+ )
431+
432+ searchdef = reqdef .search
433+ search_key = searchdef .key
434+
435+ if search_key in req_data ["params" ]:
436+ raise UniversalParameterOverridden (search_key , "_q" )
437+ req_data ["params" ][search_key ] = _q
438+
439+ for key in ["headers" , "params" , "cookies" ]:
440+ field_def = getattr (reqdef , key , None )
441+ if field_def is not None :
442+ instantiated_fields = populate_field (
443+ field_def , self ._jenv , merged_vars ,
444+ )
445+ for ikey in instantiated_fields :
446+ if ikey in req_data [key ]:
447+ warn (
448+ f"Query parameter { ikey } ={ req_data [key ][ikey ]} "
449+ " is overriden by {ikey}={instantiated_fields[ikey]}" ,
450+ RuntimeWarning ,
451+ )
452+ req_data [key ].update (** instantiated_fields )
453+
454+ for key in ["headers" , "params" , "cookies" ]:
455+ field_def = getattr (reqdef , key , None )
456+ if field_def is not None :
457+ validate_fields (field_def , req_data [key ])
458+
417459 await _throttler .acquire (_page )
418460
419461 if _allowed_page is not None and int (_allowed_page ) <= _page :
@@ -445,21 +487,37 @@ async def _fetch( # pylint: disable=too-many-locals,too-many-branches,too-many-
445487 return df
446488
447489
490+ def validate_fields (fields : Dict [str , FieldDefUnion ], data : Dict [str , Any ]) -> None :
491+ """Check required fields are provided."""
492+
493+ for key , def_ in fields .items ():
494+ from_key , to_key = key , key
495+
496+ if isinstance (def_ , bool ):
497+ required = def_
498+ if required and to_key not in data :
499+ raise KeyError (f"'{ from_key } ' is required but not provided" )
500+ elif isinstance (def_ , str ):
501+ pass
502+ else :
503+ to_key = def_ .to_key or to_key
504+ from_key = def_ .from_key or from_key
505+ required = def_ .required
506+ if required and to_key not in data :
507+ raise KeyError (f"'{ from_key } ' is required but not provided" )
508+
509+
448510def populate_field ( # pylint: disable=too-many-branches
449- fields : Dict [str , FieldDefUnion ], jenv : Environment , params : Dict [str , Any ]
511+ fields : Dict [str , FieldDefUnion ], jenv : Environment , params : Dict [str , Any ],
450512) -> Dict [str , str ]:
451513 """Populate a dict based on the fields definition and provided vars."""
452-
453514 ret : Dict [str , str ] = {}
454515
455516 for key , def_ in fields .items ():
456517 from_key , to_key = key , key
457518
458519 if isinstance (def_ , bool ):
459- required = def_
460520 value = params .get (from_key )
461- if value is None and required :
462- raise KeyError (from_key )
463521 remove_if_empty = False
464522 elif isinstance (def_ , str ):
465523 # is a template
@@ -473,10 +531,7 @@ def populate_field( # pylint: disable=too-many-branches
473531 from_key = def_ .from_key or from_key
474532
475533 if template is None :
476- required = def_ .required
477534 value = params .get (from_key )
478- if value is None and required :
479- raise KeyError (from_key )
480535 else :
481536 tmplt = jenv .from_string (template )
482537 try :
@@ -486,9 +541,12 @@ def populate_field( # pylint: disable=too-many-branches
486541
487542 if value is not None :
488543 str_value = str (value )
489- if not ( remove_if_empty and not str_value ) :
544+ if not remove_if_empty or str_value :
490545 if to_key in ret :
491- print (f"Param { key } conflicting with { to_key } " , file = sys .stderr )
546+ warn (
547+ f"{ to_key } ={ ret [to_key ]} overriden by { to_key } ={ str_value } " ,
548+ RuntimeWarning ,
549+ )
492550 ret [to_key ] = str_value
493551 continue
494552 return ret
0 commit comments