Skip to content

Commit b637a22

Browse files
committed
SNOW-1664063: sync main branch changes into async part (#2081)
1 parent d13e4e2 commit b637a22

13 files changed

+91
-21
lines changed

src/snowflake/connector/aio/_azure_storage_client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import aiohttp
1616

17+
from ..azure_storage_client import AzureCredentialFilter
1718
from ..azure_storage_client import (
1819
SnowflakeAzureRestClient as SnowflakeAzureRestClientSync,
1920
)
@@ -25,14 +26,16 @@
2526
if TYPE_CHECKING: # pragma: no cover
2627
from ..file_transfer_agent import SnowflakeFileMeta, StorageCredential
2728

28-
logger = getLogger(__name__)
29-
3029
from ..azure_storage_client import (
3130
ENCRYPTION_DATA,
3231
MATDESC,
3332
TOKEN_EXPIRATION_ERR_MESSAGE,
3433
)
3534

35+
logger = getLogger(__name__)
36+
37+
getLogger("aiohttp").addFilter(AzureCredentialFilter())
38+
3639

3740
class SnowflakeAzureRestClient(
3841
SnowflakeStorageClientAsync, SnowflakeAzureRestClientSync

src/snowflake/connector/aio/_build_upload_agent.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import TYPE_CHECKING, cast
1111

1212
from snowflake.connector import Error
13+
from snowflake.connector._utils import get_temp_type_for_object
1314
from snowflake.connector.bind_upload_agent import BindUploadAgent as BindUploadAgentSync
1415
from snowflake.connector.errors import BindUploadError
1516

@@ -30,7 +31,11 @@ def __init__(
3031
self.cursor = cast("SnowflakeCursor", cursor)
3132

3233
async def _create_stage(self) -> None:
33-
await self.cursor.execute(self._CREATE_STAGE_STMT)
34+
create_stage_sql = (
35+
f"create or replace {get_temp_type_for_object(self._use_scoped_temp_object)} stage {self._STAGE_NAME} "
36+
"file_format=(type=csv field_optionally_enclosed_by='\"')"
37+
)
38+
await self.cursor.execute(create_stage_sql)
3439

3540
async def upload(self) -> None:
3641
try:

src/snowflake/connector/aio/_connection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from ..connection import _get_private_bytes_from_file
3636
from ..connection_diagnostic import ConnectionDiagnostic
3737
from ..constants import (
38+
_CONNECTIVITY_ERR_MSG,
3839
ENV_VAR_PARTNER,
3940
PARAMETER_AUTOCOMMIT,
4041
PARAMETER_CLIENT_PREFETCH_THREADS,
@@ -443,6 +444,8 @@ async def _authenticate(self, auth_instance: AuthByPlugin):
443444
)
444445
except OperationalError as auth_op:
445446
if auth_op.errno == ER_FAILED_TO_CONNECT_TO_DB:
447+
if _CONNECTIVITY_ERR_MSG in e.msg:
448+
auth_op.msg += f"\n{_CONNECTIVITY_ERR_MSG}"
446449
raise auth_op from e
447450
logger.debug("Continuing authenticator specific timeout handling")
448451
continue

src/snowflake/connector/aio/_cursor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,10 @@ async def _timebomb_task(self, timeout, query):
129129
logger.debug("started timebomb in %ss", timeout)
130130
await asyncio.sleep(timeout)
131131
await self.__cancel_query(query)
132+
return True
132133
except asyncio.CancelledError:
133134
logger.debug("cancelled timebomb in timebomb task")
135+
return False
134136

135137
async def __cancel_query(self, query) -> None:
136138
if self._sequence_counter >= 0 and not self.is_closed():
@@ -284,7 +286,10 @@ def interrupt_handler(*_): # pragma: no cover
284286
)
285287
if self._timebomb is not None:
286288
self._timebomb.cancel()
287-
self._timebomb = None
289+
try:
290+
await self._timebomb
291+
except asyncio.CancelledError:
292+
pass
288293
logger.debug("cancelled timebomb in finally")
289294

290295
if "data" in ret and "parameters" in ret["data"]:
@@ -674,6 +679,11 @@ async def execute(
674679
logger.debug(ret)
675680
err = ret["message"]
676681
code = ret.get("code", -1)
682+
if self._timebomb and self._timebomb.result():
683+
err = (
684+
f"SQL execution was cancelled by the client due to a timeout. "
685+
f"Error message received from the server: {err}"
686+
)
677687
if "data" in ret:
678688
err += ret["data"].get("errorMessage", "")
679689
errvalue = {
@@ -1067,6 +1077,7 @@ async def wait_until_ready() -> None:
10671077
self._prefetch_hook = wait_until_ready
10681078

10691079
async def query_result(self, qid: str) -> SnowflakeCursor:
1080+
"""Query the result of a previously executed query."""
10701081
url = f"/queries/{qid}/result"
10711082
ret = await self._connection.rest.request(url=url, method="get")
10721083
self._sfqid = (

src/snowflake/connector/aio/_network.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
urlparse,
2929
)
3030
from ..constants import (
31+
_CONNECTIVITY_ERR_MSG,
3132
HTTP_HEADER_ACCEPT,
3233
HTTP_HEADER_CONTENT_TYPE,
3334
HTTP_HEADER_SERVICE_NAME,
@@ -798,8 +799,19 @@ async def _request_exec(
798799
finally:
799800
raw_ret.close() # ensure response is closed
800801
except aiohttp.ClientSSLError as se:
801-
logger.debug("Hit non-retryable SSL error, %s", str(se))
802-
802+
msg = f"Hit non-retryable SSL error, {str(se)}.\n{_CONNECTIVITY_ERR_MSG}"
803+
logger.debug(msg)
804+
# the following code is for backward compatibility with old versions of python connector which calls
805+
# self._handle_unknown_error to process SSLError
806+
Error.errorhandler_wrapper(
807+
self._connection,
808+
None,
809+
OperationalError,
810+
{
811+
"msg": msg,
812+
"errno": ER_FAILED_TO_REQUEST,
813+
},
814+
)
803815
# TODO: sync feature parity, aiohttp network error handling
804816
except (
805817
BadStatusLine,

src/snowflake/connector/aio/auth/_auth.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from typing import TYPE_CHECKING, Any, Callable
1414

1515
from ...auth import Auth as AuthSync
16-
from ...auth._auth import ID_TOKEN, MFA_TOKEN, delete_temporary_credential
16+
from ...auth._auth import (
17+
AUTHENTICATION_REQUEST_KEY_WHITELIST,
18+
ID_TOKEN,
19+
MFA_TOKEN,
20+
delete_temporary_credential,
21+
)
1722
from ...compat import urlencode
1823
from ...constants import (
1924
HTTP_HEADER_ACCEPT,
@@ -103,7 +108,6 @@ async def authenticate(
103108

104109
body = copy.deepcopy(body_template)
105110
# updating request body
106-
logger.debug("assertion content: %s", auth_instance.assertion_content)
107111
await auth_instance.update_body(body)
108112

109113
logger.debug(
@@ -141,7 +145,10 @@ async def authenticate(
141145

142146
logger.debug(
143147
"body['data']: %s",
144-
{k: v for (k, v) in body["data"].items() if k != "PASSWORD"},
148+
{
149+
k: v if k in AUTHENTICATION_REQUEST_KEY_WHITELIST else "******"
150+
for (k, v) in body["data"].items()
151+
},
145152
)
146153

147154
try:

test/integ/aio/test_cursor_async.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,11 @@ async def test_timeout_query(conn_cnx):
728728
"select seq8() as c1 from table(generator(timeLimit => 60))",
729729
timeout=5,
730730
)
731-
assert err.value.errno == 604, "Invalid error code"
731+
assert err.value.errno == 604, (
732+
"Invalid error code"
733+
and "SQL execution was cancelled by the client due to a timeout"
734+
in err.value.msg
735+
)
732736

733737

734738
async def test_executemany(conn, db_parameters):

test/integ/aio/test_dbapi_async.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,8 @@ async def conn_local(request, conn_cnx):
4444
async def fin():
4545
await drop_dbapi_tables(conn_cnx)
4646

47-
request.addfinalizer(fin)
48-
49-
return conn_cnx
47+
yield conn_cnx
48+
await fin()
5049

5150

5251
async def _paraminsert(cur):

test/integ/aio/test_large_result_set_async.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ async def fin():
8787
"drop table if exists {name}".format(name=db_parameters["name"])
8888
)
8989

90-
request.addfinalizer(fin)
91-
return first_val, last_val
90+
yield first_val, last_val
91+
await fin()
9292

9393

9494
@pytest.mark.aws

test/integ/aio/test_put_get_async.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def test_utf8_filename(tmp_path, aio_connection):
4444
await aio_connection.connect()
4545
cursor = aio_connection.cursor()
4646
await cursor.execute(f"create temporary stage {stage_name}")
47-
(
47+
await (
4848
await cursor.execute(
4949
"PUT 'file://{}' @{}".format(str(test_file).replace("\\", "/"), stage_name)
5050
)
@@ -128,7 +128,7 @@ async def test_put_special_file_name(tmp_path, aio_connection):
128128
cursor = aio_connection.cursor()
129129
await cursor.execute(f"create temporary stage {stage_name}")
130130
filename_in_put = str(test_file).replace("\\", "/")
131-
(
131+
await (
132132
await cursor.execute(
133133
f"PUT 'file://{filename_in_put}' @{stage_name}",
134134
)

0 commit comments

Comments
 (0)