Skip to content

Commit 075fde0

Browse files
committed
SNOW-1757241: migrate all integ test (#2076)
1 parent 4de4f55 commit 075fde0

33 files changed

+4520
-9
lines changed

src/snowflake/connector/aio/_network.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import itertools
1212
import json
1313
import logging
14+
import re
1415
import uuid
1516
from typing import TYPE_CHECKING, Any
1617

@@ -163,7 +164,7 @@ def __init__(
163164
self._ocsp_mode = (
164165
self._connection._ocsp_mode() if self._connection else OCSPMode.FAIL_OPEN
165166
)
166-
if self._connection.proxy_host:
167+
if self._connection and self._connection.proxy_host:
167168
self._get_proxy_headers = lambda url: {"Host": parse_url(url).hostname}
168169
else:
169170
self._get_proxy_headers = lambda _: None
@@ -416,6 +417,7 @@ async def _get_request(
416417
headers: dict[str, str],
417418
token: str = None,
418419
timeout: int | None = None,
420+
is_fetch_query_status: bool = False,
419421
) -> dict[str, Any]:
420422
if "Content-Encoding" in headers:
421423
del headers["Content-Encoding"]
@@ -429,6 +431,7 @@ async def _get_request(
429431
headers,
430432
timeout=timeout,
431433
token=token,
434+
is_fetch_query_status=is_fetch_query_status,
432435
)
433436
if ret.get("code") == SESSION_EXPIRED_GS_CODE:
434437
try:
@@ -443,7 +446,12 @@ async def _get_request(
443446
)
444447
)
445448
if ret.get("success"):
446-
return await self._get_request(url, headers, token=self.token)
449+
return await self._get_request(
450+
url,
451+
headers,
452+
token=self.token,
453+
is_fetch_query_status=is_fetch_query_status,
454+
)
447455

448456
return ret
449457

@@ -517,7 +525,13 @@ async def _post_request(
517525
result_url = ret["data"]["getResultUrl"]
518526
logger.debug("ping pong starting...")
519527
ret = await self._get_request(
520-
result_url, headers, token=self.token, timeout=timeout
528+
result_url,
529+
headers,
530+
token=self.token,
531+
timeout=timeout,
532+
is_fetch_query_status=bool(
533+
re.match(r"^/queries/.+/result$", result_url)
534+
),
521535
)
522536
logger.debug("ret[code] = %s", ret.get("code", "N/A"))
523537
logger.debug("ping pong done")
@@ -603,6 +617,7 @@ async def _request_exec_wrapper(
603617

604618
full_url = retry_ctx.add_retry_params(full_url)
605619
full_url = SnowflakeRestful.add_request_guid(full_url)
620+
is_fetch_query_status = kwargs.pop("is_fetch_query_status", False)
606621
try:
607622
return_object = await self._request_exec(
608623
session=session,
@@ -615,6 +630,13 @@ async def _request_exec_wrapper(
615630
)
616631
if return_object is not None:
617632
return return_object
633+
if is_fetch_query_status:
634+
err_msg = (
635+
"fetch query status failed and http request returned None, this"
636+
" is usually caused by transient network failures, retrying..."
637+
)
638+
logger.info(err_msg)
639+
raise RetryRequest(err_msg)
618640
self._handle_unknown_error(method, full_url, headers, data, conn)
619641
return {}
620642
except RetryRequest as e:

src/snowflake/connector/aio/_s3_storage_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ def generate_authenticated_url_and_args_v4() -> tuple[str, dict[str, bytes]]:
160160
if payload:
161161
rest_args["data"] = payload
162162

163-
# ignore_content_encoding is removed because it
164-
# does not apply to asyncio
163+
if ignore_content_encoding:
164+
rest_args["auto_decompress"] = False
165165

166166
return url, rest_args
167167

test/integ/aio/lambda/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#
2+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3+
#
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/usr/bin/env python
2+
3+
#
4+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
5+
#
6+
7+
8+
async def test_connection(conn_cnx):
9+
"""Test basic connection."""
10+
async with conn_cnx() as cnx:
11+
cur = cnx.cursor()
12+
result = await (await cur.execute("select 1;")).fetchall()
13+
assert result == [(1,)]
14+
15+
16+
async def test_large_resultset(conn_cnx):
17+
"""Test large resultset."""
18+
async with conn_cnx() as cnx:
19+
cur = cnx.cursor()
20+
result = await (
21+
await cur.execute(
22+
"select seq8(), randstr(1000, random()) from table(generator(rowcount=>10000));"
23+
)
24+
).fetchall()
25+
assert len(result) == 10000

test/integ/aio/sso/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#
2+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3+
#
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
4+
#
5+
6+
from __future__ import annotations
7+
8+
# This test requires the SSO and Snowflake admin connection parameters.
9+
#
10+
# CONNECTION_PARAMETERS_SSO = {
11+
# 'account': 'testaccount',
12+
# 'user': '[email protected]',
13+
# 'protocol': 'http',
14+
# 'host': 'testaccount.reg.snowflakecomputing.com',
15+
# 'port': '8082',
16+
# 'authenticator': 'externalbrowser',
17+
# 'timezone': 'UTC',
18+
# }
19+
#
20+
# CONNECTION_PARAMETERS_ADMIN = { ... Snowflake admin ... }
21+
import os
22+
import sys
23+
24+
import pytest
25+
26+
import snowflake.connector.aio
27+
from snowflake.connector.auth._auth import delete_temporary_credential
28+
29+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
30+
31+
try:
32+
from parameters import CONNECTION_PARAMETERS_SSO
33+
except ImportError:
34+
CONNECTION_PARAMETERS_SSO = {}
35+
36+
try:
37+
from parameters import CONNECTION_PARAMETERS_ADMIN
38+
except ImportError:
39+
CONNECTION_PARAMETERS_ADMIN = {}
40+
41+
ID_TOKEN = "ID_TOKEN"
42+
43+
44+
@pytest.fixture
45+
async def token_validity_test_values(request):
46+
async with snowflake.connector.aio.SnowflakeConnection(
47+
**CONNECTION_PARAMETERS_ADMIN
48+
) as cnx:
49+
await cnx.cursor().execute(
50+
"""
51+
ALTER SYSTEM SET
52+
MASTER_TOKEN_VALIDITY=60,
53+
SESSION_TOKEN_VALIDITY=5,
54+
ID_TOKEN_VALIDITY=60
55+
"""
56+
)
57+
# ALLOW_UNPROTECTED_ID_TOKEN is going to be deprecated in the future
58+
# cnx.cursor().execute("alter account testaccount set ALLOW_UNPROTECTED_ID_TOKEN=true;")
59+
await cnx.cursor().execute("alter account testaccount set ALLOW_ID_TOKEN=true;")
60+
await cnx.cursor().execute(
61+
"alter account testaccount set ID_TOKEN_FEATURE_ENABLED=true;"
62+
)
63+
64+
async def fin():
65+
async with snowflake.connector.connect(**CONNECTION_PARAMETERS_ADMIN) as cnx:
66+
await cnx.cursor().execute(
67+
"""
68+
ALTER SYSTEM SET
69+
MASTER_TOKEN_VALIDITY=default,
70+
SESSION_TOKEN_VALIDITY=default,
71+
ID_TOKEN_VALIDITY=default
72+
"""
73+
)
74+
75+
request.addfinalizer(fin)
76+
return None
77+
78+
79+
@pytest.mark.skipif(
80+
not (
81+
CONNECTION_PARAMETERS_SSO
82+
and CONNECTION_PARAMETERS_ADMIN
83+
and delete_temporary_credential
84+
),
85+
reason="SSO and ADMIN connection parameters must be provided.",
86+
)
87+
async def test_connect_externalbrowser(token_validity_test_values):
88+
"""SSO Id Token Cache tests. This test should only be ran if keyring optional dependency is installed.
89+
90+
In order to run this test, remove the above pytest.mark.skip annotation and run it. It will popup a windows once
91+
but the rest connections should not create popups.
92+
"""
93+
delete_temporary_credential(
94+
host=CONNECTION_PARAMETERS_SSO["host"],
95+
user=CONNECTION_PARAMETERS_SSO["user"],
96+
cred_type=ID_TOKEN,
97+
) # delete existing temporary credential
98+
CONNECTION_PARAMETERS_SSO["client_store_temporary_credential"] = True
99+
100+
# change database and schema to non-default one
101+
print(
102+
"[INFO] 1st connection gets id token and stores in the local cache (keychain/credential manager/cache file). "
103+
"This popup a browser to SSO login"
104+
)
105+
cnx = snowflake.connector.aio.SnowflakeConnection(**CONNECTION_PARAMETERS_SSO)
106+
await cnx.connect()
107+
assert cnx.database == "TESTDB"
108+
assert cnx.schema == "PUBLIC"
109+
assert cnx.role == "SYSADMIN"
110+
assert cnx.warehouse == "REGRESS"
111+
ret = await (
112+
await cnx.cursor().execute(
113+
"select current_database(), current_schema(), "
114+
"current_role(), current_warehouse()"
115+
)
116+
).fetchall()
117+
assert ret[0][0] == "TESTDB"
118+
assert ret[0][1] == "PUBLIC"
119+
assert ret[0][2] == "SYSADMIN"
120+
assert ret[0][3] == "REGRESS"
121+
await cnx.close()
122+
123+
print(
124+
"[INFO] 2nd connection reads the local cache and uses the id token. "
125+
"This should not popups a browser."
126+
)
127+
CONNECTION_PARAMETERS_SSO["database"] = "testdb"
128+
CONNECTION_PARAMETERS_SSO["schema"] = "testschema"
129+
cnx = snowflake.connector.aio.SnowflakeConnection(**CONNECTION_PARAMETERS_SSO)
130+
await cnx.connect()
131+
print(
132+
"[INFO] Running a 10 seconds query. If the session expires in 10 "
133+
"seconds, the query should renew the token in the middle, "
134+
"and the current objects should be refreshed."
135+
)
136+
await cnx.cursor().execute("select seq8() from table(generator(timelimit=>10))")
137+
assert cnx.database == "TESTDB"
138+
assert cnx.schema == "TESTSCHEMA"
139+
assert cnx.role == "SYSADMIN"
140+
assert cnx.warehouse == "REGRESS"
141+
142+
print("[INFO] Running a 1 second query. ")
143+
await cnx.cursor().execute("select seq8() from table(generator(timelimit=>1))")
144+
assert cnx.database == "TESTDB"
145+
assert cnx.schema == "TESTSCHEMA"
146+
assert cnx.role == "SYSADMIN"
147+
assert cnx.warehouse == "REGRESS"
148+
149+
print(
150+
"[INFO] Running a 90 seconds query. This pops up a browser in the "
151+
"middle of the query."
152+
)
153+
await cnx.cursor().execute("select seq8() from table(generator(timelimit=>90))")
154+
assert cnx.database == "TESTDB"
155+
assert cnx.schema == "TESTSCHEMA"
156+
assert cnx.role == "SYSADMIN"
157+
assert cnx.warehouse == "REGRESS"
158+
159+
await cnx.close()
160+
161+
# change database and schema again to ensure they are overridden
162+
CONNECTION_PARAMETERS_SSO["database"] = "testdb"
163+
CONNECTION_PARAMETERS_SSO["schema"] = "testschema"
164+
cnx = snowflake.connector.aio.SnowflakeConnection(**CONNECTION_PARAMETERS_SSO)
165+
await cnx.connect()
166+
assert cnx.database == "TESTDB"
167+
assert cnx.schema == "TESTSCHEMA"
168+
assert cnx.role == "SYSADMIN"
169+
assert cnx.warehouse == "REGRESS"
170+
await cnx.close()
171+
172+
async with snowflake.connector.aio.SnowflakeConnection(
173+
**CONNECTION_PARAMETERS_ADMIN
174+
) as cnx_admin:
175+
# cnx_admin.cursor().execute("alter account testaccount set ALLOW_UNPROTECTED_ID_TOKEN=false;")
176+
await cnx_admin.cursor().execute(
177+
"alter account testaccount set ALLOW_ID_TOKEN=false;"
178+
)
179+
await cnx_admin.cursor().execute(
180+
"alter account testaccount set ID_TOKEN_FEATURE_ENABLED=false;"
181+
)
182+
print(
183+
"[INFO] Login again with ALLOW_UNPROTECTED_ID_TOKEN unset. Please make sure this pops up the browser"
184+
)
185+
cnx = snowflake.connector.aio.SnowflakeConnection(**CONNECTION_PARAMETERS_SSO)
186+
await cnx.connect()
187+
await cnx.close()

0 commit comments

Comments
 (0)