Skip to content

Commit 0e60e63

Browse files
authored
Make ray an optional depencency (#45)
* Make `ray` an optional depencency Due to the large size of ray and all its dependencies, make it optional for deployments where it's not needed. hiclass will fallback to joblib.Parallel if ray is not installed and `n_jobs` is not 1.
1 parent 00a9adb commit 0e60e63

File tree

8 files changed

+431
-388
lines changed

8 files changed

+431
-388
lines changed

Pipfile

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ name = "pypi"
77
networkx = "*"
88
numpy = "*"
99
scikit-learn = "*"
10-
ray = "*"
1110

1211
[dev-packages]
1312
pytest = "*"
@@ -17,3 +16,6 @@ pytest-cov = "*"
1716
twine = "*"
1817
sphinx = "4.1.1"
1918
sphinx-rtd-theme = "0.5.2"
19+
20+
[extras]
21+
ray = "*"

Pipfile.lock

Lines changed: 307 additions & 366 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

hiclass/HierarchicalClassifier.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44

55
import networkx as nx
66
import numpy as np
7-
import ray
7+
8+
try:
9+
import ray
10+
except ImportError:
11+
_has_ray = False
12+
from joblib import Parallel, delayed, effective_n_jobs
13+
else:
14+
_has_ray = True
815
from sklearn.base import BaseEstimator
916
from sklearn.linear_model import LogisticRegression
1017

@@ -291,13 +298,21 @@ def _remove_separator(self, y):
291298

292299
def _fit_node_classifier(self, nodes, local_mode):
293300
if self.n_jobs > 1:
294-
ray.init(
295-
num_cpus=self.n_jobs, local_mode=local_mode, ignore_reinit_error=True
296-
)
297-
lcppn = ray.put(self)
298-
_parallel_fit = ray.remote(self._fit_classifier)
299-
results = [_parallel_fit.remote(lcppn, node) for node in nodes]
300-
classifiers = ray.get(results)
301+
if _has_ray:
302+
ray.init(
303+
num_cpus=self.n_jobs,
304+
local_mode=local_mode,
305+
ignore_reinit_error=True,
306+
)
307+
lcppn = ray.put(self)
308+
_parallel_fit = ray.remote(self._fit_classifier)
309+
results = [_parallel_fit.remote(lcppn, node) for node in nodes]
310+
classifiers = ray.get(results)
311+
else:
312+
classifiers = Parallel(n_jobs=self.n_jobs)(
313+
delayed(self._fit_classifier)(self, node) for node in nodes
314+
)
315+
301316
else:
302317
classifiers = [self._fit_classifier(self, node) for node in nodes]
303318
for classifier, node in zip(classifiers, nodes):

hiclass/LocalClassifierPerLevel.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@
66
from copy import deepcopy
77

88
import numpy as np
9-
import ray
9+
10+
try:
11+
import ray
12+
except ImportError:
13+
_has_ray = False
14+
from joblib import Parallel, delayed, effective_n_jobs
15+
else:
16+
_has_ray = True
1017
from sklearn.base import BaseEstimator
1118
from sklearn.utils.validation import check_array, check_is_fitted
1219

@@ -194,16 +201,24 @@ def _initialize_local_classifiers(self):
194201
def _fit_digraph(self, local_mode: bool = False):
195202
self.logger_.info("Fitting local classifiers")
196203
if self.n_jobs > 1:
197-
ray.init(
198-
num_cpus=self.n_jobs, local_mode=local_mode, ignore_reinit_error=True
199-
)
200-
lcpl = ray.put(self)
201-
_parallel_fit = ray.remote(self._fit_classifier)
202-
results = [
203-
_parallel_fit.remote(lcpl, level, self.separator_)
204-
for level in range(len(self.local_classifiers_))
205-
]
206-
classifiers = ray.get(results)
204+
if _has_ray:
205+
ray.init(
206+
num_cpus=self.n_jobs,
207+
local_mode=local_mode,
208+
ignore_reinit_error=True,
209+
)
210+
lcpl = ray.put(self)
211+
_parallel_fit = ray.remote(self._fit_classifier)
212+
results = [
213+
_parallel_fit.remote(lcpl, level, self.separator_)
214+
for level in range(len(self.local_classifiers_))
215+
]
216+
classifiers = ray.get(results)
217+
else:
218+
classifiers = Parallel(n_jobs=self.n_jobs)(
219+
delayed(self._fit_classifier)(self, level, self.separator_)
220+
for level in range(len(self.local_classifiers_))
221+
)
207222
else:
208223
classifiers = [
209224
self._fit_classifier(self, level, self.separator_)

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
KEYWORDS = ["hierarchical classification"]
2727
DACS_SOFTWARE = "https://gitlab.com/dacs-hpi"
2828
# What packages are required for this module to be executed?
29-
REQUIRED = ["networkx", "numpy", "scikit-learn", "ray", "protobuf<4.0"]
29+
REQUIRED = ["networkx", "numpy", "scikit-learn", "protobuf<4.0"]
3030

3131
# What packages are optional?
3232
# 'fancy feature': ['django'],}
33-
EXTRAS = {}
33+
EXTRAS = {"ray": ["ray>=1.11.0"]}
3434

3535
# The rest you shouldn't have to touch too much :)
3636
# ------------------------------------------------

tests/test_LocalClassifierPerLevel.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,28 @@ def test_fit_digraph(digraph_logistic_regression):
5959
assert 1
6060

6161

62+
def test_fit_digraph_joblib_multiprocessing(digraph_logistic_regression):
63+
from joblib import Parallel, delayed
64+
65+
LocalClassifierPerLevel._has_ray = False
66+
67+
classifiers = [
68+
LogisticRegression(),
69+
LogisticRegression(),
70+
]
71+
digraph_logistic_regression.n_jobs = 2
72+
digraph_logistic_regression.local_classifiers_ = classifiers
73+
from joblib import Parallel, delayed, effective_n_jobs
74+
75+
digraph_logistic_regression._fit_digraph(local_mode=True)
76+
for classifier in digraph_logistic_regression.local_classifiers_:
77+
try:
78+
check_is_fitted(classifier)
79+
except NotFittedError as e:
80+
pytest.fail(repr(e))
81+
assert 1
82+
83+
6284
def test_fit_1_class():
6385
lcpl = LocalClassifierPerLevel(local_classifier=LogisticRegression(), n_jobs=2)
6486
y = np.array([["1", "2"]])

tests/test_LocalClassifierPerNode.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,30 @@ def test_fit_digraph(digraph_logistic_regression):
111111
assert 1
112112

113113

114+
def test_fit_digraph_joblib_multiprocessing(digraph_logistic_regression):
115+
from joblib import Parallel, delayed
116+
117+
LocalClassifierPerNode._has_ray = False
118+
119+
classifiers = {
120+
"b": {"classifier": LogisticRegression()},
121+
"c": {"classifier": LogisticRegression()},
122+
}
123+
digraph_logistic_regression.n_jobs = 2
124+
nx.set_node_attributes(digraph_logistic_regression.hierarchy_, classifiers)
125+
digraph_logistic_regression._fit_digraph(local_mode=True)
126+
with pytest.raises(KeyError):
127+
check_is_fitted(digraph_logistic_regression.hierarchy_.nodes["a"]["classifier"])
128+
for node in ["b", "c"]:
129+
try:
130+
check_is_fitted(
131+
digraph_logistic_regression.hierarchy_.nodes[node]["classifier"]
132+
)
133+
except NotFittedError as e:
134+
pytest.fail(repr(e))
135+
assert 1
136+
137+
114138
def test_fit_1_class():
115139
lcpn = LocalClassifierPerNode(local_classifier=LogisticRegression(), n_jobs=2)
116140
y = np.array([["1", "2"]])

tests/test_LocalClassifierPerParentNode.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import tempfile
3+
import builtins
34

45
import networkx as nx
56
import numpy as np
@@ -66,6 +67,29 @@ def test_fit_digraph(digraph_logistic_regression):
6667
assert 1
6768

6869

70+
def test_fit_digraph_joblib_multiprocessing(digraph_logistic_regression):
71+
from joblib import Parallel, delayed
72+
73+
LocalClassifierPerParentNode._has_ray = False
74+
75+
classifiers = {
76+
"a": {"classifier": LogisticRegression()},
77+
}
78+
digraph_logistic_regression.n_jobs = 2
79+
nx.set_node_attributes(digraph_logistic_regression.hierarchy_, classifiers)
80+
digraph_logistic_regression._fit_digraph(local_mode=True)
81+
try:
82+
check_is_fitted(digraph_logistic_regression.hierarchy_.nodes["a"]["classifier"])
83+
except NotFittedError as e:
84+
pytest.fail(repr(e))
85+
for node in ["b", "c"]:
86+
with pytest.raises(KeyError):
87+
check_is_fitted(
88+
digraph_logistic_regression.hierarchy_.nodes[node]["classifier"]
89+
)
90+
assert 1
91+
92+
6993
def test_fit_1_class():
7094
lcppn = LocalClassifierPerParentNode(
7195
local_classifier=LogisticRegression(), n_jobs=2

0 commit comments

Comments
 (0)