Skip to content

Commit 4595264

Browse files
added datasets module, test cases and updated Read the docs examples (#117)
1 parent c436275 commit 4595264

File tree

6 files changed

+286
-34
lines changed

6 files changed

+286
-34
lines changed

docs/examples/plot_lcppn_explainer.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,11 @@
1010
"""
1111
from sklearn.ensemble import RandomForestClassifier
1212
from hiclass import LocalClassifierPerParentNode, Explainer
13-
import requests
14-
import pandas as pd
1513
import shap
14+
from hiclass.datasets import load_platypus
1615

17-
# Download training data
18-
url = "https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/1152c0b9613c2bda144a38fc4f74b5fe12255f4d/platypus_diseases.csv"
19-
path = "platypus_diseases.csv"
20-
response = requests.get(url)
21-
with open(path, "wb") as file:
22-
file.write(response.content)
23-
24-
# Load training data into pandas dataframe
25-
training_data = pd.read_csv(path).fillna(" ")
26-
27-
# Define data
28-
X_train = training_data.drop(["label"], axis=1)
29-
X_test = X_train[:100] # Use first 100 samples as test set
30-
Y_train = training_data["label"]
31-
Y_train = [eval(my) for my in Y_train]
16+
# Load train and test splits
17+
X_train, X_test, Y_train, Y_test = load_platypus()
3218

3319
# Use random forest classifiers for every node
3420
rfc = RandomForestClassifier()

docs/examples/plot_parallel_training.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,15 @@
1717
"""
1818
import sys
1919
from os import cpu_count
20-
21-
import pandas as pd
22-
import requests
2320
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
2421
from sklearn.linear_model import LogisticRegression
2522
from sklearn.pipeline import Pipeline
2623

2724
from hiclass import LocalClassifierPerParentNode
25+
from hiclass.datasets import load_hierarchical_text_classification
2826

29-
30-
# Download training data
31-
url = "https://zenodo.org/record/6657410/files/train_40k.csv?download=1"
32-
path = "train_40k.csv"
33-
response = requests.get(url)
34-
with open(path, "wb") as file:
35-
file.write(response.content)
36-
37-
# Load training data into pandas dataframe
38-
training_data = pd.read_csv(path).fillna(" ")
27+
# Load train and test splits
28+
X_train, X_test, Y_train, Y_test = load_hierarchical_text_classification()
3929

4030
# We will use logistic regression classifiers for every parent node
4131
lr = LogisticRegression(max_iter=1000)
@@ -51,10 +41,6 @@
5141
]
5242
)
5343

54-
# Select training data
55-
X_train = training_data["Title"]
56-
Y_train = training_data[["Cat1", "Cat2", "Cat3"]]
57-
5844
# Fixes bug AttributeError: '_LoggingTee' object has no attribute 'fileno'
5945
# This only happens when building the documentation
6046
# Hence, you don't actually need it for your code to work

docs/source/api/utilities.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,23 @@ F-score
8888
^^^^^^^
8989

9090
.. autofunction:: metrics.f1
91+
92+
..................................
93+
94+
95+
Datasets
96+
----------
97+
98+
Platypus diseases dataset
99+
^^^^^^^^^^^^^^^^^^^^^^^^^^
100+
101+
.. autofunction:: datasets.load_platypus
102+
103+
..................................
104+
105+
Hierarchical text classification dataset
106+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
107+
108+
.. autofunction:: datasets.load_hierarchical_text_classification
109+
110+
..................................

hiclass/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@
2222
"Explainer",
2323
"MultiLabelLocalClassifierPerNode",
2424
"MultiLabelLocalClassifierPerParentNode",
25+
"datasets",
2526
]

