diff --git a/docs/source/learners/llms4ol.rst b/docs/source/learners/llms4ol.rst index 58cd23e..820171a 100644 --- a/docs/source/learners/llms4ol.rst +++ b/docs/source/learners/llms4ol.rst @@ -31,7 +31,7 @@ LLMs4OL is a community development initiative collocated with the International - **Text2Onto** - Extract ontological terms and types from unstructured text. - **ID**: ``text-to-onto`` + **ID**: ``text2onto`` **Info**: This task focuses on extracting foundational elements (Terms and Types) from unstructured text documents to build the initial structure of an ontology. It involves recognizing domain-relevant vocabulary (Term Extraction, SubTask 1) and categorizing it appropriately (Type Extraction, SubTask 2). It bridges the gap between natural language and structured knowledge representation. diff --git a/docs/source/learners/llms4ol_challenge/alexbek_learner.rst b/docs/source/learners/llms4ol_challenge/alexbek_learner.rst index 321e280..b564596 100644 --- a/docs/source/learners/llms4ol_challenge/alexbek_learner.rst +++ b/docs/source/learners/llms4ol_challenge/alexbek_learner.rst @@ -250,3 +250,147 @@ Learn and Predict truth = cross_learner.tasks_ground_truth_former(data=test_data, task=task) metrics = evaluation_report(y_true=truth, y_pred=predicts, task=task) print(metrics) + +Text2Onto +------------------ + +Loading Ontological Data +~~~~~~~~~~~~~~~~~~~~~~ + +For the Text2Onto task, we load an ontology (via ``OM``), extract its structured content, and then generate synthetic pseudo-sentences using an LLM-backed generator (DSPy + Ollama in this example). + +.. code-block:: python + + import os + import dspy + + # Ontology loader/manager + from ontolearner.ontology import OM + + # Text2Onto utilities: synthetic generation + dataset splitting + from ontolearner.text2onto import SyntheticGenerator, SyntheticDataSplitter + + # ---- DSPy -> Ollama (LiteLLM-style) ---- + LLM_MODEL_ID = "ollama/llama3.2:3b" # use your pulled Ollama model + LLM_API_KEY = "NA" # local Ollama doesn't use a key + LLM_BASE_URL = "http://localhost:11434" # default Ollama endpoint + + dspy_llm = dspy.LM( + model=LLM_MODEL_ID, + cache=True, + max_tokens=4000, + temperature=0, + api_key=LLM_API_KEY, + base_url=LLM_BASE_URL, + ) + dspy.configure(lm=dspy_llm) + + # ---- Synthetic generation configuration ---- + pseudo_sentence_batch_size = int(os.getenv("TEXT2ONTO_BATCH", "10")) + max_worker_count_for_llm_calls = int(os.getenv("TEXT2ONTO_WORKERS", "1")) + + text2onto_synthetic_generator = SyntheticGenerator( + batch_size=pseudo_sentence_batch_size, + worker_count=max_worker_count_for_llm_calls, + ) + + # ---- Load ontology and extract structured data ---- + ontology = OM() + ontology.load() + ontological_data = ontology.extract() + + print(f"term types: {len(ontological_data.term_typings)}") + print(f"taxonomic relations: {len(ontological_data.type_taxonomies.taxonomies)}") + print(f"non-taxonomic relations: {len(ontological_data.type_non_taxonomic_relations.non_taxonomies)}") + + # ---- Generate synthetic Text2Onto samples ---- + synthetic_data = text2onto_synthetic_generator.generate( + ontological_data=ontological_data, + topic=ontology.domain, + ) + +Split Synthetic Data +~~~~~~~~~~~~~~~~~~~~ + +We split the synthetic dataset into train/val/test sets using ``SyntheticDataSplitter``. +Each split is a dict with keys: + +- ``documents`` +- ``terms`` +- ``types`` +- ``terms2docs`` +- ``terms2types`` + +.. code-block:: python + + splitter = SyntheticDataSplitter( + synthetic_data=synthetic_data, + onto_name=ontology.ontology_id, + ) + + train_data, val_data, test_data = splitter.train_test_val_split( + train=0.8, + val=0.0, + test=0.2, + ) + + print("TRAIN sizes:") + print(" documents:", len(train_data.get("documents", []))) + print(" terms:", len(train_data.get("terms", []))) + print(" types:", len(train_data.get("types", []))) + print(" terms2docs:", len(train_data.get("terms2docs", {}))) + print(" terms2types:", len(train_data.get("terms2types", {}))) + + print("TEST sizes:") + print(" documents:", len(test_data.get("documents", []))) + print(" terms:", len(test_data.get("terms", []))) + print(" types:", len(test_data.get("types", []))) + print(" terms2docs:", len(test_data.get("terms2docs", {}))) + print(" terms2types:", len(test_data.get("terms2types", {}))) + +Initialize Learner +~~~~~~~~~~~~~~~~~~ + +We configure a retrieval-augmented few-shot learner for the Text2Onto task. +The learner retrieves relevant synthetic examples and uses an LLM to predict structured outputs. + +.. code-block:: python + + from ontolearner.learner.text2onto import AlexbekRAGFewShotLearner + + text2onto_learner = AlexbekRAGFewShotLearner( + llm_model_id="Qwen/Qwen2.5-0.5B-Instruct", + retriever_model_id="sentence-transformers/all-MiniLM-L6-v2", + device="cpu", # set "cuda" if available + top_k=3, + max_new_tokens=256, + use_tfidf=True, + ) + +Learn and Predict +~~~~~~~~~~~~~~~~~ + +We run the end-to-end pipeline (train -> predict -> evaluate) with ``LearnerPipeline`` using the ``text2onto`` task id. + +.. code-block:: python + + from ontolearner import LearnerPipeline + + task = "text2onto" + + pipe = LearnerPipeline( + llm=text2onto_learner, + llm_id="Qwen/Qwen2.5-0.5B-Instruct", + ontologizer_data=False, + ) + + outputs = pipe( + train_data=train_data, + test_data=test_data, + task=task, + evaluate=True, + ontologizer_data=False, + ) + + print("Metrics:", outputs.get("metrics")) + print("Elapsed time:", outputs.get("elapsed_time")) diff --git a/docs/source/learners/llms4ol_challenge/sbunlp_learner.rst b/docs/source/learners/llms4ol_challenge/sbunlp_learner.rst index 860c3a4..bef83d2 100644 --- a/docs/source/learners/llms4ol_challenge/sbunlp_learner.rst +++ b/docs/source/learners/llms4ol_challenge/sbunlp_learner.rst @@ -31,6 +31,8 @@ Methodological Summary: - For **Taxonomy Discovery**, the focus was on detecting parent–child relationships between ontology terms. Due to the relational nature of this task, batch prompting was employed to efficiently handle multiple type pairs per inference, enabling the model to consider several candidate relations jointly. +- For **Text2Onto**, the objective was to extract ontology construction signals from text-like inputs: generating/using documents, identifying candidate terms, assigning types, and producing supporting mappings such as term–document and term–type associations. In OntoLearner, this is implemented by first generating synthetic pseudo-documents from an ontology (using an LLM-backed synthetic generator), then applying the SBU-NLP prompting strategy to infer structured outputs without any fine-tuning. Dataset splitting and optional Ontologizer-style processing are used to support reproducible evaluation and artifact generation. + Term Typing ----------------------- @@ -179,3 +181,147 @@ Learn and Predict # Evaluate taxonomy discovery performance metrics = evaluation_report(y_true=truth, y_pred=predicts, task=task) print(metrics) + +Text2Onto +------------------ + +Loading Ontological Data +~~~~~~~~~~~~~~~~~~~~~~ + +For the Text2Onto task, we load an ontology (via ``OM``), extract its structured content, and generate synthetic pseudo-sentences using an LLM-backed generator (DSPy + Ollama in this example). + +.. code-block:: python + + import os + import dspy + + # Import ontology loader/manager and Text2Onto utilities + from ontolearner.ontology import OM + from ontolearner.text2onto import SyntheticGenerator, SyntheticDataSplitter + + # ---- DSPy -> Ollama (LiteLLM-style) ---- + LLM_MODEL_ID = "ollama/llama3.2:3b" + LLM_API_KEY = "NA" # local Ollama doesn't use a key + LLM_BASE_URL = "http://localhost:11434" # default Ollama endpoint + + dspy_llm = dspy.LM( + model=LLM_MODEL_ID, + cache=True, + max_tokens=4000, + temperature=0, + api_key=LLM_API_KEY, + base_url=LLM_BASE_URL, + ) + dspy.configure(lm=dspy_llm) + + # ---- Synthetic generation configuration ---- + batch_size = int(os.getenv("TEXT2ONTO_BATCH", "10")) + worker_count = int(os.getenv("TEXT2ONTO_WORKERS", "1")) + + text2onto_synthetic_generator = SyntheticGenerator( + batch_size=batch_size, + worker_count=worker_count, + ) + + # ---- Load ontology and extract structured data ---- + ontology = OM() + ontology.load() + ontological_data = ontology.extract() + + # Optional sanity checks to verify what was extracted from the ontology + print(f"term types: {len(ontological_data.term_typings)}") + print(f"taxonomic relations: {len(ontological_data.type_taxonomies.taxonomies)}") + print(f"non-taxonomic relations: {len(ontological_data.type_non_taxonomic_relations.non_taxonomies)}") + + # ---- Generate synthetic Text2Onto samples ---- + synthetic_data = text2onto_synthetic_generator.generate( + ontological_data=ontological_data, + topic=ontology.domain, + ) + +Split Synthetic Data +~~~~~~~~~~~~~~~~~~~~ + +We split the synthetic dataset into train/val/test sets using ``SyntheticDataSplitter``. +Each split is a dict with keys: + +- ``documents`` +- ``terms`` +- ``types`` +- ``terms2docs`` +- ``terms2types`` + +.. code-block:: python + + splitter = SyntheticDataSplitter( + synthetic_data=synthetic_data, + onto_name=ontology.ontology_id, + ) + + train_data, val_data, test_data = splitter.train_test_val_split( + train=0.8, + val=0.0, + test=0.2, + ) + + print("TRAIN sizes:") + print(" documents:", len(train_data.get("documents", []))) + print(" terms:", len(train_data.get("terms", []))) + print(" types:", len(train_data.get("types", []))) + print(" terms2docs:", len(train_data.get("terms2docs", {}))) + print(" terms2types:", len(train_data.get("terms2types", {}))) + + print("TEST sizes:") + print(" documents:", len(test_data.get("documents", []))) + print(" terms:", len(test_data.get("terms", []))) + print(" types:", len(test_data.get("types", []))) + print(" terms2docs:", len(test_data.get("terms2docs", {}))) + print(" terms2types:", len(test_data.get("terms2types", {}))) + +Initialize Learner +~~~~~~~~~~~~~~~~~~ + +We configure the SBU-NLP few-shot learner for the Text2Onto task. +This learner uses an LLM to produce predictions from the synthetic Text2Onto-style samples. + +.. code-block:: python + + from ontolearner.learner.text2onto import SBUNLPFewShotLearner + + text2onto_learner = SBUNLPFewShotLearner( + llm_model_id="Qwen/Qwen2.5-0.5B-Instruct", + device="cpu", # set "cuda" if available + max_new_tokens=256, + output_dir="./results/", + ) + +Learn and Predict +~~~~~~~~~~~~~~~~~ + +We run the end-to-end pipeline (train -> predict -> evaluate) with ``LearnerPipeline`` using the ``text2onto`` task id. + +.. code-block:: python + + from ontolearner import LearnerPipeline + + task = "text2onto" + + pipe = LearnerPipeline( + llm=text2onto_learner, + llm_id="Qwen/Qwen2.5-0.5B-Instruct", + ontologizer_data=False, + ) + + outputs = pipe( + train_data=train_data, + test_data=test_data, + task=task, + evaluate=True, + ontologizer_data=True, + ) + + print("Metrics:", outputs.get("metrics")) + print("Elapsed time:", outputs.get("elapsed_time")) + + # Print all returned outputs (often includes predictions/artifacts/logs) + print(outputs) diff --git a/examples/llm_learner_alexbek_text2onto.py b/examples/llm_learner_alexbek_text2onto.py index 69282a9..fe36bec 100644 --- a/examples/llm_learner_alexbek_text2onto.py +++ b/examples/llm_learner_alexbek_text2onto.py @@ -1,84 +1,111 @@ import os -import json -import torch +import dspy -# LocalAutoLLM handles model loading/generation; AlexbekFewShotLearner provides fit/predict APIs -from ontolearner.learner.text2onto.alexbek import LocalAutoLLM, AlexbekFewShotLearner +# Import ontology loader/manager +from ontolearner.ontology import OM -# Local folder where the dataset is stored (relative to this script) -DATA_DIR = "./dataset_llms4ol_2025/TaskA-Text2Onto/ecology" +# Import Text2Onto utilities: synthetic sample generation + dataset splitting +from ontolearner.text2onto import SyntheticGenerator, SyntheticDataSplitter -# Input paths (already saved) -TRAIN_DOCS_PATH = os.path.join(DATA_DIR, "train", "documents.jsonl") -TRAIN_TERMS2DOCS_PATH = os.path.join(DATA_DIR, "train", "terms2docs.json") -TEST_DOCS_FULL_PATH = os.path.join( - DATA_DIR, "test", "text2onto_ecology_test_documents.jsonl" -) +# Import pipeline orchestrator + the specific Few-Shot learner you want to run +from ontolearner import LearnerPipeline +from ontolearner.learner.text2onto import AlexbekRAGFewShotLearner + +# ---- DSPy -> Ollama (LiteLLM-style) ---- +# Configure DSPy to send prompts to a locally running Ollama server (via LiteLLM-compatible args). +LLM_MODEL_ID = "ollama/llama3.2:3b" # use your pulled Ollama model +LLM_API_KEY = "NA" # Ollama local doesn't use a key; kept for interface compatibility +LLM_BASE_URL = "http://localhost:11434" # default Ollama server endpoint -# Output paths -DOC_TERMS_OUT_PATH = os.path.join( - DATA_DIR, "test", "extracted_terms_ecology.fast.jsonl" +# Create the DSPy language model wrapper. +# Note: DSPy uses LiteLLM-style parameters under the hood when given model/base_url/api_key. +dspy_llm = dspy.LM( + model=LLM_MODEL_ID, + cache=True, # cache generations to speed up repeated runs + max_tokens=4000, # generous context for synthetic generation prompts + temperature=0, # deterministic output; helpful for reproducibility + api_key=LLM_API_KEY, + base_url=LLM_BASE_URL, ) -TERMS2TYPES_OUT_PATH = os.path.join( - DATA_DIR, "test", "terms2types_pred_ecology.fast.json" + +# Register the LM globally so DSPy modules (and generator internals) use it. +dspy.configure(lm=dspy_llm) + +# ---- Synthetic generation configuration ---- +# Allow scaling generation without editing code by using environment variables: +# TEXT2ONTO_BATCH=20 TEXT2ONTO_WORKERS=2 python script.py +pseudo_sentence_batch_size = int(os.getenv("TEXT2ONTO_BATCH", "10")) +max_worker_count_for_llm_calls = int(os.getenv("TEXT2ONTO_WORKERS", "1")) + +# Instantiate the generator that turns ontology structures into pseudo-text samples. +text2onto_synthetic_generator = SyntheticGenerator( + batch_size=pseudo_sentence_batch_size, # number of samples requested per batch + worker_count=max_worker_count_for_llm_calls, # parallel LLM calls (increase if your machine can handle it) ) -TYPES2DOCS_OUT_PATH = os.path.join( - DATA_DIR, "test", "types2docs_pred_ecology.fast.json" + +# ---- Load ontology and extract structured data ---- +# OM loads the ontology configured in your OntoLearner setup and exposes domain metadata. +ontology = OM() +ontology.load() +ontological_data = ontology.extract() # structured: term typings, taxonomies, relations, etc. + +# ---- Generate synthetic Text2Onto samples ---- +# Uses the extracted ontology structures + domain/topic to create synthetic training examples. +synthetic_data = text2onto_synthetic_generator.generate( + ontological_data=ontological_data, + topic=ontology.domain ) -# Device selection -DEVICE = ( - "cuda" - if torch.cuda.is_available() - else ("mps" if torch.backends.mps.is_available() else "cpu") +# ---- Dataset splitter ---- +# Wrap the synthetic dataset with a splitter utility for reproducible partitioning. +splitter = SyntheticDataSplitter( + synthetic_data=synthetic_data, + onto_name=ontology.ontology_id # used to tag/identify outputs for this ontology ) -# Model config -MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct" -LOAD_IN_4BIT = DEVICE == "cuda" # 4-bit helps on GPU +# Optional sanity checks to verify what was extracted from the ontology. +print(f"term types: {len(ontological_data.term_typings)}") +print(f"taxonomic relations: {len(ontological_data.type_taxonomies.taxonomies)}") +print(f"non-taxonomic relations: {len(ontological_data.type_non_taxonomic_relations.non_taxonomies)}") -# 1) Load LLM -llm = LocalAutoLLM(device=DEVICE) -llm.load(MODEL_ID, load_in_4bit=LOAD_IN_4BIT) +# ---- Split into train/val/test ---- +# val=0.0 keeps the API consistent while skipping validation split for this run. +train_data, val_data, test_data = splitter.train_test_val_split(train=0.8, val=0.0, test=0.2) -# 2) Build few-shot exemplars from training split -learner = AlexbekFewShotLearner(model=llm, device=DEVICE) -learner.fit( - train_docs_jsonl=TRAIN_DOCS_PATH, - terms2doc_json=TRAIN_TERMS2DOCS_PATH, - # use defaults for sample size/seed +# ---- Configure the Few-Shot learner for Text2Onto ---- +# This learner will be used by the pipeline to learn/predict from Text2Onto-style samples. +text2ontolearner = AlexbekRAGFewShotLearner( + llm_model_id="Qwen/Qwen2.5-0.5B-Instruct", # generator model used inside the learner + retriever_model_id="sentence-transformers/all-MiniLM-L6-v2", # embedding model for retrieval + device="cpu", # set "cuda" if you have GPU support + top_k=3, # number of retrieved examples/chunks + max_new_tokens=256, # response length for the learner's generator + use_tfidf=True, # optional lexical retrieval alongside embeddings ) -# 3) Predict terms per test document -os.makedirs(os.path.dirname(DOC_TERMS_OUT_PATH), exist_ok=True) -num_written_doc_terms = learner.predict_terms( - docs_test_jsonl=TEST_DOCS_FULL_PATH, - out_jsonl=DOC_TERMS_OUT_PATH, - # use defaults for max_new_tokens and few_shot_k -) -print(f"[terms] wrote {num_written_doc_terms} lines → {DOC_TERMS_OUT_PATH}") - -# 4) Predict types for extracted terms, using the JSONL we just wrote -typing_summary = learner.predict_types_from_terms( - doc_terms_jsonl=DOC_TERMS_OUT_PATH, # read the predictions directly - doc_terms_list=None, # (not needed when doc_terms_jsonl is provided) - model_id=MODEL_ID, # reuse the same small model - out_terms2types=TERMS2TYPES_OUT_PATH, - out_types2docs=TYPES2DOCS_OUT_PATH, - # use defaults for everything else +# ---- Build pipeline ---- +# LearnerPipeline orchestrates training/prediction/evaluation for the chosen task. +pipe = LearnerPipeline( + llm=text2ontolearner, # the learner implementation used by the pipeline + llm_id="Qwen/Qwen2.5-0.5B-Instruct", # label/id recorded with results + ontologizer_data=False, # whether to run Ontologizer-related processing ) -print( - f"[types] {typing_summary['unique_terms']} unique terms | {typing_summary['types_count']} types" +# ---- Run end-to-end (train -> predict -> evaluate) ---- +outputs = pipe( + train_data=train_data, + test_data=test_data, + task="text2onto", + evaluate=True, # compute evaluation metrics on the test set + ontologizer_data=False, # keep consistent with pipeline setting above ) -print(f"[saved] {TERMS2TYPES_OUT_PATH}") -print(f"[saved] {TYPES2DOCS_OUT_PATH}") - -# 5) Small preview of term→types -try: - with open(TERMS2TYPES_OUT_PATH, "r", encoding="utf-8") as fin: - preview = json.load(fin)[:3] - print("[preview] first 3:") - print(json.dumps(preview, ensure_ascii=False, indent=2)) -except Exception as e: - print(f"[preview] skipped: {e}") + +# ---- Display results ---- +# Metrics typically include task-specific scores (depends on OntoLearner implementation). +print("Metrics:", outputs.get("metrics")) + +# Total elapsed time for training + prediction + evaluation. +print("Elapsed time:", outputs["elapsed_time"]) + +# Print everything returned (often includes predictions, logs, artifacts, etc.) +print(outputs) diff --git a/examples/llm_learner_sbunlp_text2onto.py b/examples/llm_learner_sbunlp_text2onto.py index cff543c..03cba2b 100644 --- a/examples/llm_learner_sbunlp_text2onto.py +++ b/examples/llm_learner_sbunlp_text2onto.py @@ -1,88 +1,108 @@ import os -import torch - -# Import all the required classes -from ontolearner import SBUNLPText2OntoLearner -from ontolearner.learner.text2onto.sbunlp import LocalAutoLLM - -# Local folder where the dataset is stored -# This path is relative to the directory where the script is executed -# (e.g., E:\OntoLearner\examples) -LOCAL_DATA_DIR = "./dataset_llms4ol_2025/TaskA-Text2Onto/ecology" - -# Ensure the base directories exist -# Creates the train and test subdirectories if they don't already exist. -os.makedirs(os.path.join(LOCAL_DATA_DIR, "train"), exist_ok=True) -os.makedirs(os.path.join(LOCAL_DATA_DIR, "test"), exist_ok=True) - -# Define local file paths: POINTING TO ALREADY SAVED FILES -# These files are used as input for the Fit and Predict phases. -DOCS_ALL_PATH = "./dataset_llms4ol_2025/TaskA-Text2Onto/ecology/train/documents.jsonl" -TERMS2DOC_PATH = "./dataset_llms4ol_2025/TaskA-Text2Onto/ecology/train/terms2docs.json" -DOCS_TEST_PATH = "./dataset_llms4ol_2025/TaskA-Text2Onto/ecology/test/text2onto_ecology_test_documents.jsonl" - -# Output files for predictions (saved directly under LOCAL_DATA_DIR/test) -# These files will be created by the predict_terms/types methods. -TERMS_PRED_OUT = ( - "./dataset_llms4ol_2025/TaskA-Text2Onto/ecology/test/extracted_terms_ecology.jsonl" +import dspy + +# Import ontology loader/manager and Text2Onto utilities +from ontolearner.ontology import OM +from ontolearner.text2onto import SyntheticGenerator, SyntheticDataSplitter + +# Import the pipeline orchestrator and the specific Few-Shot learner for Text2Onto +from ontolearner import LearnerPipeline +from ontolearner.learner.text2onto import SBUNLPFewShotLearner + +# ---- DSPy -> Ollama (LiteLLM-style) ---- +# Configure DSPy to send prompts to a locally running Ollama server. +LLM_MODEL_ID = "ollama/llama3.2:3b" +LLM_API_KEY = "NA" # Ollama local doesn't use a key; kept for interface compatibility. +LLM_BASE_URL = "http://localhost:11434" # default Ollama endpoint + +# Create the DSPy language model wrapper (LiteLLM-compatible settings) +dspy_llm = dspy.LM( + model=LLM_MODEL_ID, + cache=True, # cache generations to speed up iterative runs + max_tokens=4000, + temperature=0, # deterministic output; useful for reproducible synthetic data + api_key=LLM_API_KEY, + base_url=LLM_BASE_URL, ) -TYPES_PRED_OUT = ( - "./dataset_llms4ol_2025/TaskA-Text2Onto/ecology/test/extracted_types_ecology.jsonl" + +# Register the LM globally so DSPy modules (and generator internals) use it +dspy.configure(lm=dspy_llm) + +# ---- Synthetic generation configuration ---- +# Allow scaling generation without code edits via environment variables: +# TEXT2ONTO_BATCH=20 TEXT2ONTO_WORKERS=2 python script.py +batch_size = int(os.getenv("TEXT2ONTO_BATCH", "10")) +worker_count = int(os.getenv("TEXT2ONTO_WORKERS", "1")) + +# Instantiate the generator that turns ontology structures into pseudo-text samples +text2onto_synthetic_generator = SyntheticGenerator( + batch_size=batch_size, # number of samples requested per batch + worker_count=worker_count, # parallel LLM calls (increase if your machine can handle it) +) + +# ---- Load ontology and extract structured data ---- +# OM loads the ontology configured in your OntoLearner setup and exposes its domain metadata. +ontology = OM() +ontology.load() +ontological_data = ontology.extract() # structured: term typings, taxonomies, relations, etc. + +# ---- Generate synthetic Text2Onto samples ---- +# Uses the ontology's extracted structures + domain/topic to create synthetic training examples. +synthetic_data = text2onto_synthetic_generator.generate( + ontological_data=ontological_data, + topic=ontology.domain, +) + +# Optional sanity checks to verify what was extracted from the ontology +print(f"term types: {len(ontological_data.term_typings)}") +print(f"taxonomic relations: {len(ontological_data.type_taxonomies.taxonomies)}") +print(f"non-taxonomic relations: {len(ontological_data.type_non_taxonomic_relations.non_taxonomies)}") + +# ---- Split into train/val/test ---- +# Wrap the synthetic dataset with a splitter utility for reproducible partitioning. +splitter = SyntheticDataSplitter( + synthetic_data=synthetic_data, + onto_name=ontology.ontology_id, # used to tag/identify outputs for this ontology ) -# Initialize and Load Learner --- -MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" -# Determine the device for inference (GPU or CPU) -DEVICE = "cuda" if torch.cuda.is_available() else "cpu" - -# Instantiate the underlying LLM helper -# (LocalAutoLLM handles model loading and generation) -llm_model_helper = LocalAutoLLM(device=DEVICE) - -# Instantiate the main learner class, passing the LLM helper to its constructor -learner = SBUNLPText2OntoLearner(model=llm_model_helper, device=DEVICE) - -# Load the model (This calls llm_model_helper.load) -LOAD_IN_4BIT = torch.cuda.is_available() -learner.model.load(MODEL_ID, load_in_4bit=LOAD_IN_4BIT) - -# Build Few-Shot Exemplars (Fit Phase) -# The fit method uses the local data paths to build the in-context learning prompts. -learner.fit( - train_docs_jsonl=DOCS_ALL_PATH, - terms2doc_json=TERMS2DOC_PATH, - sample_size=28, - seed=123, # Seed for stratified random sampling stability +# Create splits for training and evaluation. +# val=0.0 keeps the API consistent while skipping validation split in this run. +train_data, val_data, test_data = splitter.train_test_val_split( + train=0.8, + val=0.0, + test=0.2, ) -MAX_NEW_TOKENS = 100 +# ---- Configure the Few-Shot learner for Text2Onto ---- +# This learner will be used by the pipeline to learn/predict from Text2Onto-style samples. +text2ontolearner = SBUNLPFewShotLearner( + llm_model_id="Qwen/Qwen2.5-0.5B-Instruct", + device="cpu", + max_new_tokens=256, +) -terms_written = learner.predict_terms( - docs_test_jsonl=DOCS_TEST_PATH, - out_jsonl=TERMS_PRED_OUT, - max_new_tokens=MAX_NEW_TOKENS, +# Build pipeline and run +# Build the pipeline, passing the Few-Shot Learner. +pipe = LearnerPipeline( + llm=text2ontolearner, + llm_id="Qwen/Qwen2.5-0.5B-Instruct", + ontologizer_data=False, ) -print(f"✅ Term Extraction Complete. Wrote {terms_written} prediction lines.") -# Type Extraction subtask -types_written = learner.predict_types( - docs_test_jsonl=DOCS_TEST_PATH, - out_jsonl=TYPES_PRED_OUT, - max_new_tokens=MAX_NEW_TOKENS, +# Run the full learning pipeline on the text2onto task +outputs = pipe( + train_data=train_data, + test_data=test_data, + task="text2onto", + evaluate=True, + ontologizer_data=True, ) -print(f"✅ Type Extraction Complete. Wrote {types_written} prediction lines.") - -try: - # Evaluate Term Extraction using the custom F1 function and gold data - f1_term = learner.evaluate_extraction_f1(TERMS2DOC_PATH, TERMS_PRED_OUT, key="term") - print(f"Final Term Extraction F1: {f1_term:.4f}") - - # Evaluate Type Extraction - f1_type = learner.evaluate_extraction_f1(TERMS2DOC_PATH, TYPES_PRED_OUT, key="type") - print(f"Final Type Extraction F1: {f1_type:.4f}") - -except Exception as e: - # Catches errors like missing sklearn (ImportError) or missing prediction files (FileNotFoundError) - print( - f"❌ Evaluation Error: {e}. Ensure sklearn is installed and prediction files were created." - ) + +# Display the evaluation results +print("Metrics:", outputs.get("metrics")) + +# Display total elapsed time for training + prediction + evaluation +print("Elapsed time:", outputs["elapsed_time"]) + +# Print all returned outputs (include predictions) +print(outputs) diff --git a/examples/text2onto.py b/examples/text2onto.py index c67bb5f..03190d5 100644 --- a/examples/text2onto.py +++ b/examples/text2onto.py @@ -58,20 +58,33 @@ onto_name=ontology.ontology_id ) -# Split the synthetic data into train/val/test for each component -terms, types, docs, types2docs = splitter.split(train=0.8, val=0.1, test=0.1) +# split the train, val, test +train_data, val_data, test_data = splitter.train_test_val_split( + train=0.8, + val=0.0, + test=0.2, +) -# Print how many items exist in each split for terms -print("Terms:") -for split in terms: - print(f" {split}: {len(terms[split])}") +# print train split +print("\nTRAIN split:") +print(" docs:", len(train_data.get("documents", []))) +print(" terms:", len(train_data.get("terms", []))) +print(" types:", len(train_data.get("types", []))) +print(" terms2docs:", len(train_data.get("terms2docs", {}))) +print(" terms2types:", len(train_data.get("terms2types", {}))) -# Print how many items exist in each split for types -print("Types:") -for split in types: - print(f" {split}: {len(types[split])}") +# print val split +print("\nVAL split:") +print(" docs:", len(val_data.get("documents", []))) +print(" terms:", len(val_data.get("terms", []))) +print(" types:", len(val_data.get("types", []))) +print(" terms2docs:", len(val_data.get("terms2docs", {}))) +print(" terms2types:", len(val_data.get("terms2types", {}))) -# Print how many items exist in each split for docs -print("Docs:") -for split in docs: - print(f" {split}: {len(docs[split])}") +# print test split +print("\nTEST split:") +print(" docs:", len(test_data.get("documents", []))) +print(" terms:", len(test_data.get("terms", []))) +print(" types:", len(test_data.get("types", []))) +print(" terms2docs:", len(test_data.get("terms2docs", {}))) +print(" terms2types:", len(test_data.get("terms2types", {}))) diff --git a/ontolearner/base/learner.py b/ontolearner/base/learner.py index c410915..46acd88 100644 --- a/ontolearner/base/learner.py +++ b/ontolearner/base/learner.py @@ -18,6 +18,7 @@ import torch import torch.nn.functional as F from sentence_transformers import SentenceTransformer +from collections import defaultdict class AutoLearner(ABC): """ @@ -70,6 +71,7 @@ def fit(self, train_data: Any, task: str, ontologizer: bool=True): - "term-typing": Predict semantic types for terms - "taxonomy-discovery": Identify hierarchical relationships - "non-taxonomy-discovery": Identify non-hierarchical relationships + - "text2onto" : Extract ontology terms and their semantic types from documents Raises: NotImplementedError: If not implemented by concrete class. @@ -81,6 +83,8 @@ def fit(self, train_data: Any, task: str, ontologizer: bool=True): self._taxonomy_discovery(train_data, test=False) elif task == 'non-taxonomic-re': self._non_taxonomic_re(train_data, test=False) + elif task == 'text2onto': + self._text2onto(train_data, test=False) else: raise ValueError(f"{task} is not a valid task.") @@ -103,6 +107,7 @@ def predict(self, eval_data: Any, task: str, ontologizer: bool=True) -> Any: - term-typing: List of predicted types for each term - taxonomy-discovery: Boolean predictions for relationships - non-taxonomy-discovery: Predicted relation types + - text2onto : Extract ontology terms and their semantic types from documents Raises: NotImplementedError: If not implemented by concrete class. @@ -115,6 +120,8 @@ def predict(self, eval_data: Any, task: str, ontologizer: bool=True) -> Any: return self._taxonomy_discovery(eval_data, test=True) elif task == 'non-taxonomic-re': return self._non_taxonomic_re(eval_data, test=True) + elif task == 'text2onto': + return self._text2onto(eval_data, test=True) else: raise ValueError(f"{task} is not a valid task.") @@ -147,6 +154,9 @@ def _taxonomy_discovery(self, data: Any, test: bool = False) -> Optional[Any]: def _non_taxonomic_re(self, data: Any, test: bool = False) -> Optional[Any]: pass + def _text2onto(self, data: Any, test: bool = False) -> Optional[Any]: + pass + def tasks_data_former(self, data: Any, task: str, test: bool = False) -> List[str | Dict[str, str]]: formatted_data = [] if task == "term-typing": @@ -171,6 +181,7 @@ def tasks_data_former(self, data: Any, task: str, test: bool = False) -> List[st non_taxonomic_types = list(set(non_taxonomic_types)) non_taxonomic_res = list(set(non_taxonomic_res)) formatted_data = {"types": non_taxonomic_types, "relations": non_taxonomic_res} + return formatted_data def tasks_ground_truth_former(self, data: Any, task: str) -> List[Dict[str, str]]: @@ -186,6 +197,26 @@ def tasks_ground_truth_former(self, data: Any, task: str) -> List[Dict[str, str] formatted_data.append({"head": non_taxonomic_triplets.head, "tail": non_taxonomic_triplets.tail, "relation": non_taxonomic_triplets.relation}) + if task == "text2onto": + terms2docs = data.get("terms2docs", {}) or {} + terms2types = data.get("terms2types", {}) or {} + + # gold doc→terms + gold_terms = [] + for term, doc_ids in terms2docs.items(): + for doc_id in doc_ids or []: + gold_terms.append({"doc_id": doc_id, "term": term}) + + # gold doc→types derived via doc→terms + term→types + doc2types = defaultdict(set) + for term, doc_ids in terms2docs.items(): + for doc_id in doc_ids or []: + for ty in (terms2types.get(term, []) or []): + if isinstance(ty, str) and ty.strip(): + doc2types[doc_id].add(ty.strip()) + gold_types = [{"doc_id": doc_id, "type": ty} for doc_id, tys in doc2types.items() for ty in tys] + return {"terms": gold_terms, "types": gold_types} + return formatted_data class AutoLLM(ABC): diff --git a/ontolearner/evaluation/metrics.py b/ontolearner/evaluation/metrics.py index 57b2d66..52340ce 100644 --- a/ontolearner/evaluation/metrics.py +++ b/ontolearner/evaluation/metrics.py @@ -11,44 +11,84 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Dict, Tuple, Set +from typing import List, Dict, Tuple, Set, Any, Union SYMMETRIC_RELATIONS = {"equivalentclass", "sameas", "disjointwith"} -def text2onto_metrics(y_true: List[str], y_pred: List[str], similarity_threshold: float = 0.8) -> Dict[str, float | int]: - def jaccard_similarity(a: str, b: str) -> float: - set_a = set(a.lower().split()) - set_b = set(b.lower().split()) - if not set_a and not set_b: +def text2onto_metrics( + y_true: Dict[str, Any], + y_pred: Dict[str, Any], + similarity_threshold: float = 0.8 +) -> Dict[str, Any]: + """ + Expects: + y_true = {"terms": [{"doc_id": str, "term": str}, ...], + "types": [{"doc_id": str, "type": str}, ...]} + y_pred = same shape + + Returns: + {"terms": {...}, "types": {...}} + """ + + def jaccard_similarity(text_a: str, text_b: str) -> float: + tokens_a = set(text_a.lower().split()) + tokens_b = set(text_b.lower().split()) + if not tokens_a and not tokens_b: return 1.0 - return len(set_a & set_b) / len(set_a | set_b) - - matched_gt_indices = set() - matched_pred_indices = set() - for i, pred_label in enumerate(y_pred): - for j, gt_label in enumerate(y_true): - if j in matched_gt_indices: - continue - sim = jaccard_similarity(pred_label, gt_label) - if sim >= similarity_threshold: - matched_pred_indices.add(i) - matched_gt_indices.add(j) - break # each gt matched once - - total_correct = len(matched_pred_indices) - total_predicted = len(y_pred) - total_ground_truth = len(y_true) + return len(tokens_a & tokens_b) / len(tokens_a | tokens_b) + + def pairs_to_strings(rows: List[Dict[str, str]], value_key: str) -> List[str]: + paired_strings: List[str] = [] + for row in rows or []: + doc_id = (row.get("doc_id") or "").strip() + value = (row.get(value_key) or "").strip() + if doc_id and value: + # keep doc association + allow token Jaccard + paired_strings.append(f"{doc_id} {value}") + return paired_strings + + def score_list(ground_truth_items: List[str], predicted_items: List[str]) -> Dict[str, Union[float, int]]: + matched_ground_truth_indices: Set[int] = set() + matched_predicted_indices: Set[int] = set() + + for predicted_index, predicted_item in enumerate(predicted_items): + for ground_truth_index, ground_truth_item in enumerate(ground_truth_items): + if ground_truth_index in matched_ground_truth_indices: + continue + + if jaccard_similarity(predicted_item, ground_truth_item) >= similarity_threshold: + matched_predicted_indices.add(predicted_index) + matched_ground_truth_indices.add(ground_truth_index) + break + + total_correct = len(matched_predicted_indices) + total_predicted = len(predicted_items) + total_ground_truth = len(ground_truth_items) + + precision = total_correct / total_predicted if total_predicted else 0.0 + recall = total_correct / total_ground_truth if total_ground_truth else 0.0 + f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0 + + return { + "f1_score": f1, + "precision": precision, + "recall": recall, + "total_correct": total_correct, + "total_predicted": total_predicted, + "total_ground_truth": total_ground_truth, + } + + ground_truth_terms = pairs_to_strings(y_true.get("terms", []), "term") + predicted_terms = pairs_to_strings(y_pred.get("terms", []), "term") + ground_truth_types = pairs_to_strings(y_true.get("types", []), "type") + predicted_types = pairs_to_strings(y_pred.get("types", []), "type") + + terms_metrics = score_list(ground_truth_terms, predicted_terms) + types_metrics = score_list(ground_truth_types, predicted_types) - precision = total_correct / total_predicted if total_predicted > 0 else 0 - recall = total_correct / total_ground_truth if total_ground_truth > 0 else 0 - f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 return { - "f1_score": f1_score, - "precision": precision, - "recall": recall, - "total_correct": total_correct, - "total_predicted": total_predicted, - "total_ground_truth": total_ground_truth + "terms": terms_metrics, + "types": types_metrics, } def term_typing_metrics(y_true: List[Dict[str, List[str]]], y_pred: List[Dict[str, List[str]]]) -> Dict[str, float | int]: diff --git a/ontolearner/learner/text2onto/__init__.py b/ontolearner/learner/text2onto/__init__.py index 489853b..af31523 100644 --- a/ontolearner/learner/text2onto/__init__.py +++ b/ontolearner/learner/text2onto/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .alexbek import AlexbekFewShotLearner +from .alexbek import AlexbekRAGFewShotLearner from .sbunlp import SBUNLPFewShotLearner diff --git a/ontolearner/learner/text2onto/alexbek.py b/ontolearner/learner/text2onto/alexbek.py index f1692f7..8dee17a 100644 --- a/ontolearner/learner/text2onto/alexbek.py +++ b/ontolearner/learner/text2onto/alexbek.py @@ -12,1208 +12,587 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Iterable import json -from json.decoder import JSONDecodeError -import os -import random import re +from typing import Any, Dict, List, Optional +from collections import defaultdict import torch from transformers import AutoTokenizer, AutoModelForCausalLM -from ...base import AutoLearner, AutoLLM +from ...base import AutoLearner, AutoRetriever -try: - from outlines.models import Transformers as OutlinesTFModel - from outlines.generate import json as outlines_generate_json - from pydantic import BaseModel - - class _PredictedTypesSchema(BaseModel): - """Schema used when generating structured JSON { "types": [...] }.""" - - types: List[str] - - OUTLINES_AVAILABLE: bool = True -except Exception: - # If outlines is unavailable, we will fall back to greedy decoding + regex parsing. - OUTLINES_AVAILABLE = False - _PredictedTypesSchema = None - OutlinesTFModel = None - outlines_generate_json = None - - -class LocalAutoLLM(AutoLLM): +class AlexbekRAGFewShotLearner(AutoLearner): """ - Minimal local LLM helper. - - - Inherits AutoLLM but overrides load/generate to avoid label_mapper. - - Optional 4-bit loading with `load_in_4bit=True` in .load(). - - Greedy decoding by default (deterministic). + What it does (2-stage): + 1) doc -> terms + - retrieve top-k similar TRAIN documents (each has gold OL terms) + - build a few-shot chat prompt: (doc -> {"terms":[...]}) examples + target doc + - generate JSON {"terms":[...]} and parse it + + 2) term -> types + - retrieve top-k similar TRAIN terms (each has gold types) + - build a few-shot chat prompt: (term -> {"types":[...]}) examples + target term + - generate JSON {"types":[...]} and parse it + + Training behavior (fit): + - builds two retrieval indices: + * doc_retriever index over JSON strings of train docs (with "OL" field = gold terms) + * term_retriever index over JSON strings of train term->types examples + + Prediction behavior (predict): + - returns a dict compatible with OntoLearner evaluation_report: + { + "terms": [{"doc_id": "...", "term": "..."}, ...], + "types": [{"doc_id": "...", "type": "..."}, ...], + } + + Expected data format for task="text2onto": + data = { + "documents": [ {"id"/"doc_id": str, "title": str, "text": str, ...}, ... ], + "terms2docs": { term(str): [doc_id(str), ...], ... } + "terms2types": { term(str): [type(str), ...], ... } + } + + IMPORTANT: + - LearnerPipeline calls learner.load(model_id=llm_id). We accept that and override llm_model_id. + - We override tasks_data_former() so AutoLearner.fit/predict does NOT rewrite text2onto dicts. + - Device placement: we put the model exactly on the device string the user provides + ("cpu", "cuda", "cuda:0", "cuda:1", ...). No device_map="auto". """ - def __init__(self, device: str = "cpu", token: str = "") -> None: + TERM2TYPES_SYSTEM_PROMPT = ( + "You are an expert in ontology and semantic type classification. Your task is to predict " + "the semantic types for given terms based on their context and similar examples.\n\n" + "Given a term, you should predict its semantic types from the domain-specific ontology. " + "Use the provided examples to understand the patterns and relationships between terms and their types.\n\n" + "Output your response as a JSON object with the following structure:\n" + '{\n "types": ["type1", "type2", ...]\n}\n\n' + "The types should be relevant semantic categories that best describe the given term." + ) + + DOC2TERMS_SYSTEM_PROMPT = ( + "You are an expert in ontology term extraction.\n\n" + "TASK: Extract specific, relevant ontology terms from scientific documents.\n\n" + "INSTRUCTIONS:\n" + "- The following conversation contains few-shot examples showing correct term extraction patterns\n" + "- Study these examples carefully to understand the extraction style and approach\n" + "- Follow the EXACT same pattern and style demonstrated in the examples\n" + "- Extract only terms that actually appear in the document text\n" + "- Focus on domain-specific terminology, concepts, and technical terms\n\n" + "- The first three user-assistant conversation pairs serve as few-shot examples\n" + "- Each example shows: user provides a document, assistant extracts relevant terms\n" + "- Pay attention to the extraction patterns and term selection criteria in these examples\n\n" + "DO:\n" + "- Extract terms that are EXPLICITLY mentioned in the LAST document\n" + "- Follow the SAME extraction pattern as shown in examples\n" + "- Return unique terms without duplicates\n" + "- Use the same JSON format as demonstrated\n\n" + "DON'T:\n" + "- Hallucinate or invent terms not present in last the document\n" + "- Repeat the same term multiple times\n" + "- Deviate from the extraction style shown in examples\n\n" + "OUTPUT FORMAT: Return a JSON object with a single field 'terms' containing a list of extracted terms." + ) + + def __init__( + self, + llm_model_id: str, + retriever_model_id: str = "sentence-transformers/all-MiniLM-L6-v2", + device: str = "cpu", + top_k: int = 3, + max_new_tokens: int = 256, + max_input_length: int = 2048, + use_tfidf: bool = False, + seed: int = 42, + restrict_to_known_types: bool = True, + hf_token: str = "", + local_files_only: bool = False, + **kwargs: Any, + ): """ - Initialize the local LLM holder. - Parameters ---------- - device : str - Execution device: "cpu" or "cuda". - token : str - Optional auth token for private model hubs. - """ - super().__init__(label_mapper=None, device=device, token=token) + llm_model_id: + HuggingFace model id OR local path to a downloaded model directory. + retriever_model_id: + SentenceTransformer model id OR local path to a downloaded SBERT directory. + device: + Exact device string to place model on ("cpu", "cuda", "cuda:0", ...). + top_k: + Number of retrieved examples for few-shot prompting in each stage. + max_new_tokens: + Max tokens to generate for each prompt. + max_input_length: + Max prompt length before truncation. + use_tfidf: + If docs include TF-IDF suggestions (key "TF-IDF" or "tfidf_terms"), include them in prompts. + seed: + Seed for reproducibility. + restrict_to_known_types: + If True, append allowed type label list (from training) to system prompt in term->types stage. + This helps exact-match evaluation by discouraging invented labels. + hf_token: + HuggingFace token for gated models (optional). + local_files_only: + If True, Transformers will not try to reach the internet (requires local cache / local path). + """ + super().__init__(**kwargs) + + self.llm_model_id: str = llm_model_id + self.retriever_model_id: str = retriever_model_id + self.device: str = device + self.top_k: int = int(top_k) + self.max_new_tokens: int = int(max_new_tokens) + self.max_input_length: int = int(max_input_length) + self.use_tfidf: bool = bool(use_tfidf) + self.seed: int = int(seed) + self.restrict_to_known_types: bool = bool(restrict_to_known_types) + self.hf_token: str = hf_token or "" + self.local_files_only: bool = bool(local_files_only) + self.model: Optional[AutoModelForCausalLM] = None self.tokenizer: Optional[AutoTokenizer] = None + self._loaded: bool = False - def load(self, model_id: str, *, load_in_4bit: bool = False) -> None: - """ - Load a Hugging Face causal model + tokenizer and set deterministic - generation defaults. - - Parameters - ---------- - model_id : str - Model identifier resolvable by HF `from_pretrained`. - load_in_4bit : bool - If True and bitsandbytes is available, load using 4-bit quantization. - """ - # Tokenizer - self.tokenizer = AutoTokenizer.from_pretrained( - model_id, padding_side="left", token=self.token - ) - if self.tokenizer.pad_token is None: - self.tokenizer.pad_token = self.tokenizer.eos_token - - # Model (optionally quantized) - if load_in_4bit: - from transformers import BitsAndBytesConfig + # Internal retrievers (always used in method-1, even in "llm-only" pipeline mode) + self.doc_retriever = AutoRetriever() + self.term_retriever = AutoRetriever() - quantization_config = BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_use_double_quant=True, - bnb_4bit_compute_dtype=torch.bfloat16, - ) - self.model = AutoModelForCausalLM.from_pretrained( - model_id, - device_map="auto", - quantization_config=quantization_config, - token=self.token, - ) - else: - device_map = ( - "auto" if (self.device != "cpu" and torch.cuda.is_available()) else None - ) - self.model = AutoModelForCausalLM.from_pretrained( - model_id, - device_map=device_map, - torch_dtype=torch.bfloat16 - if torch.cuda.is_available() - else torch.float32, - token=self.token, - ) + # Indexed corpora as JSON strings + self._doc_examples_json: List[str] = [] + self._term_examples_json: List[str] = [] - # Deterministic generation defaults - generation_cfg = self.model.generation_config - generation_cfg.do_sample = False - generation_cfg.temperature = None - generation_cfg.top_k = None - generation_cfg.top_p = None - generation_cfg.num_beams = 1 + # Cached allowed type labels (for optional restriction) + self._allowed_types: List[str] = [] - def generate(self, prompts: List[str], max_new_tokens: int = 128) -> List[str]: + def tasks_data_former(self, data: Any, task: str, test: bool = False): """ - Greedy-generate continuations for a list of prompts. + Override base formatter: for task='text2onto' return data unchanged. + """ + if task == "text2onto": + return data + return super().tasks_data_former(data=data, task=task, test=test) - Parameters - ---------- - prompts : List[str] - Prompts to generate for (batched). - max_new_tokens : int - Maximum number of new tokens per continuation. - - Returns - ------- - List[str] - Decoded new-token texts (no special tokens, stripped). + def load(self, **kwargs: Any): """ - if self.model is None or self.tokenizer is None: - raise RuntimeError( - "Call .load(model_id) on LocalAutoLLM before generate()." - ) + Called by LearnerPipeline as: learner.load(model_id=llm_id) - tokenized_batch = self.tokenizer( - prompts, return_tensors="pt", padding=True, truncation=True - ) - input_seq_len = tokenized_batch["input_ids"].shape[1] - tokenized_batch = { - k: v.to(self.model.device) for k, v in tokenized_batch.items() - } + We accept overrides via kwargs: + - model_id / llm_model_id + - device, top_k, max_new_tokens, max_input_length, use_tfidf, seed, restrict_to_known_types + - hf_token, local_files_only + """ + model_id = kwargs.get("model_id") or kwargs.get("llm_model_id") + if model_id: + self.llm_model_id = str(model_id) - with torch.no_grad(): - outputs = self.model.generate( - **tokenized_batch, - max_new_tokens=max_new_tokens, - pad_token_id=self.tokenizer.eos_token_id, - do_sample=False, - num_beams=1, - ) + for k in [ + "device", + "top_k", + "max_new_tokens", + "max_input_length", + "use_tfidf", + "seed", + "restrict_to_known_types", + "hf_token", + "local_files_only", + "retriever_model_id", + ]: + if k in kwargs: + setattr(self, k, kwargs[k]) - # Only return the newly generated part for each row in the batch - continuation_token_ids = outputs[:, input_seq_len:] - return [ - self.tokenizer.decode(row, skip_special_tokens=True).strip() - for row in continuation_token_ids - ] + if self._loaded: + return + torch.manual_seed(self.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(self.seed) -class AlexbekFewShotLearner(AutoLearner): - """ - Text2Onto learner for LLMS4OL Task A (term & type extraction). - - Public API (A1 + convenience): - - fit(train_docs_jsonl, terms2doc_json, sample_size=24, seed=42) - - predict_terms(docs_test_jsonl, out_jsonl, max_new_tokens=128, few_shot_k=6) -> int - - predict_types(docs_test_jsonl, out_jsonl, max_new_tokens=128, few_shot_k=6) -> int - - evaluate_extraction_f1(gold_item2docs_json, preds_jsonl, key="term"|"type") -> float - - Option A (A2, term→types) bridge: - - predict_types_from_terms_option_a(...) - Reads your A1 results (docs→terms), predicts types for each term, and - writes two files: terms2types_pred.json + types2docs_pred.json - """ + dev = str(self.device).strip() + if dev.startswith("cuda") and not torch.cuda.is_available(): + raise RuntimeError(f"Device was set to '{dev}', but CUDA is not available.") - def __init__(self, model: LocalAutoLLM, device: str = "cpu", **_: Any) -> None: - """ - Initialize learner state and canned prompts. + dtype = torch.bfloat16 if dev.startswith("cuda") else torch.float32 - Parameters - ---------- - model : LocalAutoLLM - Loaded local LLM helper instance. - device : str - Device name ("cpu" or "cuda"). - """ - super().__init__(**_) - self.model = model - self.device = device - - # Few-shot exemplars for A1 (Docs→Terms) and for Docs→Types: - # Each exemplar is a tuple: (title, text, gold_list) - self._fewshot_terms_docs: List[Tuple[str, str, List[str]]] = [] - self._fewshot_types_docs: List[Tuple[str, str, List[str]]] = [] - - # System prompts - self._system_prompt_terms = ( - "You are an expert in ontology term extraction.\n" - "Extract only terms that explicitly appear in the document.\n" - 'Answer strictly as JSON: {"terms": ["..."]}\n' - ) - self._system_prompt_types = ( - "You are an expert in ontology type classification.\n" - "List ontology *types* that characterize the document’s terminology.\n" - 'Answer strictly as JSON: {"types": ["..."]}\n' - ) + tok_kwargs: Dict[str, Any] = {"local_files_only": self.local_files_only} + if self.hf_token: + tok_kwargs["token"] = self.hf_token + try: + self.tokenizer = AutoTokenizer.from_pretrained(self.llm_model_id, **tok_kwargs) + except TypeError: + tok_kwargs.pop("token", None) + if self.hf_token: + tok_kwargs["use_auth_token"] = self.hf_token + self.tokenizer = AutoTokenizer.from_pretrained(self.llm_model_id, **tok_kwargs) - # Compiled regex for robust JSON extraction from LLM outputs - self._json_object_regex = re.compile(r"\{[^{}]*\}", re.S) - self._json_array_regex = re.compile(r"\[[^\]]*\]", re.S) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token - # Term→Types (Option A) specific prompt - self._system_prompt_term_to_types = ( - "You are an expert in ontology and semantic type classification.\n" - "Given a term, predict its semantic types from the domain-specific ontology.\n" - 'Answer strictly as JSON:\n{"types": ["type1", "type2", "..."]}' - ) - def fit( - self, - *, - train_docs_jsonl: str, - terms2doc_json: str, - sample_size: int = 24, - seed: int = 42, - ) -> None: - """ - Build internal few-shot exemplars from a labeled training split. + model_kwargs: Dict[str, Any] = {"local_files_only": self.local_files_only} + if self.hf_token: + model_kwargs["token"] = self.hf_token - Parameters - ---------- - train_docs_jsonl : str - Path to JSONL (or tolerant JSON/JSONL) with train documents. - terms2doc_json : str - JSON mapping item -> [doc_id,...]; "item" can be a term or type. - sample_size : int - Number of exemplar documents to keep for few-shot prompting. - seed : int - RNG seed for reproducible sampling. - """ - rng = random.Random(seed) - - # Load documents and map doc_id -> row - document_map = self._load_documents_jsonl(train_docs_jsonl) - if not document_map: - raise FileNotFoundError(f"No documents found in: {train_docs_jsonl}") - - # Load item -> [doc_ids] - item_to_docs_map = self._load_json(terms2doc_json) - if not isinstance(item_to_docs_map, dict): - raise ValueError( - f"{terms2doc_json} must be a JSON dict mapping item -> [doc_ids]" + try: + self.model = AutoModelForCausalLM.from_pretrained( + self.llm_model_id, + dtype=dtype, + **model_kwargs, ) - - # Reverse mapping: doc_id -> [items] - doc_id_to_items_map: Dict[str, List[str]] = {} - for item_label, doc_id_list in item_to_docs_map.items(): - for doc_id in doc_id_list: - doc_id_to_items_map.setdefault(doc_id, []).append(item_label) - - # Build candidate exemplars (title, text, gold_list) - exemplar_candidates: List[Tuple[str, str, List[str]]] = [] - for doc_id, labeled_items in doc_id_to_items_map.items(): - doc_row = document_map.get(doc_id) - if not doc_row: - continue - doc_title = str(doc_row.get("title", "")) # be defensive (may be None) - doc_text = self._to_text( - doc_row.get("text", "") - ) # string-ify list if needed - if not doc_text: - continue - gold_items = self._unique_preserve( - [s for s in labeled_items if isinstance(s, str)] + except TypeError: + model_kwargs.pop("token", None) + if self.hf_token: + model_kwargs["use_auth_token"] = self.hf_token + self.model = AutoModelForCausalLM.from_pretrained( + self.llm_model_id, + torch_dtype=dtype, + **model_kwargs, ) - if gold_items: - exemplar_candidates.append((doc_title, doc_text, gold_items)) - if not exemplar_candidates: - raise RuntimeError( - "No candidate docs with items found to build few-shot exemplars." - ) + self.model = self.model.to(dev) - chosen_exemplars = rng.sample( - exemplar_candidates, k=min(sample_size, len(exemplar_candidates)) - ) - # Reuse exemplars for both docs→terms and docs→types prompting - self._fewshot_terms_docs = chosen_exemplars - self._fewshot_types_docs = chosen_exemplars + self.doc_retriever.load(self.retriever_model_id) + self.term_retriever.load(self.retriever_model_id) - def predict_terms( - self, - *, - docs_test_jsonl: str, - out_jsonl: str, - max_new_tokens: int = 128, - few_shot_k: int = 6, - ) -> int: - """ - Extract terms that explicitly appear in each document. + self._loaded = True - Writes one JSON object per line: - {"id": "", "terms": ["...", "...", ...]} - Parameters - ---------- - docs_test_jsonl : str - Path to test/dev documents in JSONL or tolerant JSON/JSONL. - out_jsonl : str - Output JSONL path where predictions are written (one line per doc). - max_new_tokens : int - Max generation length. - few_shot_k : int - Number of few-shot exemplars to prepend per prompt. - - Returns - ------- - int - Number of lines written (i.e., number of processed documents). + def _format_doc(self, title: str, text: str, tfidf: Optional[List[str]] = None) -> str: """ - if self.model is None or self.model.model is None: - raise RuntimeError("Load a model first: learner.model.load(MODEL_ID, ...)") - - test_documents = self._load_documents_jsonl(docs_test_jsonl) - prompts: List[str] = [] - document_order: List[str] = [] - - for document_id, document_row in test_documents.items(): - title = str(document_row.get("title", "")) - text = self._to_text(document_row.get("text", "")) - - fewshot_block = self._format_fewshot_block( - self._system_prompt_terms, - self._fewshot_terms_docs, - key="terms", - k=few_shot_k, - ) - user_block = self._format_user_block(title, text) - - prompts.append(f"{fewshot_block}\n{user_block}\nAssistant:") - document_order.append(document_id) - - generations = self.model.generate(prompts, max_new_tokens=max_new_tokens) - parsed_term_lists = [ - self._parse_json_list(generated, key="terms") for generated in generations - ] - - os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True) - lines_written = 0 - with open(out_jsonl, "w", encoding="utf-8") as fp_out: - for document_id, term_list in zip(document_order, parsed_term_lists): - payload = {"id": document_id, "terms": self._unique_preserve(term_list)} - fp_out.write(json.dumps(payload, ensure_ascii=False) + "\n") - lines_written += 1 - return lines_written - - def predict_types( - self, - *, - docs_test_jsonl: str, - out_jsonl: str, - max_new_tokens: int = 128, - few_shot_k: int = 6, - ) -> int: + Format doc as the retriever query and as the user prompt content. """ - Predict ontology types that characterize each document’s terminology. - - Writes one JSON object per line: - {"id": "", "types": ["...", "...", ...]} + s = f"Title: {title}\n\nText:\n{text}" + if tfidf: + s += f"\n\nTF-IDF based suggestions: {tfidf}" + return s - Parameters - ---------- - docs_test_jsonl : str - Path to test/dev documents in JSONL or tolerant JSON/JSONL. - out_jsonl : str - Output JSONL path where predictions are written (one line per doc). - max_new_tokens : int - Max generation length. - few_shot_k : int - Number of few-shot exemplars to prepend per prompt. - - Returns - ------- - int - Number of lines written (i.e., number of processed documents). + def _apply_chat_template(self, conversation: List[Dict[str, str]]) -> str: """ - if self.model is None or self.model.model is None: - raise RuntimeError("Load a model first: learner.model.load(MODEL_ID, ...)") - - test_documents = self._load_documents_jsonl(docs_test_jsonl) - prompts: List[str] = [] - document_order: List[str] = [] - - for document_id, document_row in test_documents.items(): - title = str(document_row.get("title", "")) - text = self._to_text(document_row.get("text", "")) - - fewshot_block = self._format_fewshot_block( - self._system_prompt_types, - self._fewshot_types_docs, - key="types", - k=few_shot_k, - ) - user_block = self._format_user_block(title, text) - - prompts.append(f"{fewshot_block}\n{user_block}\nAssistant:") - document_order.append(document_id) - - generations = self.model.generate(prompts, max_new_tokens=max_new_tokens) - parsed_type_lists = [ - self._parse_json_list(generated, key="types") for generated in generations - ] - - os.makedirs(os.path.dirname(out_jsonl) or ".", exist_ok=True) - lines_written = 0 - with open(out_jsonl, "w", encoding="utf-8") as fp_out: - for document_id, type_list in zip(document_order, parsed_type_lists): - payload = {"id": document_id, "types": self._unique_preserve(type_list)} - fp_out.write(json.dumps(payload, ensure_ascii=False) + "\n") - lines_written += 1 - return lines_written - - def evaluate_extraction_f1( - self, - gold_item2docs_json: str, - preds_jsonl: str, - *, - key: str = "term", - ) -> float: + Convert conversation into a single prompt string using the tokenizer's chat template if available. """ - Compute micro-F1 over (doc_id, item) pairs. - - Parameters - ---------- - gold_item2docs_json : str - JSON mapping item -> [doc_ids]. - preds_jsonl : str - JSONL lines like {"id": "...", "terms":[...]} or {"id":"...","types":[...]}. - key : str - "term" or "type" depending on what you are evaluating. - - Returns - ------- - float - Micro-averaged F1 score. - """ - item_to_doc_ids: Dict[str, List[str]] = self._load_json(gold_item2docs_json) - - # Build gold: doc_id -> set(items) - gold_doc_to_items: Dict[str, set] = {} - for item_label, doc_id_list in item_to_doc_ids.items(): - for document_id in doc_id_list: - gold_doc_to_items.setdefault(document_id, set()).add( - self._norm(item_label) - ) - - # Build predictions: doc_id -> set(items) - pred_doc_to_items: Dict[str, set] = {} - with open(preds_jsonl, "r", encoding="utf-8") as fp_in: - for line in fp_in: - row = json.loads(line.strip()) - document_id = str(row.get("id", "")) - items_list = row.get("terms" if key == "term" else "types", []) - pred_doc_to_items[document_id] = { - self._norm(x) for x in items_list if isinstance(x, str) - } + assert self.tokenizer is not None + if hasattr(self.tokenizer, "apply_chat_template"): + return self.tokenizer.apply_chat_template( + conversation, add_generation_prompt=True, tokenize=False + ) - # Micro counts - true_positive = false_positive = false_negative = 0 - all_document_ids = set(gold_doc_to_items.keys()) | set(pred_doc_to_items.keys()) - for document_id in all_document_ids: - gold_set = gold_doc_to_items.get(document_id, set()) - pred_set = pred_doc_to_items.get(document_id, set()) - true_positive += len(gold_set & pred_set) - false_positive += len(pred_set - gold_set) - false_negative += len(gold_set - pred_set) - - precision = ( - true_positive / (true_positive + false_positive) - if (true_positive + false_positive) - else 0.0 - ) - recall = ( - true_positive / (true_positive + false_negative) - if (true_positive + false_negative) - else 0.0 - ) - f1 = ( - 2 * precision * recall / (precision + recall) - if (precision + recall) - else 0.0 - ) - return f1 + parts = [] + for t in conversation: + parts.append(f"{t['role'].upper()}:\n{t['content']}\n") + parts.append("ASSISTANT:\n") + return "\n".join(parts) + + def _extract_first_json_obj(self, text: str) -> Optional[dict]: + """ + Extract the first valid JSON object from generated text by scanning balanced {...}. + """ + starts = [i for i, ch in enumerate(text) if ch == "{"] + + for s in starts: + depth = 0 + for e in range(s, len(text)): + if text[e] == "{": + depth += 1 + elif text[e] == "}": + depth -= 1 + if depth == 0: + candidate = text[s : e + 1].strip().replace("\n", " ") + try: + return json.loads(candidate) + except Exception: + try: + candidate2 = re.sub(r"'", '"', candidate) + return json.loads(candidate2) + except Exception: + pass + break + return None + + def _dedup_clean(self, items: List[str]) -> List[str]: + """ + Normalize and deduplicate strings (case-insensitive). + """ + out: List[str] = [] + seen = set() + for x in items or []: + if not isinstance(x, str): + continue + x2 = re.sub(r"\s+", " ", x.strip()) + if not x2: + continue + k = x2.lower() + if k in seen: + continue + seen.add(k) + out.append(x2) + return out - def predict_types_from_terms( - self, - *, - doc_terms_jsonl: Optional[str] = None, # formerly a1_results_jsonl - doc_terms_list: Optional[List[Dict]] = None, # formerly a1_results_list - few_shot_jsonl: Optional[ - str - ] = None, # JSONL lines: {"term":"...", "types":[...]} - rag_terms_json: Optional[ - str - ] = None, # JSON list; items may contain "term" and "RAG":[...] - random_few_shot: Optional[int] = 3, - model_id: str = "Qwen/Qwen2.5-1.5B-Instruct", - use_structured_output: bool = True, - seed: int = 42, - out_terms2types: str = "terms2types_pred.json", - out_types2docs: str = "types2docs_pred.json", - ) -> Dict[str, Any]: + def _doc_id(self, d: Dict[str, Any]) -> str: + """ + Extract doc_id from common keys: doc_id, id, docid. """ - Predict types for each unique term extracted per document and derive a types→docs map. + return str(d.get("doc_id") or d.get("id") or d.get("docid") or "") - Parameters - ---------- - doc_terms_jsonl : Optional[str] - Path to JSONL with lines like {"id": "...", "terms": [...]} or a JSON with {"results":[...]}. - doc_terms_list : Optional[List[Dict]] - In-memory results like [{"id":"...","extracted_terms":[...]}] or {"id":"...","terms":[...]}. - few_shot_jsonl : Optional[str] - Global few-shot exemplars: one JSON object per line with {"term": "...", "types":[...]}. - rag_terms_json : Optional[str] - Optional per-term RAG exemplars: a JSON list of {"term": "...", "RAG":[{"term": "...", "types":[...]}]}. - random_few_shot : Optional[int] - If provided, randomly select up to this many few-shot examples for each prediction. - model_id : str - HF model id used specifically for term→types predictions. - use_structured_output : bool - If True and outlines is available, enforce structured {"types":[...]} output. - seed : int - Random seed for reproducibility. - out_terms2types : str - Output JSON path for list of {"term": "...", "predicted_types":[...]}. - out_types2docs : str - Output JSON path for dict {"TYPE":[doc_ids,...], ...}. - - Returns - ------- - Dict[str, Any] - Summary with predictions and counts. + def _extract_documents(self, data: Any) -> List[Dict[str, Any]]: """ - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + Accept list-of-docs OR dict with 'documents'/'docs'. + """ + if isinstance(data, list): + return data + if isinstance(data, dict): + if isinstance(data.get("documents"), list): + return data["documents"] + if isinstance(data.get("docs"), list): + return data["docs"] + raise ValueError("Expected dict with 'documents' (or 'docs'), or a list of docs.") - # Load normalized document→terms results - doc_term_extractions = self._load_doc_term_extractions( - results_json_path=doc_terms_jsonl, - in_memory_results=doc_terms_list, - ) - if not doc_term_extractions: - raise ValueError( - "No document→terms results provided (doc_terms_jsonl/doc_terms_list)." - ) + def _normalize_terms2docs(self, raw_terms2docs: Any, docs: List[Dict[str, Any]]) -> Dict[str, List[str]]: + """ + Normalize mapping to: term -> [doc_id, ...]. - # Prepare unique term list and term→doc occurrences - unique_terms = self._collect_unique_terms_from_extractions(doc_term_extractions) - term_to_doc_ids_map = self._build_term_to_doc_ids(doc_term_extractions) - - # Load optional global few-shot examples - global_few_shot_examples: List[Dict] = [] - if few_shot_jsonl and os.path.exists(few_shot_jsonl): - with open(few_shot_jsonl, "r", encoding="utf-8") as few_shot_file: - for raw_line in few_shot_file: - raw_line = raw_line.strip() - if not raw_line: - continue - try: - json_obj = json.loads(raw_line) - except Exception: - continue - if ( - isinstance(json_obj, dict) - and "term" in json_obj - and "types" in json_obj - ): - global_few_shot_examples.append(json_obj) - - # Optional per-term RAG examples: {normalized_term -> [examples]} - rag_examples_lookup: Dict[str, List[Dict]] = {} - if rag_terms_json and os.path.exists(rag_terms_json): - try: - rag_payload = self._load_json(rag_terms_json) - if isinstance(rag_payload, list): - for rag_item in rag_payload: - if isinstance(rag_item, dict): - normalized_term = self._normalize_term( - rag_item.get("term", "") - ) - rag_examples_lookup[normalized_term] = rag_item.get( - "RAG", [] - ) - except Exception: - pass + If caller accidentally provides inverted mapping: doc_id -> [term, ...], + we detect it (keys mostly match doc_ids) and invert it. + """ + if not isinstance(raw_terms2docs, dict) or not raw_terms2docs: + return {} - # Load a small chat LLM dedicated to Term→Types - typing_model, typing_tokenizer = self._load_llm_for_types(model_id) + doc_ids = {self._doc_id(d) for d in docs} + doc_ids.discard("") - # Predict types per term - term_to_predicted_types_list: List[Dict] = [] - for term_text in unique_terms: - normalized_term = self._normalize_term(term_text) + keys = list(raw_terms2docs.keys()) + sample = keys[:200] + hits = sum(1 for k in sample if str(k) in doc_ids) - # Prefer per-term RAG for this term, else use global few-shot - few_shot_examples_for_term = ( - rag_examples_lookup.get(normalized_term, None) - or global_few_shot_examples - ) + if sample and hits >= int(0.6 * len(sample)): + term2docs: Dict[str, List[str]] = defaultdict(list) + for did, terms in raw_terms2docs.items(): + did = str(did) + if did not in doc_ids: + continue + for t in (terms or []): + if isinstance(t, str) and t.strip(): + term2docs[t.strip()].append(did) + return {t: self._dedup_clean(ds) for t, ds in term2docs.items()} + + norm: Dict[str, List[str]] = {} + for term, doc_list in raw_terms2docs.items(): + if not isinstance(term, str) or not term.strip(): + continue + docs_norm = [str(d) for d in (doc_list or []) if str(d)] + if docs_norm: + norm[term.strip()] = self._dedup_clean(docs_norm) + return norm - # Build conversation and prompt - conversation_messages = self._build_conv_for_type_infer( - term=term_text, - few_shot_examples=few_shot_examples_for_term, - random_k=random_few_shot, - ) - typing_prompt_string = self._apply_chat_template_safe_types( - typing_tokenizer, conversation_messages - ) + def _generate(self, prompt: str) -> str: + """ + Deterministic single-prompt generation (no sampling). + Returns decoded completion only. + """ + assert self.model is not None and self.tokenizer is not None - predicted_types: List[str] = [] - raw_generation_text: str = "" - - # Structured JSON path (if requested and available) - if ( - use_structured_output - and OUTLINES_AVAILABLE - and _PredictedTypesSchema is not None - ): - try: - outlines_model = OutlinesTFModel(typing_model, typing_tokenizer) # type: ignore - generator = outlines_generate_json( - outlines_model, _PredictedTypesSchema - ) # type: ignore - structured = generator(typing_prompt_string, max_tokens=512) - predicted_types = [ - label for label in structured.types if isinstance(label, str) - ] - raw_generation_text = json.dumps( - {"types": predicted_types}, ensure_ascii=False - ) - except Exception: - # Fall back to greedy decoding - use_structured_output = False - - # Greedy decode fallback - if ( - not use_structured_output - or not OUTLINES_AVAILABLE - or _PredictedTypesSchema is None - ): - tokenized_prompt = typing_tokenizer( - typing_prompt_string, - return_tensors="pt", - truncation=True, - max_length=2048, - ) - if torch.cuda.is_available(): - tokenized_prompt = { - name: tensor.cuda() for name, tensor in tokenized_prompt.items() - } - with torch.no_grad(): - output_ids = typing_model.generate( - **tokenized_prompt, - max_new_tokens=256, - do_sample=False, - num_beams=1, - pad_token_id=typing_tokenizer.eos_token_id, - ) - new_token_span = output_ids[0][tokenized_prompt["input_ids"].shape[1] :] - raw_generation_text = typing_tokenizer.decode( - new_token_span, skip_special_tokens=True - ) - predicted_types = self._extract_types_from_text(raw_generation_text) - - term_to_predicted_types_list.append( - { - "term": term_text, - "predicted_types": sorted(set(predicted_types)), - } - ) + enc = self.tokenizer( + prompt, + return_tensors="pt", + truncation=True, + max_length=self.max_input_length, + ) + enc = {k: v.to(self.model.device) for k, v in enc.items()} - # 7) Build types→docs from (term→types) and (term→docs) - types_to_doc_id_set: Dict[str, set] = {} - for term_prediction in term_to_predicted_types_list: - normalized_term = self._normalize_term(term_prediction["term"]) - doc_ids_for_term = term_to_doc_ids_map.get(normalized_term, []) - for type_label in term_prediction.get("predicted_types", []): - types_to_doc_id_set.setdefault(type_label, set()).update( - doc_ids_for_term - ) - - types_to_doc_ids: Dict[str, List[str]] = { - type_label: sorted(doc_id_set) - for type_label, doc_id_set in types_to_doc_id_set.items() - } - - # 8) Save outputs - os.makedirs(os.path.dirname(out_terms2types) or ".", exist_ok=True) - with open(out_terms2types, "w", encoding="utf-8") as fp_terms2types: - json.dump( - term_to_predicted_types_list, - fp_terms2types, - ensure_ascii=False, - indent=2, + with torch.no_grad(): + out = self.model.generate( + **enc, + max_new_tokens=self.max_new_tokens, + do_sample=False, + num_beams=1, + pad_token_id=self.tokenizer.eos_token_id, ) - os.makedirs(os.path.dirname(out_types2docs) or ".", exist_ok=True) - with open(out_types2docs, "w", encoding="utf-8") as fp_types2docs: - json.dump(types_to_doc_ids, fp_types2docs, ensure_ascii=False, indent=2) + gen_tokens = out[0][enc["input_ids"].shape[1] :] + return self.tokenizer.decode(gen_tokens, skip_special_tokens=True).strip() - # Cleanup VRAM if any - del typing_model, typing_tokenizer - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - return { - "terms2types_pred": term_to_predicted_types_list, - "types2docs_pred": types_to_doc_ids, - "unique_terms": len(unique_terms), - "types_count": len(types_to_doc_ids), - } - - def _load_json(self, path: str) -> Dict[str, Any]: - """Load a JSON file from disk and return its parsed object.""" - with open(path, "r", encoding="utf-8") as file_obj: - return json.load(file_obj) - - def _iter_json_objects(self, blob: str) -> Iterable[Dict[str, Any]]: + def _retrieve_doc_fewshot(self, doc: Dict[str, Any]) -> List[Dict[str, Any]]: """ - Iterate over *all* JSON objects found inside a string. - - Supports cases where multiple JSON objects are concatenated back-to-back - in a single line. It skips stray commas/whitespace between objects. - - Parameters - ---------- - blob : str - A string that may contain one or more JSON objects. - - Yields - ------ - Dict[str, Any] - Each parsed JSON object. + Retrieve top-k doc examples (JSON dicts) for few-shot doc->terms prompting. """ - json_decoder = json.JSONDecoder() - cursor_index, text_length = 0, len(blob) - while cursor_index < text_length: - # Skip whitespace/commas between objects - while cursor_index < text_length and blob[cursor_index] in " \t\r\n,": - cursor_index += 1 - if cursor_index >= text_length: - break + q = self._format_doc(doc.get("title", ""), doc.get("text", "")) + hits = self.doc_retriever.retrieve([q], top_k=self.top_k)[0] + + out: List[Dict[str, Any]] = [] + for h in hits: try: - json_obj, end_index = json_decoder.raw_decode(blob, idx=cursor_index) - except JSONDecodeError: - # Can't decode from this position; stop scanning this chunk - break - yield json_obj - cursor_index = end_index - - def _load_documents_jsonl(self, path: str) -> Dict[str, Dict[str, Any]]: + out.append(json.loads(h)) + except Exception: + continue + return out + + def _retrieve_term_fewshot(self, term: str) -> List[Dict[str, Any]]: """ - Robust reader that supports: - • True JSONL (one object per line) - • Lines with multiple concatenated JSON objects - • Whole file as a JSON array - - Returns - ------- - Dict[str, Dict[str, Any]] - Mapping doc_id -> full document row. + Retrieve top-k term examples (JSON dicts) for few-shot term->types prompting. """ - documents_by_id: Dict[str, Dict[str, Any]] = {} + hits = self.term_retriever.retrieve([term], top_k=self.top_k)[0] - with open(path, "r", encoding="utf-8") as file_obj: - content = file_obj.read().strip() - - # Case A: whole-file JSON array - if content.startswith("["): + out: List[Dict[str, Any]] = [] + for h in hits: try: - json_array = json.loads(content) - if isinstance(json_array, list): - for record in json_array: - if not isinstance(record, dict): - continue - document_id = str( - record.get("id") - or record.get("doc_id") - or (record.get("doc") or {}).get("id") - or "" - ) - if document_id: - documents_by_id[document_id] = record - return documents_by_id + out.append(json.loads(h)) except Exception: - # Fall back to line-wise handling if array parsing fails - pass - - # Case B: treat as JSONL-ish; parse *all* objects per line - for raw_line in content.splitlines(): - line = raw_line.strip() - if not line: continue - for record in self._iter_json_objects(line): - if not isinstance(record, dict): - continue - document_id = str( - record.get("id") - or record.get("doc_id") - or (record.get("doc") or {}).get("id") - or "" - ) - if document_id: - documents_by_id[document_id] = record - - return documents_by_id - - def _to_text(self, text_field: Any) -> str: - """ - Convert a 'text' field into a single string (handles list-of-strings). - - Parameters - ---------- - text_field : Any - The value found under "text" in the dataset row. + return out - Returns - ------- - str - A single-string representation of the text. + def _doc_to_terms(self, doc: Dict[str, Any]) -> List[str]: """ - if isinstance(text_field, str): - return text_field - if isinstance(text_field, list): - return " ".join(str(part) for part in text_field) - return str(text_field) if text_field is not None else "" - - def _unique_preserve(self, values: List[str]) -> List[str]: + Predict terms for a document using few-shot prompting + doc retrieval. """ - Deduplicate values while preserving the original order. + fewshot = self._retrieve_doc_fewshot(doc) - Parameters - ---------- - values : List[str] - Sequence possibly containing duplicates. + convo: List[Dict[str, str]] = [{"role": "system", "content": self.DOC2TERMS_SYSTEM_PROMPT}] - Returns - ------- - List[str] - Sequence without duplicates, order preserved. - """ - seen_values: set = set() - ordered_values: List[str] = [] - for candidate in values: - if candidate not in seen_values: - seen_values.add(candidate) - ordered_values.append(candidate) - return ordered_values - - def _norm(self, text: str) -> str: - """ - Lowercased, single-spaced normalization (for comparisons). + for ex in fewshot: + ex_tfidf = ex.get("TF-IDF") or ex.get("tfidf_terms") or [] + convo += [ + { + "role": "user", + "content": self._format_doc( + ex.get("title", ""), + ex.get("text", ""), + ex_tfidf if self.use_tfidf else None, + ), + }, + { + "role": "assistant", + "content": json.dumps({"terms": ex.get("OL", [])}, ensure_ascii=False), + }, + ] - Parameters - ---------- - text : str - Input string. + tfidf = doc.get("TF-IDF") or doc.get("tfidf_terms") or [] + convo.append( + { + "role": "user", + "content": self._format_doc( + doc.get("title", ""), + doc.get("text", ""), + tfidf if self.use_tfidf else None, + ), + } + ) - Returns - ------- - str - Normalized string. - """ - return " ".join(text.lower().split()) + prompt = self._apply_chat_template(convo) + gen = self._generate(prompt) + parsed = self._extract_first_json_obj(gen) or {} + return self._dedup_clean(parsed.get("terms", [])) - def _normalize_term(self, term: str) -> str: + def _term_to_types(self, term: str) -> List[str]: + """ + Predict types for a term using few-shot prompting + term retrieval. """ - Normalization tailored for term keys / lookups. + fewshot = self._retrieve_term_fewshot(term) - Parameters - ---------- - term : str - Term to normalize. + system = self.TERM2TYPES_SYSTEM_PROMPT + if self.restrict_to_known_types and self._allowed_types: + allowed_block = "\n".join(f"- {t}" for t in self._allowed_types) + system = ( + system + + "\n\nIMPORTANT CONSTRAINT:\n" + + "Choose ONLY from the following valid ontology types (do not invent new labels):\n" + + allowed_block + ) - Returns - ------- - str - Lowercased, trimmed and single-spaced term. - """ - return " ".join(str(term).strip().split()).lower() + convo: List[Dict[str, str]] = [{"role": "system", "content": system}] - def _format_fewshot_block( - self, - system_prompt: str, - fewshot_examples: List[Tuple[str, str, List[str]]], - *, - key: str, - k: int = 6, - ) -> str: - """ - Render a few-shot block like: + for ex in fewshot: + convo += [ + {"role": "user", "content": f"Term: {ex.get('term','')}"}, + { + "role": "assistant", + "content": json.dumps({"types": ex.get("types", [])}, ensure_ascii=False), + }, + ] - + convo.append({"role": "user", "content": f"Term: {term}"}) - ### Example - User: - Title: ... - - Assistant: - {"terms": [...]} or {"types": [...]} + prompt = self._apply_chat_template(convo) + gen = self._generate(prompt) + parsed = self._extract_first_json_obj(gen) or {} + return self._dedup_clean(parsed.get("types", [])) - Parameters - ---------- - system_prompt : str - Instructional system text to prepend. - fewshot_examples : List[Tuple[str, str, List[str]]] - Examples as (title, text, labels_list). - key : str - Either "terms" or "types" depending on the task. - k : int - Number of examples to include. - - Returns - ------- - str - Formatted few-shot block text. + def _text2onto(self, data: Any, test: bool = False) -> Optional[Any]: """ - lines: List[str] = [system_prompt.strip(), ""] - for example_title, example_text, gold_list in fewshot_examples[:k]: - lines.append("### Example") - lines.append(f"User:\nTitle: {example_title}\n{example_text}") - lines.append( - f'Assistant:\n{{"{key}": ' - + json.dumps(gold_list, ensure_ascii=False) - + "}" - ) - return "\n".join(lines) + Train or predict for task="text2onto". - def _format_user_block(self, title: str, text: str) -> str: + Returns: + - training: None + - prediction: {"terms": [...], "types": [...]} """ - Format the 'Task' block for the current document. + if not self._loaded: + self.load(model_id=self.llm_model_id, device=self.device) - Parameters - ---------- - title : str - Document title. - text : str - Document text (single string). - - Returns - ------- - str - Formatted user block. - """ - return f"### Task\nUser:\nTitle: {title}\n{text}" + if not isinstance(data, dict): + raise ValueError("text2onto expects a dict with documents + mappings.") - def _parse_json_list(self, generated_text: str, *, key: str) -> List[str]: - """ - Extract a list from model output, trying: - 1) JSON object with the key ({"terms":[...]} or {"types":[...]}). - 2) Any top-level JSON array. - 3) Fallback: comma-split. + docs = self._extract_documents(data) - Parameters - ---------- - generated_text : str - Raw generation text to parse. - key : str - "terms" or "types". - - Returns - ------- - List[str] - Parsed strings (best-effort). - """ - # 1) Try a JSON object and read key - try: - object_match = self._json_object_regex.search(generated_text) - if object_match: - json_obj = json.loads(object_match.group(0)) - json_array = json_obj.get(key) - if isinstance(json_array, list): - return [value for value in json_array if isinstance(value, str)] - except Exception: - pass - - # 2) Any JSON array - try: - array_match = self._json_array_regex.search(generated_text) - if array_match: - json_array = json.loads(array_match.group(0)) - if isinstance(json_array, list): - return [value for value in json_array if isinstance(value, str)] - except Exception: - pass - - # 3) Fallback: comma-split (last resort) - if "," in generated_text: - return [ - part.strip().strip('"').strip("'") - for part in generated_text.split(",") - if part.strip() - ] - return [] + raw_terms2docs = data.get("terms2docs") or data.get("term2docs") or {} + terms2types = data.get("terms2types") or data.get("term2types") or {} - def _apply_chat_template_safe_types( - self, tokenizer: AutoTokenizer, messages: List[Dict[str, str]] - ) -> str: - """ - Safely build a prompt string for chat models. Uses the model's chat template - when available; otherwise falls back to a simple concatenation. - """ - try: - return tokenizer.apply_chat_template( - messages, add_generation_prompt=True, tokenize=False - ) - except Exception: - system_text = next( - (m["content"] for m in messages if m.get("role") == "system"), "" - ) - last_user_text = next( - (m["content"] for m in reversed(messages) if m.get("role") == "user"), - "", - ) - return f"{system_text}\n\nUser:\n{last_user_text}\n\nAssistant:" + terms2docs = self._normalize_terms2docs(raw_terms2docs, docs) - def _build_conv_for_type_infer( - self, - term: str, - few_shot_examples: Optional[List[Dict]] = None, - random_k: Optional[int] = None, - ) -> List[Dict[str, str]]: - """ - Create a chat-style conversation for a single term→types query, - optionally prepending few-shot examples. - """ - messages: List[Dict[str, str]] = [ - {"role": "system", "content": self._system_prompt_term_to_types} - ] - examples = list(few_shot_examples or []) - if random_k and len(examples) > random_k: - import random as _rnd - - examples = _rnd.sample(examples, random_k) - for exemplar in examples: - example_term = exemplar.get("term", "") - example_types = exemplar.get("types", []) - messages.append({"role": "user", "content": f"Term: {example_term}"}) - messages.append( + if not test: + self._allowed_types = sorted( { - "role": "assistant", - "content": json.dumps({"types": example_types}, ensure_ascii=False), + ty.strip() + for tys in (terms2types or {}).values() + for ty in (tys or []) + if isinstance(ty, str) and ty.strip() } ) - messages.append({"role": "user", "content": f"Term: {term}"}) - return messages - def _extract_types_from_text(self, generated_text: str) -> List[str]: - """ - Parse {"types":[...]} from a free-form generation. - """ - try: - object_match = re.search(r'\{[^}]*"types"[^}]*\}', generated_text) - if object_match: - json_obj = json.loads(object_match.group(0)) - types_array = json_obj.get("types", []) - return [ - type_label - for type_label in types_array - if isinstance(type_label, str) - ] - except Exception: - pass - return [] - - def _load_llm_for_types( - self, model_id: str - ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: - """ - Load a *separate* small chat model for Term→Types (keeps LocalAutoLLM untouched). - """ - tokenizer = AutoTokenizer.from_pretrained(model_id) - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - model = AutoModelForCausalLM.from_pretrained( - model_id, - torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, - device_map="auto" if torch.cuda.is_available() else None, - ) - return model, tokenizer + # build doc->terms from term->docs + doc2terms: Dict[str, List[str]] = defaultdict(list) + for term, doc_ids in (terms2docs or {}).items(): + for did in (doc_ids or []): + doc2terms[str(did)].append(term) + + # doc few-shot corpus: doc + gold OL terms + doc_examples: List[Dict[str, Any]] = [] + for d in docs: + did = self._doc_id(d) + ex = dict(d) + ex["doc_id"] = did + ex["OL"] = self._dedup_clean(doc2terms.get(did, [])) + doc_examples.append(ex) + + # term few-shot corpus: term + gold types + term_examples = [ + {"term": t, "types": self._dedup_clean(tys)} + for t, tys in (terms2types or {}).items() + ] - def _load_doc_term_extractions( - self, - *, - results_json_path: Optional[str] = None, - in_memory_results: Optional[List[Dict]] = None, - ) -> List[Dict]: - """ - Normalize document→terms outputs to a list of: - {"id": "", "extracted_terms": ["...", ...]} - - Accepts either: - - in_memory_results (list of dicts) - - results_json_path pointing to: - • a JSONL file with lines: {"id": "...", "terms": [...]} - • OR a JSON file with {"results":[{"id":..., "extracted_terms": [...]}, ...]} - • OR a JSON list of dicts - """ - normalized_records: List[Dict] = [] + # store as JSON strings so retrievers return parseable strings + self._doc_examples_json = [json.dumps(ex, ensure_ascii=False) for ex in doc_examples] + self._term_examples_json = [json.dumps(ex, ensure_ascii=False) for ex in term_examples] - def _coerce_to_record(source_row: Dict) -> Optional[Dict]: - document_id = str(source_row.get("id", "")) or str( - source_row.get("doc_id", "") - ) - if not document_id: - return None - terms = source_row.get("extracted_terms") - if terms is None: - terms = source_row.get("terms") - if ( - terms is None - and "payload" in source_row - and isinstance(source_row["payload"], dict) - ): - terms = source_row["payload"].get("terms") - if not isinstance(terms, list): - terms = [] - return { - "id": document_id, - "extracted_terms": [t for t in terms if isinstance(t, str)], - } + # index retrievers + self.doc_retriever.index(self._doc_examples_json) + self.term_retriever.index(self._term_examples_json) + return None - if in_memory_results is not None: - for source_row in in_memory_results: - coerced_record = _coerce_to_record(source_row) - if coerced_record: - normalized_records.append(coerced_record) - return normalized_records - - if not results_json_path: - raise ValueError("Provide results_json_path or in_memory_results") - - # Detect JSON vs JSONL by extension (best-effort) - if results_json_path.endswith(".jsonl"): - with open(results_json_path, "r", encoding="utf-8") as file_in: - for raw_line in file_in: - raw_line = raw_line.strip() - if not raw_line: - continue - # Multiple concatenated objects per line? Iterate them all. - for json_obj in self._iter_json_objects(raw_line): - if isinstance(json_obj, dict): - coerced_record = _coerce_to_record(json_obj) - if coerced_record: - normalized_records.append(coerced_record) - else: - payload_obj = self._load_json(results_json_path) - if isinstance(payload_obj, dict) and "results" in payload_obj: - for source_row in payload_obj["results"]: - coerced_record = _coerce_to_record(source_row) - if coerced_record: - normalized_records.append(coerced_record) - elif isinstance(payload_obj, list): - for source_row in payload_obj: - if isinstance(source_row, dict): - coerced_record = _coerce_to_record(source_row) - if coerced_record: - normalized_records.append(coerced_record) - - return normalized_records - - def _collect_unique_terms_from_extractions( - self, doc_term_extractions: List[Dict] - ) -> List[str]: - """ - Collect unique terms (original casing) from normalized document→terms results. - """ - seen_normalized_terms: set = set() - ordered_unique_terms: List[str] = [] - for record in doc_term_extractions: - for term_text in record.get("extracted_terms", []): - normalized = self._normalize_term(term_text) - if normalized and normalized not in seen_normalized_terms: - seen_normalized_terms.add(normalized) - ordered_unique_terms.append(term_text.strip()) - return ordered_unique_terms - - def _build_term_to_doc_ids( - self, doc_term_extractions: List[Dict] - ) -> Dict[str, List[str]]: - """ - Build lookup: normalized_term -> sorted unique list of doc_ids. - """ - term_to_doc_set: Dict[str, set] = {} - for record in doc_term_extractions: - document_id = str(record.get("id", "")) - for term_text in record.get("extracted_terms", []): - normalized = self._normalize_term(term_text) - if not normalized or not document_id: - continue - term_to_doc_set.setdefault(normalized, set()).add(document_id) - return { - normalized_term: sorted(doc_ids) - for normalized_term, doc_ids in term_to_doc_set.items() - } + doc2terms_pred: Dict[str, List[str]] = {} + for d in docs: + did = self._doc_id(d) + doc2terms_pred[did] = self._doc_to_terms(d) + + unique_terms = sorted({t for ts in doc2terms_pred.values() for t in ts}) + term2types_pred: Dict[str, List[str]] = {t: self._term_to_types(t) for t in unique_terms} + + doc2types_pred: Dict[str, List[str]] = {} + for did, terms in doc2terms_pred.items(): + tys: List[str] = [] + for t in terms: + tys.extend(term2types_pred.get(t, [])) + doc2types_pred[did] = self._dedup_clean(tys) + + pred_terms = [{"doc_id": did, "term": t} for did, ts in doc2terms_pred.items() for t in ts] + pred_types = [{"doc_id": did, "type": ty} for did, tys in doc2types_pred.items() for ty in tys] + + return {"terms": pred_terms, "types": pred_types} diff --git a/ontolearner/learner/text2onto/sbunlp.py b/ontolearner/learner/text2onto/sbunlp.py index 49067e2..a7f598e 100644 --- a/ontolearner/learner/text2onto/sbunlp.py +++ b/ontolearner/learner/text2onto/sbunlp.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -#      https://opensource.org/licenses/MIT +# https://opensource.org/licenses/MIT # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,587 +12,592 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -import random -import re import ast import gc -from typing import Any, Dict, List, Optional, Set, Tuple +import random +import re from collections import defaultdict +from typing import Any, DefaultDict, Dict, List, Optional, Set import torch -from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig +from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig -from ...base import AutoLearner, AutoLLM +from ...base import AutoLearner - -# ----------------------------------------------------------------------------- -# Concrete AutoLLM: local HF wrapper that follows the AutoLLM interface -# ----------------------------------------------------------------------------- -class LocalAutoLLM(AutoLLM): +class SBUNLPFewShotLearner(AutoLearner): """ - Handles loading and generation for a Hugging Face Causal Language Model (Qwen/TinyLlama). - Uses 4-bit quantization for efficiency and greedy decoding by default. + Public API expected by the pipeline: + - `load(model_id=...)` + - `fit(train_data, task=..., ontologizer=...)` + - `predict(test_data, task=..., ontologizer=...)` + + Expected input bundle format (train/test): + - "documents": list of dicts, each with keys: {"id", "title", "text"} + - "terms2docs": dict mapping term -> list of doc_ids + - "terms2types": optional dict mapping term -> list of types + + Prediction output payload (pipeline wraps this): + - {"terms": [{"doc_id": str, "term": str}, ...], + "types": [{"doc_id": str, "type": str}, ...]} """ def __init__( - self, label_mapper: Any = None, device: str = "cpu", token: str = "" - ) -> None: - super().__init__(label_mapper=label_mapper, device=device, token=token) - self.model = None - self.tokenizer = None - - def load( self, - model_id: str, + llm_model_id: Optional[str] = None, + device: str = "cpu", load_in_4bit: bool = False, - dtype: str = "auto", + max_new_tokens: int = 256, trust_remote_code: bool = True, - ): - """Load tokenizer + model, applying 4-bit quantization if specified and possible.""" + ) -> None: + """ + Initialize the few-shot learner. + + Args: + llm_model_id: Default HF model id to load if `load()` is called without one. + device: "cpu" or a CUDA device identifier (e.g. "cuda"). + load_in_4bit: Whether to attempt 4-bit quantized loading (bitsandbytes). + max_new_tokens: Maximum tokens to generate per prompt. + retriever_model_id: Unused (kept for compatibility). + top_k: Unused (kept for compatibility). + trust_remote_code: Forwarded to HF loaders (use with caution). + """ + super().__init__() + self.device = device + self.max_new_tokens = int(max_new_tokens) - # Determine the target data type (default to float32 for CPU, float16 for GPU) - torch_dtype_val = torch.float16 if torch.cuda.is_available() else torch.float32 + self._default_model_id = llm_model_id + self._load_in_4bit_default = bool(load_in_4bit) + self._trust_remote_code_default = bool(trust_remote_code) - # Load the tokenizer - self.tokenizer = AutoTokenizer.from_pretrained( - model_id, trust_remote_code=trust_remote_code - ) - if self.tokenizer.pad_token is None: - self.tokenizer.pad_token = self.tokenizer.eos_token + # HF objects + self.model: Optional[AutoModelForCausalLM] = None + self.tokenizer: Optional[AutoTokenizer] = None + + self._is_loaded = False + self._loaded_model_id: Optional[str] = None - quant_config = None + # Cached few-shot example blocks built during `fit()` + self.few_shot_terms_block: str = "" + self.few_shot_types_block: str = "" + + def load(self, model_id: Optional[str] = None, **kwargs: Any) -> None: + """ + Load the underlying HF causal LM and tokenizer. + + LearnerPipeline typically calls: `learner.load(model_id=llm_id)`. + + Args: + model_id: HF model id. If None, uses `llm_model_id` from __init__. + **kwargs: + load_in_4bit: override default 4-bit loading. + trust_remote_code: override default trust_remote_code. + """ + resolved_model_id = model_id or self._default_model_id + if not resolved_model_id: + raise ValueError( + f"No model_id provided to {self.__class__.__name__}.load() and no llm_model_id in __init__." + ) + + load_in_4bit = bool(kwargs.get("load_in_4bit", self._load_in_4bit_default)) + trust_remote_code = bool(kwargs.get("trust_remote_code", self._trust_remote_code_default)) + + # Avoid re-loading same model + if self._is_loaded and self._loaded_model_id == resolved_model_id: + return + + torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + + tokenizer = AutoTokenizer.from_pretrained(resolved_model_id, trust_remote_code=trust_remote_code) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + self.tokenizer = tokenizer + + quantization_config = None if load_in_4bit: - # Configure BitsAndBytes for 4-bit loading - quant_config = BitsAndBytesConfig( + quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) - if torch_dtype_val is None: - torch_dtype_val = torch.float16 + torch_dtype = torch.float16 - # Set device mapping (auto for multi-GPU or single GPU, explicit CPU otherwise) device_map = "auto" if (self.device != "cpu") else {"": "cpu"} - # Load the Causal Language Model - self.model = AutoModelForCausalLM.from_pretrained( - model_id, + model = AutoModelForCausalLM.from_pretrained( + resolved_model_id, device_map=device_map, - torch_dtype=torch_dtype_val, - quantization_config=quant_config, + torch_dtype=torch_dtype, + quantization_config=quantization_config, trust_remote_code=trust_remote_code, ) - # Ensure model is on the correct device (redundant if device_map="auto" but safe) if self.device == "cpu": - self.model.to("cpu") + model.to("cpu") - def generate( - self, - inputs: List[str], - max_new_tokens: int = 64, - temperature: float = 0.0, - top_p: float = 1.0, - ) -> List[str]: - """Generate continuations for a list of prompts, returning only the generated part.""" - if self.model is None or self.tokenizer is None: - raise RuntimeError("Model/tokenizer not loaded. Call .load() first.") + self.model = model + self._is_loaded = True + self._loaded_model_id = resolved_model_id + + def _invert_terms_to_docs_mapping(self, terms_to_documents: Dict[str, List[str]]) -> Dict[str, List[str]]: + """ + Convert term->docs mapping to doc->terms mapping. - # --- Generation Setup --- - # Tokenize batch (padding is essential for batch inference) - enc = self.tokenizer(inputs, return_tensors="pt", padding=True, truncation=True) - input_ids = enc["input_ids"] - attention_mask = enc["attention_mask"] + Args: + terms_to_documents: Mapping from term to list of document IDs. - # Move tensors to the model's device (e.g., cuda:0) - model_device = next(self.model.parameters()).device - input_ids = input_ids.to(model_device) - attention_mask = attention_mask.to(model_device) + Returns: + Mapping from document ID to list of terms associated with it. + """ + document_to_terms: DefaultDict[str, List[str]] = defaultdict(list) + for term, document_ids in (terms_to_documents or {}).items(): + for document_id in document_ids or []: + document_to_terms[str(document_id)].append(str(term)) + return dict(document_to_terms) - # --- Generate --- - with torch.no_grad(): - outputs = self.model.generate( - input_ids=input_ids, - attention_mask=attention_mask, - max_new_tokens=max_new_tokens, - do_sample=( - temperature > 0.0 - ), # Use greedy decoding if temperature is 0.0 - temperature=temperature, - top_p=top_p, - pad_token_id=self.tokenizer.eos_token_id, - ) + def _derive_document_to_types( + self, + terms_to_documents: Dict[str, List[str]], + terms_to_types: Optional[Dict[str, List[str]]], + ) -> Dict[str, List[str]]: + """ + Derive doc->types mapping using (terms->docs) and (terms->types). - # --- Post-processing: Extract only the generated tail --- - decoded_outputs: List[str] = [] - for i, output_ids in enumerate(outputs): - full_decoded_text = self.tokenizer.decode( - output_ids, skip_special_tokens=True - ) - prompt_text = self.tokenizer.decode(input_ids[i], skip_special_tokens=True) + Args: + terms_to_documents: term -> [doc_id...] + terms_to_types: term -> [type...] - # Safely strip the prompt text from the full output - if full_decoded_text.startswith(prompt_text): - generated_tail = full_decoded_text[len(prompt_text) :].strip() - else: - # Fallback extraction (less robust if padding affects token indices) - prompt_len = input_ids.shape[1] - generated_tail = self.tokenizer.decode( - output_ids[prompt_len:], skip_special_tokens=True - ).strip() - decoded_outputs.append(generated_tail) + Returns: + doc_id -> sorted list of unique types. + """ + if not terms_to_types: + return {} - return decoded_outputs + document_to_types: DefaultDict[str, Set[str]] = defaultdict(set) + for term, document_ids in (terms_to_documents or {}).items(): + candidate_types = terms_to_types.get(term, []) or [] + for document_id in document_ids or []: + for candidate_type in candidate_types: + if isinstance(candidate_type, str) and candidate_type.strip(): + document_to_types[str(document_id)].add(candidate_type.strip()) -# ----------------------------------------------------------------------------- -# Main Learner: SBUNLPFewShotLearner (Task A Text2Onto) -# ----------------------------------------------------------------------------- -class SBUNLPFewShotLearner(AutoLearner): - """ - Concrete learner implementing the Task A Text2Onto pipeline (Term and Type Extraction). - It uses Few-Shot prompts generated from training data for inference. - """ + return {doc_id: sorted(list(type_set)) for doc_id, type_set in document_to_types.items()} - def __init__(self, model: Optional[AutoLLM] = None, device: str = "cpu"): - super().__init__() - # self.model is an instance of LocalAutoLLM - self.model = model or LocalAutoLLM(device=device) - self.device = device - # Cached in-memory prompt blocks built during the fit phase - self.fewshot_terms_block: str = "" - self.fewshot_types_block: str = "" + def _truncate_text(self, text: str, max_chars: int) -> str: + """ + Truncate text to a maximum number of characters (adds an ellipsis when truncated). + + Args: + text: Input text. + max_chars: Maximum characters to keep. If <= 0, returns the original text. + + Returns: + Truncated or original text. + """ + if not max_chars or max_chars <= 0 or not text: + return text or "" + return (text[:max_chars] + "…") if len(text) > max_chars else text - # --- Few-shot construction (terms) --- - def build_stratified_fewshot_prompt( + def build_few_shot_terms_block( self, - documents_path: str, - terms_path: str, + documents: List[Dict[str, Any]], + terms_to_documents: Dict[str, List[str]], sample_size: int = 28, seed: int = 123, max_chars_per_text: int = 1200, ) -> str: """ - Builds the few-shot exemplar block for Term Extraction using stratified sampling. + Build and cache the few-shot block for term extraction. + + Strategy: + - Create strata by associated terms (doc -> associated term list). + - Sample proportionally across strata. + - Deduplicate by document id and top up from remaining docs if needed. + + Args: + documents: Documents with keys: {"id","title","text"}. + terms_to_documents: Mapping term -> list of doc IDs. + sample_size: Desired number of examples in the block. + seed: RNG seed (local to this call). + max_chars_per_text: Text truncation limit per example. + + Returns: + The formatted few-shot example block string. """ - random.seed(seed) - - # Read documents (JSONL) into a list - corpus_documents: List[Dict[str, Any]] = [] - with open(documents_path, "r", encoding="utf-8") as file_handle: - for line in file_handle: - if line.strip(): - corpus_documents.append(json.loads(line)) - - num_total_docs = len(corpus_documents) - num_sample_docs = min(sample_size, num_total_docs) - - # Load the map of term -> [list of document IDs] - with open(terms_path, "r", encoding="utf-8") as file_handle: - term_to_doc_map = json.load(file_handle) - - # Invert map: document ID -> [list of terms] - doc_id_to_terms_map = defaultdict(list) - for term, doc_ids in term_to_doc_map.items(): - for doc_id in doc_ids: - doc_id_to_terms_map[doc_id].append(term) - - # Define strata (groups of documents associated with specific terms) - strata_map = defaultdict(list) - for doc in corpus_documents: - doc_id = doc.get("id", "") - associated_terms = doc_id_to_terms_map.get(doc_id, ["no_term"]) + rng = random.Random(seed) + + document_to_terms = self._invert_terms_to_docs_mapping(terms_to_documents) + total_documents = len(documents) + target_sample_count = min(int(sample_size), total_documents) + + strata: DefaultDict[str, List[Dict[str, Any]]] = defaultdict(list) + for document in documents: + document_id = str(document.get("id", "")) + associated_terms = document_to_terms.get(document_id, ["no_term"]) for term in associated_terms: - strata_map[term].append(doc) + strata[str(term)].append(document) - # Perform proportional sampling across strata sampled_documents: List[Dict[str, Any]] = [] - for term_str, stratum_docs in strata_map.items(): - num_stratum_docs = len(stratum_docs) - if num_stratum_docs == 0: + for docs_in_stratum in strata.values(): + if not docs_in_stratum: continue - - # Calculate proportional sample size - proportion = num_stratum_docs / num_total_docs - num_to_sample_from_stratum = int(num_sample_docs * proportion) - - if num_to_sample_from_stratum > 0: - sampled_documents.extend( - random.sample( - stratum_docs, min(num_to_sample_from_stratum, num_stratum_docs) - ) + proportion = len(docs_in_stratum) / max(1, total_documents) + stratum_quota = int(target_sample_count * proportion) + if stratum_quota > 0: + sampled_documents.extend(rng.sample(docs_in_stratum, min(stratum_quota, len(docs_in_stratum)))) + + sampled_by_id = {str(d.get("id", "")): d for d in sampled_documents if d.get("id", "")} + final_documents = list(sampled_by_id.values()) + + if len(final_documents) > target_sample_count: + final_documents = rng.sample(final_documents, target_sample_count) + elif len(final_documents) < target_sample_count: + remaining_documents = [d for d in documents if str(d.get("id", "")) not in sampled_by_id] + additional_needed = min(target_sample_count - len(final_documents), len(remaining_documents)) + if additional_needed > 0: + final_documents.extend(rng.sample(remaining_documents, additional_needed)) + + lines: List[str] = [] + for document in final_documents: + document_id = str(document.get("id", "")) + title = str(document.get("title", "")) + text = self._truncate_text(str(document.get("text", "")), max_chars_per_text) + associated_terms = document_to_terms.get(document_id, []) + + lines.append( + "Document ID: {doc_id}\n" + "Title: {title}\n" + "Text: {text}\n" + "Associated Terms: {terms}\n" + "----------------------------------------".format( + doc_id=document_id, + title=title, + text=text, + terms=associated_terms, ) - - # Deduplicate sampled documents by ID and adjust count to exactly 'sample_size' - unique_docs_by_id = {} - for doc in sampled_documents: - unique_docs_by_id[doc.get("id", "")] = doc - - final_sample_docs = list(unique_docs_by_id.values()) - - if len(final_sample_docs) > num_sample_docs: - final_sample_docs = random.sample(final_sample_docs, num_sample_docs) - elif len(final_sample_docs) < num_sample_docs: - remaining_docs = [ - d for d in corpus_documents if d.get("id", "") not in unique_docs_by_id - ] - needed_count = min( - num_sample_docs - len(final_sample_docs), len(remaining_docs) - ) - final_sample_docs.extend(random.sample(remaining_docs, needed_count)) - - # Format the few-shot exemplar text block - prompt_lines: List[str] = [] - for doc in final_sample_docs: - doc_id = doc.get("id", "") - title = doc.get("title", "") - text = doc.get("text", "") - - # Truncate text if it exceeds the maximum character limit - if max_chars_per_text and len(text) > max_chars_per_text: - text = text[:max_chars_per_text] + "…" - - associated_terms = doc_id_to_terms_map.get(doc_id, []) - prompt_lines.append( - f"Document ID: {doc_id}\nTitle: {title}\nText: {text}\nAssociated Terms: {associated_terms}\n----------------------------------------" ) - prompt_block = "\n".join(prompt_lines) - self.fewshot_terms_block = prompt_block - return prompt_block + self.few_shot_terms_block = "\n".join(lines) + return self.few_shot_terms_block - # --- Few-shot construction (types) --- - def build_types_fewshot_block( + def build_few_shot_types_block( self, - docs_jsonl: str, - terms2doc_json: str, - sample_per_term: int = 1, - full_word: bool = True, - case_sensitive: bool = True, + documents: List[Dict[str, Any]], + terms_to_documents: Dict[str, List[str]], + terms_to_types: Optional[Dict[str, List[str]]] = None, + sample_size: int = 28, + seed: int = 123, max_chars_per_text: int = 800, ) -> str: """ - Builds the few-shot block for Type Extraction. - This method samples documents based on finding an associated term/type within the text. + Build and cache the few-shot block for type (class) extraction. + + Prefers doc->types derived from `terms_to_types`; if absent, falls back to treating + associated terms as "types" for stratification (behavior-preserving fallback). + + Args: + documents: Documents with keys: {"id","title","text"}. + terms_to_documents: Mapping term -> list of doc IDs. + terms_to_types: Optional mapping term -> list of types. + sample_size: Desired number of examples in the block. + seed: RNG seed (local to this call). + max_chars_per_text: Text truncation limit per example. + + Returns: + The formatted few-shot example block string. """ - # Load documents into dict by ID - docs_by_id = {} - with open(docs_jsonl, "r", encoding="utf-8") as file_handle: - for line in file_handle: - line_stripped = line.strip() - if line_stripped: - try: - doc = json.loads(line_stripped) - doc_id = doc.get("id", "") - if doc_id: - docs_by_id[doc_id] = doc - except json.JSONDecodeError: - continue - - # Load term -> [doc_id,...] map - with open(terms2doc_json, "r", encoding="utf-8") as file_handle: - term_to_doc_map = json.load(file_handle) - - flags = 0 if case_sensitive else re.IGNORECASE - prompt_lines: List[str] = [] - - # Iterate over terms (which act as types in this context) - for term, doc_ids in term_to_doc_map.items(): - escaped_term = re.escape(term) - # Create regex pattern for matching the term in the text - pattern = rf"\b{escaped_term}\b" if full_word else escaped_term - term_regex = re.compile(pattern, flags=flags) - - picked_count = 0 - for doc_id in doc_ids: - doc = docs_by_id.get(doc_id) - if not doc: - continue - - title = doc.get("title", "") - text = doc.get("text", "") - - # Check if the term/type is actually present in the document text/title - if term_regex.search(f"{title} {text}"): - text_content = text - - # Truncate text if necessary - if max_chars_per_text and len(text_content) > max_chars_per_text: - text_content = text_content[:max_chars_per_text] + "…" - - # Escape single quotes in the term for Python list formatting in the prompt - term_for_prompt = term.replace("'", "\\'") - - prompt_lines.append( - f"Document ID: {doc_id}\nTitle: {title}\nText: {text_content}\nAssociated Types: ['{term_for_prompt}']\n----------------------------------------" - ) - picked_count += 1 - - if picked_count >= sample_per_term: - break # Move to the next term - - prompt_block = "\n".join(prompt_lines) - self.fewshot_types_block = prompt_block - return prompt_block + rng = random.Random(seed) - def fit( - self, - train_docs_jsonl: str, - terms2doc_json: str, - sample_size: int = 28, - seed: int = 123, - ) -> None: + documents_by_id = {str(d.get("id", "")): d for d in documents if d.get("id", "")} + + document_to_types = self._derive_document_to_types(terms_to_documents, terms_to_types) + if not document_to_types: + document_to_types = self._invert_terms_to_docs_mapping(terms_to_documents) + + type_to_documents: DefaultDict[str, List[Dict[str, Any]]] = defaultdict(list) + for document_id, candidate_types in document_to_types.items(): + document = documents_by_id.get(document_id) + if not document: + continue + for candidate_type in candidate_types: + type_to_documents[str(candidate_type)].append(document) + + total_documents = len(documents) + target_sample_count = min(int(sample_size), total_documents) + + sampled_documents: List[Dict[str, Any]] = [] + for docs_in_stratum in type_to_documents.values(): + if not docs_in_stratum: + continue + proportion = len(docs_in_stratum) / max(1, total_documents) + stratum_quota = int(target_sample_count * proportion) + if stratum_quota > 0: + sampled_documents.extend(rng.sample(docs_in_stratum, min(stratum_quota, len(docs_in_stratum)))) + + sampled_by_id = {str(d.get("id", "")): d for d in sampled_documents if d.get("id", "")} + final_documents = list(sampled_by_id.values()) + + if len(final_documents) > target_sample_count: + final_documents = rng.sample(final_documents, target_sample_count) + elif len(final_documents) < target_sample_count: + remaining_documents = [d for d in documents if str(d.get("id", "")) not in sampled_by_id] + additional_needed = min(target_sample_count - len(final_documents), len(remaining_documents)) + if additional_needed > 0: + final_documents.extend(rng.sample(remaining_documents, additional_needed)) + + lines: List[str] = [] + for document in final_documents: + document_id = str(document.get("id", "")) + title = str(document.get("title", "")) + text = self._truncate_text(str(document.get("text", "")), max_chars_per_text) + + associated_types = document_to_types.get(document_id, []) + associated_types_escaped = [t.replace("'", "\\'") for t in associated_types] + + lines.append( + "Document ID: {doc_id}\n" + "Title: {title}\n" + "Text: {text}\n" + "Associated Types: {types}\n" + "----------------------------------------".format( + doc_id=document_id, + title=title, + text=text, + types=associated_types_escaped, + ) + ) + + self.few_shot_types_block = "\n".join(lines) + return self.few_shot_types_block + + def _format_term_prompt(self, example_block: str, title: str, text: str) -> str: """ - Fit phase: Builds and caches the few-shot prompt blocks from the training files. - No model training occurs (Few-Shot/In-Context Learning). + Format a prompt for term extraction. + + Args: + example_block: Few-shot examples block. + title: Document title. + text: Document text. + + Returns: + Prompt string. """ - # Build prompt block for Term extraction - _ = self.build_stratified_fewshot_prompt( - train_docs_jsonl, terms2doc_json, sample_size=sample_size, seed=seed + return ( + f"{example_block}\n" + "[var]\n" + f"Title: {title}\n" + f"Text: {text}\n" + "[var]\n" + "Extract all relevant terms that could form the basis of an ontology from the above document.\n" + "Return ONLY a Python list like ['term1', 'term2', ...] and nothing else.\n" + "If no terms are found, return [].\n" ) - # Build prompt block for Type extraction - _ = self.build_types_fewshot_block( - train_docs_jsonl, terms2doc_json, sample_per_term=1 + + def _format_type_prompt(self, example_block: str, title: str, text: str) -> str: + """ + Format a prompt for type (class) extraction. + + Args: + example_block: Few-shot examples block. + title: Document title. + text: Document text. + + Returns: + Prompt string. + """ + return ( + f"{example_block}\n" + "[var]\n" + f"Title: {title}\n" + f"Text: {text}\n" + "[var]\n" + "Extract all relevant TYPES mentioned in the above document that could serve as ontology classes.\n" + "Only consider content inside the [var] ... [var] block.\n" + "Return ONLY a valid Python list like ['type1', 'type2'] and nothing else. If none, return [].\n" ) - # ------------------------- - # Inference helpers (prompt construction and output parsing) - # ------------------------- - def _build_term_prompt(self, example_block: str, title: str, text: str) -> str: - """Constructs the full prompt for Term Extraction.""" - return f"""{example_block} - [var] - Title: {title} - Text: {text} - [var] - Extract all relevant terms that could form the basis of an ontology from the above document. - Return ONLY a Python list like ['term1', 'term2', ...] and nothing else. - If no terms are found, return []. - """ - - def _build_type_prompt(self, example_block: str, title: str, text: str) -> str: - """Constructs the full prompt for Type Extraction.""" - return f"""{example_block} - [var] - Title: {title} - Text: {text} - [var] - Extract all relevant TYPES mentioned in the above document that could serve as ontology classes. - Only consider content inside the [var] ... [var] block. - Return ONLY a valid Python list like ['type1', 'type2'] and nothing else. If none, return []. - """ - - def _parse_list_like(self, raw_string: str) -> List[str]: - """Try to extract a Python list of strings from model output robustly.""" - processed_string = raw_string.strip() - if processed_string in ("[]", ""): + def _parse_python_list_of_strings(self, raw_text: str) -> List[str]: + """ + Parse an LLM response intended to be a Python list of strings. + + This parser is intentionally tolerant: + 1) Try literal_eval on the full string + 2) Else extract the first [...] block and literal_eval it + 3) Else fallback to extracting quoted strings + + Args: + raw_text: Model output. + + Returns: + List of strings (possibly empty). + """ + stripped = (raw_text or "").strip() + if stripped in ("", "[]"): return [] - # 1. Try direct evaluation try: - parsed_value = ast.literal_eval(processed_string) - if isinstance(parsed_value, list): - # Filter to ensure only strings are returned - return [item for item in parsed_value if isinstance(item, str)] + parsed = ast.literal_eval(stripped) + if isinstance(parsed, list): + return [item for item in parsed if isinstance(item, str)] except Exception: pass - # 2. Try finding and evaluating text within outermost brackets [ ... ] - bracket_match = re.search(r"\[[\s\S]*?\]", processed_string) - if bracket_match: + match = re.search(r"\[[\s\S]*?\]", stripped) + if match: try: - parsed_value = ast.literal_eval(bracket_match.group(0)) - if isinstance(parsed_value, list): - return [item for item in parsed_value if isinstance(item, str)] + parsed = ast.literal_eval(match.group(0)) + if isinstance(parsed, list): + return [item for item in parsed if isinstance(item, str)] except Exception: pass - # 3. Fallback: Find comma-separated quoted substrings (less robust, but catches errors) - # Finds content inside either single quotes ('...') or double quotes ("...") - quoted_matches = re.findall(r"'([^']+)'|\"([^\"]+)\"", processed_string) - flattened_list = [a_match or b_match for a_match, b_match in quoted_matches] - return flattened_list - - def _call_model_one(self, prompt: str, max_new_tokens: int = 120) -> str: - """Calls the underlying LocalAutoLLM for a single prompt. Returns the raw tail output.""" - # self.model is an instance of LocalAutoLLM - model_output = self.model.generate( - [prompt], max_new_tokens=max_new_tokens, temperature=0.0, top_p=1.0 - ) - return model_output[0] if model_output else "" + quoted = re.findall(r"'([^']+)'|\"([^\"]+)\"", stripped) + return [a or b for a, b in quoted] - def predict_terms( - self, - docs_test_jsonl: str, - out_jsonl: str, - max_lines: int = -1, - max_new_tokens: int = 120, - ) -> int: + def _generate_completion(self, prompt_text: str) -> str: """ - Runs Term Extraction on the test documents and saves results to a JSONL file. - Returns: The count of individual terms written. + Generate a completion for a single prompt (deterministic decoding). + + Args: + prompt_text: Full prompt to send to the model. + + Returns: + The generated completion text (prompt stripped where possible). """ - if not self.fewshot_terms_block: - raise RuntimeError("Few-shot block for terms is empty. Call fit() first.") - - num_written_terms = 0 - with ( - open(docs_test_jsonl, "r", encoding="utf-8") as file_in, - open(out_jsonl, "w", encoding="utf-8") as file_out, - ): - for line_index, line in enumerate(file_in, start=1): - if 0 < max_lines < line_index: - break - - try: - document = json.loads(line.strip()) - except Exception: - continue # Skip malformed JSON lines - - doc_id = document.get("id", "unknown") - title = document.get("title", "") - text = document.get("text", "") - - # Construct and call model - prompt = self._build_term_prompt(self.fewshot_terms_block, title, text) - raw_output = self._call_model_one(prompt, max_new_tokens=max_new_tokens) - predicted_terms = self._parse_list_like(raw_output) - - # Write extracted terms - for term_or_type in predicted_terms: - if isinstance(term_or_type, str) and term_or_type.strip(): - file_out.write( - json.dumps({"doc_id": doc_id, "term": term_or_type.strip()}) - + "\n" - ) - num_written_terms += 1 - - # Lightweight memory management for long runs - if line_index % 50 == 0: - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - return num_written_terms - - def predict_types( + if self.model is None or self.tokenizer is None: + raise RuntimeError("Model/tokenizer not loaded. Call .load() first.") + + encoded = self.tokenizer([prompt_text], return_tensors="pt", padding=True, truncation=True) + input_ids = encoded["input_ids"] + attention_mask = encoded["attention_mask"] + + model_device = next(self.model.parameters()).device + input_ids = input_ids.to(model_device) + attention_mask = attention_mask.to(model_device) + + with torch.no_grad(): + output_ids = self.model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=self.max_new_tokens, + do_sample=False, + temperature=0.0, + top_p=1.0, + pad_token_id=self.tokenizer.eos_token_id, + )[0] + + decoded_full = self.tokenizer.decode(output_ids, skip_special_tokens=True) + decoded_prompt = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) + + if decoded_full.startswith(decoded_prompt): + return decoded_full[len(decoded_prompt) :].strip() + + prompt_token_count = int(attention_mask[0].sum().item()) + return self.tokenizer.decode(output_ids[prompt_token_count:], skip_special_tokens=True).strip() + + def fit( self, - docs_test_jsonl: str, - out_jsonl: str, - max_lines: int = -1, - max_new_tokens: int = 120, - ) -> int: - """ - Runs Type Extraction on the test documents and saves results to a JSONL file. - Returns: The count of individual types written. - """ - if not self.fewshot_types_block: - raise RuntimeError("Few-shot block for types is empty. Call fit() first.") - - num_written_types = 0 - with ( - open(docs_test_jsonl, "r", encoding="utf-8") as file_in, - open(out_jsonl, "w", encoding="utf-8") as file_out, - ): - for line_index, line in enumerate(file_in, start=1): - if 0 < max_lines < line_index: - break - - try: - document = json.loads(line.strip()) - except Exception: - continue # Skip malformed JSON lines - - doc_id = document.get("id", "unknown") - title = document.get("title", "") - text = document.get("text", "") - - # Construct and call model using the dedicated type prompt block - prompt = self._build_type_prompt(self.fewshot_types_block, title, text) - raw_output = self._call_model_one(prompt, max_new_tokens=max_new_tokens) - predicted_types = self._parse_list_like(raw_output) - - # Write extracted types - for term_or_type in predicted_types: - if isinstance(term_or_type, str) and term_or_type.strip(): - file_out.write( - json.dumps({"doc_id": doc_id, "type": term_or_type.strip()}) - + "\n" - ) - num_written_types += 1 - - if line_index % 50 == 0: - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - - return num_written_types - - # --- Evaluation utilities (unchanged from prior definition, added docstrings) --- - def load_gold_pairs(self, terms2doc_path: str) -> Set[Tuple[str, str]]: - """Convert terms2docs JSON into a set of unique (doc_id, term) pairs, lowercased.""" - gold_pairs = set() - with open(terms2doc_path, "r", encoding="utf-8") as file_handle: - term_to_doc_map = json.load(file_handle) - - for term, doc_ids in term_to_doc_map.items(): - clean_term = term.strip().lower() - for doc_id in doc_ids: - gold_pairs.add((doc_id, clean_term)) - return gold_pairs - - def load_predicted_pairs( - self, predicted_jsonl_path: str, key: str = "term" - ) -> Set[Tuple[str, str]]: - """Load predicted (doc_id, term/type) pairs from a JSONL file, lowercased.""" - predicted_pairs = set() - with open(predicted_jsonl_path, "r", encoding="utf-8") as file_handle: - for line in file_handle: - try: - entry = json.loads(line.strip()) - except Exception: - continue - doc_id = entry.get("doc_id") - value = entry.get(key) - if doc_id and value: - predicted_pairs.add((doc_id, value.strip().lower())) - return predicted_pairs - - def evaluate_extraction_f1( - self, terms2doc_path: str, predicted_jsonl: str, key: str = "term" - ) -> float: + train_data: Any, + task: str = "text2onto", + ontologizer: bool = False, + **kwargs: Any, + ) -> None: """ - Computes set-based binary Precision, Recall, and F1 score against the gold pairs. + Build and cache few-shot blocks from the training split. + + Args: + train_data: A split bundle dict. Must contain "documents" and "terms2docs". + task: Must be "text2onto". + ontologizer: Unused here (kept for signature compatibility). + **kwargs: + sample_size: Few-shot sample size per block (default 28). + seed: RNG seed (default 123). """ - # Load the ground truth and predictions - gold_set = self.load_gold_pairs(terms2doc_path) - predicted_set = self.load_predicted_pairs(predicted_jsonl, key=key) + if task != "text2onto": + raise ValueError(f"{self.__class__.__name__} only supports task='text2onto' (got {task!r}).") - # Build combined universe of all pairs for score calculation - all_pairs = sorted(gold_set | predicted_set) + if not self._is_loaded: + self.load(model_id=self._default_model_id) - # Create binary labels (1=present, 0=absent) - y_true = [1 if pair in gold_set else 0 for pair in all_pairs] - y_pred = [1 if pair in predicted_set else 0 for pair in all_pairs] + documents: List[Dict[str, Any]] = train_data.get("documents", []) or [] + terms_to_documents: Dict[str, List[str]] = train_data.get("terms2docs", {}) or {} + terms_to_types: Optional[Dict[str, List[str]]] = train_data.get("terms2types", None) - # Use scikit-learn for metric calculation - from sklearn.metrics import precision_recall_fscore_support + sample_size = int(kwargs.get("sample_size", 28)) + seed = int(kwargs.get("seed", 123)) - precision, recall, f1, _ = precision_recall_fscore_support( - y_true, y_pred, average="binary", zero_division=0 + self.build_few_shot_terms_block( + documents=documents, + terms_to_documents=terms_to_documents, + sample_size=sample_size, + seed=seed, + ) + self.build_few_shot_types_block( + documents=documents, + terms_to_documents=terms_to_documents, + terms_to_types=terms_to_types, + sample_size=sample_size, + seed=seed, ) - # Display results - num_true_positives = len(gold_set & predicted_set) + def predict( + self, + test_data: Any, + task: str = "text2onto", + ontologizer: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + """ + Run term/type extraction over test documents. - print("\n📊 Evaluation Results:") - print(f" ✅ Precision: {precision:.4f}") - print(f" ✅ Recall: {recall:.4f}") - print(f" ✅ F1 Score: {f1:.4f}") - print(f" 📌 Gold pairs: {len(gold_set)}") - print(f" 📌 Predicted pairs:{len(predicted_set)}") - print(f" 🎯 True Positives: {num_true_positives}") + Args: + test_data: A split bundle dict. Must contain "documents". + task: Must be "text2onto". + ontologizer: Unused here (kept for signature compatibility). + **kwargs: + max_docs: If > 0, limit number of docs processed. - return float(f1) + Returns: + Prediction payload dict: {"terms": [...], "types": [...]}. + """ + if task != "text2onto": + raise ValueError(f"{self.__class__.__name__} only supports task='text2onto' (got {task!r}).") + + if not self.few_shot_terms_block or not self.few_shot_types_block: + raise RuntimeError("Few-shot blocks are empty. Pipeline should call fit() before predict().") + + max_docs = int(kwargs.get("max_docs", -1)) + documents: List[Dict[str, Any]] = test_data.get("documents", []) or [] + if max_docs > 0: + documents = documents[:max_docs] + + term_predictions: List[Dict[str, str]] = [] + type_predictions: List[Dict[str, str]] = [] + + for doc_index, document in enumerate(documents, start=1): + document_id = str(document.get("id", "unknown")) + title = str(document.get("title", "")) + text = str(document.get("text", "")) + + term_prompt = self._format_term_prompt(self.few_shot_terms_block, title, text) + extracted_terms = self._parse_python_list_of_strings(self._generate_completion(term_prompt)) + for term in extracted_terms: + normalized_term = (term or "").strip() + if normalized_term: + term_predictions.append({"doc_id": document_id, "term": normalized_term}) + + type_prompt = self._format_type_prompt(self.few_shot_types_block, title, text) + extracted_types = self._parse_python_list_of_strings(self._generate_completion(type_prompt)) + for extracted_type in extracted_types: + normalized_type = (extracted_type or "").strip() + if normalized_type: + type_predictions.append({"doc_id": document_id, "type": normalized_type}) + + if doc_index % 50 == 0: + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # IMPORTANT: return only the prediction payload; LearnerPipeline wraps it. + return {"terms": term_predictions, "types": type_predictions} diff --git a/ontolearner/text2onto/splitter.py b/ontolearner/text2onto/splitter.py index 3555511..cdc9e15 100644 --- a/ontolearner/text2onto/splitter.py +++ b/ontolearner/text2onto/splitter.py @@ -200,10 +200,73 @@ def generate_split_artefacts(self, split_docs): return terms_splits, types_splits, docs_split, types2docs_splits - def split(self, train: float = 0.8, val: float = 0.1, test: float = 0.1): - split_targets, split_docs_targets = self.set_train_val_test_sizes(train_percentage=train, - val_percentage=val, - test_percentage=test) + def split_fine_grained(self, doc_ids): + """ + Build a single split bundle containing only: + - docs + - terms + - types + - terms2docs + - terms2types + """ + # normalize to string ids (constructor uses str(row.id)) + doc_ids = {str(d) for d in (doc_ids or [])} + + # docs + collect terms/types from docs + docs = [] + terms_set = set() + types_set = set() + + for doc_id in doc_ids: + doc = self.doc_id_to_doc[doc_id] + docs.append({"id": str(doc.id), "title": doc.title, "text": doc.text}) + + terms_set.update(self.doc_id_to_terms[doc_id]) + types_set.update(self.doc_id_to_types[doc_id]) + + terms = sorted(terms_set) + types = sorted(types_set) + + # terms2docs: use the constructor-built mapping and restrict to this split's doc_ids + terms2docs = { + term: sorted(list(self.term_to_doc_id.get(term, set()) & doc_ids)) + for term in terms + } + + # terms2types: ontology lookup (term -> parent types) + terms2types = {term: self.child_to_parent.get(term, []) for term in terms} + + return { + "documents": docs, + "terms": terms, + "types": types, + "terms2docs": terms2docs, + "terms2types": terms2types, + } + + def train_test_val_split(self, train: float = 0.8, val: float = 0.1, test: float = 0.1): + """ + Returns: + train_split, val_split, test_split + + Each split is a dict with keys: + - "docs" + - "terms" + - "types" + - "terms2docs" + - "terms2types" + """ + # compute which docs go to which split + split_targets, split_docs_targets = self.set_train_val_test_sizes( + train_percentage=train, + val_percentage=val, + test_percentage=test, + ) split_docs = self.create_train_val_test_splits(split_targets, split_docs_targets) - terms, types, docs, types2docs = self.generate_split_artefacts(split_docs) - return terms, types, docs, types2docs + # split_docs: {"train": set(doc_ids), "val": set(doc_ids), "test": set(doc_ids)} + + train_split = self.split_fine_grained(split_docs.get("train", set())) + val_split = self.split_fine_grained(split_docs.get("val", set())) + test_split = self.split_fine_grained(split_docs.get("test", set())) + + return train_split, val_split, test_split