Skip to content

Commit 8409e33

Browse files
author
Kevin D Smith
committed
Fix tests in mpp mode
1 parent 8b598a3 commit 8409e33

File tree

4 files changed

+90
-87
lines changed

4 files changed

+90
-87
lines changed

swat/cas/table.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4963,7 +4963,7 @@ def _get_casout_slice(self, n, columns=None, ascending=True,
49634963
try:
49644964
casin = self.to_view()
49654965

4966-
out = self._retrieve('datastep.runcode', code=r'''
4966+
out = self._retrieve('datastep.runcode', single='yes', code=r'''
49674967
data %s;
49684968
%s
49694969
set %s;
@@ -6160,12 +6160,14 @@ def _apply_datastep(self, code, inplace=False, casout=None,
61606160
if casout is None:
61616161
casout = {}
61626162

6163+
default_caslib = self.getsessopt('caslib').caslib
6164+
61636165
if casout.get('caslib'):
61646166
caslib = casout['caslib']
61656167
elif inplace and 'caslib' in self.params:
61666168
caslib = self.params['caslib']
61676169
else:
6168-
caslib = self.getsessopt('caslib').caslib
6170+
caslib = default_caslib
61696171

61706172
if casout.get('name'):
61716173
newname = casout['name']
@@ -6179,7 +6181,7 @@ def _apply_datastep(self, code, inplace=False, casout=None,
61796181
dscode = []
61806182
dscode.append('data %s(caslib=%s);' % (_quote(newname), _quote(caslib)))
61816183
dscode.append(' set %s(caslib=%s);' % (_quote(self.params.name),
6182-
_quote(caslib)))
6184+
_quote(self.params.get('caslib', default_caslib))))
61836185
if isinstance(code, items_types):
61846186
dscode.extend(code)
61856187
else:

swat/tests/cas/test_bygroups.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ def test_column_nunique(self):
579579

580580
tblgrp = tbl['MSRP'].groupby(['Origin', 'Cylinders'], as_index=False).nunique()
581581
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
582-
self.assertTablesEqual(dfgrp.reset_index(), tblgrp, sortby=None)
582+
self.assertTablesEqual(dfgrp.reset_index(), tblgrp, sortby=['Origin', 'Cylinders', 'MSRP'])
583583

584584
def test_nunique(self):
585585
tbl = self.table.sort_values(SORT_KEYS)
@@ -663,7 +663,7 @@ def test_column_max(self):
663663

664664
tblgrp = tbl['EngineSize'].groupby('Origin', as_index=False).max()
665665
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
666-
self.assertTablesEqual(dfgrp.reset_index(), tblgrp, sortby=None)
666+
self.assertTablesEqual(dfgrp.reset_index(), tblgrp, sortby=['Origin', 'EngineSize'])
667667

668668
@unittest.skipIf(int(pd.__version__.split('.')[1]) < 16, 'Need newer version of Pandas')
669669
def test_max(self):
@@ -698,7 +698,7 @@ def test_max(self):
698698
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
699699
# Drop Model since they get sorted differently
700700
self.assertTablesEqual(dfgrp.drop('Model', axis=1), tblgrp.drop('Model', axis=1),
701-
sortby=None, include_index=True)
701+
sortby=['Origin', 'Make', 'Type', 'DriveTrain'])
702702

703703
def test_column_min(self):
704704
df = self.get_cars_df().sort_values(SORT_KEYS)
@@ -722,7 +722,7 @@ def test_column_min(self):
722722

723723
tblgrp = tbl['EngineSize'].groupby('Origin', as_index=False).min()
724724
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
725-
self.assertTablesEqual(dfgrp.reset_index(), tblgrp, sortby=None)
725+
self.assertTablesEqual(dfgrp.reset_index(), tblgrp, sortby=['Origin', 'EngineSize'])
726726

727727
@unittest.skipIf(int(pd.__version__.split('.')[1]) < 16, 'Need newer version of Pandas')
728728
def test_min(self):
@@ -757,7 +757,7 @@ def test_min(self):
757757
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
758758
# Drop Type since it gets sorted differently
759759
self.assertTablesEqual(dfgrp.drop('Type', axis=1), tblgrp.drop('Type', axis=1),
760-
sortby=None)
760+
sortby=['Origin', 'Make', 'Model'])
761761

