diff --git a/hiclass/LocalClassifierPerLevel.py b/hiclass/LocalClassifierPerLevel.py index 907e61cf..896828cd 100644 --- a/hiclass/LocalClassifierPerLevel.py +++ b/hiclass/LocalClassifierPerLevel.py @@ -258,11 +258,14 @@ def _fit_classifier(self, level, separator): md5 = hashlib.md5(str(level).encode("utf-8")).hexdigest() filename = f"{self.tmp_dir}/{md5}.sav" if exists(filename): - (_, classifier) = pickle.load(open(filename, "rb")) - self.logger_.info( - f"Loaded trained model for local classifier {level} from file {filename}" - ) - return classifier + try: + (_, classifier) = pickle.load(open(filename, "rb")) + self.logger_.info( + f"Loaded trained model for local classifier {level} from file {filename}" + ) + return classifier + except (pickle.UnpicklingError, EOFError): + self.logger_.error(f"Could not load model from file {filename}") self.logger_.info(f"Training local classifier {level}") X, y, sample_weight = self._remove_empty_leaves( separator, self.X_, self.y_[:, level], self.sample_weight_ diff --git a/hiclass/LocalClassifierPerNode.py b/hiclass/LocalClassifierPerNode.py index 1382c72e..65c14113 100644 --- a/hiclass/LocalClassifierPerNode.py +++ b/hiclass/LocalClassifierPerNode.py @@ -249,12 +249,17 @@ def _fit_classifier(self, node): md5 = hashlib.md5(node.encode("utf-8")).hexdigest() filename = f"{self.tmp_dir}/{md5}.sav" if exists(filename): - (_, classifier) = pickle.load(open(filename, "rb")) - self.logger_.info( - f"Loaded trained model for local classifier {node.split(self.separator_)[-1]} from file {filename}" - ) - return classifier - self.logger_.info(f"Training local classifier {node}") + try: + (_, classifier) = pickle.load(open(filename, "rb")) + self.logger_.info( + f"Loaded trained model for local classifier {node.split(self.separator_)[-1]} from file {filename}" + ) + return classifier + except (pickle.UnpicklingError, EOFError): + self.logger_.error(f"Could not load model from file {filename}") + self.logger_.info( + f"Training local classifier {str(node).split(self.separator_)[-1]}" + ) X, y, sample_weight = self.binary_policy_.get_binary_examples(node) unique_y = np.unique(y) if len(unique_y) == 1 and self.replace_classifiers: diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index 47f77475..5873b52a 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -218,12 +218,17 @@ def _fit_classifier(self, node): md5 = hashlib.md5(node.encode("utf-8")).hexdigest() filename = f"{self.tmp_dir}/{md5}.sav" if exists(filename): - (_, classifier) = pickle.load(open(filename, "rb")) - self.logger_.info( - f"Loaded trained model for local classifier {node.split(self.separator_)[-1]} from file {filename}" - ) - return classifier - self.logger_.info(f"Training local classifier {node}") + try: + (_, classifier) = pickle.load(open(filename, "rb")) + self.logger_.info( + f"Loaded trained model for local classifier {node.split(self.separator_)[-1]} from file {filename}" + ) + return classifier + except (pickle.UnpicklingError, EOFError): + self.logger_.error(f"Could not load model from file {filename}") + self.logger_.info( + f"Training local classifier {str(node).split(self.separator_)[-1]}" + ) # get children examples X, y, sample_weight = self._get_successors(node) unique_y = np.unique(y)