Skip to content

Commit 9e3d1d3

Browse files
[Async] Apply 2241 to async code
1 parent 9209e40 commit 9e3d1d3

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

src/snowflake/connector/aio/_file_transfer_agent.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ async def _create_file_transfer_client(
301301
self._cursor._connection,
302302
self._command,
303303
unsafe_file_write=self._unsafe_file_write,
304+
use_virtual_endpoints=self._gcs_use_virtual_endpoints,
304305
)
305306
if client.security_token:
306307
logger.debug(f"len(GCS_ACCESS_TOKEN): {len(client.security_token)}")

src/snowflake/connector/aio/_gcs_storage_client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
cnx: SnowflakeConnection,
3939
command: str,
4040
unsafe_file_write: bool = False,
41+
use_virtual_endpoints: bool = False,
4142
) -> None:
4243
"""Creates a client object with given stage credentials.
4344
@@ -72,6 +73,7 @@ def __init__(
7273
self.endpoint: str | None = (
7374
None if "endPoint" not in stage_info else stage_info["endPoint"]
7475
)
76+
self.use_virtual_endpoints: bool = use_virtual_endpoints
7577

7678
async def _has_expired_token(self, response: aiohttp.ClientResponse) -> bool:
7779
return self.security_token and response.status == 401
@@ -147,6 +149,8 @@ def generate_url_and_rest_args() -> (
147149
if "region" not in self.stage_info
148150
else self.stage_info["region"]
149151
),
152+
self.endpoint,
153+
self.use_virtual_endpoints,
150154
)
151155
access_token = self.security_token
152156
else:
@@ -185,6 +189,7 @@ def generate_url_and_rest_args() -> (
185189
else self.stage_info["region"]
186190
),
187191
self.endpoint,
192+
self.use_virtual_endpoints,
188193
)
189194
access_token = self.security_token
190195
gcs_headers["Authorization"] = f"Bearer {access_token}"
@@ -309,6 +314,8 @@ def generate_url_and_authenticated_headers():
309314
if "region" not in self.stage_info
310315
else self.stage_info["region"]
311316
),
317+
self.endpoint,
318+
self.use_virtual_endpoints,
312319
)
313320
gcs_headers = {"Authorization": f"Bearer {self.security_token}"}
314321
rest_args = {"headers": gcs_headers}

0 commit comments

Comments
 (0)