Skip to content

Commit a959fc6

Browse files
SNOW-1325701: Use scoped temp object in write pandas (#2068)
1 parent aeb771c commit a959fc6

File tree

5 files changed

+188
-16
lines changed

5 files changed

+188
-16
lines changed

src/snowflake/connector/_utils.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#
2+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
from __future__ import annotations
6+
7+
import string
8+
from enum import Enum
9+
from random import choice
10+
11+
12+
class TempObjectType(Enum):
13+
TABLE = "TABLE"
14+
VIEW = "VIEW"
15+
STAGE = "STAGE"
16+
FUNCTION = "FUNCTION"
17+
FILE_FORMAT = "FILE_FORMAT"
18+
QUERY_TAG = "QUERY_TAG"
19+
COLUMN = "COLUMN"
20+
PROCEDURE = "PROCEDURE"
21+
TABLE_FUNCTION = "TABLE_FUNCTION"
22+
DYNAMIC_TABLE = "DYNAMIC_TABLE"
23+
AGGREGATE_FUNCTION = "AGGREGATE_FUNCTION"
24+
CTE = "CTE"
25+
26+
27+
TEMP_OBJECT_NAME_PREFIX = "SNOWPARK_TEMP_"
28+
ALPHANUMERIC = string.digits + string.ascii_lowercase
29+
TEMPORARY_STRING = "TEMP"
30+
SCOPED_TEMPORARY_STRING = "SCOPED TEMPORARY"
31+
_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING = (
32+
"PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS"
33+
)
34+
35+
36+
def generate_random_alphanumeric(length: int = 10) -> str:
37+
return "".join(choice(ALPHANUMERIC) for _ in range(length))
38+
39+
40+
def random_name_for_temp_object(object_type: TempObjectType) -> str:
41+
return f"{TEMP_OBJECT_NAME_PREFIX}{object_type.value}_{generate_random_alphanumeric().upper()}"
42+
43+
44+
def get_temp_type_for_object(use_scoped_temp_objects: bool) -> str:
45+
return SCOPED_TEMPORARY_STRING if use_scoped_temp_objects else TEMPORARY_STRING

src/snowflake/connector/bind_upload_agent.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
from logging import getLogger
1111
from typing import TYPE_CHECKING
1212

13+
from ._utils import (
14+
_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING,
15+
get_temp_type_for_object,
16+
)
1317
from .errors import BindUploadError, Error
1418

1519
if TYPE_CHECKING: # pragma: no cover
@@ -19,11 +23,6 @@
1923

2024

2125
class BindUploadAgent:
22-
_STAGE_NAME = "SYSTEMBIND"
23-
_CREATE_STAGE_STMT = (
24-
f"create or replace temporary stage {_STAGE_NAME} "
25-
"file_format=(type=csv field_optionally_enclosed_by='\"')"
26-
)
2726

2827
def __init__(
2928
self,
@@ -38,13 +37,27 @@ def __init__(
3837
rows: Rows of binding parameters in CSV format.
3938
stream_buffer_size: Size of each file, default to 10MB.
4039
"""
40+
self._use_scoped_temp_object = (
41+
cursor.connection._session_parameters.get(
42+
_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING, False
43+
)
44+
if cursor.connection._session_parameters
45+
else False
46+
)
47+
self._STAGE_NAME = (
48+
"SNOWPARK_TEMP_STAGE_BIND" if self._use_scoped_temp_object else "SYSTEMBIND"
49+
)
4150
self.cursor = cursor
4251
self.rows = rows
4352
self._stream_buffer_size = stream_buffer_size
4453
self.stage_path = f"@{self._STAGE_NAME}/{uuid.uuid4().hex}"
4554

4655
def _create_stage(self) -> None:
47-
self.cursor.execute(self._CREATE_STAGE_STMT)
56+
create_stage_sql = (
57+
f"create or replace {get_temp_type_for_object(self._use_scoped_temp_object)} stage {self._STAGE_NAME} "
58+
"file_format=(type=csv field_optionally_enclosed_by='\"')"
59+
)
60+
self.cursor.execute(create_stage_sql)
4861

4962
def upload(self) -> None:
5063
try:

src/snowflake/connector/pandas_tools.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@
2626
from snowflake.connector.telemetry import TelemetryData, TelemetryField
2727
from snowflake.connector.util_text import random_string
2828

29+
from ._utils import (
30+
_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING,
31+
TempObjectType,
32+
get_temp_type_for_object,
33+
random_name_for_temp_object,
34+
)
2935
from .cursor import SnowflakeCursor
3036

3137
if TYPE_CHECKING: # pragma: no cover
@@ -77,8 +83,9 @@ def _do_create_temp_stage(
7783
compression: str,
7884
auto_create_table: bool,
7985
overwrite: bool,
86+
use_scoped_temp_object: bool,
8087
) -> None:
81-
create_stage_sql = f"CREATE TEMP STAGE /* Python:snowflake.connector.pandas_tools.write_pandas() */ {stage_location} FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression}{' BINARY_AS_TEXT=FALSE' if auto_create_table or overwrite else ''})"
88+
create_stage_sql = f"CREATE {get_temp_type_for_object(use_scoped_temp_object)} STAGE /* Python:snowflake.connector.pandas_tools.write_pandas() */ {stage_location} FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression}{' BINARY_AS_TEXT=FALSE' if auto_create_table or overwrite else ''})"
8289
logger.debug(f"creating stage with '{create_stage_sql}'")
8390
cursor.execute(create_stage_sql, _is_internal=True).fetchall()
8491

@@ -91,8 +98,13 @@ def _create_temp_stage(
9198
compression: str,
9299
auto_create_table: bool,
93100
overwrite: bool,
101+
use_scoped_temp_object: bool = False,
94102
) -> str:
95-
stage_name = random_string()
103+
stage_name = (
104+
random_name_for_temp_object(TempObjectType.STAGE)
105+
if use_scoped_temp_object
106+
else random_string()
107+
)
96108
stage_location = build_location_helper(
97109
database=database,
98110
schema=schema,
@@ -101,7 +113,12 @@ def _create_temp_stage(
101113
)
102114
try:
103115
_do_create_temp_stage(
104-
cursor, stage_location, compression, auto_create_table, overwrite
116+
cursor,
117+
stage_location,
118+
compression,
119+
auto_create_table,
120+
overwrite,
121+
use_scoped_temp_object,
105122
)
106123
except ProgrammingError as e:
107124
# User may not have the privilege to create stage on the target schema, so fall back to use current schema as
@@ -111,7 +128,12 @@ def _create_temp_stage(
111128
)
112129
stage_location = stage_name
113130
_do_create_temp_stage(
114-
cursor, stage_location, compression, auto_create_table, overwrite
131+
cursor,
132+
stage_location,
133+
compression,
134+
auto_create_table,
135+
overwrite,
136+
use_scoped_temp_object,
115137
)
116138

117139
return stage_location
@@ -122,9 +144,10 @@ def _do_create_temp_file_format(
122144
file_format_location: str,
123145
compression: str,
124146
sql_use_logical_type: str,
147+
use_scoped_temp_object: bool,
125148
) -> None:
126149
file_format_sql = (
127-
f"CREATE TEMP FILE FORMAT {file_format_location} "
150+
f"CREATE {get_temp_type_for_object(use_scoped_temp_object)} FILE FORMAT {file_format_location} "
128151
f"/* Python:snowflake.connector.pandas_tools.write_pandas() */ "
129152
f"TYPE=PARQUET COMPRESSION={compression}{sql_use_logical_type}"
130153
)
@@ -139,8 +162,13 @@ def _create_temp_file_format(
139162
quote_identifiers: bool,
140163
compression: str,
141164
sql_use_logical_type: str,
165+
use_scoped_temp_object: bool = False,
142166
) -> str:
143-
file_format_name = random_string()
167+
file_format_name = (
168+
random_name_for_temp_object(TempObjectType.FILE_FORMAT)
169+
if use_scoped_temp_object
170+
else random_string()
171+
)
144172
file_format_location = build_location_helper(
145173
database=database,
146174
schema=schema,
@@ -149,7 +177,11 @@ def _create_temp_file_format(
149177
)
150178
try:
151179
_do_create_temp_file_format(
152-
cursor, file_format_location, compression, sql_use_logical_type
180+
cursor,
181+
file_format_location,
182+
compression,
183+
sql_use_logical_type,
184+
use_scoped_temp_object,
153185
)
154186
except ProgrammingError as e:
155187
# User may not have the privilege to create file format on the target schema, so fall back to use current schema
@@ -159,7 +191,11 @@ def _create_temp_file_format(
159191
)
160192
file_format_location = file_format_name
161193
_do_create_temp_file_format(
162-
cursor, file_format_location, compression, sql_use_logical_type
194+
cursor,
195+
file_format_location,
196+
compression,
197+
sql_use_logical_type,
198+
use_scoped_temp_object,
163199
)
164200

165201
return file_format_location
@@ -263,6 +299,14 @@ def write_pandas(
263299
f"Invalid compression '{compression}', only acceptable values are: {compression_map.keys()}"
264300
)
265301

302+
_use_scoped_temp_object = (
303+
conn._session_parameters.get(
304+
_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING, False
305+
)
306+
if conn._session_parameters
307+
else False
308+
)
309+
266310
if create_temp_table:
267311
warnings.warn(
268312
"create_temp_table is deprecated, we still respect this parameter when it is True but "
@@ -324,6 +368,7 @@ def write_pandas(
324368
compression,
325369
auto_create_table,
326370
overwrite,
371+
_use_scoped_temp_object,
327372
)
328373

329374
with TemporaryDirectory() as tmp_folder:
@@ -370,6 +415,7 @@ def drop_object(name: str, object_type: str) -> None:
370415
quote_identifiers,
371416
compression_map[compression],
372417
sql_use_logical_type,
418+
_use_scoped_temp_object,
373419
)
374420
infer_schema_sql = f"SELECT COLUMN_NAME, TYPE FROM table(infer_schema(location=>'@{stage_location}', file_format=>'{file_format_location}'))"
375421
logger.debug(f"inferring schema with '{infer_schema_sql}'")

test/integ/pandas/test_pandas_tools.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,59 @@ def mocked_execute(*args, **kwargs):
602602
)
603603
assert m_execute.called and any(
604604
map(
605-
lambda e: "CREATE TEMP STAGE" in str(e[0]),
605+
lambda e: ("CREATE TEMP STAGE" in str(e[0])),
606+
m_execute.call_args_list,
607+
)
608+
)
609+
610+
611+
@pytest.mark.parametrize(
612+
"database,schema,quote_identifiers,expected_db_schema",
613+
[
614+
("database", "schema", True, '"database"."schema"'),
615+
("database", "schema", False, "database.schema"),
616+
(None, "schema", True, '"schema"'),
617+
(None, "schema", False, "schema"),
618+
(None, None, True, ""),
619+
(None, None, False, ""),
620+
],
621+
)
622+
def test_use_scoped_object(
623+
conn_cnx,
624+
database: str | None,
625+
schema: str | None,
626+
quote_identifiers: bool,
627+
expected_db_schema: str,
628+
):
629+
"""This tests that write_pandas constructs stage location correctly with database and schema."""
630+
from snowflake.connector.cursor import SnowflakeCursor
631+
632+
with conn_cnx() as cnx:
633+
634+
def mocked_execute(*args, **kwargs):
635+
if len(args) >= 1 and args[0].startswith("create temporary stage"):
636+
db_schema = ".".join(args[0].split(" ")[-1].split(".")[:-1])
637+
assert db_schema == expected_db_schema
638+
cur = SnowflakeCursor(cnx)
639+
cur._result = iter([])
640+
return cur
641+
642+
with mock.patch(
643+
"snowflake.connector.cursor.SnowflakeCursor.execute",
644+
side_effect=mocked_execute,
645+
) as m_execute:
646+
cnx._update_parameters({"PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS": True})
647+
success, nchunks, nrows, _ = write_pandas(
648+
cnx,
649+
sf_connector_version_df.get(),
650+
"table",
651+
database=database,
652+
schema=schema,
653+
quote_identifiers=quote_identifiers,
654+
)
655+
assert m_execute.called and any(
656+
map(
657+
lambda e: ("CREATE SCOPED TEMPORARY STAGE" in str(e[0])),
606658
m_execute.call_args_list,
607659
)
608660
)
@@ -660,7 +712,7 @@ def mocked_execute(*args, **kwargs):
660712
)
661713
assert m_execute.called and any(
662714
map(
663-
lambda e: "CREATE TEMP FILE FORMAT" in str(e[0]),
715+
lambda e: ("CREATE TEMP FILE FORMAT" in str(e[0])),
664716
m_execute.call_args_list,
665717
)
666718
)

test/unit/test_bind_upload_agent.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import annotations
77

8+
from unittest import mock
89
from unittest.mock import MagicMock
910

1011

@@ -26,3 +27,18 @@ def test_bind_upload_agent_row_size_exceed_buffer_size():
2627
agent = BindUploadAgent(csr, rows, stream_buffer_size=10)
2728
agent.upload()
2829
assert csr.execute.call_count == 11 # 1 for stage creation + 10 files
30+
31+
32+
def test_bind_upload_agent_scoped_temp_object():
33+
from snowflake.connector.bind_upload_agent import BindUploadAgent
34+
35+
csr = MagicMock(auto_spec=True)
36+
rows = [bytes(15)] * 10
37+
agent = BindUploadAgent(csr, rows, stream_buffer_size=10)
38+
with mock.patch.object(agent, "_use_scoped_temp_object", new=True):
39+
with mock.patch.object(agent.cursor, "execute") as mock_execute:
40+
agent._create_stage()
41+
assert (
42+
"create or replace SCOPED TEMPORARY stage"
43+
in mock_execute.call_args[0][0]
44+
)

0 commit comments

Comments
 (0)