32
32
>> query = TrinoQuery(request, sql)
33
33
>> rows = list(query.execute())
34
34
"""
35
-
35
+ import abc
36
36
import copy
37
37
import functools
38
38
import os
39
39
import random
40
40
import re
41
41
import threading
42
42
import urllib .parse
43
- from datetime import datetime , timedelta , timezone
43
+ from datetime import date , datetime , time , timedelta , timezone
44
44
from decimal import Decimal
45
45
from time import sleep
46
- from typing import Any , Dict , List , Optional , Tuple , Union
46
+ from typing import Any , Dict , Generic , List , Optional , Tuple , TypeVar , Union
47
47
48
48
import pytz
49
49
import requests
64
64
65
65
_HEADER_EXTRA_CREDENTIAL_KEY_REGEX = re .compile (r'^\S[^\s=]*$' )
66
66
67
- INF = float ("inf" )
68
- NEGATIVE_INF = float ("-inf" )
69
- NAN = float ("nan" )
67
+ T = TypeVar ("T" )
70
68
71
69
72
70
class ClientSession (object ):
@@ -840,6 +838,128 @@ def decorated(*args, **kwargs):
840
838
return wrapper
841
839
842
840
841
+ class ValueMapper (abc .ABC , Generic [T ]):
842
+ @abc .abstractmethod
843
+ def map (self , value : Any ) -> T :
844
+ pass
845
+
846
+
847
+ class NoOpValueMapper (ValueMapper [Any ]):
848
+ def map (self , value ) -> Optional [Any ]:
849
+ return value
850
+
851
+
852
+ class DecimalValueMapper (ValueMapper [Decimal ]):
853
+ def map (self , value ) -> Decimal :
854
+ return Decimal (value )
855
+
856
+
857
+ class DoubleValueMapper (ValueMapper [float ]):
858
+ def map (self , value ) -> float :
859
+ if value == 'Infinity' :
860
+ return float ("inf" )
861
+ if value == '-Infinity' :
862
+ return float ("-inf" )
863
+ if value == 'NaN' :
864
+ return float ("nan" )
865
+ return float (value )
866
+
867
+
868
+ class TemporalValueMapper ():
869
+ def _get_number_of_digits (self , column ):
870
+ args = column ['arguments' ]
871
+ if len (args ) == 0 :
872
+ return 3 , 0
873
+ ms_size = args [0 ]['value' ]
874
+ if ms_size == 0 :
875
+ return - 1 , 0
876
+ ms_to_trim = ms_size - min (ms_size , 6 )
877
+ return ms_size , ms_to_trim
878
+
879
+
880
+ class TimeValueMapper (ValueMapper [time ], TemporalValueMapper ):
881
+ def __init__ (self , column ):
882
+ pattern = "%H:%M:%S"
883
+ ms_size , ms_to_trim = self ._get_number_of_digits (column )
884
+ if ms_size > 0 :
885
+ pattern += ".%f"
886
+ self .pattern = pattern
887
+ self .time_size = 9 + ms_size - ms_to_trim
888
+
889
+ def map (self , value ) -> time :
890
+ return datetime .strptime (value [:self .time_size ], self .pattern ).time ()
891
+
892
+
893
+ class TimeWithTimeZoneValueMapper (TimeValueMapper ):
894
+ PATTERN = r'^(.*)([\+\-])(\d{2}):(\d{2})$'
895
+
896
+ def map (self , value ) -> time :
897
+ matches = re .match (TimeWithTimeZoneValueMapper .PATTERN , value )
898
+ assert matches is not None
899
+ assert len (matches .groups ()) == 4
900
+ if matches .group (2 ) == '-' :
901
+ tz = - timedelta (hours = int (matches .group (3 )), minutes = int (matches .group (4 )))
902
+ else :
903
+ tz = timedelta (hours = int (matches .group (3 )), minutes = int (matches .group (4 )))
904
+ return datetime .strptime (matches .group (1 )[:self .time_size ], self .pattern ).time ().replace (tzinfo = timezone (tz ))
905
+
906
+
907
+ class DateValueMapper (ValueMapper [date ]):
908
+ def map (self , value ) -> date :
909
+ return datetime .strptime (value , '%Y-%m-%d' ).date ()
910
+
911
+
912
+ class TimestampValueMapper (ValueMapper [datetime ], TemporalValueMapper ):
913
+ def __init__ (self , column ):
914
+ datetime_default_size = 20 # size of 'YYYY-MM-DD HH:MM:SS.' (the datetime string up to the milliseconds)
915
+ pattern = "%Y-%m-%d %H:%M:%S"
916
+ ms_size , ms_to_trim = self ._get_number_of_digits (column )
917
+ if ms_size > 0 :
918
+ pattern += ".%f"
919
+ self .pattern = pattern
920
+ self .dt_size = datetime_default_size + ms_size - ms_to_trim
921
+ self .dt_tz_offset = datetime_default_size + ms_size
922
+
923
+ def map (self , value ) -> datetime :
924
+ return datetime .strptime (value [:self .dt_size ] + value [self .dt_tz_offset :], self .pattern )
925
+
926
+
927
+ class TimestampWithTimeZoneValueMapper (TimestampValueMapper ):
928
+ def map (self , value ) -> datetime :
929
+ dt , tz = value .rsplit (' ' , 1 )
930
+ if tz .startswith ('+' ) or tz .startswith ('-' ):
931
+ return datetime .strptime (value [:self .dt_size ] + value [self .dt_tz_offset :], self .pattern + ' %z' )
932
+ date_str = dt [:self .dt_size ] + dt [self .dt_tz_offset :]
933
+ return datetime .strptime (date_str , self .pattern ).replace (tzinfo = pytz .timezone (tz ))
934
+
935
+
936
+ class ArrayValueMapper (ValueMapper [List [Any ]]):
937
+ def __init__ (self , mapper : ValueMapper [Any ]):
938
+ self .mapper = mapper
939
+
940
+ def map (self , values : List [Any ]) -> List [Any ]:
941
+ return [self .mapper .map (value ) for value in values ]
942
+
943
+
944
+ class RowValueMapper (ValueMapper [Tuple [Optional [Any ], ...]]):
945
+ def __init__ (self , mappers : List [ValueMapper [Any ]]):
946
+ self .mappers = mappers
947
+
948
+ def map (self , values : List [Any ]) -> Tuple [Any , ...]:
949
+ return tuple (self .mappers [index ].map (value ) for index , value in enumerate (values ))
950
+
951
+
952
+ class MapValueMapper (ValueMapper [Dict [Any , Any ]]):
953
+ def __init__ (self , key_mapper : ValueMapper [Any ], value_mapper : ValueMapper [Any ]):
954
+ self .key_mapper = key_mapper
955
+ self .value_mapper = value_mapper
956
+
957
+ def map (self , values : Any ) -> Dict [Any , Any ]:
958
+ return {
959
+ self .key_mapper .map (key ): self .value_mapper .map (value ) for key , value in values .items ()
960
+ }
961
+
962
+
843
963
class NoOpRowMapper :
844
964
"""
845
965
No-op RowMapper which does not perform any transformation
@@ -856,117 +976,44 @@ class RowMapperFactory:
856
976
lambda functions (one for each column) which will process a data value
857
977
and returns a RowMapper instance which will process rows of data
858
978
"""
859
- no_op_row_mapper = NoOpRowMapper ()
979
+ NO_OP_ROW_MAPPER = NoOpRowMapper ()
860
980
861
981
def create (self , columns , experimental_python_types ):
862
982
assert columns is not None
863
983
864
984
if experimental_python_types :
865
- return RowMapper ([self ._col_func (column ['typeSignature' ]) for column in columns ])
866
- return RowMapperFactory .no_op_row_mapper
985
+ return RowMapper ([self ._create_value_mapper (column ['typeSignature' ]) for column in columns ])
986
+ return RowMapperFactory .NO_OP_ROW_MAPPER
867
987
868
- def _col_func (self , column ):
988
+ def _create_value_mapper (self , column ) -> ValueMapper :
869
989
col_type = column ['rawType' ]
870
990
871
991
if col_type == 'array' :
872
- return self ._array_map_func (column )
992
+ value_mapper = self ._create_value_mapper (column ['arguments' ][0 ]['value' ])
993
+ return ArrayValueMapper (value_mapper )
873
994
elif col_type == 'row' :
874
- return self ._row_map_func (column )
995
+ mappers = [self ._create_value_mapper (arg ['value' ]['typeSignature' ]) for arg in column ['arguments' ]]
996
+ return RowValueMapper (mappers )
875
997
elif col_type == 'map' :
876
- return self ._map_map_func (column )
998
+ key_mapper = self ._create_value_mapper (column ['arguments' ][0 ]['value' ])
999
+ value_mapper = self ._create_value_mapper (column ['arguments' ][1 ]['value' ])
1000
+ return MapValueMapper (key_mapper , value_mapper )
877
1001
elif col_type .startswith ('decimal' ):
878
- return lambda val : Decimal ( val )
1002
+ return DecimalValueMapper ( )
879
1003
elif col_type .startswith ('double' ) or col_type .startswith ('real' ):
880
- return self ._double_map_func ()
1004
+ return DoubleValueMapper ()
1005
+ elif col_type .startswith ('timestamp' ) and 'with time zone' in col_type :
1006
+ return TimestampWithTimeZoneValueMapper (column )
881
1007
elif col_type .startswith ('timestamp' ):
882
- return self ._timestamp_map_func (column , col_type )
1008
+ return TimestampValueMapper (column )
1009
+ elif col_type .startswith ('time' ) and 'with time zone' in col_type :
1010
+ return TimeWithTimeZoneValueMapper (column )
883
1011
elif col_type .startswith ('time' ):
884
- return self . _time_map_func (column , col_type )
1012
+ return TimeValueMapper (column )
885
1013
elif col_type == 'date' :
886
- return lambda val : datetime .strptime (val , '%Y-%m-%d' ).date ()
887
- else :
888
- return lambda val : val
889
-
890
- def _array_map_func (self , column ):
891
- element_mapping_func = self ._col_func (column ['arguments' ][0 ]['value' ])
892
- return lambda values : [element_mapping_func (value ) for value in values ]
893
-
894
- def _row_map_func (self , column ):
895
- element_mapping_func = [self ._col_func (arg ['value' ]['typeSignature' ]) for arg in column ['arguments' ]]
896
- return lambda values : tuple (element_mapping_func [idx ](value ) for idx , value in enumerate (values ))
897
-
898
- def _map_map_func (self , column ):
899
- key_mapping_func = self ._col_func (column ['arguments' ][0 ]['value' ])
900
- value_mapping_func = self ._col_func (column ['arguments' ][1 ]['value' ])
901
- return lambda values : {key_mapping_func (key ): value_mapping_func (value ) for key , value in values .items ()}
902
-
903
- def _double_map_func (self ):
904
- return lambda val : INF if val == 'Infinity' \
905
- else NEGATIVE_INF if val == '-Infinity' \
906
- else NAN if val == 'NaN' \
907
- else float (val )
908
-
909
- def _timestamp_map_func (self , column , col_type ):
910
- datetime_default_size = 20 # size of 'YYYY-MM-DD HH:MM:SS.' (the datetime string up to the milliseconds)
911
- pattern = "%Y-%m-%d %H:%M:%S"
912
- ms_size , ms_to_trim = self ._get_number_of_digits (column )
913
- if ms_size > 0 :
914
- pattern += ".%f"
915
-
916
- dt_size = datetime_default_size + ms_size - ms_to_trim
917
- dt_tz_offset = datetime_default_size + ms_size
918
- if 'with time zone' in col_type :
919
-
920
- if ms_to_trim > 0 :
921
- return lambda val : \
922
- [datetime .strptime (val [:dt_size ] + val [dt_tz_offset :], pattern + ' %z' )
923
- if tz .startswith ('+' ) or tz .startswith ('-' )
924
- else datetime .strptime (dt [:dt_size ] + dt [dt_tz_offset :], pattern )
925
- .replace (tzinfo = pytz .timezone (tz ))
926
- for dt , tz in [val .rsplit (' ' , 1 )]][0 ]
927
- else :
928
- return lambda val : [datetime .strptime (val , pattern + ' %z' )
929
- if tz .startswith ('+' ) or tz .startswith ('-' )
930
- else datetime .strptime (dt , pattern ).replace (tzinfo = pytz .timezone (tz ))
931
- for dt , tz in [val .rsplit (' ' , 1 )]][0 ]
932
-
933
- if ms_to_trim > 0 :
934
- return lambda val : datetime .strptime (val [:dt_size ] + val [dt_tz_offset :], pattern )
1014
+ return DateValueMapper ()
935
1015
else :
936
- return lambda val : datetime .strptime (val , pattern )
937
-
938
- def _time_map_func (self , column , col_type ):
939
- pattern = "%H:%M:%S"
940
- ms_size , ms_to_trim = self ._get_number_of_digits (column )
941
- if ms_size > 0 :
942
- pattern += ".%f"
943
-
944
- time_size = 9 + ms_size - ms_to_trim
945
-
946
- if 'with time zone' in col_type :
947
- return lambda val : self ._get_time_with_timezome (val , time_size , pattern )
948
- else :
949
- return lambda val : datetime .strptime (val [:time_size ], pattern ).time ()
950
-
951
- def _get_time_with_timezome (self , value , time_size , pattern ):
952
- matches = re .match (r'^(.*)([\+\-])(\d{2}):(\d{2})$' , value )
953
- assert matches is not None
954
- assert len (matches .groups ()) == 4
955
- if matches .group (2 ) == '-' :
956
- tz = - timedelta (hours = int (matches .group (3 )), minutes = int (matches .group (4 )))
957
- else :
958
- tz = timedelta (hours = int (matches .group (3 )), minutes = int (matches .group (4 )))
959
- return datetime .strptime (matches .group (1 )[:time_size ], pattern ).time ().replace (tzinfo = timezone (tz ))
960
-
961
- def _get_number_of_digits (self , column ):
962
- args = column ['arguments' ]
963
- if len (args ) == 0 :
964
- return 3 , 0
965
- ms_size = column ['arguments' ][0 ]['value' ]
966
- if ms_size == 0 :
967
- return - 1 , 0
968
- ms_to_trim = ms_size - min (ms_size , 6 )
969
- return ms_size , ms_to_trim
1016
+ return NoOpValueMapper ()
970
1017
971
1018
972
1019
class RowMapper :
@@ -982,14 +1029,14 @@ def map(self, rows):
982
1029
return [self ._map_row (row ) for row in rows ]
983
1030
984
1031
def _map_row (self , row ):
985
- return [self ._map_value (value , self .columns [idx ]) for idx , value in enumerate (row )]
1032
+ return [self ._map_value (value , self .columns [index ]) for index , value in enumerate (row )]
986
1033
987
- def _map_value (self , value , col_mapping_func ) :
1034
+ def _map_value (self , value , value_mapper : ValueMapper [ T ]) -> Optional [ T ] :
988
1035
if value is None :
989
1036
return None
990
1037
991
1038
try :
992
- return col_mapping_func (value )
1039
+ return value_mapper . map (value )
993
1040
except ValueError as e :
994
1041
error_str = f"Could not convert '{ value } ' into the associated python type"
995
1042
raise trino .exceptions .TrinoDataError (error_str ) from e
0 commit comments