Skip to content

Commit bee5dff

Browse files
SNOW-534004: Added database and schema to the queries related to temporary stage (#1274)
* Added database and schema to the queries related to temporary stage * Refactoring location building logic into a separate function; replacing format() to f-strings where backslashes are not used; adding tests * Fixing mocks and value extraction in new tests * Combining tests * Fix lint Co-authored-by: Sophie Tan <[email protected]>
1 parent d957164 commit bee5dff

File tree

2 files changed

+148
-60
lines changed

2 files changed

+148
-60
lines changed

src/snowflake/connector/pandas_tools.py

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,25 @@ def chunk_helper(lst: T, n: int) -> Iterator[tuple[int, T]]:
4141
yield int(i / n), lst[i : i + n]
4242

4343

44+
def build_location_helper(
45+
database: str | None, schema: str | None, name: str, quote_identifiers: bool
46+
) -> str:
47+
"""Helper to format table/stage/file format's location."""
48+
if quote_identifiers:
49+
location = (
50+
(('"' + database + '".') if database else "")
51+
+ (('"' + schema + '".') if schema else "")
52+
+ ('"' + name + '"')
53+
)
54+
else:
55+
location = (
56+
(database + "." if database else "")
57+
+ (schema + "." if schema else "")
58+
+ name
59+
)
60+
return location
61+
62+
4463
def write_pandas(
4564
conn: SnowflakeConnection,
4665
df: pandas.DataFrame,
@@ -130,9 +149,7 @@ def write_pandas(
130149
compression_map = {"gzip": "auto", "snappy": "snappy"}
131150
if compression not in compression_map.keys():
132151
raise ProgrammingError(
133-
"Invalid compression '{}', only acceptable values are: {}".format(
134-
compression, compression_map.keys()
135-
)
152+
f"Invalid compression '{compression}', only acceptable values are: {compression_map.keys()}"
136153
)
137154

138155
if create_temp_table:
@@ -150,20 +167,17 @@ def write_pandas(
150167
"Unsupported table type. Expected table types: temp/temporary, transient"
151168
)
152169

153-
if quote_identifiers:
154-
location = (f'"{database}".' if database else "") + (
155-
f'"{schema}".' if schema else ""
156-
)
157-
else:
158-
location = (f"{database}." if database else "") + (
159-
f"{schema}." if schema else ""
160-
)
161170
if chunk_size is None:
162171
chunk_size = len(df)
163172

164173
cursor = conn.cursor()
165-
stage_name = random_string()
166-
create_stage_sql = f'CREATE TEMP STAGE /* Python:snowflake.connector.pandas_tools.write_pandas() */ "{stage_name}"'
174+
stage_location = build_location_helper(
175+
database=database,
176+
schema=schema,
177+
name=random_string(),
178+
quote_identifiers=quote_identifiers,
179+
)
180+
create_stage_sql = f"CREATE TEMP STAGE /* Python:snowflake.connector.pandas_tools.write_pandas() */ {stage_location}"
167181
logger.debug(f"creating stage with '{create_stage_sql}'")
168182
cursor.execute(create_stage_sql, _is_internal=True).fetchall()
169183

@@ -175,10 +189,10 @@ def write_pandas(
175189
# Upload parquet file
176190
upload_sql = (
177191
"PUT /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
178-
"'file://{path}' @\"{stage_name}\" PARALLEL={parallel}"
192+
"'file://{path}' @{stage_location} PARALLEL={parallel}"
179193
).format(
180194
path=chunk_path.replace("\\", "\\\\").replace("'", "\\'"),
181-
stage_name=stage_name,
195+
stage_location=stage_location,
182196
parallel=parallel,
183197
)
184198
logger.debug(f"uploading files with '{upload_sql}'")
@@ -203,16 +217,21 @@ def drop_object(name: str, object_type: str) -> None:
203217
cursor.execute(drop_sql, _is_internal=True)
204218

205219
if auto_create_table or overwrite:
206-
file_format_name = random_string()
220+
file_format_location = build_location_helper(
221+
database=database,
222+
schema=schema,
223+
name=random_string(),
224+
quote_identifiers=quote_identifiers,
225+
)
207226
file_format_sql = (
208-
f"CREATE TEMP FILE FORMAT {file_format_name} "
227+
f"CREATE TEMP FILE FORMAT {file_format_location} "
209228
f"/* Python:snowflake.connector.pandas_tools.write_pandas() */ "
210229
f"TYPE=PARQUET COMPRESSION={compression_map[compression]}"
211230
)
212231
logger.debug(f"creating file format with '{file_format_sql}'")
213232
cursor.execute(file_format_sql, _is_internal=True)
214233

215-
infer_schema_sql = f"SELECT COLUMN_NAME, TYPE FROM table(infer_schema(location=>'@\"{stage_name}\"', file_format=>'{file_format_name}'))"
234+
infer_schema_sql = f"SELECT COLUMN_NAME, TYPE FROM table(infer_schema(location=>'@{stage_location}', file_format=>'{file_format_location}'))"
216235
logger.debug(f"inferring schema with '{infer_schema_sql}'")
217236
column_type_mapping = dict(
218237
cursor.execute(infer_schema_sql, _is_internal=True).fetchall()
@@ -224,39 +243,53 @@ def drop_object(name: str, object_type: str) -> None:
224243
[f"{quote}{c}{quote} {column_type_mapping[c]}" for c in df.columns]
225244
)
226245

227-
target_table_name = (
228-
f"{location}{quote}{random_string() if overwrite else table_name}{quote}"
246+
target_table_location = build_location_helper(
247+
database,
248+
schema,
249+
random_string() if overwrite else table_name,
250+
quote_identifiers,
229251
)
252+
230253
create_table_sql = (
231-
f"CREATE {table_type.upper()} TABLE IF NOT EXISTS {target_table_name} "
254+
f"CREATE {table_type.upper()} TABLE IF NOT EXISTS {target_table_location} "
232255
f"({create_table_columns})"
233256
f" /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
234257
)
235258
logger.debug(f"auto creating table with '{create_table_sql}'")
236259
cursor.execute(create_table_sql, _is_internal=True)
237260
else:
238-
target_table_name = f"{location}{quote}{table_name}{quote}"
261+
target_table_location = build_location_helper(
262+
database=database,
263+
schema=schema,
264+
name=table_name,
265+
quote_identifiers=quote_identifiers,
266+
)
239267

240268
try:
241269
copy_into_sql = (
242-
f"COPY INTO {target_table_name} /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
270+
f"COPY INTO {target_table_location} /* Python:snowflake.connector.pandas_tools.write_pandas() */ "
243271
f"({columns}) "
244-
f'FROM (SELECT {parquet_columns} FROM @"{stage_name}") '
272+
f"FROM (SELECT {parquet_columns} FROM @{stage_location}) "
245273
f"FILE_FORMAT=(TYPE=PARQUET COMPRESSION={compression_map[compression]}) "
246274
f"PURGE=TRUE ON_ERROR={on_error}"
247275
)
248276
logger.debug(f"copying into with '{copy_into_sql}'")
249277
copy_results = cursor.execute(copy_into_sql, _is_internal=True).fetchall()
250278

251279
if overwrite:
252-
original_table_name = f"{location}{quote}{table_name}{quote}"
253-
drop_object(original_table_name, "table")
254-
rename_table_sql = f"ALTER TABLE {target_table_name} RENAME TO {original_table_name} /* Python:snowflake.connector.pandas_tools.write_pandas() */"
280+
original_table_location = build_location_helper(
281+
database=database,
282+
schema=schema,
283+
name=table_name,
284+
quote_identifiers=quote_identifiers,
285+
)
286+
drop_object(original_table_location, "table")
287+
rename_table_sql = f"ALTER TABLE {target_table_location} RENAME TO {original_table_location} /* Python:snowflake.connector.pandas_tools.write_pandas() */"
255288
logger.debug(f"rename table with '{rename_table_sql}'")
256289
cursor.execute(rename_table_sql, _is_internal=True)
257290
except ProgrammingError:
258291
if overwrite:
259-
drop_object(target_table_name, "table")
292+
drop_object(target_table_location, "table")
260293
raise
261294
finally:
262295
cursor._log_telemetry_job_data(TelemetryField.PANDAS_WRITE, TelemetryData.TRUE)

test/integ/pandas/test_pandas_tools.py

Lines changed: 87 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -342,20 +342,33 @@ def test_empty_dataframe_write_pandas(
342342
), f"sucess: {success}, num_chunks: {num_chunks}, num_rows: {num_rows}"
343343

344344

345-
@pytest.mark.parametrize("quote_identifiers", [True, False])
346-
def test_location_building_db_schema(conn_cnx, quote_identifiers: bool):
347-
"""This tests that write_pandas constructs location correctly with database, schema and table name."""
345+
@pytest.mark.parametrize(
346+
"database,schema,quote_identifiers,expected_location",
347+
[
348+
("database", "schema", True, '"database"."schema"."table"'),
349+
("database", "schema", False, "database.schema.table"),
350+
(None, "schema", True, '"schema"."table"'),
351+
(None, "schema", False, "schema.table"),
352+
(None, None, True, '"table"'),
353+
(None, None, False, "table"),
354+
],
355+
)
356+
def test_table_location_building(
357+
conn_cnx,
358+
database: str | None,
359+
schema: str | None,
360+
quote_identifiers: bool,
361+
expected_location: str,
362+
):
363+
"""This tests that write_pandas constructs table location correctly with database, schema, and table name."""
348364
from snowflake.connector.cursor import SnowflakeCursor
349365

350366
with conn_cnx() as cnx:
351367

352368
def mocked_execute(*args, **kwargs):
353369
if len(args) >= 1 and args[0].startswith("COPY INTO"):
354370
location = args[0].split(" ")[2]
355-
if quote_identifiers:
356-
assert location == '"database"."schema"."table"'
357-
else:
358-
assert location == "database.schema.table"
371+
assert location == expected_location
359372
cur = SnowflakeCursor(cnx)
360373
cur._result = iter([])
361374
return cur
@@ -368,29 +381,42 @@ def mocked_execute(*args, **kwargs):
368381
cnx,
369382
sf_connector_version_df.get(),
370383
"table",
371-
database="database",
372-
schema="schema",
384+
database=database,
385+
schema=schema,
373386
quote_identifiers=quote_identifiers,
374387
)
375388
assert m_execute.called and any(
376389
map(lambda e: "COPY INTO" in str(e[0]), m_execute.call_args_list)
377390
)
378391

379392

380-
@pytest.mark.parametrize("quote_identifiers", [True, False])
381-
def test_location_building_schema(conn_cnx, quote_identifiers: bool):
382-
"""This tests that write_pandas constructs location correctly with schema and table name."""
393+
@pytest.mark.parametrize(
394+
"database,schema,quote_identifiers,expected_db_schema",
395+
[
396+
("database", "schema", True, '"database"."schema"'),
397+
("database", "schema", False, "database.schema"),
398+
(None, "schema", True, '"schema"'),
399+
(None, "schema", False, "schema"),
400+
(None, None, True, ""),
401+
(None, None, False, ""),
402+
],
403+
)
404+
def test_stage_location_building(
405+
conn_cnx,
406+
database: str | None,
407+
schema: str | None,
408+
quote_identifiers: bool,
409+
expected_db_schema: str,
410+
):
411+
"""This tests that write_pandas constructs stage location correctly with database and schema."""
383412
from snowflake.connector.cursor import SnowflakeCursor
384413

385414
with conn_cnx() as cnx:
386415

387416
def mocked_execute(*args, **kwargs):
388-
if len(args) >= 1 and args[0].startswith("COPY INTO"):
389-
location = args[0].split(" ")[2]
390-
if quote_identifiers:
391-
assert location == '"schema"."table"'
392-
else:
393-
assert location == "schema.table"
417+
if len(args) >= 1 and args[0].startswith("create temporary stage"):
418+
db_schema = ".".join(args[0].split(" ")[-1].split(".")[:-1])
419+
assert db_schema == expected_db_schema
394420
cur = SnowflakeCursor(cnx)
395421
cur._result = iter([])
396422
return cur
@@ -403,30 +429,53 @@ def mocked_execute(*args, **kwargs):
403429
cnx,
404430
sf_connector_version_df.get(),
405431
"table",
406-
schema="schema",
432+
database=database,
433+
schema=schema,
407434
quote_identifiers=quote_identifiers,
408435
)
409436
assert m_execute.called and any(
410-
map(lambda e: "COPY INTO" in str(e[0]), m_execute.call_args_list)
437+
map(
438+
lambda e: "CREATE TEMP STAGE" in str(e[0]),
439+
m_execute.call_args_list,
440+
)
411441
)
412442

413443

414-
@pytest.mark.parametrize("quote_identifiers", [True, False])
415-
def test_location_building(conn_cnx, quote_identifiers: bool):
416-
"""This tests that write_pandas constructs location correctly with schema and table name."""
444+
@pytest.mark.parametrize(
445+
"database,schema,quote_identifiers,expected_db_schema",
446+
[
447+
("database", "schema", True, '"database"."schema"'),
448+
("database", "schema", False, "database.schema"),
449+
(None, "schema", True, '"schema"'),
450+
(None, "schema", False, "schema"),
451+
(None, None, True, ""),
452+
(None, None, False, ""),
453+
],
454+
)
455+
def test_file_format_location_building(
456+
conn_cnx,
457+
database: str | None,
458+
schema: str | None,
459+
quote_identifiers: bool,
460+
expected_db_schema: str,
461+
):
462+
"""This tests that write_pandas constructs file format location correctly with database and schema."""
417463
from snowflake.connector.cursor import SnowflakeCursor
418464

419465
with conn_cnx() as cnx:
420466

421467
def mocked_execute(*args, **kwargs):
422-
if len(args) >= 1 and args[0].startswith("COPY INTO"):
423-
location = args[0].split(" ")[2]
424-
if quote_identifiers:
425-
assert location == '"teble.table"'
426-
else:
427-
assert location == "teble.table"
468+
if len(args) >= 1 and args[0].startswith("CREATE FILE FORMAT"):
469+
db_schema = ".".join(args[0].split(" ")[3].split(".")[:-1])
470+
assert db_schema == expected_db_schema
428471
cur = SnowflakeCursor(cnx)
429-
cur._result = iter([])
472+
if args[0].startswith("SELECT"):
473+
cur._rownumber = 0
474+
cur._result = iter(
475+
[(col, "") for col in sf_connector_version_df.get().columns]
476+
)
477+
else:
478+
cur._result = iter([])
430479
return cur
431480

432481
with mock.patch(
@@ -436,11 +485,17 @@ def mocked_execute(*args, **kwargs):
436485
success, nchunks, nrows, _ = write_pandas(
437486
cnx,
438487
sf_connector_version_df.get(),
439-
"teble.table",
488+
"table",
489+
database=database,
490+
schema=schema,
440491
quote_identifiers=quote_identifiers,
492+
auto_create_table=True,
441493
)
442494
assert m_execute.called and any(
443-
map(lambda e: "COPY INTO" in str(e[0]), m_execute.call_args_list)
495+
map(
496+
lambda e: "CREATE TEMP FILE FORMAT" in str(e[0]),
497+
m_execute.call_args_list,
498+
)
444499
)
445500

446501

0 commit comments

Comments
 (0)