Skip to content

Commit ca57138

Browse files
mdesmethashhar
authored andcommitted
Support Trino's ROW datatype
1 parent 45cac72 commit ca57138

File tree

3 files changed

+50
-17
lines changed

3 files changed

+50
-17
lines changed

tests/integration/test_dbapi_integration.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,32 @@ def test_boolean_query_param(trino_connection):
616616
assert rows[0][0] is False
617617

618618

619+
def test_row(trino_connection):
620+
cur = trino_connection.cursor(experimental_python_types=True)
621+
params = (1, Decimal("2.0"), datetime(2020, 1, 1, 0, 0, 0))
622+
cur.execute("SELECT ?", (params,))
623+
rows = cur.fetchall()
624+
625+
assert rows[0][0] == params
626+
627+
628+
def test_nested_row(trino_connection):
629+
cur = trino_connection.cursor(experimental_python_types=True)
630+
params = ((1, "test", Decimal("3.1")), Decimal("2.0"), datetime(2020, 1, 1, 0, 0, 0))
631+
cur.execute("SELECT ?", (params,))
632+
rows = cur.fetchall()
633+
634+
assert rows[0][0] == params
635+
636+
637+
def test_named_row(trino_connection):
638+
cur = trino_connection.cursor(experimental_python_types=True)
639+
cur.execute("SELECT CAST(ROW(1, 2e0) AS ROW(x BIGINT, y DOUBLE))")
640+
rows = cur.fetchall()
641+
642+
assert rows[0][0] == (1, 2.0)
643+
644+
619645
def test_float_query_param(trino_connection):
620646
cur = trino_connection.cursor()
621647
cur.execute("SELECT ?", params=(1.1,))

trino/client.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -530,25 +530,30 @@ def _map_to_python_type(cls, item: Tuple[Any, Dict]) -> Any:
530530

531531
try:
532532
if isinstance(value, list):
533-
raw_type = {
534-
"typeSignature": data_type["typeSignature"]["arguments"][0]["value"]
535-
}
536-
return [cls._map_to_python_type((array_item, raw_type)) for array_item in value]
537-
if isinstance(value, dict):
538-
if raw_type == "map":
539-
raw_key_type = {
533+
if raw_type == "array":
534+
raw_type = {
540535
"typeSignature": data_type["typeSignature"]["arguments"][0]["value"]
541536
}
542-
raw_value_type = {
543-
"typeSignature": data_type["typeSignature"]["arguments"][1]["value"]
544-
}
545-
return {
546-
cls._map_to_python_type((key, raw_key_type)):
547-
cls._map_to_python_type((value[key], raw_value_type))
548-
for key in value
549-
}
550-
# TODO: support row type
537+
return [cls._map_to_python_type((array_item, raw_type)) for array_item in value]
538+
if raw_type == "row":
539+
raw_types = map(lambda arg: arg["value"], data_type["typeSignature"]["arguments"])
540+
return tuple(
541+
cls._map_to_python_type((array_item, raw_type))
542+
for (array_item, raw_type) in zip(value, raw_types)
543+
)
551544
return value
545+
if isinstance(value, dict):
546+
raw_key_type = {
547+
"typeSignature": data_type["typeSignature"]["arguments"][0]["value"]
548+
}
549+
raw_value_type = {
550+
"typeSignature": data_type["typeSignature"]["arguments"][1]["value"]
551+
}
552+
return {
553+
cls._map_to_python_type((key, raw_key_type)):
554+
cls._map_to_python_type((value[key], raw_value_type))
555+
for key in value
556+
}
552557
elif "decimal" in raw_type:
553558
return Decimal(value)
554559
elif raw_type == "double":

trino/dbapi.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,8 +394,10 @@ def _format_prepared_param(self, param):
394394
if isinstance(param, list):
395395
return "ARRAY[%s]" % ','.join(map(self._format_prepared_param, param))
396396

397+
if isinstance(param, tuple):
398+
return "ROW(%s)" % ','.join(map(self._format_prepared_param, param))
399+
397400
if isinstance(param, dict):
398-
# TODO: support mixed types in dicts and convert to ROW
399401
keys = list(param.keys())
400402
values = [param[key] for key in keys]
401403
return "MAP({}, {})".format(

0 commit comments

Comments
 (0)