hiclass/datasets.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
"""Datasets util for downloading and maintaining sample datasets."""
2+
3+
import requests
4+
import pandas as pd
5+
import os
6+
import tempfile
7+
import logging
8+
from sklearn.model_selection import train_test_split
9+
10+
# Configure logging
11+
logging.basicConfig(level=logging.INFO)
12+
logger = logging.getLogger(__name__)
13+
14+
# Use temp directory to store cached datasets
15+
CACHE_DIR = tempfile.gettempdir()
16+
17+
# Ensure cache directory exists
18+
os.makedirs(CACHE_DIR, exist_ok=True)
19+
20+
# Dataset urls
21+
PLATYPUS_URL = "https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/1152c0b9613c2bda144a38fc4f74b5fe12255f4d/platypus_diseases.csv"
22+
HIERARCHICAL_TEXT_CLASSIFICATION_URL = (
23+
"https://zenodo.org/record/6657410/files/train_40k.csv?download=1"
24+
)
25+
26+
27+
def _download_file(url, destination):
28+
"""Download file from given URL to specified destination."""
29+
try:
30+
response = requests.get(url)
31+
# Raise HTTPError if response code is not OK
32+
response.raise_for_status()
33+
with open(destination, "wb") as f:
34+
f.write(response.content)
35+
except requests.RequestException as e:
36+
raise RuntimeError(f"Failed to download file from {url}: {str(e)}")
37+
38+
39+
def load_platypus(test_size=0.3, random_state=42):
40+
"""
41+
Load platypus diseases dataset.
42+
43+
Parameters
44+
----------
45+
test_size : float, default=0.3
46+
The proportion of the dataset to include in the test split.
47+
random_state : int or None, default=42
48+
Controls the randomness of the dataset. Pass an int for reproducible output across multiple function calls.
49+
50+
Returns
51+
-------
52+
list
53+
List containing train-test split of inputs.
54+
55+
Raises
56+
------
57+
RuntimeError
58+
If failed to access or process the dataset.
59+
Examples
60+
--------
61+
>>> from hiclass.datasets import load_platypus
62+
>>> X_train, X_test, Y_train, Y_test = load_platypus()
63+
>>> X_train[:3]
64+
fever diarrhea stomach pain skin rash cough sniffles short breath headache size
65+
220 37.8 0 3 5 1 1 0 2 27.6
66+
539 37.2 0 6 1 1 1 0 3 28.4
67+
326 39.9 0 2 5 1 1 1 2 30.7
68+
>>> X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
69+
(572, 9) (246, 9) (572,) (246,)
70+
"""
71+
dataset_name = "platypus_diseases.csv"
72+
cached_file_path = os.path.join(CACHE_DIR, dataset_name)
73+
74+
# Check if the file exists in the cache
75+
if not os.path.exists(cached_file_path):
76+
try:
77+
logger.info("Downloading platypus diseases dataset..")
78+
_download_file(PLATYPUS_URL, cached_file_path)
79+
except Exception as e:
80+
raise RuntimeError(f"Failed to access or download dataset: {str(e)}")
81+
82+
data = pd.read_csv(cached_file_path).fillna(" ")
83+
X = data.drop(["label"], axis=1)
84+
y = pd.Series([eval(val) for val in data["label"]])
85+
86+
# Return tuple (X_train, X_test, y_train, y_test)
87+
return train_test_split(X, y, test_size=test_size, random_state=random_state)
88+
89+
90+
def load_hierarchical_text_classification(test_size=0.3, random_state=42):
91+
"""
92+
Load hierarchical text classification dataset.
93+
94+
Parameters
95+
----------
96+
test_size : float, default=0.3
97+
The proportion of the dataset to include in the test split.
98+
random_state : int or None, default=42
99+
Controls the randomness of the dataset. Pass an int for reproducible output across multiple function calls.
100+
101+
Returns
102+
-------
103+
list
104+
List containing train-test split of inputs.
105+
106+
Raises
107+
------
108+
RuntimeError
109+
If failed to access or process the dataset.
110+
Examples
111+
--------
112+
>>> from hiclass.datasets import load_hierarchical_text_classification
113+
>>> X_train, X_test, Y_train, Y_test = load_hierarchical_text_classification()
114+
>>> X_train[:3]
115+
38015 Nature's Way Selenium
116+
2281 Music In Motion Developmental Mobile W Remote
117+
36629 Twinings Ceylon Orange Pekoe Tea, Tea Bags, 20...
118+
Name: Title, dtype: object
119+
>>> X_train.shape, X_test.shape, Y_train.shape, Y_test.shape
120+
(28000,) (12000,) (28000, 3) (12000, 3)
121+
"""
122+
dataset_name = "hierarchical_text_classification.csv"
123+
cached_file_path = os.path.join(CACHE_DIR, dataset_name)
124+
125+
# Check if the file exists in the cache
126+
if not os.path.exists(cached_file_path):
127+
try:
128+
logger.info("Downloading hierarchical text classification dataset..")
129+
_download_file(HIERARCHICAL_TEXT_CLASSIFICATION_URL, cached_file_path)
130+
except Exception as e:
131+
raise RuntimeError(f"Failed to access or download dataset: {str(e)}")
132+
133+
data = pd.read_csv(cached_file_path).fillna(" ")
134+
X = data["Title"]
135+
y = data[["Cat1", "Cat2", "Cat3"]]
136+
137+
# Return tuple (X_train, X_test, y_train, y_test)
138+
return train_test_split(X, y, test_size=test_size, random_state=random_state)

