Skip to content

Commit 30f7676

Browse files
author
Kevin D Smith
committed
Add data type maps for uploading DataFrames; add server version / features attributes to connection; misc cleanup
1 parent 61e69f3 commit 30f7676

File tree

7 files changed

+362
-44
lines changed

7 files changed

+362
-44
lines changed

swat/cas/connection.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from __future__ import print_function, division, absolute_import, unicode_literals
2525

26+
import collections
2627
import contextlib
2728
import copy
2829
import json
@@ -37,7 +38,7 @@
3738
from ..utils.config import subscribe, get_option
3839
from ..clib import errorcheck
3940
from ..utils.compat import (a2u, a2n, int32, int64, float64, text_types,
40-
binary_types, items_types, int_types)
41+
binary_types, items_types, int_types, dict_types)
4142
from ..utils import getsoptions
4243
from ..utils.args import iteroptions
4344
from ..formatter import SASFormatter
@@ -388,11 +389,33 @@ def _id_generator():
388389
num = num + 1
389390
self._id_generator = _id_generator()
390391

392+
self.server_version, self.server_features = self._get_server_features()
393+
391394
def _gen_id(self):
392395
''' Generate an ID unique to the session '''
393396
import numpy
394397
return numpy.base_repr(next(self._id_generator), 36)
395398

