@@ -173,16 +173,10 @@ def __init__(
173173 super ().__init__ (create_connection , dbms_type )
174174
175175 def to_snow_type (self , schema : List [Any ]) -> StructType :
176- # TODO: Implement this method to convert PostgreSQL types to Snowflake types.
177- # https://other-docs.snowflake.com/en/connectors/postgres6/view-data#postgresql-to-snowflake-data-type-mapping
178- # psycopg2 type code: https://github.com/psycopg/psycopg2/blob/master/psycopg/pgtypes.h
179- # https://www.postgresql.org/docs/current/datatype.html
176+ # The psycopg2 spec is defined in the following links:
180177 # https://www.psycopg.org/docs/cursor.html#cursor.description
181- # https://www.psycopg.org/docs/extensions.html#psycopg2.extensions.Column.type_code
182- # https://www.postgresql.org/docs/current/catalog-pg-type.html
183- # https://www.psycopg.org/docs/advanced.html#type-casting-from-sql-to-python
184- fields = []
185178 # https://www.psycopg.org/docs/extensions.html#psycopg2.extensions.Column
179+ fields = []
186180 for (
187181 name ,
188182 type_code ,
@@ -222,14 +216,6 @@ def data_source_data_to_pandas_df(
222216 data : List [Any ], schema : StructType
223217 ) -> "pd.DataFrame" :
224218 df = BaseDriver .data_source_data_to_pandas_df (data , schema )
225- # psycopg2 returns binary data as memoryview, we need to convert it to bytes
226- binary_type_indexes = [
227- i
228- for i , field in enumerate (schema .fields )
229- if isinstance (field .datatype , BinaryType )
230- ]
231- col_names = df .columns [binary_type_indexes ]
232- df [col_names ] = BaseDriver .df_map_method (df [col_names ])(lambda x : bytes (x ))
233219
234220 variant_type_indexes = [
235221 i
@@ -259,8 +245,8 @@ def to_result_snowpark_df(
259245 project_columns , _emit_ast = _emit_ast
260246 )
261247
248+ @staticmethod
262249 def prepare_connection (
263- self ,
264250 conn : "Connection" ,
265251 query_timeout : int = 0 ,
266252 ) -> "Connection" :
@@ -275,4 +261,97 @@ def prepare_connection(
275261 lambda data , cursor : data ,
276262 )
277263 register_type (SNOWPARK_INTERVAL_STR , conn )
264+
265+ # by default psycopg2 returns binary data as memoryview
266+ # to avoid using pandas to convert memoryview to bytes, we use the following native psycopg2 type conversion
267+ # psycopg2.extensions.new_type() only works for text format data, it returns bytes as hex string
268+ # we reconstruct the bytes from hex string
269+ SNOWPARK_BYTE = new_type (
270+ (Psycopg2TypeCode .BYTEAOID .value ,),
271+ "SNOWPARK_BYTE_BYTES" ,
272+ lambda data , cursor : bytes .fromhex (data [2 :])
273+ if data is not None
274+ else None , # [2:] to skip the '\\x' prefix
275+ )
276+ register_type (SNOWPARK_BYTE , conn )
277+
278+ if query_timeout :
279+ # https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT
280+ # postgres default uses milliseconds
281+ conn .cursor ().execute (f"SET STATEMENT_TIMEOUT = { query_timeout * 1000 } " )
278282 return conn
283+
284+ def udtf_class_builder (
285+ self , fetch_size : int = 1000 , schema : StructType = None
286+ ) -> type :
287+ create_connection = self .create_connection
288+
289+ # TODO: SNOW-2101485 ues class method to prepare connection
290+ # ideally we should use the same function as prepare_connection
291+ # however, since we introduce new module for new driver support and initially the new module is not available in the backend
292+ # so if registering UDTF which uses the class method, cloudpickle will pickle the class method along with
293+ # the new module -- this leads to not being able to find the new module when unpickling on the backend.
294+ # once the new module is available in the backend, we can use the class method.
295+ def prepare_connection_in_udtf (
296+ conn : "Connection" ,
297+ query_timeout : int = 0 ,
298+ ) -> "Connection" :
299+ # The following is to align with Snowflake Connector behavior which get Interval as string
300+ # the default behavior of psycopg2 is to get Interval as datetime.timedelta
301+ # https://other-docs.snowflake.com/en/connectors/postgres6/view-data#postgresql-to-snowflake-data-type-mapping
302+ from psycopg2 .extensions import new_type , register_type
303+
304+ # we do not use Psycopg2TypeCode.INTERVALOID.value because UTDF pickles the psycopg2_driver module
305+ # unpickling in the UDTF would results in module not found error if package not available in the backend
306+ SNOWPARK_INTERVAL_STR = new_type (
307+ (1186 ,),
308+ "SNOWPARK_INTERVAL_STR" ,
309+ lambda data , cursor : data ,
310+ )
311+ register_type (SNOWPARK_INTERVAL_STR , conn )
312+
313+ if query_timeout :
314+ # https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT
315+ # postgres default uses milliseconds
316+ conn .cursor ().execute (f"SET STATEMENT_TIMEOUT = { query_timeout * 1000 } " )
317+ return conn
318+
319+ binary_column_indexes = [
320+ i
321+ for i , field in enumerate (schema .fields )
322+ if isinstance (field .datatype , BinaryType )
323+ ]
324+ time_column_indexes = [
325+ i
326+ for i , field in enumerate (schema .fields )
327+ if isinstance (field .datatype , TimeType )
328+ ]
329+
330+ # postgres returns binary data as memoryview, we need to convert it to bytes
331+ def convert_rows (rows_to_update ):
332+ ret = []
333+ for row in rows_to_update :
334+ # convert tuple to list to make it mutable
335+ new_row = list (row )
336+ # convert bytes to hexstring so that variant column can be cast to bytes
337+ for idx in binary_column_indexes :
338+ new_row [idx ] = bytes (row [idx ]).hex () if row [idx ] else None
339+ # remove timezone info from time columns
340+ for idx in time_column_indexes :
341+ new_row [idx ] = row [idx ].replace (tzinfo = None ) if row [idx ] else None
342+ # convert list back to tuple as UDTF requires tuple
343+ ret .append (tuple (new_row ))
344+ return ret
345+
346+ class UDTFIngestion :
347+ def process (self , query : str ):
348+ conn = prepare_connection_in_udtf (create_connection ())
349+ cursor = conn .cursor ()
350+ cursor .execute (query )
351+ while True :
352+ rows = cursor .fetchmany (fetch_size )
353+ if not rows :
354+ break
355+ yield from convert_rows (rows )
356+
357+ return UDTFIngestion
0 commit comments