From 097a58e975f526f13691fa1ceda6234ebf1b5181 Mon Sep 17 00:00:00 2001 From: Fabio Date: Tue, 23 Apr 2024 17:17:21 +0200 Subject: [PATCH 1/9] Add encoder --- hiclass/LocalClassifierPerLevel.py | 9 ++++++--- hiclass/LocalClassifierPerParentNode.py | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/hiclass/LocalClassifierPerLevel.py b/hiclass/LocalClassifierPerLevel.py index 907e61cf..2420ac36 100644 --- a/hiclass/LocalClassifierPerLevel.py +++ b/hiclass/LocalClassifierPerLevel.py @@ -5,13 +5,13 @@ """ import hashlib +import numpy as np import pickle from copy import deepcopy -from os.path import exists - -import numpy as np from joblib import Parallel, delayed +from os.path import exists from sklearn.base import BaseEstimator +from sklearn.preprocessing import LabelEncoder from sklearn.utils.validation import check_array, check_is_fitted from hiclass.ConstantClassifier import ConstantClassifier @@ -273,6 +273,9 @@ def _fit_classifier(self, level, separator): classifier = ConstantClassifier() if not self.bert: try: + label_encoder = LabelEncoder() + label_encoder.fit(y) + y = label_encoder.transform(y) classifier.fit(X, y, sample_weight) except TypeError: classifier.fit(X, y) diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index 47f77475..77d674a5 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -5,13 +5,13 @@ """ import hashlib +import networkx as nx +import numpy as np import pickle from copy import deepcopy from os.path import exists - -import networkx as nx -import numpy as np from sklearn.base import BaseEstimator +from sklearn.preprocessing import LabelEncoder from sklearn.utils.validation import check_array, check_is_fitted from hiclass.ConstantClassifier import ConstantClassifier @@ -231,6 +231,9 @@ def _fit_classifier(self, node): classifier = ConstantClassifier() if not self.bert: try: + label_encoder = LabelEncoder() + label_encoder.fit(y) + y = label_encoder.transform(y) classifier.fit(X, y, sample_weight) except TypeError: classifier.fit(X, y) From cdb81fecb8d868b3750a47ea585438c81df8848d Mon Sep 17 00:00:00 2001 From: Fabio Date: Tue, 23 Apr 2024 17:26:38 +0200 Subject: [PATCH 2/9] Add encoder --- hiclass/HierarchicalClassifier.py | 9 ++++++--- hiclass/LocalClassifierPerLevel.py | 9 +++------ hiclass/LocalClassifierPerParentNode.py | 9 +++------ 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/hiclass/HierarchicalClassifier.py b/hiclass/HierarchicalClassifier.py index 23e422ab..1bb34355 100644 --- a/hiclass/HierarchicalClassifier.py +++ b/hiclass/HierarchicalClassifier.py @@ -3,13 +3,13 @@ import abc import hashlib import logging -import pickle - import networkx as nx import numpy as np +import pickle from joblib import Parallel, delayed from sklearn.base import BaseEstimator from sklearn.linear_model import LogisticRegression +from sklearn.preprocessing import LabelEncoder from sklearn.utils.validation import _check_sample_weight try: @@ -215,7 +215,10 @@ def _disambiguate(self): child = str(self.y_[i, j]) row.append(parent + self.separator_ + child) new_y.append(np.asarray(row, dtype=np.str_)) - self.y_ = np.array(new_y) + new_y = np.array(new_y) + self.label_encoder_ = LabelEncoder() + self.label_encoder_.fit(new_y) + self.y_ = self.label_encoder_.transform(new_y) def _create_digraph(self): # Create DiGraph diff --git a/hiclass/LocalClassifierPerLevel.py b/hiclass/LocalClassifierPerLevel.py index 2420ac36..907e61cf 100644 --- a/hiclass/LocalClassifierPerLevel.py +++ b/hiclass/LocalClassifierPerLevel.py @@ -5,13 +5,13 @@ """ import hashlib -import numpy as np import pickle from copy import deepcopy -from joblib import Parallel, delayed from os.path import exists + +import numpy as np +from joblib import Parallel, delayed from sklearn.base import BaseEstimator -from sklearn.preprocessing import LabelEncoder from sklearn.utils.validation import check_array, check_is_fitted from hiclass.ConstantClassifier import ConstantClassifier @@ -273,9 +273,6 @@ def _fit_classifier(self, level, separator): classifier = ConstantClassifier() if not self.bert: try: - label_encoder = LabelEncoder() - label_encoder.fit(y) - y = label_encoder.transform(y) classifier.fit(X, y, sample_weight) except TypeError: classifier.fit(X, y) diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index 77d674a5..47f77475 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -5,13 +5,13 @@ """ import hashlib -import networkx as nx -import numpy as np import pickle from copy import deepcopy from os.path import exists + +import networkx as nx +import numpy as np from sklearn.base import BaseEstimator -from sklearn.preprocessing import LabelEncoder from sklearn.utils.validation import check_array, check_is_fitted from hiclass.ConstantClassifier import ConstantClassifier @@ -231,9 +231,6 @@ def _fit_classifier(self, node): classifier = ConstantClassifier() if not self.bert: try: - label_encoder = LabelEncoder() - label_encoder.fit(y) - y = label_encoder.transform(y) classifier.fit(X, y, sample_weight) except TypeError: classifier.fit(X, y) From c2a8bda3d96f3d6ef5ef9ec94d842f313d721d18 Mon Sep 17 00:00:00 2001 From: Fabio Date: Tue, 23 Apr 2024 18:39:53 +0200 Subject: [PATCH 3/9] Fix encoding --- hiclass/HierarchicalClassifier.py | 18 +++++++++++------- hiclass/LocalClassifierPerLevel.py | 3 +++ hiclass/LocalClassifierPerNode.py | 9 ++++++--- hiclass/LocalClassifierPerParentNode.py | 7 +++++-- tests/test_Explainer.py | 8 ++++++++ tests/test_HierarchicalClassifier.py | 6 ++++++ tests/test_LocalClassifierPerLevel.py | 4 +++- tests/test_LocalClassifiers.py | 18 +++--------------- 8 files changed, 45 insertions(+), 28 deletions(-) diff --git a/hiclass/HierarchicalClassifier.py b/hiclass/HierarchicalClassifier.py index 1bb34355..b7c0e738 100644 --- a/hiclass/HierarchicalClassifier.py +++ b/hiclass/HierarchicalClassifier.py @@ -216,9 +216,13 @@ def _disambiguate(self): row.append(parent + self.separator_ + child) new_y.append(np.asarray(row, dtype=np.str_)) new_y = np.array(new_y) - self.label_encoder_ = LabelEncoder() - self.label_encoder_.fit(new_y) - self.y_ = self.label_encoder_.transform(new_y) + flat_y = np.unique(np.append(new_y.flatten(), "hiclass::root")) + if not self.bert: + self.label_encoder_ = LabelEncoder() + self.label_encoder_.fit(flat_y) + self.y_ = np.array( + [self.label_encoder_.transform(row) for row in new_y] + ) def _create_digraph(self): # Create DiGraph @@ -258,8 +262,8 @@ def _create_digraph_2d(self): self.logger_.info(f"Creating digraph from {rows} 2D labels") for row in range(rows): for column in range(columns - 1): - parent = self.y_[row, column].split(self.separator_)[-1] - child = self.y_[row, column + 1].split(self.separator_)[-1] + parent = self.y_[row, column] + child = self.y_[row, column + 1] if parent != "" and child != "": # Only add edge if both parent and child are not empty self.hierarchy_.add_edge( @@ -274,7 +278,7 @@ def _export_digraph(self): # Add quotes to all nodes in case the text has commas mapping = {} for node in self.hierarchy_: - mapping[node] = '"{}"'.format(node.split(self.separator_)[-1]) + mapping[node] = '"{}"'.format(node) hierarchy = nx.relabel_nodes(self.hierarchy_, mapping, copy=True) # Export DAG to CSV file self.logger_.info(f"Writing edge list to file {self.edge_list}") @@ -374,5 +378,5 @@ def _save_tmp(self, name, classifier): with open(filename, "wb") as file: pickle.dump((name, classifier), file) self.logger_.info( - f"Stored trained model for local classifier {str(name).split(self.separator_)[-1]} in file {filename}" + f"Stored trained model for local classifier {str(name)} in file {filename}" ) diff --git a/hiclass/LocalClassifierPerLevel.py b/hiclass/LocalClassifierPerLevel.py index 907e61cf..4ec47719 100644 --- a/hiclass/LocalClassifierPerLevel.py +++ b/hiclass/LocalClassifierPerLevel.py @@ -168,6 +168,9 @@ def predict(self, X): y = self._convert_to_1d(y) + if hasattr(self, "label_encoder_"): + y = np.array([self.label_encoder_.inverse_transform(row) for row in y]) + self._remove_separator(y) return y diff --git a/hiclass/LocalClassifierPerNode.py b/hiclass/LocalClassifierPerNode.py index 1382c72e..42f7a648 100644 --- a/hiclass/LocalClassifierPerNode.py +++ b/hiclass/LocalClassifierPerNode.py @@ -182,7 +182,7 @@ def predict(self, X): if subset_x.shape[0] > 0: probabilities = np.zeros((subset_x.shape[0], len(successors))) for i, successor in enumerate(successors): - successor_name = str(successor).split(self.separator_)[-1] + successor_name = str(successor) self.logger_.info(f"Predicting for node '{successor_name}'") classifier = self.hierarchy_.nodes[successor]["classifier"] positive_index = np.where(classifier.classes_ == 1)[0] @@ -201,6 +201,9 @@ def predict(self, X): y = self._convert_to_1d(y) + if hasattr(self, "label_encoder_"): + y = np.array([self.label_encoder_.inverse_transform(row) for row in y]) + self._remove_separator(y) return y @@ -246,12 +249,12 @@ def _fit_digraph(self, local_mode: bool = False, use_joblib: bool = False): def _fit_classifier(self, node): classifier = self.hierarchy_.nodes[node]["classifier"] if self.tmp_dir: - md5 = hashlib.md5(node.encode("utf-8")).hexdigest() + md5 = hashlib.md5(str(node).encode("utf-8")).hexdigest() filename = f"{self.tmp_dir}/{md5}.sav" if exists(filename): (_, classifier) = pickle.load(open(filename, "rb")) self.logger_.info( - f"Loaded trained model for local classifier {node.split(self.separator_)[-1]} from file {filename}" + f"Loaded trained model for local classifier {node} from file {filename}" ) return classifier self.logger_.info(f"Training local classifier {node}") diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index 47f77475..dcbc5ad3 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -161,6 +161,9 @@ def predict(self, X): y = self._convert_to_1d(y) + if hasattr(self, "label_encoder_"): + y = np.array([self.label_encoder_.inverse_transform(row) for row in y]) + self._remove_separator(y) return y @@ -215,12 +218,12 @@ def _get_successors(self, node): def _fit_classifier(self, node): classifier = self.hierarchy_.nodes[node]["classifier"] if self.tmp_dir: - md5 = hashlib.md5(node.encode("utf-8")).hexdigest() + md5 = hashlib.md5(str(node).encode("utf-8")).hexdigest() filename = f"{self.tmp_dir}/{md5}.sav" if exists(filename): (_, classifier) = pickle.load(open(filename, "rb")) self.logger_.info( - f"Loaded trained model for local classifier {node.split(self.separator_)[-1]} from file {filename}" + f"Loaded trained model for local classifier {node} from file {filename}" ) return classifier self.logger_.info(f"Training local classifier {node}") diff --git a/tests/test_Explainer.py b/tests/test_Explainer.py index 303216f6..a466e69b 100644 --- a/tests/test_Explainer.py +++ b/tests/test_Explainer.py @@ -52,6 +52,7 @@ def explainer_data_no_root(): @pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) def test_explainer_tree_lcppn(data, request): rfc = RandomForestClassifier() @@ -104,6 +105,7 @@ def test_explainer_tree_lcpn(data, request): @pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) def test_explainer_tree_lcpl(data, request): rfc = RandomForestClassifier() @@ -124,6 +126,7 @@ def test_explainer_tree_lcpl(data, request): @pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) def test_traversal_path_lcppn(data, request): x_train, x_test, y_train = request.getfixturevalue(data) @@ -146,6 +149,7 @@ def test_traversal_path_lcppn(data, request): @pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) def test_traversal_path_lcpn(data, request): x_train, x_test, y_train = request.getfixturevalue(data) @@ -168,6 +172,7 @@ def test_traversal_path_lcpn(data, request): @pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize("data", ["explainer_data", "explainer_data_no_root"]) def test_traversal_path_lcpl(data, request): x_train, x_test, y_train = request.getfixturevalue(data) @@ -205,6 +210,8 @@ def test_explain_with_xr(data, request, classifier): assert isinstance(explanations, xarray.Dataset) +@pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize( "classifier", [LocalClassifierPerParentNode, LocalClassifierPerLevel, LocalClassifierPerNode], @@ -222,6 +229,7 @@ def test_imports(classifier): @pytest.mark.skipif(not shap_installed, reason="shap not installed") +@pytest.mark.skipif(not xarray_installed, reason="xarray not installed") @pytest.mark.parametrize( "classifier", [LocalClassifierPerLevel, LocalClassifierPerParentNode, LocalClassifierPerNode], diff --git a/tests/test_HierarchicalClassifier.py b/tests/test_HierarchicalClassifier.py index 3333cf52..cb608f05 100644 --- a/tests/test_HierarchicalClassifier.py +++ b/tests/test_HierarchicalClassifier.py @@ -22,6 +22,9 @@ def test_disambiguate_str(ambiguous_node_str): [["a", "a::HiClass::Separator::b"], ["b", "b::HiClass::Separator::c"]] ) ambiguous_node_str._disambiguate() + ground_truth = np.array( + [ambiguous_node_str.label_encoder_.transform(row) for row in ground_truth] + ) assert_array_equal(ground_truth, ambiguous_node_str.y_) @@ -37,6 +40,9 @@ def test_disambiguate_int(ambiguous_node_int): [["1", "1::HiClass::Separator::2"], ["2", "2::HiClass::Separator::3"]] ) ambiguous_node_int._disambiguate() + ground_truth = np.array( + [ambiguous_node_int.label_encoder_.transform(row) for row in ground_truth] + ) assert_array_equal(ground_truth, ambiguous_node_int.y_) diff --git a/tests/test_LocalClassifierPerLevel.py b/tests/test_LocalClassifierPerLevel.py index 27312f85..ee39e90b 100644 --- a/tests/test_LocalClassifierPerLevel.py +++ b/tests/test_LocalClassifierPerLevel.py @@ -128,7 +128,9 @@ def test_fit_predict(): for level, classifier in enumerate(lcpl.local_classifiers_): try: check_is_fitted(classifier) - assert_array_equal(ground_truth[level], classifier.classes_) + assert_array_equal( + lcpl.label_encoder_.transform(ground_truth[level]), classifier.classes_ + ) except NotFittedError as e: pytest.fail(repr(e)) predictions = lcpl.predict(x) diff --git a/tests/test_LocalClassifiers.py b/tests/test_LocalClassifiers.py index abd7bddf..b47e88b4 100644 --- a/tests/test_LocalClassifiers.py +++ b/tests/test_LocalClassifiers.py @@ -63,15 +63,7 @@ def test_empty_levels(empty_levels, classifier): ["2", "2.1", ""], ["3", "3.1", "3.1.2"], ] - assert list(clf.hierarchy_.nodes) == [ - "1", - "2", - "2" + clf.separator_ + "2.1", - "3", - "3" + clf.separator_ + "3.1", - "3" + clf.separator_ + "3.1" + clf.separator_ + "3.1.2", - clf.root_, - ] + assert list(clf.hierarchy_.nodes) == [0, 1, 2, 3, 4, 5, 6, 7, 8, "hiclass::root"] assert_array_equal(ground_truth, predictions) @@ -132,12 +124,8 @@ def test_tmp_dir(classifier): x = np.array([[1, 2], [3, 4]]) y = np.array([["a", "b"], ["c", "d"]]) clf.fit(x, y) - if isinstance(clf, LocalClassifierPerLevel): - filename = "cfcd208495d565ef66e7dff9f98764da.sav" - expected_name = 0 - else: - filename = "0cc175b9c0f1b6a831c399e269772661.sav" - expected_name = "a" + filename = "cfcd208495d565ef66e7dff9f98764da.sav" + expected_name = 0 assert patcher.fs.exists(filename) (name, classifier) = pickle.load(open(filename, "rb")) assert expected_name == name From f2227f07dd50b0b04c1f6d70a207896563ae9fa6 Mon Sep 17 00:00:00 2001 From: Fabio Date: Tue, 23 Apr 2024 19:13:02 +0200 Subject: [PATCH 4/9] Add print --- hiclass/LocalClassifierPerParentNode.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index dcbc5ad3..f6eaf21b 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -234,6 +234,8 @@ def _fit_classifier(self, node): classifier = ConstantClassifier() if not self.bert: try: + print(X) + print(y) classifier.fit(X, y, sample_weight) except TypeError: classifier.fit(X, y) From b6b9f8eeb11a508dcb972b02782456fff16a7ef2 Mon Sep 17 00:00:00 2001 From: Fabio Date: Tue, 23 Apr 2024 19:27:31 +0200 Subject: [PATCH 5/9] Add print --- hiclass/LocalClassifierPerLevel.py | 2 ++ hiclass/LocalClassifierPerParentNode.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/hiclass/LocalClassifierPerLevel.py b/hiclass/LocalClassifierPerLevel.py index 4ec47719..ede7606a 100644 --- a/hiclass/LocalClassifierPerLevel.py +++ b/hiclass/LocalClassifierPerLevel.py @@ -275,6 +275,8 @@ def _fit_classifier(self, level, separator): if len(unique_y) == 1 and self.replace_classifiers: classifier = ConstantClassifier() if not self.bert: + self.logger_.info(X) + self.logger_.info(y) try: classifier.fit(X, y, sample_weight) except TypeError: diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index f6eaf21b..43a64178 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -233,9 +233,9 @@ def _fit_classifier(self, node): if len(unique_y) == 1 and self.replace_classifiers: classifier = ConstantClassifier() if not self.bert: + self.logger_.info(X) + self.logger_.info(y) try: - print(X) - print(y) classifier.fit(X, y, sample_weight) except TypeError: classifier.fit(X, y) From 90ffd7f370a42c5a0389675b44e00004f099405d Mon Sep 17 00:00:00 2001 From: Fabio Date: Tue, 23 Apr 2024 20:37:08 +0200 Subject: [PATCH 6/9] Add multiclass --- hiclass/LocalClassifierPerParentNode.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index 43a64178..1bc4b5ab 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -5,12 +5,12 @@ """ import hashlib +import networkx as nx +import numpy as np import pickle from copy import deepcopy +from cuml.multiclass import MulticlassClassifier from os.path import exists - -import networkx as nx -import numpy as np from sklearn.base import BaseEstimator from sklearn.utils.validation import check_array, check_is_fitted @@ -186,7 +186,11 @@ def _initialize_local_classifiers(self): local_classifiers = {} nodes = self._get_parents() for node in nodes: - local_classifiers[node] = {"classifier": deepcopy(self.local_classifier_)} + local_classifiers[node] = { + "classifier": MulticlassClassifier( + deepcopy(self.local_classifier_), strategy="ovr" + ) + } nx.set_node_attributes(self.hierarchy_, local_classifiers) def _get_parents(self): From c5eadd3e250c6e1d8491c442c6a3dac5c1f3d9b4 Mon Sep 17 00:00:00 2001 From: Fabio Date: Tue, 23 Apr 2024 20:50:03 +0200 Subject: [PATCH 7/9] Remove weight --- hiclass/LocalClassifierPerParentNode.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index 1bc4b5ab..e5417c4a 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -239,10 +239,7 @@ def _fit_classifier(self, node): if not self.bert: self.logger_.info(X) self.logger_.info(y) - try: - classifier.fit(X, y, sample_weight) - except TypeError: - classifier.fit(X, y) + classifier.fit(X, y) else: classifier.fit(X, y) self._save_tmp(node, classifier) From 0cbe956008a358934797d0ac0414a1603849c428 Mon Sep 17 00:00:00 2001 From: Fabio Date: Tue, 23 Apr 2024 23:59:33 +0200 Subject: [PATCH 8/9] enforce gpu use --- hiclass/LocalClassifierPerLevel.py | 3 ++- hiclass/LocalClassifierPerParentNode.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/hiclass/LocalClassifierPerLevel.py b/hiclass/LocalClassifierPerLevel.py index ede7606a..9cebadaa 100644 --- a/hiclass/LocalClassifierPerLevel.py +++ b/hiclass/LocalClassifierPerLevel.py @@ -221,7 +221,8 @@ def _get_successors(self, level): def _initialize_local_classifiers(self): super()._initialize_local_classifiers() self.local_classifiers_ = [ - deepcopy(self.local_classifier_) for _ in range(self.y_.shape[1]) + MulticlassClassifier(deepcopy(self.local_classifier_), strategy="ovr") + for _ in range(self.y_.shape[1]) ] self.masks_ = [None for _ in range(self.y_.shape[1])] diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index e5417c4a..266e76d9 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -9,6 +9,7 @@ import numpy as np import pickle from copy import deepcopy +from cuml.common.device_selection import using_device_type from cuml.multiclass import MulticlassClassifier from os.path import exists from sklearn.base import BaseEstimator @@ -239,7 +240,8 @@ def _fit_classifier(self, node): if not self.bert: self.logger_.info(X) self.logger_.info(y) - classifier.fit(X, y) + with using_device_type("gpu"): + classifier.fit(X, y) else: classifier.fit(X, y) self._save_tmp(node, classifier) From 62d605032b77447c9dfce3ca933cc3ded516dee8 Mon Sep 17 00:00:00 2001 From: Fabio Date: Wed, 24 Apr 2024 00:42:14 +0200 Subject: [PATCH 9/9] Remove prints --- hiclass/LocalClassifierPerParentNode.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index 266e76d9..089ce0dc 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -187,11 +187,7 @@ def _initialize_local_classifiers(self): local_classifiers = {} nodes = self._get_parents() for node in nodes: - local_classifiers[node] = { - "classifier": MulticlassClassifier( - deepcopy(self.local_classifier_), strategy="ovr" - ) - } + local_classifiers[node] = {"classifier": deepcopy(self.local_classifier_)} nx.set_node_attributes(self.hierarchy_, local_classifiers) def _get_parents(self): @@ -238,8 +234,6 @@ def _fit_classifier(self, node): if len(unique_y) == 1 and self.replace_classifiers: classifier = ConstantClassifier() if not self.bert: - self.logger_.info(X) - self.logger_.info(y) with using_device_type("gpu"): classifier.fit(X, y) else: