Skip to content

Commit 2583452

Browse files
authored
SNOW-2161990 introduce a tiny abstraction to allow sproc to override … (#2370)
1 parent 1062e1e commit 2583452

File tree

2 files changed

+56
-3
lines changed

2 files changed

+56
-3
lines changed

src/snowflake/connector/file_transfer_agent.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,11 +1062,14 @@ def _init_file_metadata(self) -> None:
10621062
for idx, file_name in enumerate(self._src_files):
10631063
if not file_name:
10641064
continue
1065-
first_path_sep = file_name.find("/")
10661065
dst_file_name = (
1067-
file_name[first_path_sep + 1 :]
1066+
self._strip_stage_prefix_from_dst_file_name_for_download(file_name)
1067+
)
1068+
first_path_sep = dst_file_name.find("/")
1069+
dst_file_name = (
1070+
dst_file_name[first_path_sep + 1 :]
10681071
if first_path_sep >= 0
1069-
else file_name
1072+
else dst_file_name
10701073
)
10711074
url = None
10721075
if self._presigned_urls and idx < len(self._presigned_urls):
@@ -1201,3 +1204,12 @@ def _process_file_compression_type(self) -> None:
12011204
else:
12021205
m.dst_file_name = m.name
12031206
m.dst_compression_type = None
1207+
1208+
def _strip_stage_prefix_from_dst_file_name_for_download(self, dst_file_name):
1209+
"""Strips the stage prefix from dst_file_name for download.
1210+
1211+
Note that this is no-op in most cases, and therefore we return as is.
1212+
But for some workloads they will monkeypatch this method to add their
1213+
stripping logic.
1214+
"""
1215+
return dst_file_name

test/unit/test_put_get.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from os import chmod, path
55
from unittest import mock
6+
from unittest.mock import patch
67

78
import pytest
89

@@ -252,3 +253,43 @@ def test_iobound_limit(tmp_path):
252253
pass
253254
# 2 IObound TPEs should be created for 3 files limited to 2
254255
assert len(list(filter(lambda e: e.args == (2,), tpe.call_args_list))) == 2
256+
257+
258+
def test_strip_stage_prefix_from_dst_file_name_for_download():
259+
"""Verifies that _strip_stage_prefix_from_dst_file_name_for_download is called when initializing file meta.
260+
261+
Workloads like sproc will need to monkeypatch _strip_stage_prefix_from_dst_file_name_for_download on the server side
262+
to maintain its behavior. So we add this unit test to make sure that we do not accidentally refactor this method and
263+
break sproc workloads.
264+
"""
265+
file = "test.txt"
266+
agent = SnowflakeFileTransferAgent(
267+
mock.MagicMock(autospec=SnowflakeCursor),
268+
"GET @stage_foo/test.txt file:///tmp",
269+
{
270+
"data": {
271+
"localLocation": "/tmp",
272+
"command": "DOWNLOAD",
273+
"autoCompress": False,
274+
"src_locations": [file],
275+
"sourceCompression": "none",
276+
"stageInfo": {
277+
"creds": {},
278+
"location": "",
279+
"locationType": "S3",
280+
"path": "remote_loc",
281+
},
282+
},
283+
"success": True,
284+
},
285+
)
286+
agent._parse_command()
287+
with patch.object(
288+
agent,
289+
"_strip_stage_prefix_from_dst_file_name_for_download",
290+
return_value="mock value",
291+
):
292+
agent._init_file_metadata()
293+
agent._strip_stage_prefix_from_dst_file_name_for_download.assert_called_with(
294+
file
295+
)

0 commit comments

Comments
 (0)