Skip to content

Commit 675c5f1

Browse files
committed
Merge branch 'main' into dev
# Conflicts: # ontolearner/base/learner.py
2 parents cbbe462 + 67612d0 commit 675c5f1

File tree

3 files changed

+40
-16
lines changed

3 files changed

+40
-16
lines changed

ontolearner/base/learner.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -236,15 +236,21 @@ def load(self, model_id: str) -> None:
236236
self.tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side='left', token=self.token)
237237
self.tokenizer.pad_token = self.tokenizer.eos_token
238238
if self.device == "cpu":
239-
device_map = "cpu"
239+
# device_map = "cpu"
240+
self.model = AutoModelForCausalLM.from_pretrained(
241+
model_id,
242+
# device_map=device_map,
243+
torch_dtype=torch.bfloat16,
244+
token=self.token
245+
)
240246
else:
241247
device_map = "balanced"
242-
self.model = AutoModelForCausalLM.from_pretrained(
243-
model_id,
244-
device_map=device_map,
245-
torch_dtype=torch.bfloat16,
246-
token=self.token
247-
)
248+
self.model = AutoModelForCausalLM.from_pretrained(
249+
model_id,
250+
device_map=device_map,
251+
torch_dtype=torch.bfloat16,
252+
token=self.token
253+
)
248254
self.label_mapper.fit()
249255

250256
def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
@@ -290,7 +296,8 @@ def generate(self, inputs: List[str], max_new_tokens: int = 50) -> List[str]:
290296

291297
# Decode only the generated part
292298
decoded_outputs = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in generated_tokens]
293-
299+
print(decoded_outputs)
300+
print(self.label_mapper.predict(decoded_outputs))
294301
# Map the decoded text to labels
295302
return self.label_mapper.predict(decoded_outputs)
296303

@@ -301,9 +308,6 @@ class AutoRetriever(ABC):
301308
This class defines the interface for retrieval components used in ontology learning.
302309
Retrievers are responsible for finding semantically similar examples from training
303310
data to provide context for language models or to make direct predictions.
304-
305-
Attributes:
306-
model: The loaded retrieval/embedding model instance.
307311
"""
308312

309313
def __init__(self) -> None:
@@ -313,7 +317,6 @@ def __init__(self) -> None:
313317
Sets up the basic structure with a model attribute that will be
314318
populated when load() is called.
315319
"""
316-
self.model: Optional[Any] = None
317320
self.embedding_model = None
318321
self.documents = []
319322
self.embeddings = None

ontolearner/learner/label_mapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,6 @@ def predict(self, X: List[str]) -> List[str]:
8585
Returns:
8686
List[str]: Predicted labels.
8787
"""
88-
predictions = list(self.model.predict(X))
88+
predictions = self.model.predict(X).tolist()
8989
self.validate_predicts(predictions)
9090
return predictions

ontolearner/learner/retriever.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,16 @@ def _taxonomy_discovery(self, data: Any, test: bool = False) -> Optional[Any]:
6666
taxonomic_pairs = [{"parent": candidate, "child": query}
6767
for query, candidates in zip(data, candidates_lst)
6868
for candidate in candidates if candidate.lower() != query.lower()]
69-
return taxonomic_pairs
69+
taxonomic_pairs += [{"parent": query, "child": candidate}
70+
for query, candidates in zip(data, candidates_lst)
71+
for candidate in candidates if candidate.lower() != query.lower()]
72+
unique_taxonomic_pairs, seen = [], set()
73+
for pair in taxonomic_pairs:
74+
key = (pair["parent"].lower(), pair["child"].lower()) # Directional key (parent, child)
75+
if key not in seen:
76+
seen.add(key)
77+
unique_taxonomic_pairs.append(pair)
78+
return unique_taxonomic_pairs
7079
else:
7180
warnings.warn("No requirement for fiting the taxonomy discovery model, the predict module will use the input data to do the fit as well.")
7281

@@ -86,11 +95,23 @@ def _non_taxonomic_re(self, data: Any, test: bool = False) -> Optional[Any]:
8695
candidates_lst = self._retriever_predict(data=data['types'], top_k=self.top_k + 1)
8796
taxonomic_pairs = []
8897
taxonomic_pairs_query = []
98+
seen = set()
8999
for query, candidates in zip(data['types'], candidates_lst):
90100
for candidate in candidates:
91101
if candidate != query:
92-
taxonomic_pairs.append((query, candidate))
93-
taxonomic_pairs_query.append(f"Head: {query} \n Tail: {candidate}")
102+
# Directional pair 1: query -> candidate
103+
key1 = (query.lower(), candidate.lower())
104+
if key1 not in seen:
105+
seen.add(key1)
106+
taxonomic_pairs.append((query, candidate))
107+
taxonomic_pairs_query.append(f"Head: {query}\nTail: {candidate}")
108+
# Directional pair 2: candidate -> query
109+
key2 = (candidate.lower(), query.lower())
110+
if key2 not in seen:
111+
seen.add(key2)
112+
taxonomic_pairs.append((candidate, query))
113+
taxonomic_pairs_query.append(f"Head: {candidate}\nTail: {query}")
114+
94115
self._retriever_fit(data=data['relations'])
95116
candidate_relations_lst = self._retriever_predict(data=taxonomic_pairs_query, top_k=self.top_k)
96117
non_taxonomic_re = [{"head": head, "tail": tail, "relation": relation}

0 commit comments

Comments
 (0)