diff --git a/ontolearner/base/learner.py b/ontolearner/base/learner.py index 93d6757..e0792f6 100644 --- a/ontolearner/base/learner.py +++ b/ontolearner/base/learner.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC -from typing import Any, List, Optional +from typing import Any, List, Optional, Dict from transformers import AutoModelForCausalLM, AutoTokenizer import torch import torch.nn.functional as F @@ -147,7 +147,7 @@ def _taxonomy_discovery(self, data: Any, test: bool = False) -> Optional[Any]: def _non_taxonomic_re(self, data: Any, test: bool = False) -> Optional[Any]: pass - def tasks_data_former(self, data: Any, task: str, test: bool = False) -> Any: + def tasks_data_former(self, data: Any, task: str, test: bool = False) -> List[str | Dict[str, str]]: formatted_data = [] if task == "term-typing": for typing in data.term_typings: @@ -173,7 +173,7 @@ def tasks_data_former(self, data: Any, task: str, test: bool = False) -> Any: formatted_data = {"types": non_taxonomic_types, "relations": non_taxonomic_res} return formatted_data - def tasks_ground_truth_former(self, data: Any, task: str) -> Any: + def tasks_ground_truth_former(self, data: Any, task: str) -> List[Dict[str, str]]: formatted_data = [] if task == "term-typing": for typing in data.term_typings: diff --git a/ontolearner/evaluation/metrics.py b/ontolearner/evaluation/metrics.py index cf54343..57b2d66 100644 --- a/ontolearner/evaluation/metrics.py +++ b/ontolearner/evaluation/metrics.py @@ -11,13 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Dict +from typing import List, Dict, Tuple, Set SYMMETRIC_RELATIONS = {"equivalentclass", "sameas", "disjointwith"} -def text2onto_metrics(y_true, y_pred, similarity_threshold: float = 0.8) -> Dict: - def jaccard_similarity(a, b): +def text2onto_metrics(y_true: List[str], y_pred: List[str], similarity_threshold: float = 0.8) -> Dict[str, float | int]: + def jaccard_similarity(a: str, b: str) -> float: set_a = set(a.lower().split()) set_b = set(b.lower().split()) if not set_a and not set_b: @@ -46,10 +45,13 @@ def jaccard_similarity(a, b): return { "f1_score": f1_score, "precision": precision, - "recall": recall + "recall": recall, + "total_correct": total_correct, + "total_predicted": total_predicted, + "total_ground_truth": total_ground_truth } -def term_typing_metrics(y_true, y_pred) -> Dict: +def term_typing_metrics(y_true: List[Dict[str, List[str]]], y_pred: List[Dict[str, List[str]]]) -> Dict[str, float | int]: """ Compute precision, recall, and F1-score for term typing using (term, type) pair-level matching instead of ID-based lookups. @@ -77,13 +79,17 @@ def term_typing_metrics(y_true, y_pred) -> Dict: precision = total_correct / total_predicted if total_predicted > 0 else 0.0 recall = total_correct / total_ground_truth if total_ground_truth > 0 else 0.0 f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 + return { "f1_score": f1_score, "precision": precision, - "recall": recall + "recall": recall, + "total_correct": total_correct, + "total_predicted": total_predicted, + "total_ground_truth": total_ground_truth } -def taxonomy_discovery_metrics(y_true, y_pred) -> Dict: +def taxonomy_discovery_metrics(y_true: List[Dict[str, str]], y_pred: List[Dict[str, str]]) -> Dict[str, float | int]: total_predicted = len(y_pred) total_ground_truth = len(y_true) # Convert ground truth and predictions to sets of tuples for easy comparison @@ -102,18 +108,22 @@ def taxonomy_discovery_metrics(y_true, y_pred) -> Dict: return { "f1_score": f1_score, "precision": precision, - "recall": recall + "recall": recall, + "total_correct": total_correct, + "total_predicted": total_predicted, + "total_ground_truth": total_ground_truth } -def non_taxonomic_re_metrics(y_true, y_pred) -> Dict: - def normalize_triple(item): + +def non_taxonomic_re_metrics(y_true: List[Dict[str, str]], y_pred: List[Dict[str, str]]) -> Dict[str, float | int]: + def normalize_triple(item: Dict[str, str]) -> Tuple[str, str, str]: return ( item["head"].strip().lower(), item["relation"].strip().lower(), item["tail"].strip().lower() ) - def expand_symmetric(triples): + def expand_symmetric(triples: Set[Tuple[str, str, str]]) -> Set[Tuple[str, str, str]]: expanded = set() for h, r, t in triples: expanded.add((h, r, t)) @@ -136,5 +146,8 @@ def expand_symmetric(triples): return { "f1_score": f1_score, "precision": precision, - "recall": recall + "recall": recall, + "total_correct": total_correct, + "total_predicted": total_predicted, + "total_ground_truth": total_ground_truth } diff --git a/ontolearner/learner/retriever.py b/ontolearner/learner/retriever.py index ae78c24..59bc2ca 100644 --- a/ontolearner/learner/retriever.py +++ b/ontolearner/learner/retriever.py @@ -22,7 +22,6 @@ def __init__(self, base_retriever: Any = AutoRetriever(), top_k: int = 5, batch_ self.retriever = base_retriever self.top_k = top_k self._is_term_typing_fit = False - self._is_taxonomy_discovery_fit = False self._batch_size = batch_size def load(self, model_id: str = "sentence-transformers/all-MiniLM-L6-v2"): @@ -64,9 +63,9 @@ def _taxonomy_discovery(self, data: Any, test: bool = False) -> Optional[Any]: if test: self._retriever_fit(data=data) candidates_lst = self._retriever_predict(data=data, top_k=self.top_k + 1) - taxonomic_pairs = [{"parent": query, "child": candidate} + taxonomic_pairs = [{"parent": candidate, "child": query} for query, candidates in zip(data, candidates_lst) - for candidate in candidates if candidate != query] + for candidate in candidates if candidate.lower() != query.lower()] return taxonomic_pairs else: warnings.warn("No requirement for fiting the taxonomy discovery model, the predict module will use the input data to do the fit as well.")