399+
def _get_server_features(self):
400+
'''
401+
Determine which features are available in the server
402+
403+
Returns
404+
-------
405+
set-of-strings
406+
407+
'''
408+
out = set()
409+
410+
info = self.retrieve('builtins.serverstatus', _messagelevel='error',
411+
_apptag='UI')
412+
version = tuple([int(x) for x in info['About']['Version'].split('.')][:2])
413+
414+
# if version >= (3, 4):
415+
# out.add('csv-ints')
416+
417+
return version, out
418+
396419
def _detect_protocol(self, hostname, port, protocol=None):
397420
'''
398421
Detect the protocol type for the given host and port
@@ -1147,6 +1170,110 @@ def _invoke_with_signature(self, _name_, **kwargs):
11471170

11481171
return signature
11491172

1173+
def _extract_dtypes(self, df):
1174+
'''
1175+
Extract importoptions= style data types from the DataFrame
1176+
1177+
Parameters
1178+
----------
1179+
df : pandas.DataFrame
1180+
The DataFrame to get types from
1181+
format : string, optional
1182+
The output format: dict or list
1183+
1184+
Returns
1185+
-------
1186+
OrderedDict
1187+
1188+
'''
1189+
out = collections.OrderedDict()
1190+
1191+
for key, value in df.dtypes.items():
1192+
value = value.name
1193+
1194+
if value == 'object':
1195+
value = 'varchar'
1196+
1197+
elif value.startswith('float'):
1198+
value = 'double'
1199+
1200+
elif value.endswith('int64'):
1201+
if 'csv-ints' in self.server_features:
1202+
value = 'int64'
1203+
else:
1204+
value = 'double'
1205+
1206+
elif value.startswith('int'):
1207+
if 'csv-ints' in self.server_features:
1208+
value = 'int32'
1209+
else:
1210+
value = 'double'
1211+
1212+
elif value.startswith('bool'):
1213+
if 'csv-ints' in self.server_features:
1214+
value = 'int32'
1215+
else:
1216+
value = 'double'
1217+
1218+
elif value.startswith('datetime'):
1219+
value = 'varchar'
1220+
1221+
else:
1222+
continue
1223+
1224+
out[key] = dict(type=value)
1225+
1226+
return out
1227+
1228+
def _apply_importoptions_vars(self, importoptions, df_dtypes):
1229+
'''
1230+
Merge in vars= parameters to importoptions=
1231+
1232+
Notes
1233+
-----
1234+
This method modifies the importoptions in-place.
1235+
1236+
Parameters
1237+
----------
1238+
importoptions : dict
1239+
The importoptions= parameter
1240+
df_dtypes : dict or list
1241+
The DataFrame data types dictionary
1242+
1243+
'''
1244+
if 'vars' not in importoptions:
1245+
importoptions['vars'] = df_dtypes
1246+
return
1247+
1248+
vars = importoptions['vars']
1249+
1250+
# Merge options into dict vars
1251+
if isinstance(vars, dict_types):
1252+
for key, value in six.iteritems(df_dtypes):
1253+
if key in vars:
1254+
for k, v in six.iteritems(value):
1255+
vars[key].setdefault(k, v)
1256+
else:
1257+
vars[key] = value
1258+
1259+
# Merge options into list vars
1260+
else:
1261+
df_dtypes_list = []
1262+
for key, value in six.iteritems(df_dtypes):
1263+
value = dict(value)
1264+
value['name'] = key
1265+
df_dtypes_list.append(value)
1266+
1267+
for i, item in enumerate(df_dtypes_list):
1268+
if i < len(vars):
1269+
if not vars[i]:
1270+
vars[i] = item
1271+
else:
1272+
for key, value in six.iteritems(item):
1273+
vars[i].setdefault(key, value)
1274+
else:
1275+
vars.append(item)
1276+
11501277
def upload(self, data, importoptions=None, casout=None, **kwargs):
11511278
'''
11521279
Upload data from a local file into a CAS table
@@ -1207,6 +1334,7 @@ def upload(self, data, importoptions=None, casout=None, **kwargs):
12071334
'''
12081335
delete = False
12091336
name = None
1337+
df_dtypes = None
12101338

12111339
for key, value in list(kwargs.items()):
12121340
if importoptions is None and key.lower() == 'importoptions':
@@ -1224,6 +1352,7 @@ def upload(self, data, importoptions=None, casout=None, **kwargs):
12241352
filename = tmp.name
12251353
name = os.path.splitext(os.path.basename(filename))[0]
12261354
data.to_csv(filename, encoding='utf-8', index=False)
1355+
df_dtypes = self._extract_dtypes(data)
12271356

12281357
elif data.startswith('http://') or \
12291358
data.startswith('https://') or \
@@ -1256,15 +1385,20 @@ def upload(self, data, importoptions=None, casout=None, **kwargs):
12561385

12571386
if importoptions is None:
12581387
importoptions = {}
1388+
12591389
if isinstance(importoptions, (dict, ParamManager)) and \
12601390
'filetype' not in [x.lower() for x in importoptions.keys()]:
12611391
ext = os.path.splitext(filename)[-1][1:].lower()
12621392
if ext in filetype:
12631393
importoptions['filetype'] = filetype[ext]
12641394
elif len(ext) == 3 and ext.endswith('sv'):
12651395
importoptions['filetype'] = 'csv'
1396+
12661397
kwargs['importoptions'] = importoptions
12671398

1399+
if df_dtypes:
1400+
self._apply_importoptions_vars(importoptions, df_dtypes)
1401+
12681402
if casout is None:
12691403
casout = {}
12701404
if isinstance(casout, CASTable):

swat/cas/rest/connection.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -248,9 +248,9 @@ def __init__(self, hostname, port, username, password, soptions, error):
248248
self._req_sess = requests.Session()
249249

250250
if 'SSLCALISTLOC' in os.environ:
251-
self._req_sess.verify = os.environ['SSLCALISTLOC']
251+
self._req_sess.verify = os.path.expanduser(os.environ['SSLCALISTLOC'])
252252
elif 'CAS_CLIENT_SSL_CA_LIST' in os.environ:
253-
self._req_sess.verify = os.environ['CAS_CLIENT_SSL_CA_LIST']
253+
self._req_sess.verify = os.path.expanduser(os.environ['CAS_CLIENT_SSL_CA_LIST'])
254254

255255
if os.environ.get('SSLREQCERT', 'y').lower().startswith('n'):
256256
self._req_sess.verify = False
@@ -526,13 +526,14 @@ def upload(self, file_name, params):
526526

527527
while True:
528528
try:
529+
url = urllib.parse.urljoin(self._current_baseurl,
530+
'cas/sessions/%s/actions/table.upload' %
531+
self._session)
532+
529533
if get_option('cas.debug.requests'):
530534
_print_request('DELETE', url, self._req_sess.headers, data)
531535

532-
res = self._req_sess.put(
533-
urllib.parse.urljoin(self._current_baseurl,
534-
'cas/sessions/%s/actions/table.upload' %
535-
self._session), data=data)
536+
res = self._req_sess.put(url, data=data)
536537

537538
if get_option('cas.debug.responses'):
538539
_print_response(res.text)

swat/tests/cas/test_bygroups.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ class TestByGroups(tm.TestCase):
5050
def setUp(self):
5151
swat.reset_option()
5252
swat.options.cas.print_messages = False
53-
swat.options.cas.trace_actions = False
54-
swat.options.cas.trace_ui_actions = False
53+
# swat.options.cas.trace_actions = False
54+
# swat.options.cas.trace_ui_actions = False
5555
swat.options.interactive_mode = False
5656

5757
self.s = swat.CAS(HOST, PORT, USER, PASSWD, protocol=PROTOCOL)
@@ -299,8 +299,6 @@ def test_nsmallest(self):
299299
#
300300
swat.options.cas.dataset.bygroup_casout_threshold = 2
301301

302-
swat.options.cas.trace_actions = True
303-
swat.options.cas.trace_ui_actions = True
304302
tblgrp = tbl[['Model', 'MSRP', 'Horsepower']].groupby(['Make',
305303
'Cylinders'], as_index=False).query('Make in ("Porsche", "BMW")').nsmallest(2, columns=['MSRP'])
306304
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
@@ -1119,9 +1117,6 @@ def test_column_nmiss(self):
11191117
self.assertEqual(len(tblgrp), 3)
11201118

11211119
# Test character missing values
1122-
swat.options.cas.trace_actions = True
1123-
swat.options.cas.trace_ui_actions = True
1124-
swat.options.cas.print_messages = True
11251120
tbl = self.table.replace({'Make': {'Buick': ''}})
11261121

11271122
tblgrp = tbl.groupby('Origin')['Make'].nmiss()

0 commit comments

Comments
 (0)