762762
def test_column_mean(self):
763763
df = self.get_cars_df().sort_values(SORT_KEYS)
@@ -781,7 +781,7 @@ def test_column_mean(self):
781781

782782
tblgrp = tbl['EngineSize'].groupby('Origin', as_index=False).mean()
783783
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
784-
self.assertTablesEqual(dfgrp.reset_index(), tblgrp, sortby=None, decimals=5)
784+
self.assertTablesEqual(dfgrp.reset_index(), tblgrp, sortby=['Origin', 'EngineSize'], decimals=5)
785785

786786
@unittest.skipIf(sys.version_info.major < 3, 'Need newer version of Python')
787787
def test_mean(self):
@@ -804,7 +804,7 @@ def test_mean(self):
804804
dfgrp = df.groupby('Origin', as_index=False).mean()
805805
tblgrp = tbl.groupby('Origin', as_index=False).mean()
806806
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
807-
self.assertTablesEqual(dfgrp, tblgrp, sortby=None, decimals=5)
807+
self.assertTablesEqual(dfgrp, tblgrp, sortby=['Origin', 'MSRP', 'Invoice'], decimals=5)
808808

809809
def test_column_median(self):
810810
df = self.get_cars_df().sort_values(SORT_KEYS)
@@ -978,7 +978,7 @@ def test_column_sum(self):
978978

979979
tblgrp = tbl['EngineSize'].groupby('Origin', as_index=False).sum()
980980
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
981-
self.assertTablesEqual(dfgrp.reset_index(), tblgrp, sortby=None, decimals=5)
981+
self.assertTablesEqual(dfgrp.reset_index(), tblgrp, sortby=['Origin', 'EngineSize'], decimals=5)
982982

983983
def test_sum(self):
984984
df = self.get_cars_df().sort_values(SORT_KEYS)
@@ -1000,7 +1000,7 @@ def test_sum(self):
10001000
dfgrp = df.groupby('Origin', as_index=False).sum()
10011001
tblgrp = tbl.groupby('Origin', as_index=False).sum()
10021002
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
1003-
self.assertTablesEqual(dfgrp, tblgrp, decimals=5, sortby=None)
1003+
self.assertTablesEqual(dfgrp, tblgrp, decimals=5, sortby=['Origin', 'MSRP', 'Invoice'])
10041004

10051005
def test_column_std(self):
10061006
df = self.get_cars_df().sort_values(SORT_KEYS)
@@ -1024,7 +1024,7 @@ def test_column_std(self):
10241024

10251025
tblgrp = tbl['EngineSize'].groupby('Origin', as_index=False).std()
10261026
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
1027-
self.assertTablesEqual(dfgrp.reset_index(), tblgrp, sortby=None, decimals=5)
1027+
self.assertTablesEqual(dfgrp.reset_index(), tblgrp, sortby=['Origin', 'EngineSize'], decimals=5)
10281028

10291029
def test_std(self):
10301030
df = self.get_cars_df().sort_values(SORT_KEYS)
@@ -1046,7 +1046,7 @@ def test_std(self):
10461046
#dfgrp = df.groupby('Origin', as_index=False).std()
10471047
tblgrp = tbl.groupby('Origin', as_index=False).std()
10481048
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
1049-
self.assertTablesEqual(dfgrp.reset_index(), tblgrp, decimals=5, sortby=None)
1049+
self.assertTablesEqual(dfgrp.reset_index(), tblgrp, decimals=5, sortby=['Origin', 'MSRP', 'Invoice'])
10501050

