66
77import collections .abc
88import os
9- import random
10- import string
119import warnings
1210from functools import partial
1311from logging import getLogger
1917from snowflake .connector import ProgrammingError
2018from snowflake .connector .options import pandas
2119from snowflake .connector .telemetry import TelemetryData , TelemetryField
20+ from snowflake .connector .util_text import random_string
2221
2322if TYPE_CHECKING : # pragma: no cover
2423 from .connection import SnowflakeConnection
@@ -152,37 +151,21 @@ def write_pandas(
152151 )
153152
154153 if quote_identifiers :
155- location = (
156- (('"' + database + '".' ) if database else "" )
157- + (('"' + schema + '".' ) if schema else "" )
158- + ('"' + table_name + '"' )
154+ location = (f'"{ database } ".' if database else "" ) + (
155+ f'"{ schema } ".' if schema else ""
159156 )
160157 else :
161- location = (
162- (database + "." if database else "" )
163- + (schema + "." if schema else "" )
164- + (table_name )
158+ location = (f"{ database } ." if database else "" ) + (
159+ f"{ schema } ." if schema else ""
165160 )
166161 if chunk_size is None :
167162 chunk_size = len (df )
163+
168164 cursor = conn .cursor ()
169- stage_name = None # Forward declaration
170- while True :
171- try :
172- stage_name = "" .join (
173- random .choice (string .ascii_lowercase ) for _ in range (5 )
174- )
175- create_stage_sql = (
176- "create temporary stage /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
177- '"{stage_name}"'
178- ).format (stage_name = stage_name )
179- logger .debug (f"creating stage with '{ create_stage_sql } '" )
180- cursor .execute (create_stage_sql , _is_internal = True ).fetchall ()
181- break
182- except ProgrammingError as pe :
183- if pe .msg .endswith ("already exists." ):
184- continue
185- raise
165+ stage_name = random_string ()
166+ create_stage_sql = f'CREATE TEMP STAGE /* Python:snowflake.connector.pandas_tools.write_pandas() */ "{ stage_name } "'
167+ logger .debug (f"creating stage with '{ create_stage_sql } '" )
168+ cursor .execute (create_stage_sql , _is_internal = True ).fetchall ()
186169
187170 with TemporaryDirectory () as tmp_folder :
188171 for i , chunk in chunk_helper (df , chunk_size ):
@@ -202,42 +185,33 @@ def write_pandas(
202185 cursor .execute (upload_sql , _is_internal = True )
203186 # Remove chunk file
204187 os .remove (chunk_path )
188+
189+ # in Snowflake, all parquet data is stored in a single column, $1, so we must select columns explicitly
190+ # see (https://docs.snowflake.com/en/user-guide/script-data-load-transform-parquet.html)
205191 if quote_identifiers :
192+ quote = '"'
206193 columns = '"' + '","' .join (list (df .columns )) + '"'
194+ parquet_columns = "$1:" + ",$1:" .join (f'"{ c } "' for c in df .columns )
207195 else :
196+ quote = ""
208197 columns = "," .join (list (df .columns ))
198+ parquet_columns = "$1:" + ",$1:" .join (df .columns )
199+
200+ def drop_object (name : str , object_type : str ) -> None :
201+ drop_sql = f"DROP { object_type .upper ()} IF EXISTS { name } /* Python:snowflake.connector.pandas_tools.write_pandas() */"
202+ logger .debug (f"dropping { object_type } with '{ drop_sql } '" )
203+ cursor .execute (drop_sql , _is_internal = True )
204+
205+ if auto_create_table or overwrite :
206+ file_format_name = random_string ()
207+ file_format_sql = (
208+ f"CREATE TEMP FILE FORMAT { file_format_name } "
209+ f"/* Python:snowflake.connector.pandas_tools.write_pandas() */ "
210+ f"TYPE=PARQUET COMPRESSION={ compression_map [compression ]} "
211+ )
212+ logger .debug (f"creating file format with '{ file_format_sql } '" )
213+ cursor .execute (file_format_sql , _is_internal = True )
209214
210- if overwrite :
211- if auto_create_table :
212- drop_table_sql = f"DROP TABLE IF EXISTS { location } /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
213- logger .debug (f"dropping table with '{ drop_table_sql } '" )
214- cursor .execute (drop_table_sql , _is_internal = True )
215- else :
216- truncate_table_sql = f"TRUNCATE TABLE IF EXISTS { location } /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
217- logger .debug (f"truncating table with '{ truncate_table_sql } '" )
218- cursor .execute (truncate_table_sql , _is_internal = True )
219-
220- if auto_create_table :
221- file_format_name = None
222- while True :
223- try :
224- file_format_name = (
225- '"'
226- + "" .join (random .choice (string .ascii_lowercase ) for _ in range (5 ))
227- + '"'
228- )
229- file_format_sql = (
230- f"CREATE FILE FORMAT { file_format_name } "
231- f"/* Python:snowflake.connector.pandas_tools.write_pandas() */ "
232- f"TYPE=PARQUET COMPRESSION={ compression_map [compression ]} "
233- )
234- logger .debug (f"creating file format with '{ file_format_sql } '" )
235- cursor .execute (file_format_sql , _is_internal = True )
236- break
237- except ProgrammingError as pe :
238- if pe .msg .endswith ("already exists." ):
239- continue
240- raise
241215 infer_schema_sql = f"SELECT COLUMN_NAME, TYPE FROM table(infer_schema(location=>'@\" { stage_name } \" ', file_format=>'{ file_format_name } '))"
242216 logger .debug (f"inferring schema with '{ infer_schema_sql } '" )
243217 column_type_mapping = dict (
@@ -246,46 +220,48 @@ def write_pandas(
246220 # Infer schema can return the columns out of order depending on the chunking we do when uploading
247221 # so we have to iterate through the dataframe columns to make sure we create the table with its
248222 # columns in order
249- quote = '"' if quote_identifiers else ""
250223 create_table_columns = ", " .join (
251224 [f"{ quote } { c } { quote } { column_type_mapping [c ]} " for c in df .columns ]
252225 )
226+
227+ target_table_name = (
228+ f"{ location } { quote } { random_string () if overwrite else table_name } { quote } "
229+ )
253230 create_table_sql = (
254- f"CREATE { table_type .upper ()} TABLE IF NOT EXISTS { location } "
231+ f"CREATE { table_type .upper ()} TABLE IF NOT EXISTS { target_table_name } "
255232 f"({ create_table_columns } )"
256233 f" /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
257234 )
258235 logger .debug (f"auto creating table with '{ create_table_sql } '" )
259236 cursor .execute (create_table_sql , _is_internal = True )
260- drop_file_format_sql = f"DROP FILE FORMAT IF EXISTS { file_format_name } "
261- logger .debug (f"dropping file format with '{ drop_file_format_sql } '" )
262- cursor .execute (drop_file_format_sql , _is_internal = True )
263-
264- # in Snowflake, all parquet data is stored in a single column, $1, so we must select columns explicitly
265- # see (https://docs.snowflake.com/en/user-guide/script-data-load-transform-parquet.html)
266- if quote_identifiers :
267- parquet_columns = "$1:" + ",$1:" .join (f'"{ c } "' for c in df .columns )
268237 else :
269- parquet_columns = "$1:" + ",$1:" .join (df .columns )
238+ target_table_name = f"{ location } { quote } { table_name } { quote } "
239+
240+ try :
241+ copy_into_sql = (
242+ f"COPY INTO { target_table_name } /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
243+ f"({ columns } ) "
244+ f'FROM (SELECT { parquet_columns } FROM @"{ stage_name } ") '
245+ f"FILE_FORMAT=(TYPE=PARQUET COMPRESSION={ compression_map [compression ]} ) "
246+ f"PURGE=TRUE ON_ERROR={ on_error } "
247+ )
248+ logger .debug (f"copying into with '{ copy_into_sql } '" )
249+ copy_results = cursor .execute (copy_into_sql , _is_internal = True ).fetchall ()
250+
251+ if overwrite :
252+ original_table_name = f"{ location } { quote } { table_name } { quote } "
253+ drop_object (original_table_name , "table" )
254+ rename_table_sql = f"ALTER TABLE { target_table_name } RENAME TO { original_table_name } /* Python:snowflake.connector.pandas_tools.write_pandas() */"
255+ logger .debug (f"rename table with '{ rename_table_sql } '" )
256+ cursor .execute (rename_table_sql , _is_internal = True )
257+ except ProgrammingError :
258+ if overwrite :
259+ drop_object (target_table_name , "table" )
260+ raise
261+ finally :
262+ cursor ._log_telemetry_job_data (TelemetryField .PANDAS_WRITE , TelemetryData .TRUE )
263+ cursor .close ()
270264
271- copy_into_sql = (
272- "COPY INTO {location} /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
273- "({columns}) "
274- 'FROM (SELECT {parquet_columns} FROM @"{stage_name}") '
275- "FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression}) "
276- "PURGE=TRUE ON_ERROR={on_error}"
277- ).format (
278- location = location ,
279- columns = columns ,
280- parquet_columns = parquet_columns ,
281- stage_name = stage_name ,
282- compression = compression_map [compression ],
283- on_error = on_error ,
284- )
285- logger .debug (f"copying into with '{ copy_into_sql } '" )
286- copy_results = cursor .execute (copy_into_sql , _is_internal = True ).fetchall ()
287- cursor ._log_telemetry_job_data (TelemetryField .PANDAS_WRITE , TelemetryData .TRUE )
288- cursor .close ()
289265 return (
290266 all (e [1 ] == "LOADED" for e in copy_results ),
291267 len (copy_results ),
0 commit comments