Skip to content

Commit 57b0ab3

Browse files
Merge pull request #119 from TwsThomas/fix_examples
[WIP] fix bug in examples
2 parents 3695d51 + e6ea1ed commit 57b0ab3

File tree

5 files changed

+33
-20
lines changed

5 files changed

+33
-20
lines changed

CHANGES.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
Release 0.0.7
2+
=============
3+
* **datasets.fetch_employee_salaries**: change the origin of download for employee_salaries.
4+
- The function now return a bunch with a dataframe under the field "data",
5+
and not the path to the csv file.
6+
- The field "description" has been renamed to "DESCR".
7+
8+
19
Release 0.0.6
210
=============
311
* **SimilarityEncoder**: Fixed a bug when using the Jaro-Winkler distance as a

dirty_cat/datasets/fetching.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from collections import namedtuple
1414
import contextlib
1515
import warnings
16+
from sklearn.datasets import fetch_openml
1617

1718
from ..datasets.utils import md5_hash, _check_if_exists, \
1819
_uncompress_file, \
@@ -346,18 +347,23 @@ def fetch_employee_salaries():
346347
dict
347348
a dictionary containing:
348349
349-
- a short description of the dataset (under the ``description``
350+
- a short description of the dataset (under the ``DESCR``
350351
key)
351-
- an absolute path leading to the csv file where the data is stored
352-
locally (under the ``path`` key)
352+
- the tabular data (under the ``data`` key)
353+
- the target (under the ``target`` key)
353354
354355
References
355356
----------
356357
https://catalog.data.gov/dataset/employee-salaries-2016
357358
358359
"""
359360

360-
return fetch_dataset(EMPLOYEE_SALARIES_CONFIG, show_progress=False)
361+
data = fetch_openml(data_id=42125, as_frame=True)
362+
data.data['Current Annual Salary'] = data['target']
363+
return data
364+
365+
# link dead.
366+
# return fetch_dataset(EMPLOYEE_SALARIES_CONFIG, show_progress=False)
361367

362368

363369
def fetch_road_safety():

examples/01_investigating_dirty_categories.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from dirty_cat import datasets
1616

1717
employee_salaries = datasets.fetch_employee_salaries()
18-
print(employee_salaries['description'])
19-
data = pd.read_csv(employee_salaries['path'])
18+
print(employee_salaries['DESCR'])
19+
data = employee_salaries['data']
2020
print(data.head(n=5))
2121

2222
#########################################################################
@@ -25,7 +25,7 @@
2525

2626
#########################################################################
2727
# As we can see, some entries have many different unique values:
28-
print(data['Employee Position Title'].value_counts().sort_index())
28+
print(data['employee_position_title'].value_counts().sort_index())
2929

3030
#########################################################################
3131
# These different entries are often variations on the same entities:
@@ -47,7 +47,7 @@
4747
# To simplify understanding, we will focus on the column describing the
4848
# employee's position title:
4949
# data
50-
values = data[['Employee Position Title', 'Gender', 'Current Annual Salary']]
50+
values = data[['employee_position_title', 'gender', 'Current Annual Salary']]
5151

5252
#########################################################################
5353
# String similarity between entries
@@ -56,7 +56,7 @@
5656
# That's where our encoders get into play. In order to robustly
5757
# embed dirty semantic data, the SimilarityEncoder creates a similarity
5858
# matrix based on the 3-gram structure of the data.
59-
sorted_values = values['Employee Position Title'].sort_values().unique()
59+
sorted_values = values['employee_position_title'].sort_values().unique()
6060

6161
from dirty_cat import SimilarityEncoder
6262

@@ -142,7 +142,7 @@
142142

143143
# encoding simply a subset of the observations
144144
n_obs = 20
145-
employee_position_titles = values['Employee Position Title'].head(
145+
employee_position_titles = values['employee_position_title'].head(
146146
n_obs).to_frame()
147147
categorical_encoder = OneHotEncoder(sparse=False)
148148
one_hot_encoded = categorical_encoder.fit_transform(employee_position_titles)

examples/02_fit_predict_plot_employee_salaries.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,17 @@
2121
# We first download the dataset:
2222
from dirty_cat.datasets import fetch_employee_salaries
2323
employee_salaries = fetch_employee_salaries()
24-
print(employee_salaries['description'])
24+
print(employee_salaries['DESCR'])
25+
2526

2627
################################################################################
2728
# Then we load it:
2829
import pandas as pd
29-
df = pd.read_csv(employee_salaries['path']).astype(str)
30+
df = employee_salaries['data']
3031

3132
################################################################################
3233
# Now, let's carry out some basic preprocessing:
33-
df['Current Annual Salary'] = df['Current Annual Salary'].str.strip('$').astype(
34-
float)
35-
df['Date First Hired'] = pd.to_datetime(df['Date First Hired'])
34+
df['Date First Hired'] = pd.to_datetime(df['date_first_hired'])
3635
df['Year First Hired'] = df['Date First Hired'].apply(lambda x: x.year)
3736

3837
target_column = 'Current Annual Salary'
@@ -45,17 +44,17 @@
4544
# use one hot encoding to transform them:
4645

4746
clean_columns = {
48-
'Gender': 'one-hot',
49-
'Department Name': 'one-hot',
50-
'Assignment Category': 'one-hot',
47+
'gender': 'one-hot',
48+
'department_name': 'one-hot',
49+
'assignment_category': 'one-hot',
5150
'Year First Hired': 'numerical'}
5251

5352
#########################################################################
5453
# We then choose the categorical encoding methods we want to benchmark
5554
# and the dirty categorical variable:
5655

5756
encoding_methods = ['one-hot', 'target', 'similarity']
58-
dirty_column = 'Employee Position Title'
57+
dirty_column = 'employee_position_title'
5958
#########################################################################
6059

6160

examples/04_dimension_reduction_and_performance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def wrapped_func(*args, **kwargs):
3434
max_usage=True,
3535
retval=True)
3636
print("Run time: %.1is Memory used: %iMb"
37-
% (time() - t0, mem[0]))
37+
% (time() - t0, mem))
3838
return out
3939

4040
return wrapped_func

0 commit comments

Comments
 (0)