|
| 1 | +"""Datasets util for downloading and maintaining sample datasets.""" |
| 2 | + |
| 3 | +import requests |
| 4 | +import pandas as pd |
| 5 | +import os |
| 6 | +import tempfile |
| 7 | +import logging |
| 8 | +from sklearn.model_selection import train_test_split |
| 9 | + |
| 10 | +# Configure logging |
| 11 | +logging.basicConfig(level=logging.INFO) |
| 12 | +logger = logging.getLogger(__name__) |
| 13 | + |
| 14 | +# Use temp directory to store cached datasets |
| 15 | +CACHE_DIR = tempfile.gettempdir() |
| 16 | + |
| 17 | +# Ensure cache directory exists |
| 18 | +os.makedirs(CACHE_DIR, exist_ok=True) |
| 19 | + |
| 20 | +# Dataset urls |
| 21 | +PLATYPUS_URL = "https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/1152c0b9613c2bda144a38fc4f74b5fe12255f4d/platypus_diseases.csv" |
| 22 | +HIERARCHICAL_TEXT_CLASSIFICATION_URL = ( |
| 23 | + "https://zenodo.org/record/6657410/files/train_40k.csv?download=1" |
| 24 | +) |
| 25 | + |
| 26 | + |
| 27 | +def _download_file(url, destination): |
| 28 | + """Download file from given URL to specified destination.""" |
| 29 | + try: |
| 30 | + response = requests.get(url) |
| 31 | + # Raise HTTPError if response code is not OK |
| 32 | + response.raise_for_status() |
| 33 | + with open(destination, "wb") as f: |
| 34 | + f.write(response.content) |
| 35 | + except requests.RequestException as e: |
| 36 | + raise RuntimeError(f"Failed to download file from {url}: {str(e)}") |
| 37 | + |
| 38 | + |
| 39 | +def load_platypus(test_size=0.3, random_state=42): |
| 40 | + """ |
| 41 | + Load platypus diseases dataset. |
| 42 | +
|
| 43 | + Parameters |
| 44 | + ---------- |
| 45 | + test_size : float, default=0.3 |
| 46 | + The proportion of the dataset to include in the test split. |
| 47 | + random_state : int or None, default=42 |
| 48 | + Controls the randomness of the dataset. Pass an int for reproducible output across multiple function calls. |
| 49 | +
|
| 50 | + Returns |
| 51 | + ------- |
| 52 | + list |
| 53 | + List containing train-test split of inputs. |
| 54 | +
|
| 55 | + Raises |
| 56 | + ------ |
| 57 | + RuntimeError |
| 58 | + If failed to access or process the dataset. |
| 59 | + Examples |
| 60 | + -------- |
| 61 | + >>> from hiclass.datasets import load_platypus |
| 62 | + >>> X_train, X_test, Y_train, Y_test = load_platypus() |
| 63 | + >>> X_train[:3] |
| 64 | + fever diarrhea stomach pain skin rash cough sniffles short breath headache size |
| 65 | + 220 37.8 0 3 5 1 1 0 2 27.6 |
| 66 | + 539 37.2 0 6 1 1 1 0 3 28.4 |
| 67 | + 326 39.9 0 2 5 1 1 1 2 30.7 |
| 68 | + >>> X_train.shape, X_test.shape, Y_train.shape, Y_test.shape |
| 69 | + (572, 9) (246, 9) (572,) (246,) |
| 70 | + """ |
| 71 | + dataset_name = "platypus_diseases.csv" |
| 72 | + cached_file_path = os.path.join(CACHE_DIR, dataset_name) |
| 73 | + |
| 74 | + # Check if the file exists in the cache |
| 75 | + if not os.path.exists(cached_file_path): |
| 76 | + try: |
| 77 | + logger.info("Downloading platypus diseases dataset..") |
| 78 | + _download_file(PLATYPUS_URL, cached_file_path) |
| 79 | + except Exception as e: |
| 80 | + raise RuntimeError(f"Failed to access or download dataset: {str(e)}") |
| 81 | + |
| 82 | + data = pd.read_csv(cached_file_path).fillna(" ") |
| 83 | + X = data.drop(["label"], axis=1) |
| 84 | + y = pd.Series([eval(val) for val in data["label"]]) |
| 85 | + |
| 86 | + # Return tuple (X_train, X_test, y_train, y_test) |
| 87 | + return train_test_split(X, y, test_size=test_size, random_state=random_state) |
| 88 | + |
| 89 | + |
| 90 | +def load_hierarchical_text_classification(test_size=0.3, random_state=42): |
| 91 | + """ |
| 92 | + Load hierarchical text classification dataset. |
| 93 | +
|
| 94 | + Parameters |
| 95 | + ---------- |
| 96 | + test_size : float, default=0.3 |
| 97 | + The proportion of the dataset to include in the test split. |
| 98 | + random_state : int or None, default=42 |
| 99 | + Controls the randomness of the dataset. Pass an int for reproducible output across multiple function calls. |
| 100 | +
|
| 101 | + Returns |
| 102 | + ------- |
| 103 | + list |
| 104 | + List containing train-test split of inputs. |
| 105 | +
|
| 106 | + Raises |
| 107 | + ------ |
| 108 | + RuntimeError |
| 109 | + If failed to access or process the dataset. |
| 110 | + Examples |
| 111 | + -------- |
| 112 | + >>> from hiclass.datasets import load_hierarchical_text_classification |
| 113 | + >>> X_train, X_test, Y_train, Y_test = load_hierarchical_text_classification() |
| 114 | + >>> X_train[:3] |
| 115 | + 38015 Nature's Way Selenium |
| 116 | + 2281 Music In Motion Developmental Mobile W Remote |
| 117 | + 36629 Twinings Ceylon Orange Pekoe Tea, Tea Bags, 20... |
| 118 | + Name: Title, dtype: object |
| 119 | + >>> X_train.shape, X_test.shape, Y_train.shape, Y_test.shape |
| 120 | + (28000,) (12000,) (28000, 3) (12000, 3) |
| 121 | + """ |
| 122 | + dataset_name = "hierarchical_text_classification.csv" |
| 123 | + cached_file_path = os.path.join(CACHE_DIR, dataset_name) |
| 124 | + |
| 125 | + # Check if the file exists in the cache |
| 126 | + if not os.path.exists(cached_file_path): |
| 127 | + try: |
| 128 | + logger.info("Downloading hierarchical text classification dataset..") |
| 129 | + _download_file(HIERARCHICAL_TEXT_CLASSIFICATION_URL, cached_file_path) |
| 130 | + except Exception as e: |
| 131 | + raise RuntimeError(f"Failed to access or download dataset: {str(e)}") |
| 132 | + |
| 133 | + data = pd.read_csv(cached_file_path).fillna(" ") |
| 134 | + X = data["Title"] |
| 135 | + y = data[["Cat1", "Cat2", "Cat3"]] |
| 136 | + |
| 137 | + # Return tuple (X_train, X_test, y_train, y_test) |
| 138 | + return train_test_split(X, y, test_size=test_size, random_state=random_state) |
0 commit comments