Skip to content

Commit d4a11fd

Browse files
author
Kevin D Smith
committed
Add support for authorization codes
1 parent 1726d99 commit d4a11fd

File tree

3 files changed

+68
-49
lines changed

3 files changed

+68
-49
lines changed

swat/cas/connection.py

Lines changed: 41 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@
3232
import os
3333
import random
3434
import re
35+
import requests
3536
import six
3637
import warnings
3738
import weakref
38-
from six.moves.urllib.parse import urlparse
39+
from six.moves.urllib.parse import urlparse, urlencode, urljoin
3940
from . import rest
4041
from .. import clib
4142
from .. import config as cf
@@ -177,6 +178,8 @@ class CAS(object):
177178
Base path of URL when using the REST protocol.
178179
ssl_ca_list : string, optional
179180
The path to the SSL certificates for the CAS server.
181+
authcode : string, optional
182+
Authorization code from SASLogon used to retrieve an OAuth token.
180183
**kwargs : any, optional
181184
Arbitrary keyword arguments used for internal purposes only.
182185
@@ -346,7 +349,7 @@ def _get_connection_info(cls, hostname, port, username, password, protocol, path
346349
def __init__(self, hostname=None, port=None, username=None, password=None,
347350
session=None, locale=None, nworkers=None, name=None,
348351
authinfo=None, protocol=None, path=None, ssl_ca_list=None,
349-
auth_code=None, **kwargs):
352+
authcode=None, **kwargs):
350353

