Skip to content

Commit a7f35b8

Browse files
sfc-gh-yuwangsfc-gh-aling
authored andcommitted
SNOW-1728340: support gcp and azure (#2067)
1 parent 9a85c62 commit a7f35b8

12 files changed

+1751
-37
lines changed
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
#
2+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
from __future__ import annotations
6+
7+
import json
8+
import xml.etree.ElementTree as ET
9+
from datetime import datetime, timezone
10+
from logging import getLogger
11+
from random import choice
12+
from string import hexdigits
13+
from typing import TYPE_CHECKING, Any
14+
15+
import aiohttp
16+
17+
from ..azure_storage_client import (
18+
SnowflakeAzureRestClient as SnowflakeAzureRestClientSync,
19+
)
20+
from ..compat import quote
21+
from ..constants import FileHeader, ResultStatus
22+
from ..encryption_util import EncryptionMetadata
23+
from ._storage_client import SnowflakeStorageClient as SnowflakeStorageClientAsync
24+
25+
if TYPE_CHECKING: # pragma: no cover
26+
from ..file_transfer_agent import SnowflakeFileMeta, StorageCredential
27+
28+
logger = getLogger(__name__)
29+
30+
from ..azure_storage_client import (
31+
ENCRYPTION_DATA,
32+
MATDESC,
33+
TOKEN_EXPIRATION_ERR_MESSAGE,
34+
)
35+
36+
37+
class SnowflakeAzureRestClient(
38+
SnowflakeStorageClientAsync, SnowflakeAzureRestClientSync
39+
):
40+
def __init__(
41+
self,
42+
meta: SnowflakeFileMeta,
43+
credentials: StorageCredential | None,
44+
chunk_size: int,
45+
stage_info: dict[str, Any],
46+
use_s3_regional_url: bool = False,
47+
) -> None:
48+
SnowflakeAzureRestClientSync.__init__(
49+
self,
50+
meta=meta,
51+
stage_info=stage_info,
52+
chunk_size=chunk_size,
53+
credentials=credentials,
54+
)
55+
56+
async def _has_expired_token(self, response: aiohttp.ClientResponse) -> bool:
57+
return response.status == 403 and any(
58+
message in response.reason for message in TOKEN_EXPIRATION_ERR_MESSAGE
59+
)
60+
61+
async def _send_request_with_authentication_and_retry(
62+
self,
63+
verb: str,
64+
url: str,
65+
retry_id: int | str,
66+
headers: dict[str, Any] = None,
67+
data: bytes = None,
68+
) -> aiohttp.ClientResponse:
69+
if not headers:
70+
headers = {}
71+
72+
def generate_authenticated_url_and_rest_args() -> tuple[str, dict[str, Any]]:
73+
curtime = datetime.now(timezone.utc).replace(tzinfo=None)
74+
timestamp = curtime.strftime("YYYY-MM-DD")
75+
sas_token = self.credentials.creds["AZURE_SAS_TOKEN"]
76+
if sas_token and sas_token.startswith("?"):
77+
sas_token = sas_token[1:]
78+
if "?" in url:
79+
_url = url + "&" + sas_token
80+
else:
81+
_url = url + "?" + sas_token
82+
headers["Date"] = timestamp
83+
rest_args = {"headers": headers}
84+
if data:
85+
rest_args["data"] = data
86+
return _url, rest_args
87+
88+
return await self._send_request_with_retry(
89+
verb, generate_authenticated_url_and_rest_args, retry_id
90+
)
91+
92+
async def get_file_header(self, filename: str) -> FileHeader | None:
93+
"""Gets Azure file properties."""
94+
container_name = quote(self.azure_location.container_name)
95+
path = quote(self.azure_location.path) + quote(filename)
96+
meta = self.meta
97+
# HTTP HEAD request
98+
url = f"https://{self.storage_account}.blob.{self.endpoint}/{container_name}/{path}"
99+
retry_id = "HEAD"
100+
self.retry_count[retry_id] = 0
101+
r = await self._send_request_with_authentication_and_retry(
102+
"HEAD", url, retry_id
103+
)
104+
if r.status == 200:
105+
meta.result_status = ResultStatus.UPLOADED
106+
enc_data_str = r.headers.get(ENCRYPTION_DATA)
107+
encryption_data = None if enc_data_str is None else json.loads(enc_data_str)
108+
encryption_metadata = (
109+
None
110+
if not encryption_data
111+
else EncryptionMetadata(
112+
key=encryption_data["WrappedContentKey"]["EncryptedKey"],
113+
iv=encryption_data["ContentEncryptionIV"],
114+
matdesc=r.headers.get(MATDESC),
115+
)
116+
)
117+
return FileHeader(
118+
digest=r.headers.get("x-ms-meta-sfcdigest"),
119+
content_length=int(r.headers.get("Content-Length")),
120+
encryption_metadata=encryption_metadata,
121+
)
122+
elif r.status == 404:
123+
meta.result_status = ResultStatus.NOT_FOUND_FILE
124+
return FileHeader(
125+
digest=None, content_length=None, encryption_metadata=None
126+
)
127+
else:
128+
r.raise_for_status()
129+
130+
async def _initiate_multipart_upload(self) -> None:
131+
self.block_ids = [
132+
"".join(choice(hexdigits) for _ in range(20))
133+
for _ in range(self.num_of_chunks)
134+
]
135+
136+
async def _upload_chunk(self, chunk_id: int, chunk: bytes) -> None:
137+
container_name = quote(self.azure_location.container_name)
138+
path = quote(self.azure_location.path + self.meta.dst_file_name.lstrip("/"))
139+
140+
if self.num_of_chunks > 1:
141+
block_id = self.block_ids[chunk_id]
142+
url = (
143+
f"https://{self.storage_account}.blob.{self.endpoint}/{container_name}/{path}?comp=block"
144+
f"&blockid={block_id}"
145+
)
146+
headers = {"Content-Length": str(len(chunk))}
147+
r = await self._send_request_with_authentication_and_retry(
148+
"PUT", url, chunk_id, headers=headers, data=chunk
149+
)
150+
else:
151+
# single request
152+
azure_metadata = self._prepare_file_metadata()
153+
url = f"https://{self.storage_account}.blob.{self.endpoint}/{container_name}/{path}"
154+
headers = {
155+
"x-ms-blob-type": "BlockBlob",
156+
"Content-Encoding": "utf-8",
157+
}
158+
headers.update(azure_metadata)
159+
r = await self._send_request_with_authentication_and_retry(
160+
"PUT", url, chunk_id, headers=headers, data=chunk
161+
)
162+
r.raise_for_status() # expect status code 201
163+
164+
async def _complete_multipart_upload(self) -> None:
165+
container_name = quote(self.azure_location.container_name)
166+
path = quote(self.azure_location.path + self.meta.dst_file_name.lstrip("/"))
167+
url = (
168+
f"https://{self.storage_account}.blob.{self.endpoint}/{container_name}/{path}?comp"
169+
f"=blocklist"
170+
)
171+
root = ET.Element("BlockList")
172+
for block_id in self.block_ids:
173+
part = ET.Element("Latest")
174+
part.text = block_id
175+
root.append(part)
176+
headers = {"x-ms-blob-content-encoding": "utf-8"}
177+
azure_metadata = self._prepare_file_metadata()
178+
headers.update(azure_metadata)
179+
retry_id = "COMPLETE"
180+
self.retry_count[retry_id] = 0
181+
r = await self._send_request_with_authentication_and_retry(
182+
"PUT", url, "COMPLETE", headers=headers, data=ET.tostring(root)
183+
)
184+
r.raise_for_status() # expects status code 201
185+
186+
async def download_chunk(self, chunk_id: int) -> None:
187+
container_name = quote(self.azure_location.container_name)
188+
path = quote(self.azure_location.path + self.meta.src_file_name.lstrip("/"))
189+
url = f"https://{self.storage_account}.blob.{self.endpoint}/{container_name}/{path}"
190+
if self.num_of_chunks > 1:
191+
chunk_size = self.chunk_size
192+
if chunk_id < self.num_of_chunks - 1:
193+
_range = f"{chunk_id * chunk_size}-{(chunk_id + 1) * chunk_size - 1}"
194+
else:
195+
_range = f"{chunk_id * chunk_size}-"
196+
headers = {"Range": f"bytes={_range}"}
197+
r = await self._send_request_with_authentication_and_retry(
198+
"GET", url, chunk_id, headers=headers
199+
) # expect 206
200+
else:
201+
# single request
202+
r = await self._send_request_with_authentication_and_retry(
203+
"GET", url, chunk_id
204+
)
205+
if r.status in (200, 206):
206+
self.write_downloaded_chunk(chunk_id, await r.read())
207+
r.raise_for_status()

src/snowflake/connector/aio/_file_transfer_agent.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from logging import getLogger
1111
from typing import IO, TYPE_CHECKING, Any
1212

13-
from ..azure_storage_client import SnowflakeAzureRestClient
1413
from ..constants import (
1514
AZURE_CHUNK_SIZE,
1615
AZURE_FS,
@@ -29,8 +28,9 @@
2928
SnowflakeFileTransferAgent as SnowflakeFileTransferAgentSync,
3029
)
3130
from ..file_transfer_agent import SnowflakeProgressPercentage, _chunk_size_calculator
32-
from ..gcs_storage_client import SnowflakeGCSRestClient
3331
from ..local_storage_client import SnowflakeLocalStorageClient
32+
from ._azure_storage_client import SnowflakeAzureRestClient
33+
from ._gcs_storage_client import SnowflakeGCSRestClient
3434
from ._s3_storage_client import SnowflakeS3RestClient
3535
from ._storage_client import SnowflakeStorageClient
3636

@@ -92,7 +92,7 @@ async def execute(self) -> None:
9292
for m in self._file_metadata:
9393
m.sfagent = self
9494

95-
self._transfer_accelerate_config()
95+
await self._transfer_accelerate_config()
9696

9797
if self._command_type == CMD_TYPE_DOWNLOAD:
9898
if not os.path.isdir(self._local_location):
@@ -139,7 +139,7 @@ async def execute(self) -> None:
139139
result.result_status = result.result_status.value
140140

141141
async def transfer(self, metas: list[SnowflakeFileMeta]) -> None:
142-
files = [self._create_file_transfer_client(m) for m in metas]
142+
files = [await self._create_file_transfer_client(m) for m in metas]
143143
is_upload = self._command_type == CMD_TYPE_UPLOAD
144144
finish_download_upload_tasks = []
145145

@@ -258,7 +258,12 @@ def postprocess_done_cb(
258258

259259
self._results = metas
260260

261-
def _create_file_transfer_client(
261+
async def _transfer_accelerate_config(self) -> None:
262+
if self._stage_location_type == S3_FS and self._file_metadata:
263+
client = await self._create_file_transfer_client(self._file_metadata[0])
264+
self._use_accelerate_endpoint = client.transfer_accelerate_config()
265+
266+
async def _create_file_transfer_client(
262267
self, meta: SnowflakeFileMeta
263268
) -> SnowflakeStorageClient:
264269
if self._stage_location_type == LOCAL_FS:
@@ -276,21 +281,30 @@ def _create_file_transfer_client(
276281
use_s3_regional_url=self._use_s3_regional_url,
277282
)
278283
elif self._stage_location_type == S3_FS:
279-
return SnowflakeS3RestClient(
284+
client = SnowflakeS3RestClient(
280285
meta=meta,
281286
credentials=self._credentials,
282287
stage_info=self._stage_info,
283288
chunk_size=_chunk_size_calculator(meta.src_file_size),
284289
use_accelerate_endpoint=self._use_accelerate_endpoint,
285290
use_s3_regional_url=self._use_s3_regional_url,
286291
)
292+
return client
287293
elif self._stage_location_type == GCS_FS:
288-
return SnowflakeGCSRestClient(
294+
client = SnowflakeGCSRestClient(
289295
meta,
290296
self._credentials,
291297
self._stage_info,
292298
self._cursor._connection,
293299
self._command,
294300
use_s3_regional_url=self._use_s3_regional_url,
295301
)
302+
if client.security_token:
303+
logger.debug(f"len(GCS_ACCESS_TOKEN): {len(client.security_token)}")
304+
else:
305+
logger.debug(
306+
"No access token received from GS, requesting presigned url"
307+
)
308+
await client._update_presigned_url()
309+
return client
296310
raise Exception(f"{self._stage_location_type} is an unknown stage type")

0 commit comments

Comments
 (0)