10511051
def test_column_var(self):
10521052
df = self.get_cars_df().sort_values(SORT_KEYS)
@@ -1074,7 +1074,7 @@ def test_column_var(self):
10741074
# For some reason Pandas drops this column, but I think it should be there.
10751075
tblgrp = tblgrp.drop('Origin', axis=1)
10761076
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
1077-
self.assertTablesEqual(dfgrp, tblgrp, decimals=5, sortby=None)
1077+
self.assertTablesEqual(dfgrp, tblgrp, decimals=5, sortby=['EngineSize'])
10781078

10791079
def test_var(self):
10801080
df = self.get_cars_df().sort_values(SORT_KEYS)
@@ -1096,7 +1096,7 @@ def test_var(self):
10961096
dfgrp = df.groupby('Origin', as_index=False).var()
10971097
tblgrp = tbl.groupby('Origin', as_index=False).var()
10981098
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
1099-
self.assertTablesEqual(dfgrp, tblgrp, decimals=3, sortby=None)
1099+
self.assertTablesEqual(dfgrp, tblgrp, decimals=3, sortby=['Origin', 'MSRP', 'Invoice'])
11001100

11011101
def test_column_nmiss(self):
11021102
# TODO: Not supported by Pandas; need comparison values
@@ -1119,6 +1119,9 @@ def test_column_nmiss(self):
11191119
self.assertEqual(len(tblgrp), 3)
11201120

11211121
# Test character missing values
1122+
swat.options.cas.trace_actions = True
1123+
swat.options.cas.trace_ui_actions = True
1124+
swat.options.cas.print_messages = True
11221125
tbl = self.table.replace({'Make': {'Buick': ''}})
11231126

11241127
tblgrp = tbl.groupby('Origin')['Make'].nmiss()
@@ -1138,12 +1141,10 @@ def test_column_nmiss(self):
11381141
#
11391142
swat.options.cas.dataset.bygroup_casout_threshold = 2
11401143

1141-
swat.options.cas.print_messages = True
11421144
tblgrp = tbl['Cylinders'].groupby('Origin').nmiss()
11431145
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
11441146
self.assertEqual(len(tblgrp), 3)
11451147
tblgrp = tblgrp.to_frame().set_index('Origin')['Cylinders']
1146-
print(tblgrp)
11471148
self.assertEqual(tblgrp.loc['Asia'], 2)
11481149
self.assertEqual(tblgrp.loc['Europe'], 0)
11491150
self.assertEqual(tblgrp.loc['USA'], 0)

swat/tests/cas/test_datamsg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -458,13 +458,13 @@ def test_dbapi(self):
458458
self.assertEqual(colinfo['Column'].tolist(),
459459
'Make,Model,Type,Origin,DriveTrain,MSRP,Invoice,EngineSize,Cylinders,Horsepower,MPG_City,MPG_Highway,Weight,Wheelbase,Length'.split(','))
460460

461-
self.assertEqual(list(tbl.head().itertuples(index=False)),
462-
[('Dodge', 'Viper SRT-10 convertible 2dr', 'Sports', 'USA', 'Rear', 81795.0,
461+
self.assertEqual(sorted(tuple(x) for x in tbl.head().itertuples(index=False)),
462+
sorted([('Dodge', 'Viper SRT-10 convertible 2dr', 'Sports', 'USA', 'Rear', 81795.0,
463463
74451.0, '8.3', 10.0, 500.0, 12.0, 20.0, 3410.0, 99.0, 176.0),
464464
('Mercedes-Benz', 'CL600 2dr', 'Sedan', 'Europe', 'Rear', 128420.0, 119600.0,
465465
'5.5', 12.0, 493.0, 13.0, 19.0, 4473.0, 114.0, 196.0),
466466
('Mercedes-Benz', 'SL600 convertible 2dr', 'Sports', 'Europe', 'Rear',
467-
126670.0, 117854.0, '5.5', 12.0, 493.0, 13.0, 19.0, 4429.0, 101.0, 179.0)])
467+
126670.0, 117854.0, '5.5', 12.0, 493.0, 13.0, 19.0, 4429.0, 101.0, 179.0)]))
468468

469469
try:
470470
os.remove(tmpf)

0 commit comments

Comments
 (0)