Skip to content

Commit 1726d99

Browse files
author
Kevin D Smith
committed
Add method of retrieving token from Viya deployment
1 parent af4e780 commit 1726d99

File tree

2 files changed

+63
-14
lines changed

2 files changed

+63
-14
lines changed

swat/cas/connection.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def _get_connection_info(cls, hostname, port, username, password, protocol, path
346346
def __init__(self, hostname=None, port=None, username=None, password=None,
347347
session=None, locale=None, nworkers=None, name=None,
348348
authinfo=None, protocol=None, path=None, ssl_ca_list=None,
349-
**kwargs):
349+
auth_code=None, **kwargs):
350350

351351
# Filter session options allowed as parameters
352352
_kwargs = {}
@@ -364,19 +364,6 @@ def __init__(self, hostname=None, port=None, username=None, password=None,
364364
warnings.warn('Unrecognized keys in connection parameters: %s' %
365365
', '.join(unknown_keys))
366366

367-
# If a prototype exists, use it for the connection config
368-
prototype = kwargs.get('prototype')
369-
if prototype is not None:
370-
soptions = a2n(prototype._soptions)
371-
protocol = a2n(prototype._protocol)
372-
else:
373-
# Distill connection information from parameters, config, and environment
374-
hostname, port, username, password, protocol = \
375-
self._get_connection_info(hostname, port, username,
376-
password, protocol, path)
377-
soptions = a2n(getsoptions(session=session, locale=locale,
378-
nworkers=nworkers, protocol=protocol))
379-
380367
# Check for SSL certificate
381368
if ssl_ca_list is None:
382369
ssl_ca_list = cf.get_option('cas.ssl_ca_list')
@@ -392,6 +379,25 @@ def __init__(self, hostname=None, port=None, username=None, password=None,
392379
raise OSError('None of the specified authinfo files from'
393380
'list exist: %s' % ', '.join(authinfo))
394381

382+
# If a prototype exists, use it for the connection config
383+
prototype = kwargs.get('prototype')
384+
if prototype is not None:
385+
soptions = a2n(prototype._soptions)
386+
protocol = a2n(prototype._protocol)
387+
else:
388+
# Distill connection information from parameters, config, and environment
389+
hostname, port, username, password, protocol = \
390+
self._get_connection_info(hostname, port, username,
391+
password, protocol, path)
392+
soptions = a2n(getsoptions(session=session, locale=locale,
393+
nworkers=nworkers, protocol=protocol))
394+
395+
# Check for auth_code authentication
396+
auth_code = auth_code or cf.get_option('cas.auth_code')
397+
if protocol in ['http', 'https'] and auth_code:
398+
username = None
399+
password = self.get_token(auth_code=auth_code, url=hostname)
400+
395401
# Create error handler
396402
try:
397403
if protocol in ['http', 'https']:
@@ -523,6 +529,34 @@ def _id_generator():
523529
num = num + 1
524530
self._id_generator = _id_generator()
525531

532+
def _get_token(self, username=None, password=None, auth_code=None,
533+
client_id=None, client_secret=None, base_url=None):
534+
''' Retrieve token from Viya installation '''
535+
headers = {'Accept': 'application/vnd.sas.compute.session+json',
536+
'Content-Type': 'application/x-www-form-urlencoded'}
537+
538+
auth_code = auth_code or cf.get_option('cas.auth_code')
539+
if auth_code:
540+
client_id = client_id or cf.get_option('cas.client_id') or 'SWAT'
541+
client_secret = client_secret or cf.get_option('cas.client_secret') or ''
542+
body = {'grant_type': 'authorization_code', 'code': auth_code,
543+
'client_id': client_id, 'client_secret': client_secret}
544+
else:
545+
user = client_id or cf.get_option('cas.username')
546+
password = client_secret or cf.get_option('cas.token')
547+
body = {'grant_type': 'password', 'username': user,
548+
'password': password}
549+
550+
resp = requests.post(base_url + '/SASLogon/oauth/token',
551+
headers=headers, user=('sas.tkmtrb', ''),
552+
data=urlparse.quote_plus(body))
553+
554+
if resp.status_code >= 300:
555+
raise SWATError('Token request resulted in a status of %s' %
556+
resp.status_code)
557+
558+
return resp['access_token']
559+
526560
def _gen_id(self):
527561
''' Generate an ID unique to the session '''
528562
import numpy

swat/config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,21 @@ def check_tz(value):
173173
'Sets the port number for the CAS server.',
174174
environ='CAS_PORT')
175175

176+
register_option('cas.auth_code', 'string', check_string, None,
177+
'Sets the auth_code for retrieving an OAuth token from '
178+
'a Viya deployment.',
179+
environ='CAS_AUTHCODE')
180+
181+
register_option('cas.client_id', 'string', check_string, 'SWAT',
182+
'Sets the client ID for retrieving an OAuth token from '
183+
'a Viya deployment.',
184+
environ='CAS_CLIENT_ID')
185+
186+
register_option('cas.client_secret', 'string', check_string, '',
187+
'Sets the client secret for retrieving an OAuth token from '
188+
'a Viya deployment.',
189+
environ='CAS_CLIENT_SECRET')
190+
176191
register_option('cas.protocol', 'string',
177192
functools.partial(check_string,
178193
valid_values=['auto', 'cas', 'http', 'https']),

0 commit comments

Comments
 (0)