Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ontolearner/base/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from abc import ABC
from typing import Any, List, Optional
from typing import Any, List, Optional, Dict
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -147,7 +147,7 @@ def _taxonomy_discovery(self, data: Any, test: bool = False) -> Optional[Any]:
def _non_taxonomic_re(self, data: Any, test: bool = False) -> Optional[Any]:
pass

def tasks_data_former(self, data: Any, task: str, test: bool = False) -> Any:
def tasks_data_former(self, data: Any, task: str, test: bool = False) -> List[str | Dict[str, str]]:
formatted_data = []
if task == "term-typing":
for typing in data.term_typings:
Expand All @@ -173,7 +173,7 @@ def tasks_data_former(self, data: Any, task: str, test: bool = False) -> Any:
formatted_data = {"types": non_taxonomic_types, "relations": non_taxonomic_res}
return formatted_data

def tasks_ground_truth_former(self, data: Any, task: str) -> Any:
def tasks_ground_truth_former(self, data: Any, task: str) -> List[Dict[str, str]]:
formatted_data = []
if task == "term-typing":
for typing in data.term_typings:
Expand Down
39 changes: 26 additions & 13 deletions ontolearner/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict
from typing import List, Dict, Tuple, Set

SYMMETRIC_RELATIONS = {"equivalentclass", "sameas", "disjointwith"}

def text2onto_metrics(y_true, y_pred, similarity_threshold: float = 0.8) -> Dict:
def jaccard_similarity(a, b):
def text2onto_metrics(y_true: List[str], y_pred: List[str], similarity_threshold: float = 0.8) -> Dict[str, float | int]:
def jaccard_similarity(a: str, b: str) -> float:
set_a = set(a.lower().split())
set_b = set(b.lower().split())
if not set_a and not set_b:
Expand Down Expand Up @@ -46,10 +45,13 @@ def jaccard_similarity(a, b):
return {
"f1_score": f1_score,
"precision": precision,
"recall": recall
"recall": recall,
"total_correct": total_correct,
"total_predicted": total_predicted,
"total_ground_truth": total_ground_truth
}

def term_typing_metrics(y_true, y_pred) -> Dict:
def term_typing_metrics(y_true: List[Dict[str, List[str]]], y_pred: List[Dict[str, List[str]]]) -> Dict[str, float | int]:
"""
Compute precision, recall, and F1-score for term typing
using (term, type) pair-level matching instead of ID-based lookups.
Expand Down Expand Up @@ -77,13 +79,17 @@ def term_typing_metrics(y_true, y_pred) -> Dict:
precision = total_correct / total_predicted if total_predicted > 0 else 0.0
recall = total_correct / total_ground_truth if total_ground_truth > 0 else 0.0
f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

return {
"f1_score": f1_score,
"precision": precision,
"recall": recall
"recall": recall,
"total_correct": total_correct,
"total_predicted": total_predicted,
"total_ground_truth": total_ground_truth
}

def taxonomy_discovery_metrics(y_true, y_pred) -> Dict:
def taxonomy_discovery_metrics(y_true: List[Dict[str, str]], y_pred: List[Dict[str, str]]) -> Dict[str, float | int]:
total_predicted = len(y_pred)
total_ground_truth = len(y_true)
# Convert ground truth and predictions to sets of tuples for easy comparison
Expand All @@ -102,18 +108,22 @@ def taxonomy_discovery_metrics(y_true, y_pred) -> Dict:
return {
"f1_score": f1_score,
"precision": precision,
"recall": recall
"recall": recall,
"total_correct": total_correct,
"total_predicted": total_predicted,
"total_ground_truth": total_ground_truth
}

def non_taxonomic_re_metrics(y_true, y_pred) -> Dict:
def normalize_triple(item):

def non_taxonomic_re_metrics(y_true: List[Dict[str, str]], y_pred: List[Dict[str, str]]) -> Dict[str, float | int]:
def normalize_triple(item: Dict[str, str]) -> Tuple[str, str, str]:
return (
item["head"].strip().lower(),
item["relation"].strip().lower(),
item["tail"].strip().lower()
)

def expand_symmetric(triples):
def expand_symmetric(triples: Set[Tuple[str, str, str]]) -> Set[Tuple[str, str, str]]:
expanded = set()
for h, r, t in triples:
expanded.add((h, r, t))
Expand All @@ -136,5 +146,8 @@ def expand_symmetric(triples):
return {
"f1_score": f1_score,
"precision": precision,
"recall": recall
"recall": recall,
"total_correct": total_correct,
"total_predicted": total_predicted,
"total_ground_truth": total_ground_truth
}
5 changes: 2 additions & 3 deletions ontolearner/learner/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(self, base_retriever: Any = AutoRetriever(), top_k: int = 5, batch_
self.retriever = base_retriever
self.top_k = top_k
self._is_term_typing_fit = False
self._is_taxonomy_discovery_fit = False
self._batch_size = batch_size

def load(self, model_id: str = "sentence-transformers/all-MiniLM-L6-v2"):
Expand Down Expand Up @@ -64,9 +63,9 @@ def _taxonomy_discovery(self, data: Any, test: bool = False) -> Optional[Any]:
if test:
self._retriever_fit(data=data)
candidates_lst = self._retriever_predict(data=data, top_k=self.top_k + 1)
taxonomic_pairs = [{"parent": query, "child": candidate}
taxonomic_pairs = [{"parent": candidate, "child": query}
for query, candidates in zip(data, candidates_lst)
for candidate in candidates if candidate != query]
for candidate in candidates if candidate.lower() != query.lower()]
return taxonomic_pairs
else:
warnings.warn("No requirement for fiting the taxonomy discovery model, the predict module will use the input data to do the fit as well.")
Expand Down