Skip to content

Commit b6e0a38

Browse files
committed
SNOW-1572304: asyncio add proxy support and test (#2066)
1 parent a7f35b8 commit b6e0a38

File tree

2 files changed

+20
-19
lines changed

2 files changed

+20
-19
lines changed

src/snowflake/connector/aio/_network.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import TYPE_CHECKING, Any
1616

1717
import OpenSSL.SSL
18+
from urllib3.util.url import parse_url
1819

1920
from ..compat import (
2021
FORBIDDEN,
@@ -80,7 +81,7 @@
8081
SQLSTATE_CONNECTION_REJECTED,
8182
SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED,
8283
)
83-
from ..time_util import TimeoutBackoffCtx, get_time_millis
84+
from ..time_util import TimeoutBackoffCtx
8485
from ._ssl_connector import SnowflakeSSLConnector
8586

8687
if TYPE_CHECKING:
@@ -162,6 +163,10 @@ def __init__(
162163
self._ocsp_mode = (
163164
self._connection._ocsp_mode() if self._connection else OCSPMode.FAIL_OPEN
164165
)
166+
if self._connection.proxy_host:
167+
self._get_proxy_headers = lambda url: {"Host": parse_url(url).hostname}
168+
else:
169+
self._get_proxy_headers = lambda _: None
165170

166171
async def close(self) -> None:
167172
if hasattr(self, "_token"):
@@ -704,11 +709,6 @@ async def _request_exec(
704709
else:
705710
input_data = data
706711

707-
download_start_time = get_time_millis()
708-
# socket timeout is constant. You should be able to receive
709-
# the response within the time. If not, ConnectReadTimeout or
710-
# ReadTimeout is raised.
711-
712712
# TODO: aiohttp auth parameter works differently than requests.session.request
713713
# we can check if there's other aiohttp built-in mechanism to update this
714714
if HEADER_AUTHORIZATION_KEY in headers:
@@ -718,26 +718,31 @@ async def _request_exec(
718718
token=token
719719
)
720720

721-
# TODO: sync feature parity, parameters verify/stream in sync version
721+
# socket timeout is constant. You should be able to receive
722+
# the response within the time. If not, asyncio.TimeoutError is raised.
723+
724+
# delta compared to sync:
725+
# - in sync, we specify "verify" to True; in aiohttp,
726+
# the counter parameter is "ssl" and it already defaults to True
722727
raw_ret = await session.request(
723728
method=method,
724729
url=full_url,
725730
headers=headers,
726731
data=input_data,
727732
timeout=aiohttp.ClientTimeout(socket_timeout),
733+
proxy_headers=self._get_proxy_headers(full_url),
728734
)
729-
730-
download_end_time = get_time_millis()
731-
732735
try:
733736
if raw_ret.status == OK:
734737
logger.debug("SUCCESS")
735738
if is_raw_text:
736739
ret = await raw_ret.text()
737740
elif is_raw_binary:
738-
content = await raw_ret.read()
739-
ret = binary_data_handler.to_iterator(
740-
content, download_end_time - download_start_time
741+
# check SNOW-1738595 for is_raw_binary support
742+
raise NotImplementedError(
743+
"reading raw binary data is not supported in asyncio connector,"
744+
" please open a feature request issue in"
745+
" github: https://github.com/snowflakedb/snowflake-connector-python/issues/new/choose"
741746
)
742747
else:
743748
ret = await raw_ret.json()
@@ -818,12 +823,9 @@ async def _request_exec(
818823

819824
def make_requests_session(self) -> aiohttp.ClientSession:
820825
s = aiohttp.ClientSession(
821-
connector=SnowflakeSSLConnector(snowflake_ocsp_mode=self._ocsp_mode)
826+
connector=SnowflakeSSLConnector(snowflake_ocsp_mode=self._ocsp_mode),
827+
trust_env=True, # this is for proxy support, proxy.set_proxy will set envs and trust_env allows reading env
822828
)
823-
# TODO: sync feature parity, proxy support
824-
# s.mount("http://", ProxySupportAdapter(max_retries=REQUESTS_RETRY))
825-
# s.mount("https://", ProxySupportAdapter(max_retries=REQUESTS_RETRY))
826-
# s._reuse_count = itertools.count()
827829
return s
828830

829831
@contextlib.asynccontextmanager

test/integ/aio/test_connection_async.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,6 @@ async def test_invalid_account_timeout():
578578
pass
579579

580580

581-
@pytest.mark.skip("SNOW-1572304 proxy support")
582581
@pytest.mark.timeout(15)
583582
async def test_invalid_proxy(db_parameters):
584583
with pytest.raises(OperationalError):

0 commit comments

Comments
 (0)