diff --git a/chunk_downloader.py b/chunk_downloader.py index 3c419b179..e6f112181 100644 --- a/chunk_downloader.py +++ b/chunk_downloader.py @@ -11,8 +11,6 @@ from .errorcode import (ER_NO_ADDITIONAL_CHUNK, ER_CHUNK_DOWNLOAD_FAILED) from .errors import (Error, OperationalError) -from .network import (SnowflakeRestful, NO_TOKEN, MAX_CONNECTION_POOL) -from .ssl_wrap_socket import (set_proxies) DEFAULT_REQUEST_TIMEOUT = 3600 DEFAULT_CLIENT_RESULT_PREFETCH_SLOTS = 2 @@ -126,9 +124,7 @@ def _download_chunk(self, idx): logger.debug(u"started getting the result set %s: %s", idx + 1, self._chunks[idx].url) - result_data = self._get_request( - self._chunks[idx].url, - headers, max_connection_pool=self._effective_threads) + result_data = self._fetch_chunk(self._chunks[idx].url, headers) logger.debug(u"finished getting the result set %s: %s", idx + 1, self._chunks[idx].url) @@ -255,35 +251,15 @@ def __del__(self): # ignore all errors in the destructor pass - def _get_request( - self, url, headers, - is_raw_binary_iterator=True, - max_connection_pool=MAX_CONNECTION_POOL): + def _fetch_chunk(self, url, headers): """ - GET request for Large Result set chunkloader + Fetch the chunk from S3. """ - # sharing the proxy and certificate - proxies = set_proxies( - self._connection.rest._proxy_host, - self._connection.rest._proxy_port, - self._connection.rest._proxy_user, - self._connection.rest._proxy_password) - - logger.debug(u'proxies=%s, url=%s', proxies, url) - - return SnowflakeRestful.access_url( - self._connection, - self, - u'get', - full_url=url, - headers=headers, - data=None, - proxies=proxies, - timeout=(self._connection._connect_timeout, - self._connection._connect_timeout, - DEFAULT_REQUEST_TIMEOUT), - token=NO_TOKEN, - is_raw_binary=True, - is_raw_binary_iterator=is_raw_binary_iterator, - max_connection_pool=max_connection_pool, + timeouts = ( + self._connection._connect_timeout, + self._connection._connect_timeout, + DEFAULT_REQUEST_TIMEOUT + ) + return self._connection.rest.fetch(u'get', url, headers, + timeouts=timeouts, is_raw_binary=True, is_raw_binary_iterator=True, use_ijson=self._use_ijson) diff --git a/connection.py b/connection.py index a18e88ac9..b42e5be33 100644 --- a/connection.py +++ b/connection.py @@ -35,6 +35,7 @@ import logging from logging import getLogger + # default configs DEFAULT_CONFIGURATION = { u'dsn': None, # standard @@ -72,7 +73,6 @@ u'session_parameters': None, # snowflake internal u'autocommit': None, # snowflake u'numpy': False, # snowflake - u'max_connection_pool': network.MAX_CONNECTION_POOL, # snowflake internal u'ocsp_response_cache_filename': None, # snowflake internal u'converter_class': SnowflakeConverter, # snowflake internal u'chunk_downloader_class': SnowflakeChunkDownloader, # snowflake internal @@ -428,7 +428,6 @@ def __open_connection(self, mfa_callback, password_callback): connect_timeout=self._connect_timeout, request_timeout=self._request_timeout, injectClientPause=self._injectClientPause, - max_connection_pool=self._max_connection_pool, connection=self) self.logger.debug(u'REST API object was created: %s:%s, proxy=%s:%s, ' u'proxy_user=%s', diff --git a/network.py b/network.py index 1083daf6b..b5f121d60 100644 --- a/network.py +++ b/network.py @@ -3,20 +3,23 @@ # # Copyright (c) 2012-2017 Snowflake Computing Inc. All right reserved. # +import collections +import contextlib import copy import gzip +import itertools import json +import logging import platform import sys import time import uuid from io import StringIO, BytesIO -from logging import getLogger from threading import Thread import OpenSSL from botocore.vendored import requests -from botocore.vendored.requests.adapters import (HTTPAdapter, DEFAULT_POOLSIZE) +from botocore.vendored.requests.adapters import HTTPAdapter from botocore.vendored.requests.auth import AuthBase from botocore.vendored.requests.exceptions import (ConnectionError, SSLError) from botocore.vendored.requests.packages.urllib3.exceptions import ( @@ -37,14 +40,16 @@ GatewayTimeoutError, ServiceUnavailableError, InterfaceError, InternalServerError, ForbiddenError, BadGatewayError, BadRequest) -from .gzip_decoder import (decompress_raw_data) +from .gzip_decoder import decompress_raw_data from .sqlstate import (SQLSTATE_CONNECTION_NOT_EXISTS, SQLSTATE_CONNECTION_WAS_NOT_ESTABLISHED, SQLSTATE_CONNECTION_REJECTED) -from .ssl_wrap_socket import (set_proxies) +from .ssl_wrap_socket import set_proxies from .util_text import split_rows_from_stream from .version import VERSION +logger = logging.getLogger(__name__) + """ Monkey patch for PyOpenSSL Socket wrapper """ @@ -69,8 +74,6 @@ HEADER_AUTHORIZATION_KEY = u"Authorization" HEADER_SNOWFLAKE_TOKEN = u'Snowflake Token="{token}"' -MAX_CONNECTION_POOL = DEFAULT_POOLSIZE # max connetion pool size in urllib3 - SNOWFLAKE_CONNECTOR_VERSION = u'.'.join(TO_UNICODE(v) for v in VERSION[0:3]) PYTHON_VERSION = u'.'.join(TO_UNICODE(v) for v in sys.version_info[:3]) PLATFORM = platform.platform() @@ -135,7 +138,6 @@ def __init__(self, host=u'127.0.0.1', port=8080, connect_timeout=DEFAULT_CONNECT_TIMEOUT, request_timeout=DEFAULT_REQUEST_TIMEOUT, injectClientPause=0, - max_connection_pool=MAX_CONNECTION_POOL, connection=None): self._host = host self._port = port @@ -146,11 +148,11 @@ def __init__(self, host=u'127.0.0.1', port=8080, self._protocol = protocol self._connect_timeout = connect_timeout or DEFAULT_CONNECT_TIMEOUT self._request_timeout = request_timeout or DEFAULT_REQUEST_TIMEOUT - self._session = None self._injectClientPause = injectClientPause - self._max_connection_pool = max_connection_pool self._connection = connection - self.logger = getLogger(__name__) + self._idle_sessions = collections.deque() + self._active_sessions = set() + self._request_count = itertools.count() # insecure mode (disabled by default) ssl_wrap_socket.FEATURE_INSECURE_MODE = \ @@ -181,7 +183,17 @@ def close(self): del self._token if hasattr(self, u'_master_token'): del self._master_token - self._session = None + sessions = list(self._active_sessions) + if sessions: + logger.warn("Closing %d active sessions" % len(sessions)) + sessions.extend(self._idle_sessions) + self._active_sessions.clear() + self._idle_sessions.clear() + for s in sessions: + try: + s.close() + except Exception as e: + logger.warn("Session cleanup failed: %s" % e) def authenticate(self, account, user, password, master_token=None, token=None, database=None, schema=None, @@ -189,12 +201,12 @@ def authenticate(self, account, user, password, master_token=None, passcode_in_password=False, saml_response=None, mfa_callback=None, password_callback=None, session_parameters=None): - self.logger.info(u'authenticate') + logger.info(u'authenticate') if token and master_token: self._token = token self._master_token = token - self.logger.debug(u'token is given. no authentication was done') + logger.debug(u'token is given. no authentication was done') return application = self._connection.application if \ @@ -229,14 +241,14 @@ def authenticate(self, account, user, password, master_token=None, } body = copy.deepcopy(body_template) - self.logger.info(u'saml: %s', saml_response is not None) + logger.info(u'saml: %s', saml_response is not None) if saml_response: body[u'data'][u'RAW_SAML_RESPONSE'] = saml_response else: body[u'data'][u"LOGIN_NAME"] = user body[u'data'][u"PASSWORD"] = password - self.logger.debug( + logger.debug( u'account=%s, user=%s, database=%s, schema=%s, ' u'warehouse=%s, role=%s, request_id=%s', account, @@ -271,7 +283,7 @@ def authenticate(self, account, user, password, master_token=None, if session_parameters: body[u'data'][u'SESSION_PARAMETERS'] = session_parameters - self.logger.debug( + logger.debug( "body['data']: %s", {k: v for (k, v) in body[u'data'].items() if k != u'PASSWORD'}) @@ -359,7 +371,7 @@ def post_request_wrapper(self, url, headers, body): url, headers, json.dumps(body), timeout=self._connection._login_timeout) - self.logger.debug(u'completed authentication') + logger.debug(u'completed authentication') if not ret[u'success']: Error.errorhandler_wrapper( self._connection, None, DatabaseError, @@ -381,8 +393,8 @@ def post_request_wrapper(self, url, headers, body): else: self._token = ret[u'data'][u'token'] self._master_token = ret[u'data'][u'masterToken'] - self.logger.debug(u'token = %s', self._token) - self.logger.debug(u'master_token = %s', self._master_token) + logger.debug(u'token = %s', self._token) + logger.debug(u'master_token = %s', self._master_token) if u'sessionId' in ret[u'data']: self._connection._session_id = ret[u'data'][u'sessionId'] if u'sessionInfo' in ret[u'data']: @@ -439,15 +451,15 @@ def _renew_session(self): u'sqlstate': SQLSTATE_CONNECTION_NOT_EXISTS, }) - self.logger.debug(u'updating session') - self.logger.debug(u'master_token: %s', self._master_token) + logger.debug(u'updating session') + logger.debug(u'master_token: %s', self._master_token) headers = { u'Content-Type': CONTENT_TYPE_APPLICATION_JSON, u"accept": CONTENT_TYPE_APPLICATION_JSON, u"User-Agent": PYTHON_CONNECTOR_USER_AGENT, } request_id = TO_UNICODE(uuid.uuid4()) - self.logger.debug(u'request_id: %s', request_id) + logger.debug(u'request_id: %s', request_id) url = u'/session/token-request?' + urlencode({ u'requestId': request_id}) @@ -462,13 +474,13 @@ def _renew_session(self): timeout=self._connection._network_timeout) if ret[u'success'] and u'data' in ret \ and u'sessionToken' in ret[u'data']: - self.logger.debug(u'success: %s', ret) + logger.debug(u'success: %s', ret) self._token = ret[u'data'][u'sessionToken'] self._master_token = ret[u'data'][u'masterToken'] - self.logger.debug(u'updating session completed') + logger.debug(u'updating session completed') return ret else: - self.logger.debug(u'failed: %s', ret) + logger.debug(u'failed: %s', ret) err = ret[u'message'] if u'data' in ret and u'errorMessage' in ret[u'data']: err += ret[u'data'][u'errorMessage'] @@ -522,26 +534,11 @@ def _get_request(self, url, headers, token=None, timeout=None): port=self._port, url=url, ) - proxies = set_proxies( - self._proxy_host, self._proxy_port, self._proxy_user, - self._proxy_password - ) - self.logger.debug(u'url=%s, proxies=%s', full_url, proxies) - ret = SnowflakeRestful.access_url( - conn=self._connection, - session_context=self, - method=u'get', - full_url=full_url, - headers=headers, - data=None, - proxies=proxies, - timeout=(self._connect_timeout, self._connect_timeout, timeout), - token=token, - max_connection_pool=self._max_connection_pool) - + ret = self.fetch(u'get', full_url, headers, timeout=timeout, + token=token) if u'code' in ret and ret[u'code'] == SESSION_EXPIRED_GS_CODE: ret = self._renew_session() - self.logger.debug( + logger.debug( u'ret[code] = {code} after renew_session'.format( code=(ret[u'code'] if u'code' in ret else u'N/A'))) if u'success' in ret and ret[u'success']: @@ -557,29 +554,15 @@ def _post_request(self, url, headers, body, token=None, port=self._port, url=url, ) - proxies = set_proxies( - self._proxy_host, self._proxy_port, self._proxy_user, - self._proxy_password) - - ret = SnowflakeRestful.access_url( - conn=self._connection, - session_context=self, - method=u'post', - full_url=full_url, - headers=headers, - data=body, - proxies=proxies, - timeout=( - self._connect_timeout, self._connect_timeout, timeout), - token=token, - max_connection_pool=self._max_connection_pool) - self.logger.debug( + ret = self.fetch(u'post', full_url, headers, data=body, + timeout=timeout, token=token) + logger.debug( u'ret[code] = {code}, after post request'.format( code=(ret.get(u'code', u'N/A')))) if u'code' in ret and ret[u'code'] == SESSION_EXPIRED_GS_CODE: ret = self._renew_session() - self.logger.debug( + logger.debug( u'ret[code] = {code} after renew_session'.format( code=(ret[u'code'] if u'code' in ret else u'N/A'))) if u'success' in ret and ret[u'success']: @@ -595,23 +578,23 @@ def _post_request(self, url, headers, body, token=None, while is_session_renewed or u'code' in ret and ret[u'code'] in \ (QUERY_IN_PROGRESS_CODE, QUERY_IN_PROGRESS_ASYNC_CODE): if self._injectClientPause > 0: - self.logger.debug( + logger.debug( u'waiting for {inject_client_pause}...'.format( inject_client_pause=self._injectClientPause)) time.sleep(self._injectClientPause) # ping pong result_url = ret[u'data'][ u'getResultUrl'] if not is_session_renewed else result_url - self.logger.debug(u'ping pong starting...') + logger.debug(u'ping pong starting...') ret = self._get_request( result_url, headers, token=self._token, timeout=timeout) - self.logger.debug( + logger.debug( u'ret[code] = %s', ret[u'code'] if u'code' in ret else u'N/A') - self.logger.debug(u'ping pong done') + logger.debug(u'ping pong done') if u'code' in ret and ret[u'code'] == SESSION_EXPIRED_GS_CODE: ret = self._renew_session() - self.logger.debug( + logger.debug( u'ret[code] = %s after renew_session', ret[u'code'] if u'code' in ret else u'N/A') if u'success' in ret and ret[u'success']: @@ -621,43 +604,41 @@ def _post_request(self, url, headers, body, token=None, return ret - @staticmethod - def access_url(conn, session_context, method, full_url, headers, data, - proxies, timeout=( - DEFAULT_CONNECT_TIMEOUT, - DEFAULT_CONNECT_TIMEOUT, - DEFAULT_REQUEST_TIMEOUT), - requests_retry=REQUESTS_RETRY, - token=None, - is_raw_text=False, - catch_okta_unauthorized_error=False, - is_raw_binary=False, - is_raw_binary_iterator=True, - max_connection_pool=MAX_CONNECTION_POOL, - use_ijson=False, is_single_thread=False): - logger = getLogger(__name__) - - connection_timeout = timeout[0:2] - request_timeout = timeout[2] # total request timeout - request_thread_timeout = 60 # one request thread timeout - - def request_thread(result_queue): + def fetch(self, method, full_url, headers, *, data=None, timeout=None, **kwargs): + """ Curried API request with session management. """ + if timeout is not None and 'timeouts' in kwargs: + raise TypeError("Mutually exclusive args: timeout, timeouts") + if timeout is None: + timeout = self._request_timeout + timeouts = kwargs.pop('timeouts', (self._connect_timeout, + self._connect_timeout, timeout)) + proxies = set_proxies(self._proxy_host, self._proxy_port, + self._proxy_user, self._proxy_password) + with self._use_requests_session() as session: + return self._fetch(session, method, full_url, headers, data, + proxies, timeouts, **kwargs) + + def _fetch(self, session, method, full_url, headers, data, proxies, + timeouts=(DEFAULT_CONNECT_TIMEOUT, DEFAULT_CONNECT_TIMEOUT, + DEFAULT_REQUEST_TIMEOUT), + token=NO_TOKEN, + is_raw_text=False, + catch_okta_unauthorized_error=False, + is_raw_binary=False, + is_raw_binary_iterator=True, + use_ijson=False, is_single_thread=False): + """ This is the lowest level of HTTP handling. All arguments culminate + here and the `requests.request` is issued and monitored from this + call using an inline thread for timeout monitoring. """ + connection_timeout = timeouts[0:2] + request_timeout = timeouts[2] # total request timeout + request_exec_timeout = 60 # one request thread timeout + conn = self._connection + proxies = set_proxies(conn.rest._proxy_host, conn.rest._proxy_port, + conn.rest._proxy_user, conn.rest._proxy_password) + + def request_exec(result_queue): try: - if session_context._session is None: - session_context._session = requests.Session() - session_context._session.mount( - u'http://', - HTTPAdapter( - pool_connections=int(max_connection_pool), - pool_maxsize=int(max_connection_pool), - max_retries=requests_retry)) - session_context._session.mount( - u'https://', - HTTPAdapter( - pool_connections=int(max_connection_pool), - pool_maxsize=int(max_connection_pool), - max_retries=requests_retry)) - if not catch_okta_unauthorized_error and data and len(data) > 0: gzdata = BytesIO() gzip.GzipFile(fileobj=gzdata, mode=u'wb').write( @@ -668,7 +649,7 @@ def request_thread(result_queue): else: input_data = data - raw_ret = session_context._session.request( + raw_ret = session.request( method=method, url=full_url, proxies=proxies, @@ -762,7 +743,7 @@ def request_thread(result_queue): if is_single_thread: # This is dedicated code for DELETE SESSION when Python exists. request_result_queue = Queue() - request_thread(request_result_queue) + request_exec(request_result_queue) try: # don't care about the return value, because no retry and # no error will show up @@ -775,19 +756,19 @@ def request_thread(result_queue): while True: return_object = None request_result_queue = Queue() - th = Thread(name='request_thread', target=request_thread, - args=(request_result_queue,)) + th = Thread(name='RequestExec-%d' % next(self._request_count), + target=request_exec, args=(request_result_queue,)) th.daemon = True th.start() try: logger.debug('request thread timeout: %s, ' 'rest of request timeout: %s, ' 'retry cnt: %s', - request_thread_timeout, + request_exec_timeout, request_timeout, retry_cnt + 1) start_request_thread = time.time() - th.join(timeout=request_thread_timeout) + th.join(timeout=request_exec_timeout) logger.debug('request thread joined') if request_timeout is not None: request_timeout -= min( @@ -795,7 +776,7 @@ def request_thread(result_queue): request_timeout) start_get_queue = time.time() return_object, retryable = request_result_queue.get( - timeout=int(request_thread_timeout / 4)) + timeout=int(request_exec_timeout / 4)) if request_timeout is not None: request_timeout -= min( int(time.time() - start_get_queue), request_timeout) @@ -886,11 +867,18 @@ def request_thread(result_queue): }) return return_object + def make_requests_session(self): + s = requests.Session() + s.mount(u'http://', HTTPAdapter(max_retries=REQUESTS_RETRY)) + s.mount(u'https://', HTTPAdapter(max_retries=REQUESTS_RETRY)) + s._reuse_count = itertools.count() + return s + def authenticate_by_saml(self, authenticator, account, user, password): u""" SAML Authentication """ - self.logger.info(u'authenticating by SAML') + logger.info(u'authenticating by SAML') headers = { u'Content-Type': CONTENT_TYPE_APPLICATION_JSON, u"accept": CONTENT_TYPE_APPLICATION_JSON, @@ -907,7 +895,7 @@ def authenticate_by_saml(self, authenticator, account, user, password): }, } - self.logger.debug( + logger.debug( u'account=%s, authenticator=%s', account, authenticator, ) @@ -940,28 +928,13 @@ def authenticate_by_saml(self, authenticator, account, user, password): token_url = data[u'tokenUrl'] sso_url = data[u'ssoUrl'] - proxies = set_proxies( - self._proxy_host, self._proxy_port, self._proxy_user, - self._proxy_password) - self.logger.debug(u'token_url=%s, proxies=%s', token_url, proxies) - data = { u'username': user, u'password': password, } - self.logger.debug(u'token url: %s', token_url) - ret = SnowflakeRestful.access_url( - conn=self._connection, - session_context=self, - method=u'post', - full_url=token_url, - headers=headers, - data=json.dumps(data), - proxies=proxies, - timeout=(self._connect_timeout, - self._connect_timeout, - self._connection._login_timeout), - catch_okta_unauthorized_error=True) + ret = self.fetch(u'post', token_url, headers, data=json.dumps(data), + timeout=self._connection._login_timeout, + catch_okta_unauthorized_error=True) one_time_token = ret[u'cookieToken'] url_parameters = { @@ -973,17 +946,28 @@ def authenticate_by_saml(self, authenticator, account, user, password): headers = { u"Accept": u'*/*', } - self.logger.debug(u'sso url: %s', sso_url) - ret = SnowflakeRestful.access_url( - conn=self._connection, - session_context=self, - method=u'get', - full_url=sso_url, - headers=headers, - data=None, - proxies=proxies, - timeout=(self._connect_timeout, - self._connect_timeout, - self._connection._login_timeout), - is_raw_text=True) - return ret + return self.fetch(u'get', sso_url, headers, + timeout=self._connection._login_timeout, + is_raw_text=True) + + @contextlib.contextmanager + def _use_requests_session(self): + """ Session caching context manager. Note that the session is not + closed until close() is called so each session may be used multiple + times. """ + try: + session = self._idle_sessions.pop() + except IndexError: + session = self.make_requests_session() + self._active_sessions.add(session) + logger.info("Active requests sessions: %d, idle: %d" % ( + len(self._active_sessions), len(self._idle_sessions))) + try: + yield session + finally: + self._idle_sessions.appendleft(session) + self._active_sessions.remove(session) + active = len(self._active_sessions) + idle = len(self._idle_sessions) + logger.info("Active requests sessions: %d, idle: %d" % (active, + idle)) diff --git a/ocsp_pyopenssl.py b/ocsp_pyopenssl.py index c753fa05f..522c25821 100644 --- a/ocsp_pyopenssl.py +++ b/ocsp_pyopenssl.py @@ -539,9 +539,6 @@ def execute_ocsp_request(ocsp_uri, cert_id, proxies=None, do_retry=True): # transform objects into data in requests data = der_encoder.encode(ocsp_request) parsed_url = urlsplit(ocsp_uri) - session = requests.Session() - session.mount('http://', HTTPAdapter(max_retries=5)) - session.mount('https://', HTTPAdapter(max_retries=5)) max_retry = 100 if do_retry else 1 # NOTE: This retry is to retry getting HTTP 200. @@ -554,25 +551,27 @@ def execute_ocsp_request(ocsp_uri, cert_id, proxies=None, do_retry=True): } logger.debug('url: %s, headers: %s, proxies: %s', ocsp_uri, headers, proxies) - for attempt in range(max_retry): - response = session.post( - ocsp_uri, - headers=headers, - proxies=proxies, - data=data) - if response.status_code == OK: - logger.debug("OCSP response was successfully returned") - break - elif max_retry > 1: - wait_time = 2 ** attempt - wait_time = 16 if wait_time > 16 else wait_time - logger.debug("OCSP server returned %s. Retrying in %s(s)", - response.status_code, wait_time) - time.sleep(wait_time) - else: - logger.error("Failed to get OCSP response after %s attempt.", - max_retry) - + with requests.Session() as session: + session.mount('http://', HTTPAdapter(max_retries=5)) + session.mount('https://', HTTPAdapter(max_retries=5)) + for attempt in range(max_retry): + response = session.post( + ocsp_uri, + headers=headers, + proxies=proxies, + data=data) + if response.status_code == OK: + logger.debug("OCSP response was successfully returned") + break + elif max_retry > 1: + wait_time = 2 ** attempt + wait_time = 16 if wait_time > 16 else wait_time + logger.debug("OCSP server returned %s. Retrying in %s(s)", + response.status_code, wait_time) + time.sleep(wait_time) + else: + logger.error("Failed to get OCSP response after %s attempt.", + max_retry) return response.content @@ -793,10 +792,10 @@ def download_ocsp_response_cache(url): Downloads OCSP response cache from Snowflake. """ import binascii - session = requests.session() - session.mount('http://', HTTPAdapter(max_retries=5)) - session.mount('https://', HTTPAdapter(max_retries=5)) - response = session.get(url) + with requests.Session() as session: + session.mount('http://', HTTPAdapter(max_retries=5)) + session.mount('https://', HTTPAdapter(max_retries=5)) + response = session.get(url) if response.status_code == OK: try: _decode_ocsp_response_cache(response.json(), OCSP_VALIDATION_CACHE)