351354
# Filter session options allowed as parameters
352355
_kwargs = {}
@@ -392,11 +395,11 @@ def __init__(self, hostname=None, port=None, username=None, password=None,
392395
soptions = a2n(getsoptions(session=session, locale=locale,
393396
nworkers=nworkers, protocol=protocol))
394397

395-
# Check for auth_code authentication
396-
auth_code = auth_code or cf.get_option('cas.auth_code')
397-
if protocol in ['http', 'https'] and auth_code:
398+
# Check for authcode authentication
399+
authcode = authcode or cf.get_option('cas.authcode')
400+
if protocol in ['http', 'https'] and authcode:
398401
username = None
399-
password = self.get_token(auth_code=auth_code, url=hostname)
402+
password = type(self)._get_token(authcode=authcode, url=hostname)
400403

401404
# Create error handler
402405
try:
@@ -529,33 +532,41 @@ def _id_generator():
529532
num = num + 1
530533
self._id_generator = _id_generator()
531534

532-
def _get_token(self, username=None, password=None, auth_code=None,
533-
client_id=None, client_secret=None, base_url=None):
535+
@classmethod
536+
def _get_token(cls, username=None, password=None, authcode=None,
537+
client_id=None, client_secret=None, url=None):
534538
''' Retrieve token from Viya installation '''
535-
headers = {'Accept': 'application/vnd.sas.compute.session+json',
536-
'Content-Type': 'application/x-www-form-urlencoded'}
539+
from .rest.connection import _print_request, _setup_ssl
540+
541+
with requests.Session() as req_sess:
542+
543+
_setup_ssl(req_sess)
544+
545+
req_sess.headers.update({'Accept': 'application/vnd.sas.compute.session+json',
546+
'Content-Type': 'application/x-www-form-urlencoded'})
537547

538-
auth_code = auth_code or cf.get_option('cas.auth_code')
539-
if auth_code:
540548
client_id = client_id or cf.get_option('cas.client_id') or 'SWAT'
541-
client_secret = client_secret or cf.get_option('cas.client_secret') or ''
542-
body = {'grant_type': 'authorization_code', 'code': auth_code,
543-
'client_id': client_id, 'client_secret': client_secret}
544-
else:
545-
user = client_id or cf.get_option('cas.username')
546-
password = client_secret or cf.get_option('cas.token')
547-
body = {'grant_type': 'password', 'username': user,
548-
'password': password}
549-
550-
resp = requests.post(base_url + '/SASLogon/oauth/token',
551-
headers=headers, user=('sas.tkmtrb', ''),
552-
data=urlparse.quote_plus(body))
553-
554-
if resp.status_code >= 300:
555-
raise SWATError('Token request resulted in a status of %s' %
556-
resp.status_code)
557-
558-
return resp['access_token']
549+
550+
authcode = authcode or cf.get_option('cas.authcode')
551+
if authcode:
552+
client_secret = client_secret or cf.get_option('cas.client_secret') or ''
553+
body = {'grant_type': 'authorization_code', 'code': authcode,
554+
'client_id': client_id, 'client_secret': client_secret}
555+
else:
556+
username = username or cf.get_option('cas.username')
557+
password = password or cf.get_option('cas.token')
558+
body = {'grant_type': 'password', 'username': username,
559+
'password': password}
560+
561+
resp = req_sess.post(urljoin(url, '/SASLogon/oauth/token'),
562+
auth=(client_id, ''),
563+
data=urlencode(body))
564+
565+
if resp.status_code >= 300:
566+
raise SWATError('Token request resulted in a status of %s' %
567+
resp.status_code)
568+
569+
return resp.json()['access_token']
559570

560571
def _gen_id(self):
561572
''' Generate an ID unique to the session '''

swat/cas/rest/connection.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,24 @@ def init_poolmanager(self, connections, maxsize,
175175
**pool_kwargs)
176176

177177

178+
def _setup_ssl(req_sess):
179+
''' Configure a Requests session for SSL '''
180+
if os.environ.get('SSLREQCERT', 'y').lower() in ['n', 'no', '0',
181+
'f', 'false', 'off']:
182+
req_sess.verify = False
183+
elif 'CAS_CLIENT_SSL_CA_LIST' in os.environ:
184+
req_sess.verify = os.path.expanduser(
185+
os.environ['CAS_CLIENT_SSL_CA_LIST'])
186+
elif 'SAS_TRUSTED_CA_CERTIFICATES_PEM_FILE' in os.environ:
187+
req_sess.verify = os.path.expanduser(
188+
os.environ['SAS_TRUSTED_CA_CERTIFICATES_PEM_FILE'])
189+
elif 'SSLCALISTLOC' in os.environ:
190+
req_sess.verify = os.path.expanduser(
191+
os.environ['SSLCALISTLOC'])
192+
elif 'REQUESTS_CA_BUNDLE' not in os.environ:
193+
req_sess.mount('https://', SSLContextAdapter())
194+
195+
178196
class REST_CASConnection(object):
179197
'''
180198
Create a REST CAS connection
@@ -286,22 +304,10 @@ def __init__(self, hostname, port, username, password, soptions, error):
286304

287305
self._req_sess = requests.Session()
288306

289-
if os.environ.get('SSLREQCERT', 'y').lower() in ['n', 'no', '0',
290-
'f', 'false', 'off']:
291-
self._req_sess.verify = False
292-
elif 'CAS_CLIENT_SSL_CA_LIST' in os.environ:
293-
self._req_sess.verify = os.path.expanduser(
294-
os.environ['CAS_CLIENT_SSL_CA_LIST'])
295-
elif 'SAS_TRUSTED_CA_CERTIFICATES_PEM_FILE' in os.environ:
296-
self._req_sess.verify = os.path.expanduser(
297-
os.environ['SAS_TRUSTED_CA_CERTIFICATES_PEM_FILE'])
298-
elif 'SSLCALISTLOC' in os.environ:
299-
self._req_sess.verify = os.path.expanduser(
300-
os.environ['SSLCALISTLOC'])
301-
elif 'REQUESTS_CA_BUNDLE' not in os.environ:
302-
self._req_sess.mount('https://', SSLContextAdapter())
307+
_setup_ssl(self._req_sess)
303308

304309
self._req_sess.headers.update({
310+
'Accept': 'application/json',
305311
'Content-Type': 'application/json',
306312
'Content-Length': '0',
307313
'Authorization': self._auth,
@@ -491,6 +497,7 @@ def invoke(self, action_name, kwargs):
491497

492498
post_data = a2u(kwargs).encode('utf-8')
493499
self._req_sess.headers.update({
500+
'Accept': 'application/json',
494501
'Content-Type': 'application/json',
495502
'Content-Length': str(len(post_data)),
496503
})
@@ -706,6 +713,7 @@ def upload(self, file_name, params):
706713
data = datafile.read()
707714

708715
self._req_sess.headers.update({
716+
'Accept': 'application/json',
709717
'Content-Type': 'application/octet-stream',
710718
'Content-Length': str(len(data)),
711719
'JSON-Parameters': json.dumps(_normalize_params(params))

swat/config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,20 +173,20 @@ def check_tz(value):
173173
'Sets the port number for the CAS server.',
174174
environ='CAS_PORT')
175175

176-
register_option('cas.auth_code', 'string', check_string, None,
177-
'Sets the auth_code for retrieving an OAuth token from '
176+
register_option('cas.authcode', 'string', check_string, None,
177+
'Sets the authorization code for retrieving an OAuth token from '
178178
'a Viya deployment.',
179-
environ='CAS_AUTHCODE')
179+
environ=['CAS_AUTHCODE', 'VIYA_AUTHCODE'])
180180

181181
register_option('cas.client_id', 'string', check_string, 'SWAT',
182182
'Sets the client ID for retrieving an OAuth token from '
183183
'a Viya deployment.',
184-
environ='CAS_CLIENT_ID')
184+
environ=['CAS_CLIENT_ID', 'VIYA_CLIENT_ID'])
185185

186186
register_option('cas.client_secret', 'string', check_string, '',
187187
'Sets the client secret for retrieving an OAuth token from '
188188
'a Viya deployment.',
189-
environ='CAS_CLIENT_SECRET')
189+
environ=['CAS_CLIENT_SECRET', 'VIYA_CLIENT_SECRET'])
190190

191191
register_option('cas.protocol', 'string',
192192
functools.partial(check_string,

0 commit comments

Comments
 (0)