Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions chunk_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .errorcode import (ER_NO_ADDITIONAL_CHUNK, ER_CHUNK_DOWNLOAD_FAILED)
from .errors import (Error, OperationalError)
from .network import (SnowflakeRestful, NO_TOKEN, MAX_CONNECTION_POOL)
from .network import (SnowflakeRestful, NO_TOKEN, make_session)
from .ssl_wrap_socket import (set_proxies)

DEFAULT_REQUEST_TIMEOUT = 3600
Expand Down Expand Up @@ -126,9 +126,7 @@ def _download_chunk(self, idx):

logger.debug(u"started getting the result set %s: %s",
idx + 1, self._chunks[idx].url)
result_data = self._get_request(
self._chunks[idx].url,
headers, max_connection_pool=self._effective_threads)
result_data = self._get_request(self._chunks[idx].url, headers)
logger.debug(u"finished getting the result set %s: %s",
idx + 1, self._chunks[idx].url)

Expand Down Expand Up @@ -255,10 +253,7 @@ def __del__(self):
# ignore all errors in the destructor
pass

def _get_request(
self, url, headers,
is_raw_binary_iterator=True,
max_connection_pool=MAX_CONNECTION_POOL):
def _get_request(self, url, headers, is_raw_binary_iterator=True):
"""
GET request for Large Result set chunkloader
"""
Expand All @@ -273,7 +268,7 @@ def _get_request(

return SnowflakeRestful.access_url(
self._connection,
self,
make_session(),
u'get',
full_url=url,
headers=headers,
Expand All @@ -285,5 +280,4 @@ def _get_request(
token=NO_TOKEN,
is_raw_binary=True,
is_raw_binary_iterator=is_raw_binary_iterator,
max_connection_pool=max_connection_pool,
use_ijson=self._use_ijson)
3 changes: 1 addition & 2 deletions connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import logging
from logging import getLogger


# default configs
DEFAULT_CONFIGURATION = {
u'dsn': None, # standard
Expand Down Expand Up @@ -72,7 +73,6 @@
u'session_parameters': None, # snowflake internal
u'autocommit': None, # snowflake
u'numpy': False, # snowflake
u'max_connection_pool': network.MAX_CONNECTION_POOL, # snowflake internal
u'ocsp_response_cache_filename': None, # snowflake internal
u'converter_class': SnowflakeConverter, # snowflake internal
u'chunk_downloader_class': SnowflakeChunkDownloader, # snowflake internal
Expand Down Expand Up @@ -428,7 +428,6 @@ def __open_connection(self, mfa_callback, password_callback):
connect_timeout=self._connect_timeout,
request_timeout=self._request_timeout,
injectClientPause=self._injectClientPause,
max_connection_pool=self._max_connection_pool,
connection=self)
self.logger.debug(u'REST API object was created: %s:%s, proxy=%s:%s, '
u'proxy_user=%s',
Expand Down
119 changes: 68 additions & 51 deletions network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#
# Copyright (c) 2012-2017 Snowflake Computing Inc. All right reserved.
#
import collections
import contextlib
import copy
import gzip
import json
Expand All @@ -16,7 +18,7 @@

import OpenSSL
from botocore.vendored import requests
from botocore.vendored.requests.adapters import (HTTPAdapter, DEFAULT_POOLSIZE)
from botocore.vendored.requests.adapters import HTTPAdapter
from botocore.vendored.requests.auth import AuthBase
from botocore.vendored.requests.exceptions import (ConnectionError, SSLError)
from botocore.vendored.requests.packages.urllib3.exceptions import (
Expand Down Expand Up @@ -69,8 +71,6 @@
HEADER_AUTHORIZATION_KEY = u"Authorization"
HEADER_SNOWFLAKE_TOKEN = u'Snowflake Token="{token}"'

MAX_CONNECTION_POOL = DEFAULT_POOLSIZE # max connetion pool size in urllib3

SNOWFLAKE_CONNECTOR_VERSION = u'.'.join(TO_UNICODE(v) for v in VERSION[0:3])
PYTHON_VERSION = u'.'.join(TO_UNICODE(v) for v in sys.version_info[:3])
PLATFORM = platform.platform()
Expand Down Expand Up @@ -99,6 +99,14 @@
}


def make_session():
""" Make a custom requests.Session instance. """
s = requests.Session()
s.mount(u'http://', HTTPAdapter(max_retries=REQUESTS_RETRY))
s.mount(u'https://', HTTPAdapter(max_retries=REQUESTS_RETRY))
return s


class RequestRetry(Exception):
pass

Expand Down Expand Up @@ -135,7 +143,6 @@ def __init__(self, host=u'127.0.0.1', port=8080,
connect_timeout=DEFAULT_CONNECT_TIMEOUT,
request_timeout=DEFAULT_REQUEST_TIMEOUT,
injectClientPause=0,
max_connection_pool=MAX_CONNECTION_POOL,
connection=None):
self._host = host
self._port = port
Expand All @@ -146,10 +153,10 @@ def __init__(self, host=u'127.0.0.1', port=8080,
self._protocol = protocol
self._connect_timeout = connect_timeout or DEFAULT_CONNECT_TIMEOUT
self._request_timeout = request_timeout or DEFAULT_REQUEST_TIMEOUT
self._session = None
self._injectClientPause = injectClientPause
self._max_connection_pool = max_connection_pool
self._connection = connection
self._avail_sessions = collections.deque()
self._active_sessions = set()
self.logger = getLogger(__name__)

# insecure mode (disabled by default)
Expand Down Expand Up @@ -181,7 +188,17 @@ def close(self):
del self._token
if hasattr(self, u'_master_token'):
del self._master_token
self._session = None
sessions = list(self._active_sessions)
if sessions:
self.logger.warn("Closing %d active sessions" % len(sessions))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use parameterized logging, e.g., self.logger.warn("Closing %s active sessions", len(sessions))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only avoid that syntax because it only works for logging functions and I sometimes switch back and forth with other printers like print(). But I can use printf style formatting from now on.

sessions.extend(self._avail_sessions)
self._active_sessions.clear()
self._avail_sessions.clear()
for s in sessions:
try:
s.close()
except Exception as e:
self.logger.warn("Session cleanup failed: %s" % e)

def authenticate(self, account, user, password, master_token=None,
token=None, database=None, schema=None,
Expand Down Expand Up @@ -527,17 +544,20 @@ def _get_request(self, url, headers, token=None, timeout=None):
self._proxy_password
)
self.logger.debug(u'url=%s, proxies=%s', full_url, proxies)
ret = SnowflakeRestful.access_url(
conn=self._connection,
session_context=self,
method=u'get',
full_url=full_url,
headers=headers,
data=None,
proxies=proxies,
timeout=(self._connect_timeout, self._connect_timeout, timeout),
token=token,
max_connection_pool=self._max_connection_pool)
with self._use_session() as session:
ret = SnowflakeRestful.access_url(
conn=self._connection,
session=session,
method=u'get',
full_url=full_url,
headers=headers,
data=None,
proxies=proxies,
timeout=(
self._connect_timeout,
self._connect_timeout,
timeout),
token=token)

if u'code' in ret and ret[u'code'] == SESSION_EXPIRED_GS_CODE:
ret = self._renew_session()
Expand All @@ -561,18 +581,18 @@ def _post_request(self, url, headers, body, token=None,
self._proxy_host, self._proxy_port, self._proxy_user,
self._proxy_password)

ret = SnowflakeRestful.access_url(
conn=self._connection,
session_context=self,
method=u'post',
full_url=full_url,
headers=headers,
data=body,
proxies=proxies,
timeout=(
self._connect_timeout, self._connect_timeout, timeout),
token=token,
max_connection_pool=self._max_connection_pool)
with self._use_session() as session:
ret = SnowflakeRestful.access_url(
conn=self._connection,
session=session,
method=u'post',
full_url=full_url,
headers=headers,
data=body,
proxies=proxies,
timeout=(
self._connect_timeout, self._connect_timeout, timeout),
token=token)
self.logger.debug(
u'ret[code] = {code}, after post request'.format(
code=(ret.get(u'code', u'N/A'))))
Expand Down Expand Up @@ -622,18 +642,16 @@ def _post_request(self, url, headers, body, token=None,
return ret

@staticmethod
def access_url(conn, session_context, method, full_url, headers, data,
def access_url(conn, session, method, full_url, headers, data,
proxies, timeout=(
DEFAULT_CONNECT_TIMEOUT,
DEFAULT_CONNECT_TIMEOUT,
DEFAULT_REQUEST_TIMEOUT),
requests_retry=REQUESTS_RETRY,
token=None,
is_raw_text=False,
catch_okta_unauthorized_error=False,
is_raw_binary=False,
is_raw_binary_iterator=True,
max_connection_pool=MAX_CONNECTION_POOL,
use_ijson=False, is_single_thread=False):
logger = getLogger(__name__)

Expand All @@ -643,21 +661,6 @@ def access_url(conn, session_context, method, full_url, headers, data,

def request_thread(result_queue):
try:
if session_context._session is None:
session_context._session = requests.Session()
session_context._session.mount(
u'http://',
HTTPAdapter(
pool_connections=int(max_connection_pool),
pool_maxsize=int(max_connection_pool),
max_retries=requests_retry))
session_context._session.mount(
u'https://',
HTTPAdapter(
pool_connections=int(max_connection_pool),
pool_maxsize=int(max_connection_pool),
max_retries=requests_retry))

if not catch_okta_unauthorized_error and data and len(data) > 0:
gzdata = BytesIO()
gzip.GzipFile(fileobj=gzdata, mode=u'wb').write(
Expand All @@ -668,7 +671,7 @@ def request_thread(result_queue):
else:
input_data = data

raw_ret = session_context._session.request(
raw_ret = session.request(
method=method,
url=full_url,
proxies=proxies,
Expand Down Expand Up @@ -952,7 +955,6 @@ def authenticate_by_saml(self, authenticator, account, user, password):
self.logger.debug(u'token url: %s', token_url)
ret = SnowflakeRestful.access_url(
conn=self._connection,
session_context=self,
method=u'post',
full_url=token_url,
headers=headers,
Expand All @@ -976,7 +978,6 @@ def authenticate_by_saml(self, authenticator, account, user, password):
self.logger.debug(u'sso url: %s', sso_url)
ret = SnowflakeRestful.access_url(
conn=self._connection,
session_context=self,
method=u'get',
full_url=sso_url,
headers=headers,
Expand All @@ -987,3 +988,19 @@ def authenticate_by_saml(self, authenticator, account, user, password):
self._connection._login_timeout),
is_raw_text=True)
return ret

@contextlib.contextmanager
def _use_session(self):
""" Session caching context manager. Note that the session is not
closed until Snowflakerestful.close is called so each session may be
used multiple times. """
try:
session = self._avail_sessions.pop()
except IndexError:
session = make_session()
self._active_sessions.add(session)
try:
yield session
finally:
self._avail_sessions.appendleft(session)
self._active_sessions.remove(session)
51 changes: 25 additions & 26 deletions ocsp_pyopenssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,9 +539,6 @@ def execute_ocsp_request(ocsp_uri, cert_id, proxies=None, do_retry=True):
# transform objects into data in requests
data = der_encoder.encode(ocsp_request)
parsed_url = urlsplit(ocsp_uri)
session = requests.Session()
session.mount('http://', HTTPAdapter(max_retries=5))
session.mount('https://', HTTPAdapter(max_retries=5))

max_retry = 100 if do_retry else 1
# NOTE: This retry is to retry getting HTTP 200.
Expand All @@ -554,25 +551,27 @@ def execute_ocsp_request(ocsp_uri, cert_id, proxies=None, do_retry=True):
}
logger.debug('url: %s, headers: %s, proxies: %s',
ocsp_uri, headers, proxies)
for attempt in range(max_retry):
response = session.post(
ocsp_uri,
headers=headers,
proxies=proxies,
data=data)
if response.status_code == OK:
logger.debug("OCSP response was successfully returned")
break
elif max_retry > 1:
wait_time = 2 ** attempt
wait_time = 16 if wait_time > 16 else wait_time
logger.debug("OCSP server returned %s. Retrying in %s(s)",
response.status_code, wait_time)
time.sleep(wait_time)
else:
logger.error("Failed to get OCSP response after %s attempt.",
max_retry)

with requests.Session() as session:
session.mount('http://', HTTPAdapter(max_retries=5))
session.mount('https://', HTTPAdapter(max_retries=5))
for attempt in range(max_retry):
response = session.post(
ocsp_uri,
headers=headers,
proxies=proxies,
data=data)
if response.status_code == OK:
logger.debug("OCSP response was successfully returned")
break
elif max_retry > 1:
wait_time = 2 ** attempt
wait_time = 16 if wait_time > 16 else wait_time
logger.debug("OCSP server returned %s. Retrying in %s(s)",
response.status_code, wait_time)
time.sleep(wait_time)
else:
logger.error("Failed to get OCSP response after %s attempt.",
max_retry)
return response.content


Expand Down Expand Up @@ -793,10 +792,10 @@ def download_ocsp_response_cache(url):
Downloads OCSP response cache from Snowflake.
"""
import binascii
session = requests.session()
session.mount('http://', HTTPAdapter(max_retries=5))
session.mount('https://', HTTPAdapter(max_retries=5))
response = session.get(url)
with requests.Session() as session:
session.mount('http://', HTTPAdapter(max_retries=5))
session.mount('https://', HTTPAdapter(max_retries=5))
response = session.get(url)
if response.status_code == OK:
try:
_decode_ocsp_response_cache(response.json(), OCSP_VALIDATION_CACHE)
Expand Down