tests/test_Datasets.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import numpy as np
2+
import pytest
3+
4+
import hiclass.datasets
5+
from hiclass.datasets import load_platypus, load_hierarchical_text_classification
6+
import os
7+
import tempfile
8+
9+
10+
def test_load_platypus_output_shape():
11+
X_train, X_test, y_train, y_test = load_platypus(test_size=0.2, random_state=42)
12+
assert X_train.shape[0] == y_train.shape[0]
13+
assert X_test.shape[0] == y_test.shape[0]
14+
15+
16+
def test_load_platypus_random_state():
17+
X_train_1, X_test_1, y_train_1, y_test_1 = load_platypus(
18+
test_size=0.2, random_state=42
19+
)
20+
X_train_2, X_test_2, y_train_2, y_test_2 = load_platypus(
21+
test_size=0.2, random_state=42
22+
)
23+
assert (X_train_1.values == X_train_2.values).all()
24+
assert (X_test_1.values == X_test_2.values).all()
25+
assert (y_train_1.index == y_train_2.index).all()
26+
assert (y_test_1.index == y_test_2.index).all()
27+
28+
29+
def test_load_hierarchical_text_classification_shape():
30+
X_train, X_test, y_train, y_test = load_hierarchical_text_classification(
31+
test_size=0.2, random_state=42
32+
)
33+
assert X_train.shape[0] == y_train.shape[0]
34+
assert X_test.shape[0] == y_test.shape[0]
35+
36+
37+
def test_load_hierarchical_text_classification_random_state():
38+
X_train_1, X_test_1, y_train_1, y_test_1 = load_hierarchical_text_classification(
39+
test_size=0.2, random_state=42
40+
)
41+
X_train_2, X_test_2, y_train_2, y_test_2 = load_hierarchical_text_classification(
42+
test_size=0.2, random_state=42
43+
)
44+
assert (X_train_1 == X_train_2).all()
45+
assert (X_test_1 == X_test_2).all()
46+
assert (y_train_1.index == y_train_2.index).all()
47+
assert (y_test_1.index == y_test_2.index).all()
48+
49+
50+
def test_load_hierarchical_text_classification_file_exists():
51+
dataset_name = "hierarchical_text_classification.csv"
52+
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name)
53+
54+
if os.path.exists(cached_file_path):
55+
os.remove(cached_file_path)
56+
57+
if not os.path.exists(cached_file_path):
58+
load_hierarchical_text_classification()
59+
assert os.path.exists(cached_file_path)
60+
61+
62+
def test_load_platypus_file_exists():
63+
dataset_name = "platypus_diseases.csv"
64+
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name)
65+
66+
if os.path.exists(cached_file_path):
67+
os.remove(cached_file_path)
68+
69+
if not os.path.exists(cached_file_path):
70+
load_platypus()
71+
assert os.path.exists(cached_file_path)
72+
73+
74+
def test_download_dataset():
75+
dataset_name = "platypus_diseases_test.csv"
76+
url = hiclass.datasets.PLATYPUS_URL
77+
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name)
78+
79+
if os.path.exists(cached_file_path):
80+
os.remove(cached_file_path)
81+
82+
if not os.path.exists(cached_file_path):
83+
hiclass.datasets._download_file(url, cached_file_path)
84+
assert os.path.exists(cached_file_path)
85+
86+
87+
def test_download_error_load_platypus():
88+
dataset_name = "platypus_diseases.csv"
89+
backup_url = hiclass.datasets.PLATYPUS_URL
90+
hiclass.datasets.PLATYPUS_URL = ""
91+
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name)
92+
93+
if os.path.exists(cached_file_path):
94+
os.remove(cached_file_path)
95+
96+
if not os.path.exists(cached_file_path):
97+
with pytest.raises(RuntimeError):
98+
load_platypus()
99+
100+
hiclass.datasets.PLATYPUS_URL = backup_url
101+
102+
103+
def test_download_error_load_hierarchical_text():
104+
dataset_name = "hierarchical_text_classification.csv"
105+
backup_url = hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL
106+
hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL = ""
107+
cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name)
108+
109+
if os.path.exists(cached_file_path):
110+
os.remove(cached_file_path)
111+
112+
if not os.path.exists(cached_file_path):
113+
with pytest.raises(RuntimeError):
114+
load_hierarchical_text_classification()
115+
116+
hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL = backup_url
117+
118+
119+
def test_url_links():
120+
assert hiclass.datasets.PLATYPUS_URL != ""
121+
assert hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL != ""

0 commit comments

Comments
 (0)