Skip to content

Commit aca68a0

Browse files
mdesmethashhar
authored andcommitted
Refactor value mappers to separate classes
1 parent a975e5d commit aca68a0

File tree

1 file changed

+150
-103
lines changed

1 file changed

+150
-103
lines changed

trino/client.py

Lines changed: 150 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,18 @@
3232
>> query = TrinoQuery(request, sql)
3333
>> rows = list(query.execute())
3434
"""
35-
35+
import abc
3636
import copy
3737
import functools
3838
import os
3939
import random
4040
import re
4141
import threading
4242
import urllib.parse
43-
from datetime import datetime, timedelta, timezone
43+
from datetime import date, datetime, time, timedelta, timezone
4444
from decimal import Decimal
4545
from time import sleep
46-
from typing import Any, Dict, List, Optional, Tuple, Union
46+
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar, Union
4747

4848
import pytz
4949
import requests
@@ -64,9 +64,7 @@
6464

6565
_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re.compile(r'^\S[^\s=]*$')
6666

67-
INF = float("inf")
68-
NEGATIVE_INF = float("-inf")
69-
NAN = float("nan")
67+
T = TypeVar("T")
7068

7169

7270
class ClientSession(object):
@@ -840,6 +838,128 @@ def decorated(*args, **kwargs):
840838
return wrapper
841839

842840

841+
class ValueMapper(abc.ABC, Generic[T]):
842+
@abc.abstractmethod
843+
def map(self, value: Any) -> T:
844+
pass
845+
846+
847+
class NoOpValueMapper(ValueMapper[Any]):
848+
def map(self, value) -> Optional[Any]:
849+
return value
850+
851+
852+
class DecimalValueMapper(ValueMapper[Decimal]):
853+
def map(self, value) -> Decimal:
854+
return Decimal(value)
855+
856+
857+
class DoubleValueMapper(ValueMapper[float]):
858+
def map(self, value) -> float:
859+
if value == 'Infinity':
860+
return float("inf")
861+
if value == '-Infinity':
862+
return float("-inf")
863+
if value == 'NaN':
864+
return float("nan")
865+
return float(value)
866+
867+
868+
class TemporalValueMapper():
869+
def _get_number_of_digits(self, column):
870+
args = column['arguments']
871+
if len(args) == 0:
872+
return 3, 0
873+
ms_size = args[0]['value']
874+
if ms_size == 0:
875+
return -1, 0
876+
ms_to_trim = ms_size - min(ms_size, 6)
877+
return ms_size, ms_to_trim
878+
879+
880+
class TimeValueMapper(ValueMapper[time], TemporalValueMapper):
881+
def __init__(self, column):
882+
pattern = "%H:%M:%S"
883+
ms_size, ms_to_trim = self._get_number_of_digits(column)
884+
if ms_size > 0:
885+
pattern += ".%f"
886+
self.pattern = pattern
887+
self.time_size = 9 + ms_size - ms_to_trim
888+
889+
def map(self, value) -> time:
890+
return datetime.strptime(value[:self.time_size], self.pattern).time()
891+
892+
893+
class TimeWithTimeZoneValueMapper(TimeValueMapper):
894+
PATTERN = r'^(.*)([\+\-])(\d{2}):(\d{2})$'
895+
896+
def map(self, value) -> time:
897+
matches = re.match(TimeWithTimeZoneValueMapper.PATTERN, value)
898+
assert matches is not None
899+
assert len(matches.groups()) == 4
900+
if matches.group(2) == '-':
901+
tz = -timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4)))
902+
else:
903+
tz = timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4)))
904+
return datetime.strptime(matches.group(1)[:self.time_size], self.pattern).time().replace(tzinfo=timezone(tz))
905+
906+
907+
class DateValueMapper(ValueMapper[date]):
908+
def map(self, value) -> date:
909+
return datetime.strptime(value, '%Y-%m-%d').date()
910+
911+
912+
class TimestampValueMapper(ValueMapper[datetime], TemporalValueMapper):
913+
def __init__(self, column):
914+
datetime_default_size = 20 # size of 'YYYY-MM-DD HH:MM:SS.' (the datetime string up to the milliseconds)
915+
pattern = "%Y-%m-%d %H:%M:%S"
916+
ms_size, ms_to_trim = self._get_number_of_digits(column)
917+
if ms_size > 0:
918+
pattern += ".%f"
919+
self.pattern = pattern
920+
self.dt_size = datetime_default_size + ms_size - ms_to_trim
921+
self.dt_tz_offset = datetime_default_size + ms_size
922+
923+
def map(self, value) -> datetime:
924+
return datetime.strptime(value[:self.dt_size] + value[self.dt_tz_offset:], self.pattern)
925+
926+
927+
class TimestampWithTimeZoneValueMapper(TimestampValueMapper):
928+
def map(self, value) -> datetime:
929+
dt, tz = value.rsplit(' ', 1)
930+
if tz.startswith('+') or tz.startswith('-'):
931+
return datetime.strptime(value[:self.dt_size] + value[self.dt_tz_offset:], self.pattern + ' %z')
932+
date_str = dt[:self.dt_size] + dt[self.dt_tz_offset:]
933+
return datetime.strptime(date_str, self.pattern).replace(tzinfo=pytz.timezone(tz))
934+
935+
936+
class ArrayValueMapper(ValueMapper[List[Any]]):
937+
def __init__(self, mapper: ValueMapper[Any]):
938+
self.mapper = mapper
939+
940+
def map(self, values: List[Any]) -> List[Any]:
941+
return [self.mapper.map(value) for value in values]
942+
943+
944+
class RowValueMapper(ValueMapper[Tuple[Optional[Any], ...]]):
945+
def __init__(self, mappers: List[ValueMapper[Any]]):
946+
self.mappers = mappers
947+
948+
def map(self, values: List[Any]) -> Tuple[Any, ...]:
949+
return tuple(self.mappers[index].map(value) for index, value in enumerate(values))
950+
951+
952+
class MapValueMapper(ValueMapper[Dict[Any, Any]]):
953+
def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]):
954+
self.key_mapper = key_mapper
955+
self.value_mapper = value_mapper
956+
957+
def map(self, values: Any) -> Dict[Any, Any]:
958+
return {
959+
self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items()
960+
}
961+
962+
843963
class NoOpRowMapper:
844964
"""
845965
No-op RowMapper which does not perform any transformation
@@ -856,117 +976,44 @@ class RowMapperFactory:
856976
lambda functions (one for each column) which will process a data value
857977
and returns a RowMapper instance which will process rows of data
858978
"""
859-
no_op_row_mapper = NoOpRowMapper()
979+
NO_OP_ROW_MAPPER = NoOpRowMapper()
860980

861981
def create(self, columns, experimental_python_types):
862982
assert columns is not None
863983

864984
if experimental_python_types:
865-
return RowMapper([self._col_func(column['typeSignature']) for column in columns])
866-
return RowMapperFactory.no_op_row_mapper
985+
return RowMapper([self._create_value_mapper(column['typeSignature']) for column in columns])
986+
return RowMapperFactory.NO_OP_ROW_MAPPER
867987

868-
def _col_func(self, column):
988+
def _create_value_mapper(self, column) -> ValueMapper:
869989
col_type = column['rawType']
870990

871991
if col_type == 'array':
872-
return self._array_map_func(column)
992+
value_mapper = self._create_value_mapper(column['arguments'][0]['value'])
993+
return ArrayValueMapper(value_mapper)
873994
elif col_type == 'row':
874-
return self._row_map_func(column)
995+
mappers = [self._create_value_mapper(arg['value']['typeSignature']) for arg in column['arguments']]
996+
return RowValueMapper(mappers)
875997
elif col_type == 'map':
876-
return self._map_map_func(column)
998+
key_mapper = self._create_value_mapper(column['arguments'][0]['value'])
999+
value_mapper = self._create_value_mapper(column['arguments'][1]['value'])
1000+
return MapValueMapper(key_mapper, value_mapper)
8771001
elif col_type.startswith('decimal'):
878-
return lambda val: Decimal(val)
1002+
return DecimalValueMapper()
8791003
elif col_type.startswith('double') or col_type.startswith('real'):
880-
return self._double_map_func()
1004+
return DoubleValueMapper()
1005+
elif col_type.startswith('timestamp') and 'with time zone' in col_type:
1006+
return TimestampWithTimeZoneValueMapper(column)
8811007
elif col_type.startswith('timestamp'):
882-
return self._timestamp_map_func(column, col_type)
1008+
return TimestampValueMapper(column)
1009+
elif col_type.startswith('time') and 'with time zone' in col_type:
1010+
return TimeWithTimeZoneValueMapper(column)
8831011
elif col_type.startswith('time'):
884-
return self._time_map_func(column, col_type)
1012+
return TimeValueMapper(column)
8851013
elif col_type == 'date':
886-
return lambda val: datetime.strptime(val, '%Y-%m-%d').date()
887-
else:
888-
return lambda val: val
889-
890-
def _array_map_func(self, column):
891-
element_mapping_func = self._col_func(column['arguments'][0]['value'])
892-
return lambda values: [element_mapping_func(value) for value in values]
893-
894-
def _row_map_func(self, column):
895-
element_mapping_func = [self._col_func(arg['value']['typeSignature']) for arg in column['arguments']]
896-
return lambda values: tuple(element_mapping_func[idx](value) for idx, value in enumerate(values))
897-
898-
def _map_map_func(self, column):
899-
key_mapping_func = self._col_func(column['arguments'][0]['value'])
900-
value_mapping_func = self._col_func(column['arguments'][1]['value'])
901-
return lambda values: {key_mapping_func(key): value_mapping_func(value) for key, value in values.items()}
902-
903-
def _double_map_func(self):
904-
return lambda val: INF if val == 'Infinity' \
905-
else NEGATIVE_INF if val == '-Infinity' \
906-
else NAN if val == 'NaN' \
907-
else float(val)
908-
909-
def _timestamp_map_func(self, column, col_type):
910-
datetime_default_size = 20 # size of 'YYYY-MM-DD HH:MM:SS.' (the datetime string up to the milliseconds)
911-
pattern = "%Y-%m-%d %H:%M:%S"
912-
ms_size, ms_to_trim = self._get_number_of_digits(column)
913-
if ms_size > 0:
914-
pattern += ".%f"
915-
916-
dt_size = datetime_default_size + ms_size - ms_to_trim
917-
dt_tz_offset = datetime_default_size + ms_size
918-
if 'with time zone' in col_type:
919-
920-
if ms_to_trim > 0:
921-
return lambda val: \
922-
[datetime.strptime(val[:dt_size] + val[dt_tz_offset:], pattern + ' %z')
923-
if tz.startswith('+') or tz.startswith('-')
924-
else datetime.strptime(dt[:dt_size] + dt[dt_tz_offset:], pattern)
925-
.replace(tzinfo=pytz.timezone(tz))
926-
for dt, tz in [val.rsplit(' ', 1)]][0]
927-
else:
928-
return lambda val: [datetime.strptime(val, pattern + ' %z')
929-
if tz.startswith('+') or tz.startswith('-')
930-
else datetime.strptime(dt, pattern).replace(tzinfo=pytz.timezone(tz))
931-
for dt, tz in [val.rsplit(' ', 1)]][0]
932-
933-
if ms_to_trim > 0:
934-
return lambda val: datetime.strptime(val[:dt_size] + val[dt_tz_offset:], pattern)
1014+
return DateValueMapper()
9351015
else:
936-
return lambda val: datetime.strptime(val, pattern)
937-
938-
def _time_map_func(self, column, col_type):
939-
pattern = "%H:%M:%S"
940-
ms_size, ms_to_trim = self._get_number_of_digits(column)
941-
if ms_size > 0:
942-
pattern += ".%f"
943-
944-
time_size = 9 + ms_size - ms_to_trim
945-
946-
if 'with time zone' in col_type:
947-
return lambda val: self._get_time_with_timezome(val, time_size, pattern)
948-
else:
949-
return lambda val: datetime.strptime(val[:time_size], pattern).time()
950-
951-
def _get_time_with_timezome(self, value, time_size, pattern):
952-
matches = re.match(r'^(.*)([\+\-])(\d{2}):(\d{2})$', value)
953-
assert matches is not None
954-
assert len(matches.groups()) == 4
955-
if matches.group(2) == '-':
956-
tz = -timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4)))
957-
else:
958-
tz = timedelta(hours=int(matches.group(3)), minutes=int(matches.group(4)))
959-
return datetime.strptime(matches.group(1)[:time_size], pattern).time().replace(tzinfo=timezone(tz))
960-
961-
def _get_number_of_digits(self, column):
962-
args = column['arguments']
963-
if len(args) == 0:
964-
return 3, 0
965-
ms_size = column['arguments'][0]['value']
966-
if ms_size == 0:
967-
return -1, 0
968-
ms_to_trim = ms_size - min(ms_size, 6)
969-
return ms_size, ms_to_trim
1016+
return NoOpValueMapper()
9701017

9711018

9721019
class RowMapper:
@@ -982,14 +1029,14 @@ def map(self, rows):
9821029
return [self._map_row(row) for row in rows]
9831030

9841031
def _map_row(self, row):
985-
return [self._map_value(value, self.columns[idx]) for idx, value in enumerate(row)]
1032+
return [self._map_value(value, self.columns[index]) for index, value in enumerate(row)]
9861033

987-
def _map_value(self, value, col_mapping_func):
1034+
def _map_value(self, value, value_mapper: ValueMapper[T]) -> Optional[T]:
9881035
if value is None:
9891036
return None
9901037

9911038
try:
992-
return col_mapping_func(value)
1039+
return value_mapper.map(value)
9931040
except ValueError as e:
9941041
error_str = f"Could not convert '{value}' into the associated python type"
9951042
raise trino.exceptions.TrinoDataError(error_str) from e

0 commit comments

Comments
 (0)