Skip to content

Commit ee8f1a5

Browse files
mdesmetebyhr
authored andcommitted
Support Trino's map data type
1 parent 81184b5 commit ee8f1a5

File tree

3 files changed

+26
-0
lines changed

3 files changed

+26
-0
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,16 @@ def test_dict_query_param(trino_connection):
591591
assert rows[0][0] == "map(varchar(3), varchar(3))"
592592

593593

594+
def test_dict_timestamp_query_param_types(trino_connection):
595+
cur = trino_connection.cursor(experimental_python_types=True)
596+
597+
params = {"foo": datetime(2020, 1, 1, 16, 43, 22, 320000)}
598+
cur.execute("SELECT ?", params=(params,))
599+
rows = cur.fetchall()
600+
601+
assert rows[0][0] == params
602+
603+
594604
def test_boolean_query_param(trino_connection):
595605
cur = trino_connection.cursor()
596606

trino/client.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,21 @@ def _map_to_python_type(cls, item: Tuple[Any, Dict]) -> Any:
528528
"typeSignature": data_type["typeSignature"]["arguments"][0]["value"]
529529
}
530530
return [cls._map_to_python_type((array_item, raw_type)) for array_item in value]
531+
if isinstance(value, dict):
532+
if raw_type == "map":
533+
raw_key_type = {
534+
"typeSignature": data_type["typeSignature"]["arguments"][0]["value"]
535+
}
536+
raw_value_type = {
537+
"typeSignature": data_type["typeSignature"]["arguments"][1]["value"]
538+
}
539+
return {
540+
cls._map_to_python_type((key, raw_key_type)):
541+
cls._map_to_python_type((value[key], raw_value_type))
542+
for key in value
543+
}
544+
# TODO: support row type
545+
return value
531546
elif "decimal" in raw_type:
532547
return Decimal(value)
533548
elif raw_type == "double":

trino/dbapi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ def _format_prepared_param(self, param):
392392
return "ARRAY[%s]" % ','.join(map(self._format_prepared_param, param))
393393

394394
if isinstance(param, dict):
395+
# TODO: support mixed types in dicts and convert to ROW
395396
keys = list(param.keys())
396397
values = [param[key] for key in keys]
397398
return "MAP({}, {})".format(

0 commit comments

Comments
 (0)