Skip to content

Commit 8a12c64

Browse files
authored
SNOW-669650: Fix write_pandas atomicity when overwrite=True (#1291)
1 parent 5a4e7a6 commit 8a12c64

17 files changed

+128
-132
lines changed

DESCRIPTION.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
88

99
# Release Notes
1010

11+
- v2.9.0(Unreleased)
12+
13+
- Enhanced the atomicity of write_pandas when overwrite is set to True
14+
1115
- v2.8.0(September 27,2022)
1216

1317
- Fixed a bug where rowcount was deleted when the cursor was closed

src/snowflake/connector/pandas_tools.py

Lines changed: 63 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
import collections.abc
88
import os
9-
import random
10-
import string
119
import warnings
1210
from functools import partial
1311
from logging import getLogger
@@ -19,6 +17,7 @@
1917
from snowflake.connector import ProgrammingError
2018
from snowflake.connector.options import pandas
2119
from snowflake.connector.telemetry import TelemetryData, TelemetryField
20+
from snowflake.connector.util_text import random_string
2221

2322
if TYPE_CHECKING: # pragma: no cover
2423
from .connection import SnowflakeConnection
@@ -152,37 +151,21 @@ def write_pandas(
152151
)
153152

154153
if quote_identifiers:
155-
location = (
156-
(('"' + database + '".') if database else "")
157-
+ (('"' + schema + '".') if schema else "")
158-
+ ('"' + table_name + '"')
154+
location = (f'"{database}".' if database else "") + (
155+
f'"{schema}".' if schema else ""
159156
)
160157
else:
161-
location = (
162-
(database + "." if database else "")
163-
+ (schema + "." if schema else "")
164-
+ (table_name)
158+
location = (f"{database}." if database else "") + (
159+
f"{schema}." if schema else ""
165160
)
166161
if chunk_size is None:
167162
chunk_size = len(df)
163+
168164
cursor = conn.cursor()
169-
stage_name = None # Forward declaration
170-
while True:
171-
try:
172-
stage_name = "".join(
173-
random.choice(string.ascii_lowercase) for _ in range(5)
174-
)
175-
create_stage_sql = (
176-
"create temporary stage /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
177-
'"{stage_name}"'
178-
).format(stage_name=stage_name)
179-
logger.debug(f"creating stage with '{create_stage_sql}'")
180-
cursor.execute(create_stage_sql, _is_internal=True).fetchall()
181-
break
182-
except ProgrammingError as pe:
183-
if pe.msg.endswith("already exists."):
184-
continue
185-
raise
165+
stage_name = random_string()
166+
create_stage_sql = f'CREATE TEMP STAGE /* Python:snowflake.connector.pandas_tools.write_pandas() */ "{stage_name}"'
167+
logger.debug(f"creating stage with '{create_stage_sql}'")
168+
cursor.execute(create_stage_sql, _is_internal=True).fetchall()
186169

187170
with TemporaryDirectory() as tmp_folder:
188171
for i, chunk in chunk_helper(df, chunk_size):
@@ -202,42 +185,33 @@ def write_pandas(
202185
cursor.execute(upload_sql, _is_internal=True)
203186
# Remove chunk file
204187
os.remove(chunk_path)
188+
189+
# in Snowflake, all parquet data is stored in a single column, $1, so we must select columns explicitly
190+
# see (https://docs.snowflake.com/en/user-guide/script-data-load-transform-parquet.html)
205191
if quote_identifiers:
192+
quote = '"'
206193
columns = '"' + '","'.join(list(df.columns)) + '"'
194+
parquet_columns = "$1:" + ",$1:".join(f'"{c}"' for c in df.columns)
207195
else:
196+
quote = ""
208197
columns = ",".join(list(df.columns))
198+
parquet_columns = "$1:" + ",$1:".join(df.columns)
199+
200+
def drop_object(name: str, object_type: str) -> None:
201+
drop_sql = f"DROP {object_type.upper()} IF EXISTS {name} /* Python:snowflake.connector.pandas_tools.write_pandas() */"
202+
logger.debug(f"dropping {object_type} with '{drop_sql}'")
203+
cursor.execute(drop_sql, _is_internal=True)
204+
205+
if auto_create_table or overwrite:
206+
file_format_name = random_string()
207+
file_format_sql = (
208+
f"CREATE TEMP FILE FORMAT {file_format_name} "
209+
f"/* Python:snowflake.connector.pandas_tools.write_pandas() */ "
210+
f"TYPE=PARQUET COMPRESSION={compression_map[compression]}"
211+
)
212+
logger.debug(f"creating file format with '{file_format_sql}'")
213+
cursor.execute(file_format_sql, _is_internal=True)
209214

210-
if overwrite:
211-
if auto_create_table:
212-
drop_table_sql = f"DROP TABLE IF EXISTS {location} /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
213-
logger.debug(f"dropping table with '{drop_table_sql}'")
214-
cursor.execute(drop_table_sql, _is_internal=True)
215-
else:
216-
truncate_table_sql = f"TRUNCATE TABLE IF EXISTS {location} /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
217-
logger.debug(f"truncating table with '{truncate_table_sql}'")
218-
cursor.execute(truncate_table_sql, _is_internal=True)
219-
220-
if auto_create_table:
221-
file_format_name = None
222-
while True:
223-
try:
224-
file_format_name = (
225-
'"'
226-
+ "".join(random.choice(string.ascii_lowercase) for _ in range(5))
227-
+ '"'
228-
)
229-
file_format_sql = (
230-
f"CREATE FILE FORMAT {file_format_name} "
231-
f"/* Python:snowflake.connector.pandas_tools.write_pandas() */ "
232-
f"TYPE=PARQUET COMPRESSION={compression_map[compression]}"
233-
)
234-
logger.debug(f"creating file format with '{file_format_sql}'")
235-
cursor.execute(file_format_sql, _is_internal=True)
236-
break
237-
except ProgrammingError as pe:
238-
if pe.msg.endswith("already exists."):
239-
continue
240-
raise
241215
infer_schema_sql = f"SELECT COLUMN_NAME, TYPE FROM table(infer_schema(location=>'@\"{stage_name}\"', file_format=>'{file_format_name}'))"
242216
logger.debug(f"inferring schema with '{infer_schema_sql}'")
243217
column_type_mapping = dict(
@@ -246,46 +220,48 @@ def write_pandas(
246220
# Infer schema can return the columns out of order depending on the chunking we do when uploading
247221
# so we have to iterate through the dataframe columns to make sure we create the table with its
248222
# columns in order
249-
quote = '"' if quote_identifiers else ""
250223
create_table_columns = ", ".join(
251224
[f"{quote}{c}{quote} {column_type_mapping[c]}" for c in df.columns]
252225
)
226+
227+
target_table_name = (
228+
f"{location}{quote}{random_string() if overwrite else table_name}{quote}"
229+
)
253230
create_table_sql = (
254-
f"CREATE {table_type.upper()} TABLE IF NOT EXISTS {location} "
231+
f"CREATE {table_type.upper()} TABLE IF NOT EXISTS {target_table_name} "
255232
f"({create_table_columns})"
256233
f" /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
257234
)
258235
logger.debug(f"auto creating table with '{create_table_sql}'")
259236
cursor.execute(create_table_sql, _is_internal=True)
260-
drop_file_format_sql = f"DROP FILE FORMAT IF EXISTS {file_format_name}"
261-
logger.debug(f"dropping file format with '{drop_file_format_sql}'")
262-
cursor.execute(drop_file_format_sql, _is_internal=True)
263-
264-
# in Snowflake, all parquet data is stored in a single column, $1, so we must select columns explicitly
265-
# see (https://docs.snowflake.com/en/user-guide/script-data-load-transform-parquet.html)
266-
if quote_identifiers:
267-
parquet_columns = "$1:" + ",$1:".join(f'"{c}"' for c in df.columns)
268237
else:
269-
parquet_columns = "$1:" + ",$1:".join(df.columns)
238+
target_table_name = f"{location}{quote}{table_name}{quote}"
239+
240+
try:
241+
copy_into_sql = (
242+
f"COPY INTO {target_table_name} /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
243+
f"({columns}) "
244+
f'FROM (SELECT {parquet_columns} FROM @"{stage_name}") '
245+
f"FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression_map[compression]}) "
246+
f"PURGE=TRUE ON_ERROR={on_error}"
247+
)
248+
logger.debug(f"copying into with '{copy_into_sql}'")
249+
copy_results = cursor.execute(copy_into_sql, _is_internal=True).fetchall()
250+
251+
if overwrite:
252+
original_table_name = f"{location}{quote}{table_name}{quote}"
253+
drop_object(original_table_name, "table")
254+
rename_table_sql = f"ALTER TABLE {target_table_name} RENAME TO {original_table_name} /* Python:snowflake.connector.pandas_tools.write_pandas() */"
255+
logger.debug(f"rename table with '{rename_table_sql}'")
256+
cursor.execute(rename_table_sql, _is_internal=True)
257+
except ProgrammingError:
258+
if overwrite:
259+
drop_object(target_table_name, "table")
260+
raise
261+
finally:
262+
cursor._log_telemetry_job_data(TelemetryField.PANDAS_WRITE, TelemetryData.TRUE)
263+
cursor.close()
270264

271-
copy_into_sql = (
272-
"COPY INTO {location} /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
273-
"({columns}) "
274-
'FROM (SELECT {parquet_columns} FROM @"{stage_name}") '
275-
"FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression}) "
276-
"PURGE=TRUE ON_ERROR={on_error}"
277-
).format(
278-
location=location,
279-
columns=columns,
280-
parquet_columns=parquet_columns,
281-
stage_name=stage_name,
282-
compression=compression_map[compression],
283-
on_error=on_error,
284-
)
285-
logger.debug(f"copying into with '{copy_into_sql}'")
286-
copy_results = cursor.execute(copy_into_sql, _is_internal=True).fetchall()
287-
cursor._log_telemetry_job_data(TelemetryField.PANDAS_WRITE, TelemetryData.TRUE)
288-
cursor.close()
289265
return (
290266
all(e[1] == "LOADED" for e in copy_results),
291267
len(copy_results),

src/snowflake/connector/util_text.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
from __future__ import annotations
77

88
import logging
9+
import random
910
import re
11+
import string
1012
from io import StringIO
13+
from typing import Sequence
1114

1215
COMMENT_PATTERN_RE = re.compile(r"^\s*\-\-")
1316
EMPTY_LINE_RE = re.compile(r"^\s*$")
@@ -254,3 +257,21 @@ def parse_account(account):
254257
parsed_account = account
255258

256259
return parsed_account
260+
261+
262+
def random_string(
263+
length: int = 10,
264+
prefix: str = "",
265+
suffix: str = "",
266+
choices: Sequence[str] = string.ascii_lowercase,
267+
) -> str:
268+
"""Our convenience function to generate random string for object names.
269+
270+
Args:
271+
length: How many random characters to choose from choices.
272+
prefix: Prefix to add to random string generated.
273+
suffix: Suffix to add to random string generated.
274+
choices: A generator of things to choose from.
275+
"""
276+
random_part = "".join([random.choice(choices) for _ in range(length)])
277+
return "".join([prefix, random_part, suffix])

test/integ/pandas/test_pandas_tools.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
import pytest
1414

1515
from snowflake.connector import DictCursor
16+
from snowflake.connector.errors import ProgrammingError
17+
from snowflake.connector.util_text import random_string
1618

1719
from ...lazy_var import LazyVar
18-
from ...randomize import random_string
1920

2021
try:
2122
from snowflake.connector.options import pandas
@@ -60,6 +61,8 @@ def test_write_pandas_with_overwrite(
6061
df2 = pandas.DataFrame(df2_data, columns=["name", "points"])
6162
df3_data = [(2022, "Jan", 10000), (2022, "Feb", 10220)]
6263
df3 = pandas.DataFrame(df3_data, columns=["year", "month", "revenue"])
64+
df4_data = [("Frank", 100)]
65+
df4 = pandas.DataFrame(df4_data, columns=["name%", "points"])
6366

6467
if quote_identifiers:
6568
table_name = '"' + random_table_name + '"'
@@ -133,6 +136,27 @@ def test_write_pandas_with_overwrite(
133136
else "YEAR" in [col.name for col in result[0].description]
134137
)
135138

139+
if not quote_identifiers:
140+
original_result = (
141+
cnx.cursor(DictCursor).execute(select_count_sql).fetchone()
142+
)
143+
# the column name contains special char which should fail
144+
with pytest.raises(ProgrammingError, match="unexpected '%'"):
145+
write_pandas(
146+
cnx,
147+
df4,
148+
random_table_name,
149+
quote_identifiers=quote_identifiers,
150+
auto_create_table=auto_create_table,
151+
overwrite=True,
152+
index=index,
153+
)
154+
# the original table shouldn't have any change
155+
assert (
156+
original_result
157+
== cnx.cursor(DictCursor).execute(select_count_sql).fetchone()
158+
)
159+
136160
finally:
137161
cnx.execute_string(drop_sql)
138162

test/integ/test_bindings.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020

2121
from snowflake.connector.converter import convert_datetime_to_epoch
2222
from snowflake.connector.errors import ForbiddenError, ProgrammingError
23-
24-
from ..randomize import random_string
23+
from snowflake.connector.util_text import random_string
2524

2625
tempfile.gettempdir()
2726

test/integ/test_cursor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,7 @@ class ResultMetadata(NamedTuple):
5151
)
5252
from snowflake.connector.sqlstate import SQLSTATE_FEATURE_NOT_SUPPORTED
5353
from snowflake.connector.telemetry import TelemetryField
54-
55-
from ..randomize import random_string
54+
from snowflake.connector.util_text import random_string
5655

5756
try:
5857
from snowflake.connector.constants import (

test/integ/test_dataintegrity.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
import pytz
1919

2020
from snowflake.connector.dbapi import DateFromTicks, TimeFromTicks, TimestampFromTicks
21-
22-
from ..randomize import random_string
21+
from snowflake.connector.util_text import random_string
2322

2423

2524
def table_exists(conn_cnx, name):

test/integ/test_dbapi.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
import snowflake.connector
1818
import snowflake.connector.dbapi
1919
from snowflake.connector import dbapi, errorcode, errors
20-
21-
from ..randomize import random_string
20+
from snowflake.connector.util_text import random_string
2221

2322
TABLE1 = "dbapi_ddl1"
2423
TABLE2 = "dbapi_ddl2"

test/integ/test_put_get.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
import pytest
1818

1919
from snowflake.connector import OperationalError
20+
from snowflake.connector.util_text import random_string
2021

2122
from ..generate_test_files import generate_k_lines_of_n_files
2223
from ..integ_helpers import put
23-
from ..randomize import random_string
2424

2525
if TYPE_CHECKING:
2626
from snowflake.connector import SnowflakeConnection

test/integ/test_put_get_compress_enc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111

1212
import pytest
1313

14+
from snowflake.connector.util_text import random_string
15+
1416
from ..integ_helpers import put
15-
from ..randomize import random_string
1617

1718
pytestmark = pytest.mark.skipolddriver # old test driver tests won't run this module
1819

0 commit comments

Comments
 (0)