Skip to content

Commit 65b439d

Browse files
authored
MR #277: metric update, type annotation, minor fix to taxonomy
2 parents 3ca7b46 + eb81b9d commit 65b439d

File tree

3 files changed

+31
-19
lines changed

3 files changed

+31
-19
lines changed

ontolearner/base/learner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from abc import ABC
16-
from typing import Any, List, Optional
16+
from typing import Any, List, Optional, Dict
1717
from transformers import AutoModelForCausalLM, AutoTokenizer
1818
import torch
1919
import torch.nn.functional as F
@@ -147,7 +147,7 @@ def _taxonomy_discovery(self, data: Any, test: bool = False) -> Optional[Any]:
147147
def _non_taxonomic_re(self, data: Any, test: bool = False) -> Optional[Any]:
148148
pass
149149

150-
def tasks_data_former(self, data: Any, task: str, test: bool = False) -> Any:
150+
def tasks_data_former(self, data: Any, task: str, test: bool = False) -> List[str | Dict[str, str]]:
151151
formatted_data = []
152152
if task == "term-typing":
153153
for typing in data.term_typings:
@@ -173,7 +173,7 @@ def tasks_data_former(self, data: Any, task: str, test: bool = False) -> Any:
173173
formatted_data = {"types": non_taxonomic_types, "relations": non_taxonomic_res}
174174
return formatted_data
175175

176-
def tasks_ground_truth_former(self, data: Any, task: str) -> Any:
176+
def tasks_ground_truth_former(self, data: Any, task: str) -> List[Dict[str, str]]:
177177
formatted_data = []
178178
if task == "term-typing":
179179
for typing in data.term_typings:

ontolearner/evaluation/metrics.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
from typing import Dict
14+
from typing import List, Dict, Tuple, Set
1615

1716
SYMMETRIC_RELATIONS = {"equivalentclass", "sameas", "disjointwith"}
1817

19-
def text2onto_metrics(y_true, y_pred, similarity_threshold: float = 0.8) -> Dict:
20-
def jaccard_similarity(a, b):
18+
def text2onto_metrics(y_true: List[str], y_pred: List[str], similarity_threshold: float = 0.8) -> Dict[str, float | int]:
19+
def jaccard_similarity(a: str, b: str) -> float:
2120
set_a = set(a.lower().split())
2221
set_b = set(b.lower().split())
2322
if not set_a and not set_b:
@@ -46,10 +45,13 @@ def jaccard_similarity(a, b):
4645
return {
4746
"f1_score": f1_score,
4847
"precision": precision,
49-
"recall": recall
48+
"recall": recall,
49+
"total_correct": total_correct,
50+
"total_predicted": total_predicted,
51+
"total_ground_truth": total_ground_truth
5052
}
5153

52-
def term_typing_metrics(y_true, y_pred) -> Dict:
54+
def term_typing_metrics(y_true: List[Dict[str, List[str]]], y_pred: List[Dict[str, List[str]]]) -> Dict[str, float | int]:
5355
"""
5456
Compute precision, recall, and F1-score for term typing
5557
using (term, type) pair-level matching instead of ID-based lookups.
@@ -77,13 +79,17 @@ def term_typing_metrics(y_true, y_pred) -> Dict:
7779
precision = total_correct / total_predicted if total_predicted > 0 else 0.0
7880
recall = total_correct / total_ground_truth if total_ground_truth > 0 else 0.0
7981
f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
82+
8083
return {
8184
"f1_score": f1_score,
8285
"precision": precision,
83-
"recall": recall
86+
"recall": recall,
87+
"total_correct": total_correct,
88+
"total_predicted": total_predicted,
89+
"total_ground_truth": total_ground_truth
8490
}
8591

86-
def taxonomy_discovery_metrics(y_true, y_pred) -> Dict:
92+
def taxonomy_discovery_metrics(y_true: List[Dict[str, str]], y_pred: List[Dict[str, str]]) -> Dict[str, float | int]:
8793
total_predicted = len(y_pred)
8894
total_ground_truth = len(y_true)
8995
# Convert ground truth and predictions to sets of tuples for easy comparison
@@ -102,18 +108,22 @@ def taxonomy_discovery_metrics(y_true, y_pred) -> Dict:
102108
return {
103109
"f1_score": f1_score,
104110
"precision": precision,
105-
"recall": recall
111+
"recall": recall,
112+
"total_correct": total_correct,
113+
"total_predicted": total_predicted,
114+
"total_ground_truth": total_ground_truth
106115
}
107116

108-
def non_taxonomic_re_metrics(y_true, y_pred) -> Dict:
109-
def normalize_triple(item):
117+
118+
def non_taxonomic_re_metrics(y_true: List[Dict[str, str]], y_pred: List[Dict[str, str]]) -> Dict[str, float | int]:
119+
def normalize_triple(item: Dict[str, str]) -> Tuple[str, str, str]:
110120
return (
111121
item["head"].strip().lower(),
112122
item["relation"].strip().lower(),
113123
item["tail"].strip().lower()
114124
)
115125

116-
def expand_symmetric(triples):
126+
def expand_symmetric(triples: Set[Tuple[str, str, str]]) -> Set[Tuple[str, str, str]]:
117127
expanded = set()
118128
for h, r, t in triples:
119129
expanded.add((h, r, t))
@@ -136,5 +146,8 @@ def expand_symmetric(triples):
136146
return {
137147
"f1_score": f1_score,
138148
"precision": precision,
139-
"recall": recall
149+
"recall": recall,
150+
"total_correct": total_correct,
151+
"total_predicted": total_predicted,
152+
"total_ground_truth": total_ground_truth
140153
}

ontolearner/learner/retriever.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ def __init__(self, base_retriever: Any = AutoRetriever(), top_k: int = 5, batch_
2222
self.retriever = base_retriever
2323
self.top_k = top_k
2424
self._is_term_typing_fit = False
25-
self._is_taxonomy_discovery_fit = False
2625
self._batch_size = batch_size
2726

2827
def load(self, model_id: str = "sentence-transformers/all-MiniLM-L6-v2"):
@@ -64,9 +63,9 @@ def _taxonomy_discovery(self, data: Any, test: bool = False) -> Optional[Any]:
6463
if test:
6564
self._retriever_fit(data=data)
6665
candidates_lst = self._retriever_predict(data=data, top_k=self.top_k + 1)
67-
taxonomic_pairs = [{"parent": query, "child": candidate}
66+
taxonomic_pairs = [{"parent": candidate, "child": query}
6867
for query, candidates in zip(data, candidates_lst)
69-
for candidate in candidates if candidate != query]
68+
for candidate in candidates if candidate.lower() != query.lower()]
7069
return taxonomic_pairs
7170
else:
7271
warnings.warn("No requirement for fiting the taxonomy discovery model, the predict module will use the input data to do the fit as well.")

0 commit comments

Comments
 (0)