|
6 | 6 | from copy import deepcopy |
7 | 7 |
|
8 | 8 | import numpy as np |
9 | | -import ray |
| 9 | + |
| 10 | +try: |
| 11 | + import ray |
| 12 | +except ImportError: |
| 13 | + _has_ray = False |
| 14 | + from joblib import Parallel, delayed, effective_n_jobs |
| 15 | +else: |
| 16 | + _has_ray = True |
10 | 17 | from sklearn.base import BaseEstimator |
11 | 18 | from sklearn.utils.validation import check_array, check_is_fitted |
12 | 19 |
|
@@ -194,16 +201,24 @@ def _initialize_local_classifiers(self): |
194 | 201 | def _fit_digraph(self, local_mode: bool = False): |
195 | 202 | self.logger_.info("Fitting local classifiers") |
196 | 203 | if self.n_jobs > 1: |
197 | | - ray.init( |
198 | | - num_cpus=self.n_jobs, local_mode=local_mode, ignore_reinit_error=True |
199 | | - ) |
200 | | - lcpl = ray.put(self) |
201 | | - _parallel_fit = ray.remote(self._fit_classifier) |
202 | | - results = [ |
203 | | - _parallel_fit.remote(lcpl, level, self.separator_) |
204 | | - for level in range(len(self.local_classifiers_)) |
205 | | - ] |
206 | | - classifiers = ray.get(results) |
| 204 | + if _has_ray: |
| 205 | + ray.init( |
| 206 | + num_cpus=self.n_jobs, |
| 207 | + local_mode=local_mode, |
| 208 | + ignore_reinit_error=True, |
| 209 | + ) |
| 210 | + lcpl = ray.put(self) |
| 211 | + _parallel_fit = ray.remote(self._fit_classifier) |
| 212 | + results = [ |
| 213 | + _parallel_fit.remote(lcpl, level, self.separator_) |
| 214 | + for level in range(len(self.local_classifiers_)) |
| 215 | + ] |
| 216 | + classifiers = ray.get(results) |
| 217 | + else: |
| 218 | + classifiers = Parallel(n_jobs=self.n_jobs)( |
| 219 | + delayed(self._fit_classifier)(self, level, self.separator_) |
| 220 | + for level in range(len(self.local_classifiers_)) |
| 221 | + ) |
207 | 222 | else: |
208 | 223 | classifiers = [ |
209 | 224 | self._fit_classifier(self, level, self.separator_) |
|
0 commit comments