Skip to content

Commit cad0bdf

Browse files
authored
Add binary policies to documentation (#55)
* Add binary policies to documentation * Increase test coverage * Fix if else for ray verification * Specify version of dependencies for tests
1 parent 8c94039 commit cad0bdf

11 files changed

+54
-42
lines changed

.github/workflows/deploy-pypi.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@ jobs:
2323
- name: Install dependencies
2424
run: |
2525
python -m pip install --upgrade pip
26-
python -m pip install pytest
27-
python -m pip install pytest-flake8
28-
python -m pip install pytest-pydocstyle
29-
python -m pip install pytest-cov
26+
python -m pip install flake8==4.0.1
27+
python -m pip install pytest==7.1.2
28+
python -m pip install pytest-flake8==1.1.1
29+
python -m pip install pytest-pydocstyle==2.3.0
30+
python -m pip install pytest-cov==3.0.0
3031
python -m pip install .
3132
- name: Test with pytest
3233
run: |

.github/workflows/test-pr.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,11 @@ jobs:
2424
- name: Install dependencies
2525
run: |
2626
python -m pip install --upgrade pip
27-
python -m pip install pytest
28-
python -m pip install pytest-flake8
29-
python -m pip install pytest-pydocstyle
30-
python -m pip install pytest-cov
27+
python -m pip install flake8==4.0.1
28+
python -m pip install pytest==7.1.2
29+
python -m pip install pytest-flake8==1.1.1
30+
python -m pip install pytest-pydocstyle==2.3.0
31+
python -m pip install pytest-cov==3.0.0
3132
python -m pip install .
3233
- name: Test with pytest
3334
run: |

CONTRIBUTING.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@ Please make sure all tests pass before submitting a pull request. It is also goo
1010

1111
## Testing the code locally
1212

13-
To test the code locally you need to install the dependencies for the library in the current environment. Additionally, you need to install the following dependencies for testing:
13+
To test the code locally you need to install the dependencies for the library in the current environment. Additionally, you need to install the dependencies for testing. All of those dependencies can be installed with:
1414

1515
```
16-
pip install pytest
17-
pip install pytest-flake8
18-
pip install pytest-pydocstyle
19-
pip install pytest-cov
16+
pip install flake8==4.0.1
17+
pip install pytest==7.1.2
18+
pip install pytest-flake8==1.1.1
19+
pip install pytest-pydocstyle==2.3.0
20+
pip install pytest-cov==3.0.0
2021
pip install -e .
2122
```
2223

docs/examples/plot_parallel_training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Larger datasets require more time for training.
88
While by default the models in HiClass are trained using a single core,
99
it is possible to train each local classifier in parallel by leveraging the library Ray [1]_.
10+
If Ray is not installed, the parallelism defaults to Joblib.
1011
In this example, we demonstrate how to train a hierarchical classifier in parallel by
1112
setting the parameter :literal:`n_jobs` to use all the cores available. Training
1213
is performed on a mock dataset from Kaggle [2]_.

hiclass/HierarchicalClassifier.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44

55
import networkx as nx
66
import numpy as np
7+
from joblib import Parallel, delayed
8+
from sklearn.base import BaseEstimator
9+
from sklearn.linear_model import LogisticRegression
710

811
try:
912
import ray
1013
except ImportError:
1114
_has_ray = False
12-
from joblib import Parallel, delayed, effective_n_jobs
1315
else:
1416
_has_ray = True
15-
from sklearn.base import BaseEstimator
16-
from sklearn.linear_model import LogisticRegression
1717

1818

1919
def make_leveled(y):
@@ -85,6 +85,7 @@ def __init__(
8585
a single unique class.
8686
n_jobs : int, default=1
8787
The number of jobs to run in parallel. Only :code:`fit` is parallelized.
88+
If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`.
8889
classifier_abbreviation : str, default=""
8990
The abbreviation of the local hierarchical classifier to be displayed during logging.
9091
"""
@@ -296,9 +297,11 @@ def _remove_separator(self, y):
296297
for j in range(1, y.shape[1]):
297298
y[i, j] = y[i, j].split(self.separator_)[-1]
298299

299-
def _fit_node_classifier(self, nodes, local_mode):
300+
def _fit_node_classifier(
301+
self, nodes, local_mode: bool = False, use_joblib: bool = False
302+
):
300303
if self.n_jobs > 1:
301-
if _has_ray:
304+
if _has_ray and not use_joblib:
302305
ray.init(
303306
num_cpus=self.n_jobs,
304307
local_mode=local_mode,

hiclass/LocalClassifierPerLevel.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,19 @@
66
from copy import deepcopy
77

88
import numpy as np
9+
from joblib import Parallel, delayed
10+
from sklearn.base import BaseEstimator
11+
from sklearn.utils.validation import check_array, check_is_fitted
12+
13+
from hiclass.ConstantClassifier import ConstantClassifier
14+
from hiclass.HierarchicalClassifier import HierarchicalClassifier
915

1016
try:
1117
import ray
1218
except ImportError:
1319
_has_ray = False
14-
from joblib import Parallel, delayed, effective_n_jobs
1520
else:
1621
_has_ray = True
17-
from sklearn.base import BaseEstimator
18-
from sklearn.utils.validation import check_array, check_is_fitted
19-
20-
from hiclass.ConstantClassifier import ConstantClassifier
21-
from hiclass.HierarchicalClassifier import HierarchicalClassifier
2222

2323

2424
class LocalClassifierPerLevel(BaseEstimator, HierarchicalClassifier):
@@ -67,6 +67,7 @@ def __init__(
6767
a single unique class.
6868
n_jobs : int, default=1
6969
The number of jobs to run in parallel. Only :code:`fit` is parallelized.
70+
If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`.
7071
"""
7172
super().__init__(
7273
local_classifier=local_classifier,
@@ -198,10 +199,10 @@ def _initialize_local_classifiers(self):
198199
]
199200
self.masks_ = [None for _ in range(self.y_.shape[1])]
200201

201-
def _fit_digraph(self, local_mode: bool = False):
202+
def _fit_digraph(self, local_mode: bool = False, use_joblib: bool = False):
202203
self.logger_.info("Fitting local classifiers")
203204
if self.n_jobs > 1:
204-
if _has_ray:
205+
if _has_ray and not use_joblib:
205206
ray.init(
206207
num_cpus=self.n_jobs,
207208
local_mode=local_mode,

hiclass/LocalClassifierPerNode.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,17 @@ def __init__(
5151
local_classifier : BaseEstimator, default=LogisticRegression
5252
The local_classifier used to create the collection of local classifiers. Needs to have fit, predict and
5353
clone methods.
54-
binary_policy : str, default="siblings"
55-
Rules for defining positive and negative training examples.
54+
binary_policy : {"exclusive", "less_exclusive", "exclusive_siblings", "inclusive", "less_inclusive", "siblings"}, str, default="siblings"
55+
Specify the rule for defining positive and negative training examples, using one of the following options:
56+
57+
- `exclusive`: Positive examples belong only to the class being considered. All classes are negative examples, except for the selected class;
58+
- `less_exclusive`: Positive examples belong only to the class being considered. All classes are negative examples, except for the selected class and its descendants;
59+
- `exclusive_siblings`: Positive examples belong only to the class being considered. All sibling classes are negative examples;
60+
- `inclusive`: Positive examples belong only to the class being considered and its descendants. All classes are negative examples, except for the selected class, its descendants and ancestors;
61+
- `less_inclusive`: Positive examples belong only to the class being considered and its descendants. All classes are negative examples, except for the selected class and its descendants;
62+
- `siblings`: Positive examples belong only to the class being considered and its descendants. All siblings and their descendant classes are negative examples.
63+
64+
See :ref:`Training Policies` for more information about the different policies.
5665
verbose : int, default=0
5766
Controls the verbosity when fitting and predicting.
5867
See https://verboselogs.readthedocs.io/en/latest/readme.html#overview-of-logging-levels
@@ -64,6 +73,7 @@ def __init__(
6473
a single unique class.
6574
n_jobs : int, default=1
6675
The number of jobs to run in parallel. Only :code:`fit` is parallelized.
76+
If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`.
6777
"""
6878
super().__init__(
6979
local_classifier=local_classifier,
@@ -206,12 +216,12 @@ def _initialize_local_classifiers(self):
206216
}
207217
nx.set_node_attributes(self.hierarchy_, local_classifiers)
208218

209-
def _fit_digraph(self, local_mode: bool = False):
219+
def _fit_digraph(self, local_mode: bool = False, use_joblib: bool = False):
210220
self.logger_.info("Fitting local classifiers")
211221
nodes = list(self.hierarchy_.nodes)
212222
# Remove root because it does not need to be fitted
213223
nodes.remove(self.root_)
214-
self._fit_node_classifier(nodes, local_mode)
224+
self._fit_node_classifier(nodes, local_mode, use_joblib)
215225

216226
@staticmethod
217227
def _fit_classifier(self, node):

hiclass/LocalClassifierPerParentNode.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
a single unique class.
6161
n_jobs : int, default=1
6262
The number of jobs to run in parallel. Only :code:`fit` is parallelized.
63+
If :code:`Ray` is installed it is used, otherwise it defaults to :code:`Joblib`.
6364
"""
6465
super().__init__(
6566
local_classifier=local_classifier,
@@ -199,7 +200,7 @@ def _fit_classifier(self, node):
199200
classifier.fit(X, y)
200201
return classifier
201202

202-
def _fit_digraph(self, local_mode: bool = False):
203+
def _fit_digraph(self, local_mode: bool = False, use_joblib: bool = False):
203204
self.logger_.info("Fitting local classifiers")
204205
nodes = self._get_parents()
205-
self._fit_node_classifier(nodes, local_mode)
206+
self._fit_node_classifier(nodes, local_mode, use_joblib)

tests/test_LocalClassifierPerLevel.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from sklearn.linear_model import LogisticRegression
1010
from sklearn.utils.estimator_checks import parametrize_with_checks
1111
from sklearn.utils.validation import check_is_fitted
12-
1312
from hiclass import LocalClassifierPerLevel
1413

1514

@@ -60,16 +59,14 @@ def test_fit_digraph(digraph_logistic_regression):
6059

6160

6261
def test_fit_digraph_joblib_multiprocessing(digraph_logistic_regression):
63-
LocalClassifierPerLevel._has_ray = False
64-
6562
classifiers = [
6663
LogisticRegression(),
6764
LogisticRegression(),
6865
]
6966
digraph_logistic_regression.n_jobs = 2
7067
digraph_logistic_regression.local_classifiers_ = classifiers
7168

72-
digraph_logistic_regression._fit_digraph(local_mode=True)
69+
digraph_logistic_regression._fit_digraph(local_mode=True, use_joblib=True)
7370
for classifier in digraph_logistic_regression.local_classifiers_:
7471
try:
7572
check_is_fitted(classifier)

tests/test_LocalClassifierPerNode.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,15 +112,13 @@ def test_fit_digraph(digraph_logistic_regression):
112112

113113

114114
def test_fit_digraph_joblib_multiprocessing(digraph_logistic_regression):
115-
LocalClassifierPerNode._has_ray = False
116-
117115
classifiers = {
118116
"b": {"classifier": LogisticRegression()},
119117
"c": {"classifier": LogisticRegression()},
120118
}
121119
digraph_logistic_regression.n_jobs = 2
122120
nx.set_node_attributes(digraph_logistic_regression.hierarchy_, classifiers)
123-
digraph_logistic_regression._fit_digraph(local_mode=True)
121+
digraph_logistic_regression._fit_digraph(local_mode=True, use_joblib=True)
124122
with pytest.raises(KeyError):
125123
check_is_fitted(digraph_logistic_regression.hierarchy_.nodes["a"]["classifier"])
126124
for node in ["b", "c"]:

0 commit comments

Comments
 (0)