Skip to content

Commit dae4c41

Browse files
mdesmethashhar
authored andcommitted
Fix None support in structural types
1 parent aca68a0 commit dae4c41

File tree

2 files changed

+39
-17
lines changed

2 files changed

+39
-17
lines changed

tests/integration/test_types_integration.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,20 +156,25 @@ def test_array(trino_connection):
156156
SqlTest(trino_connection) \
157157
.add_field(sql="CAST(null AS ARRAY(VARCHAR))", python=None) \
158158
.add_field(sql="ARRAY['a', 'b', null]", python=['a', 'b', None]) \
159+
.add_field(sql="ARRAY[1.2, 2.4, null]", python=[Decimal("1.2"), Decimal("2.4"), None]) \
160+
.add_field(sql="ARRAY[CAST(4.9E-324 AS DOUBLE), null]", python=[5e-324, None]) \
159161
.execute()
160162

161163

162164
def test_map(trino_connection):
163165
SqlTest(trino_connection) \
164166
.add_field(sql="CAST(null AS MAP(VARCHAR, INTEGER))", python=None) \
165167
.add_field(sql="MAP(ARRAY['a', 'b'], ARRAY[1, null])", python={'a': 1, 'b': None}) \
168+
.add_field(sql="MAP(ARRAY['a', 'b'], ARRAY[2.4, null])", python={'a': Decimal("2.4"), 'b': None}) \
169+
.add_field(sql="MAP(ARRAY[2.4, 4.8], ARRAY[CAST(4.9E-324 AS DOUBLE), null])",
170+
python={Decimal("2.4"): 5e-324, Decimal("4.8"): None}) \
166171
.execute()
167172

168173

169174
def test_row(trino_connection):
170175
SqlTest(trino_connection) \
171176
.add_field(sql="CAST(null AS ROW(x BIGINT, y DOUBLE))", python=None) \
172-
.add_field(sql="CAST(ROW(1, 2e0) AS ROW(x BIGINT, y DOUBLE))", python=(1, 2.0)) \
177+
.add_field(sql="CAST(ROW(1, 2e0, null) AS ROW(x BIGINT, y DOUBLE, z DOUBLE))", python=(1, 2.0, None)) \
173178
.execute()
174179

175180

trino/client.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ def decorated(*args, **kwargs):
840840

841841
class ValueMapper(abc.ABC, Generic[T]):
842842
@abc.abstractmethod
843-
def map(self, value: Any) -> T:
843+
def map(self, value: Any) -> Optional[T]:
844844
pass
845845

846846

@@ -850,12 +850,16 @@ def map(self, value) -> Optional[Any]:
850850

851851

852852
class DecimalValueMapper(ValueMapper[Decimal]):
853-
def map(self, value) -> Decimal:
853+
def map(self, value) -> Optional[Decimal]:
854+
if value is None:
855+
return None
854856
return Decimal(value)
855857

856858

857859
class DoubleValueMapper(ValueMapper[float]):
858-
def map(self, value) -> float:
860+
def map(self, value) -> Optional[float]:
861+
if value is None:
862+
return None
859863
if value == 'Infinity':
860864
return float("inf")
861865
if value == '-Infinity':
@@ -886,14 +890,18 @@ def __init__(self, column):
886890
self.pattern = pattern
887891
self.time_size = 9 + ms_size - ms_to_trim
888892

889-
def map(self, value) -> time:
893+
def map(self, value) -> Optional[time]:
894+
if value is None:
895+
return None
890896
return datetime.strptime(value[:self.time_size], self.pattern).time()
891897

892898

893899
class TimeWithTimeZoneValueMapper(TimeValueMapper):
894900
PATTERN = r'^(.*)([\+\-])(\d{2}):(\d{2})$'
895901

896-
def map(self, value) -> time:
902+
def map(self, value) -> Optional[time]:
903+
if value is None:
904+
return None
897905
matches = re.match(TimeWithTimeZoneValueMapper.PATTERN, value)
898906
assert matches is not None
899907
assert len(matches.groups()) == 4
@@ -905,7 +913,9 @@ def map(self, value) -> time:
905913

906914

907915
class DateValueMapper(ValueMapper[date]):
908-
def map(self, value) -> date:
916+
def map(self, value) -> Optional[date]:
917+
if value is None:
918+
return None
909919
return datetime.strptime(value, '%Y-%m-%d').date()
910920

911921

@@ -920,41 +930,51 @@ def __init__(self, column):
920930
self.dt_size = datetime_default_size + ms_size - ms_to_trim
921931
self.dt_tz_offset = datetime_default_size + ms_size
922932

923-
def map(self, value) -> datetime:
933+
def map(self, value) -> Optional[datetime]:
934+
if value is None:
935+
return None
924936
return datetime.strptime(value[:self.dt_size] + value[self.dt_tz_offset:], self.pattern)
925937

926938

927939
class TimestampWithTimeZoneValueMapper(TimestampValueMapper):
928-
def map(self, value) -> datetime:
940+
def map(self, value) -> Optional[datetime]:
941+
if value is None:
942+
return None
929943
dt, tz = value.rsplit(' ', 1)
930944
if tz.startswith('+') or tz.startswith('-'):
931945
return datetime.strptime(value[:self.dt_size] + value[self.dt_tz_offset:], self.pattern + ' %z')
932946
date_str = dt[:self.dt_size] + dt[self.dt_tz_offset:]
933947
return datetime.strptime(date_str, self.pattern).replace(tzinfo=pytz.timezone(tz))
934948

935949

936-
class ArrayValueMapper(ValueMapper[List[Any]]):
950+
class ArrayValueMapper(ValueMapper[List[Optional[Any]]]):
937951
def __init__(self, mapper: ValueMapper[Any]):
938952
self.mapper = mapper
939953

940-
def map(self, values: List[Any]) -> List[Any]:
954+
def map(self, values: List[Any]) -> Optional[List[Any]]:
955+
if values is None:
956+
return None
941957
return [self.mapper.map(value) for value in values]
942958

943959

944960
class RowValueMapper(ValueMapper[Tuple[Optional[Any], ...]]):
945961
def __init__(self, mappers: List[ValueMapper[Any]]):
946962
self.mappers = mappers
947963

948-
def map(self, values: List[Any]) -> Tuple[Any, ...]:
964+
def map(self, values: List[Any]) -> Optional[Tuple[Optional[Any], ...]]:
965+
if values is None:
966+
return None
949967
return tuple(self.mappers[index].map(value) for index, value in enumerate(values))
950968

951969

952-
class MapValueMapper(ValueMapper[Dict[Any, Any]]):
970+
class MapValueMapper(ValueMapper[Dict[Any, Optional[Any]]]):
953971
def __init__(self, key_mapper: ValueMapper[Any], value_mapper: ValueMapper[Any]):
954972
self.key_mapper = key_mapper
955973
self.value_mapper = value_mapper
956974

957-
def map(self, values: Any) -> Dict[Any, Any]:
975+
def map(self, values: Any) -> Optional[Dict[Any, Optional[Any]]]:
976+
if values is None:
977+
return None
958978
return {
959979
self.key_mapper.map(key): self.value_mapper.map(value) for key, value in values.items()
960980
}
@@ -1032,9 +1052,6 @@ def _map_row(self, row):
10321052
return [self._map_value(value, self.columns[index]) for index, value in enumerate(row)]
10331053

10341054
def _map_value(self, value, value_mapper: ValueMapper[T]) -> Optional[T]:
1035-
if value is None:
1036-
return None
1037-
10381055
try:
10391056
return value_mapper.map(value)
10401057
except ValueError as e:

0 commit comments

Comments
 (0)