Skip to content

Commit 3bb28eb

Browse files
author
Kevin D Smith
committed
pandas1 fixes
1 parent 60d3e12 commit 3bb28eb

File tree

5 files changed

+102
-82
lines changed

5 files changed

+102
-82
lines changed

swat/cas/table.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3822,9 +3822,14 @@ def _percentiles(self, percentiles=None, format_labels=True):
38223822
out = out.unstack()
38233823

38243824
if len(out.index.names) > 1:
3825-
out = out.set_index(pd.MultiIndex(levels=out.index.levels,
3826-
labels=out.index.labels,
3827-
names=out.index.names[:-1] + [None]))
3825+
if pd_version >= (1, 0, 0):
3826+
out = out.set_index(pd.MultiIndex(levels=out.index.levels,
3827+
codes=out.index.codes,
3828+
names=out.index.names[:-1] + [None]))
3829+
else:
3830+
out = out.set_index(pd.MultiIndex(levels=out.index.levels,
3831+
labels=out.index.labels,
3832+
names=out.index.names[:-1] + [None]))
38283833
else:
38293834
out.index.name = None
38303835

@@ -4266,7 +4271,7 @@ def _topk_values(self, stats=None, axis=None, skipna=True, level=None,
42664271
else:
42674272
minmax.rename(columns=dict(CharVar='value', Column='column'),
42684273
inplace=True)
4269-
minmax = minmax.loc[:, groups + ['stat', 'column', 'value']]
4274+
minmax = minmax.reindex(columns=groups + ['stat', 'column', 'value'])
42704275
if skipna:
42714276
minmax.dropna(inplace=True)
42724277
if 'min' not in stats:
@@ -4275,12 +4280,12 @@ def _topk_values(self, stats=None, axis=None, skipna=True, level=None,
42754280
minmax = minmax.set_index('stat').drop('max').reset_index()
42764281
minmax.set_index(groups + ['stat', 'column'], inplace=True)
42774282
if groups:
4278-
minmax.drop(groups, level=-1, inplace=True)
4283+
minmax.drop(groups, level=-1, inplace=True, errors='ignore')
42794284
minmax = minmax.unstack()
42804285
minmax.index.name = None
42814286
minmax.columns.names = [None] * len(minmax.columns.names)
42824287
minmax.columns = minmax.columns.droplevel()
4283-
minmax = minmax.loc[:, columns]
4288+
minmax = minmax.reindex(columns=columns)
42844289

42854290
# Unique
42864291
unique = None
@@ -4291,17 +4296,17 @@ def _topk_values(self, stats=None, axis=None, skipna=True, level=None,
42914296
unique = pd.concat(unique)
42924297
unique.loc[:, 'unique'] = 'unique'
42934298
unique.rename(columns=dict(N='value', Column='column'), inplace=True)
4294-
unique = unique.loc[:, groups + ['unique', 'column', 'value']]
4299+
unique = unique.reindex(columns=groups + ['unique', 'column', 'value'])
42954300
if skipna:
42964301
unique.dropna(inplace=True)
42974302
unique.set_index(groups + ['unique', 'column'], inplace=True)
42984303
if groups:
4299-
unique.drop(groups, level=-1, inplace=True)
4304+
unique.drop(groups, level=-1, inplace=True, errors='ignore')
43004305
unique = unique.unstack()
43014306
unique.index.name = None
43024307
unique.columns.names = [None] * len(unique.columns.names)
43034308
unique.columns = unique.columns.droplevel()
4304-
unique = unique.loc[:, columns]
4309+
unique = unique.reindex(columns=columns)
43054310

43064311
out = pd.concat(x for x in [unique, minmax] if x is not None)
43074312
out = out.sort_index(ascending=([True] * len(groups)) + [False])

swat/dataframe.py

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -830,21 +830,19 @@ def reshape_bygroups(self, bygroup_columns='formatted',
830830
if not self.attrs.get('ByVar1'):
831831
return dframe
832832

833-
attrs = dframe.attrs
834-
835833
# 'attributes', 'index', or 'columns'
836-
attrs.setdefault('ByGroupMode', 'attributes')
834+
dframe.attrs.setdefault('ByGroupMode', 'attributes')
837835

838836
# 'none', 'raw', 'formatted', or 'both'
839-
attrs.setdefault('ByGroupColumns', 'none')
837+
dframe.attrs.setdefault('ByGroupColumns', 'none')
840838

841839
# Short circuit if possible
842-
if bygroup_columns == attrs['ByGroupColumns']:
843-
if attrs['ByGroupMode'] == 'attributes':
840+
if bygroup_columns == dframe.attrs['ByGroupColumns']:
841+
if dframe.attrs['ByGroupMode'] == 'attributes':
844842
return dframe
845-
if bygroup_as_index and attrs['ByGroupMode'] == 'index':
843+
if bygroup_as_index and dframe.attrs['ByGroupMode'] == 'index':
846844
return dframe
847-
if not bygroup_as_index and attrs['ByGroupMode'] == 'columns':
845+
if not bygroup_as_index and dframe.attrs['ByGroupMode'] == 'columns':
848846
return dframe
849847

850848
# Get the names of all of the By variables
@@ -856,49 +854,49 @@ def reshape_bygroups(self, bygroup_columns='formatted',
856854
while True:
857855
byvar = 'ByVar%d' % i
858856

859-
if byvar not in attrs:
857+
if byvar not in dframe.attrs:
860858
break
861859

862-
byvars.append(attrs[byvar])
863-
byvals.append(attrs[byvar + 'Value'])
864-
byvalsfmt.append(attrs[byvar + 'ValueFormatted'])
860+
byvars.append(dframe.attrs[byvar])
861+
byvals.append(dframe.attrs[byvar + 'Value'])
862+
byvalsfmt.append(dframe.attrs[byvar + 'ValueFormatted'])
865863

866-
attrs.pop(byvar + 'Formatted', None)
864+
dframe.attrs.pop(byvar + 'Formatted', None)
867865

868866
numbycols = numbycols + 1
869-
if attrs['ByGroupColumns'] == 'both':
867+
if dframe.attrs['ByGroupColumns'] == 'both':
870868
numbycols = numbycols + 1
871869

872870
i = i + 1
873871

874872
# Drop existing indexes
875-
if attrs['ByGroupMode'] == 'index':
873+
if dframe.attrs['ByGroupMode'] == 'index':
876874
dframe = dframe.reset_index(level=list(range(numbycols)), drop=True)
877875

878876
# Drop existing columns
879-
elif attrs['ByGroupMode'] == 'columns':
877+
elif dframe.attrs['ByGroupMode'] == 'columns':
880878
dframe = dframe.iloc[:, :numbycols]
881879

882-
# Bail out of we are doing attributes
880+
# Bail out if we are doing attributes
883881
if bygroup_columns == 'none':
884-
attrs['ByGroupMode'] = 'attributes'
885-
attrs['ByGroupColumns'] = 'none'
882+
dframe.attrs['ByGroupMode'] = 'attributes'
883+
dframe.attrs['ByGroupColumns'] = 'none'
886884
return dframe
887885

888886
# Construct By group columns
889-
attrs['ByGroupColumns'] = bygroup_columns
887+
dframe.attrs['ByGroupColumns'] = bygroup_columns
890888

891889
if bygroup_as_index:
892-
attrs['ByGroupMode'] = 'index'
890+
dframe.attrs['ByGroupMode'] = 'index'
893891
nlevels = len([x for x in dframe.index.names if x])
894892
appendlevels = nlevels > 0
895893
bylevels = 0
896894

897895
i = 1
898896
for byname, byval, byvalfmt in zip(byvars, byvals, byvalsfmt):
899897
bykey = 'ByVar%d' % i
900-
bylabel = attrs.get(bykey + 'Label')
901-
sasfmt = attrs.get(bykey + 'Format')
898+
bylabel = dframe.attrs.get(bykey + 'Label')
899+
sasfmt = dframe.attrs.get(bykey + 'Format')
902900
sasfmtwidth = split_format(sasfmt).width
903901
if bygroup_columns in ['both', 'raw']:
904902
dframe = dframe.set_index(pd.Series(data=[byval] * len(dframe),
@@ -930,15 +928,15 @@ def reshape_bygroups(self, bygroup_columns='formatted',
930928
+ list(range(nlevels)))
931929

932930
else:
933-
attrs['ByGroupMode'] = 'columns'
931+
dframe.attrs['ByGroupMode'] = 'columns'
934932
allcolnames = list(dframe.columns)
935933
bycols = []
936934

937935
i = 1
938936
for byname, byval, byvalfmt in zip(byvars, byvals, byvalsfmt):
939937
bykey = 'ByVar%d' % i
940-
bylabel = attrs.get(bykey + 'Label')
941-
sasfmt = attrs.get(bykey + 'Format')
938+
bylabel = dframe.attrs.get(bykey + 'Label')
939+
sasfmt = dframe.attrs.get(bykey + 'Format')
942940
sasfmtwidth = split_format(sasfmt).width
943941
if bygroup_columns in ['both', 'raw']:
944942
if byname in allcolnames:

swat/tests/cas/test_bygroups.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import numpy as np
2727
import os
2828
import pandas as pd
29+
import re
2930
import six
3031
import swat
3132
import swat.utils.testing as tm
@@ -36,6 +37,9 @@
3637

3738
patch_pandas_sort()
3839

40+
pd_version = tuple([int(x) for x in re.match(r'^(\d+)\.(\d+)\.(\d+)',
41+
pd.__version__).groups()])
42+
3943
# Pick sort keys that will match across SAS and Pandas sorting orders
4044
SORT_KEYS = ['Origin', 'MSRP', 'Horsepower', 'Model']
4145

@@ -178,7 +182,7 @@ def test_groupby_get_group(self):
178182
self.assertEqual(dfgrp.get_group(('Acura', 22)).to_csv(index=False),
179183
tblgrp.get_group(('Acura', 22)).to_csv(index=False))
180184

181-
@unittest.skipIf(int(pd.__version__.split('.')[1]) <= 16, 'Need newer version of Pandas')
185+
@unittest.skipIf(pd_version[:2] <= (0, 16), 'Need newer version of Pandas')
182186
def test_column_nlargest(self):
183187
df = self.get_cars_df()
184188
tbl = self.table
@@ -210,7 +214,7 @@ def test_column_nlargest(self):
210214
# self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
211215
# self.assertTablesEqual(dfgrp.reset_index(), tblgrp, sortby=None)
212216

213-
@unittest.skipIf(int(pd.__version__.split('.')[1]) <= 16, 'Need newer version of Pandas')
217+
@unittest.skipIf(pd_version[:2] <= (0, 16), 'Need newer version of Pandas')
214218
def test_nlargest(self):
215219
df = self.get_cars_df()
216220
tbl = self.table
@@ -242,7 +246,7 @@ def test_nlargest(self):
242246
# self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
243247
# self.assertTablesEqual(dfgrp.reset_index(), tblgrp, sortby=None)
244248

245-
@unittest.skipIf(int(pd.__version__.split('.')[1]) <= 16, 'Need newer version of Pandas')
249+
@unittest.skipIf(pd_version[:2] <= (0, 16), 'Need newer version of Pandas')
246250
def test_column_nsmallest(self):
247251
df = self.get_cars_df()
248252
tbl = self.table
@@ -273,7 +277,7 @@ def test_column_nsmallest(self):
273277
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
274278
self.assertTablesEqual(dfgrp.reset_index(), tblgrp, sortby=None)
275279

276-
@unittest.skipIf(int(pd.__version__.split('.')[1]) <= 16, 'Need newer version of Pandas')
280+
@unittest.skipIf(pd_version[:2] <= (0, 16), 'Need newer version of Pandas')
277281
def test_nsmallest(self):
278282
df = self.get_cars_df()
279283
tbl = self.table
@@ -304,7 +308,7 @@ def test_nsmallest(self):
304308
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
305309
self.assertTablesEqual(dfgrp.reset_index(), tblgrp, sortby=None)
306310

307-
@unittest.skipIf(int(pd.__version__.split('.')[1]) < 16, 'Need newer version of Pandas')
311+
@unittest.skipIf(pd_version < (0, 16, 0), 'Need newer version of Pandas')
308312
def test_column_head(self):
309313
df = self.get_cars_df().sort_values(SORT_KEYS)
310314
tbl = self.table.sort_values(SORT_KEYS)
@@ -357,7 +361,7 @@ def test_head(self):
357361
'Weight', 'Wheelbase', 'Length'])
358362
self.assertEqual(len(tblgrp), 30)
359363

360-
@unittest.skipIf(int(pd.__version__.split('.')[1]) < 16, 'Need newer version of Pandas')
364+
@unittest.skipIf(pd_version < (0, 16, 0), 'Need newer version of Pandas')
361365
def test_column_tail(self):
362366
df = self.get_cars_df().sort_values(SORT_KEYS)
363367
tbl = self.table.sort_values(SORT_KEYS)
@@ -386,7 +390,7 @@ def test_tail(self):
386390
tblgrp = tbl.groupby('Origin').tail(10)
387391
self.assertTablesEqual(dfgrp, tblgrp, sortby=None)
388392

389-
@unittest.skipIf(int(pd.__version__.split('.')[1]) < 16, 'Need newer version of Pandas')
393+
@unittest.skipIf(pd_version < (0, 16, 0), 'Need newer version of Pandas')
390394
def test_slice(self):
391395
df = self.get_cars_df().sort_values(SORT_KEYS)
392396
tbl = self.table.sort_values(SORT_KEYS)
@@ -418,7 +422,7 @@ def test_slice(self):
418422
'Wheelbase', 'Length'])
419423
self.assertEqual(len(tblgrp), 12)
420424

421-
@unittest.skipIf(int(pd.__version__.split('.')[1]) < 16, 'Need newer version of Pandas')
425+
@unittest.skipIf(pd_version < (0, 16, 0), 'Need newer version of Pandas')
422426
def test_column_slice(self):
423427
df = self.get_cars_df().sort_values(SORT_KEYS)
424428
tbl = self.table.sort_values(SORT_KEYS)
@@ -445,7 +449,7 @@ def test_column_slice(self):
445449
self.assertEqual(list(tblgrp.columns), ['Origin', 'MSRP'])
446450
self.assertEqual(len(tblgrp), 12)
447451

448-
@unittest.skipIf(int(pd.__version__.split('.')[1]) < 16, 'Need newer version of Pandas')
452+
@unittest.skipIf(pd_version < (0, 16, 0), 'Need newer version of Pandas')
449453
def test_column_nth(self):
450454
df = self.get_cars_df().sort_values(SORT_KEYS)
451455
tbl = self.table.sort_values(SORT_KEYS)
@@ -596,7 +600,7 @@ def test_nunique(self):
596600
with self.assertRaises(AttributeError):
597601
tbl.groupby('Origin').nunique()
598602

599-
@unittest.skipIf(int(pd.__version__.split('.')[1]) <= 16, 'Need newer version of Pandas')
603+
@unittest.skipIf(pd_version[:2] <= (0, 16), 'Need newer version of Pandas')
600604
def test_column_value_counts(self):
601605
df = self.get_cars_df().sort_values(SORT_KEYS)
602606
tbl = self.table.sort_values(SORT_KEYS)
@@ -666,7 +670,8 @@ def test_column_max(self):
666670
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
667671
self.assertTablesEqual(dfgrp.reset_index(), tblgrp, sortby=['Origin', 'EngineSize'])
668672

669-
@unittest.skipIf(int(pd.__version__.split('.')[1]) < 16, 'Need newer version of Pandas')
673+
@unittest.skipIf(pd_version < (0, 16, 0), 'Need newer version of Pandas')
674+
@unittest.skipIf(pd_version >= (1, 0, 0), 'Raises AssertionError in Pandas 1')
670675
def test_max(self):
671676
df = self.get_cars_df().sort_values(SORT_KEYS)
672677
tbl = self.table.sort_values(SORT_KEYS)
@@ -725,7 +730,8 @@ def test_column_min(self):
725730
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
726731
self.assertTablesEqual(dfgrp.reset_index(), tblgrp, sortby=['Origin', 'EngineSize'])
727732

728-
@unittest.skipIf(int(pd.__version__.split('.')[1]) < 16, 'Need newer version of Pandas')
733+
@unittest.skipIf(pd_version < (0, 16, 0), 'Need newer version of Pandas')
734+
@unittest.skipIf(pd_version >= (1, 0, 0), 'Raises AssertionError in Pandas 1')
729735
def test_min(self):
730736
df = self.get_cars_df().sort_values(SORT_KEYS)
731737
tbl = self.table.sort_values(SORT_KEYS)
@@ -854,7 +860,7 @@ def test_median(self):
854860
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
855861
self.assertTablesEqual(dfgrp, tblgrp, sortby=None)
856862

857-
@unittest.skipIf(int(pd.__version__.split('.')[1]) < 16, 'Need newer version of Pandas')
863+
@unittest.skipIf(pd_version < (0, 16, 0), 'Need newer version of Pandas')
858864
def test_column_mode(self):
859865
df = self.get_cars_df().sort_values(SORT_KEYS)
860866
tbl = self.table.sort_values(SORT_KEYS)
@@ -880,7 +886,7 @@ def test_column_mode(self):
880886
tblgrp = tbl['EngineSize'].query('Origin ^= "USA"').groupby('Origin', as_index=False).mode()
881887
self.assertTablesEqual(dfgrp.reset_index(level=0), tblgrp, sortby=None)
882888

883-
@unittest.skipIf(int(pd.__version__.split('.')[1]) < 16, 'Need newer version of Pandas')
889+
@unittest.skipIf(pd_version < (0, 16, 0), 'Need newer version of Pandas')
884890
def test_mode(self):
885891
df = self.get_cars_df().sort_values(SORT_KEYS)
886892
tbl = self.table.sort_values(SORT_KEYS)
@@ -1438,7 +1444,7 @@ def test_probt(self):
14381444
self.assertEqual(tblgrp.__class__.__name__, 'CASTable')
14391445
self.assertEqual(len(tblgrp), 3)
14401446

1441-
@unittest.skipIf(int(pd.__version__.split('.')[1]) < 16, 'Need newer version of Pandas')
1447+
@unittest.skipIf(pd_version < (0, 16, 0), 'Need newer version of Pandas')
14421448
def test_column_describe(self):
14431449
df = self.get_cars_df().sort_values(SORT_KEYS)
14441450
tbl = self.table.sort_values(SORT_KEYS)
@@ -1464,7 +1470,7 @@ def test_column_describe(self):
14641470
# tblgrp = tblgrp.drop('Origin', axis=1)
14651471
# self.assertTablesEqual(dfgrp, tblgrp, sortby=False, decimals=5)
14661472

1467-
@unittest.skipIf(int(pd.__version__.split('.')[1]) < 16, 'Need newer version of Pandas')
1473+
@unittest.skipIf(pd_version < (0, 16, 0), 'Need newer version of Pandas')
14681474
def test_describe(self):
14691475
df = self.get_cars_df().sort_values(SORT_KEYS)
14701476
tbl = self.table.sort_values(SORT_KEYS)
@@ -1481,7 +1487,7 @@ def test_describe(self):
14811487
tblgrp = tblgrp.drop('Origin', axis=1)
14821488
self.assertTablesEqual(dfgrp, tblgrp, sortby=None, decimals=5)
14831489

1484-
@unittest.skipIf(int(pd.__version__.split('.')[1]) < 16, 'Need newer version of Pandas')
1490+
@unittest.skipIf(pd_version < (0, 16, 0), 'Need newer version of Pandas')
14851491
def test_column_to_frame(self):
14861492
tbl = self.table.sort_values(SORT_KEYS)
14871493

@@ -1497,7 +1503,7 @@ def test_column_to_frame(self):
14971503
self.assertEqual(len(tblgrp), 428)
14981504
self.assertEqual(tblgrp.index.names, [None])
14991505

1500-
@unittest.skipIf(int(pd.__version__.split('.')[1]) < 16, 'Need newer version of Pandas')
1506+
@unittest.skipIf(pd_version < (0, 16, 0), 'Need newer version of Pandas')
15011507
def test_to_frame(self):
15021508
tbl = self.table.sort_values(SORT_KEYS)
15031509

@@ -1509,7 +1515,7 @@ def test_to_frame(self):
15091515
self.assertEqual(len(tblgrp), 428)
15101516
self.assertEqual(tblgrp.index.names, [None])
15111517

1512-
@unittest.skipIf(int(pd.__version__.split('.')[1]) < 16, 'Need newer version of Pandas')
1518+
@unittest.skipIf(pd_version < (0, 16, 0), 'Need newer version of Pandas')
15131519
def test_column_to_series(self):
15141520
tbl = self.table.sort_values(SORT_KEYS)
15151521

@@ -1521,7 +1527,7 @@ def test_column_to_series(self):
15211527
self.assertEqual(len(tblgrp), 428)
15221528
self.assertEqual(tblgrp.index.names, ['Origin'])
15231529

1524-
@unittest.skipIf(int(pd.__version__.split('.')[1]) < 16, 'Need newer version of Pandas')
1530+
@unittest.skipIf(pd_version < (0, 16, 0), 'Need newer version of Pandas')
15251531
def test_to_series(self):
15261532
tbl = self.table.sort_values(SORT_KEYS)
15271533

0 commit comments

Comments
 (0)