|
| 1 | +from collections import defaultdict, deque |
| 2 | +import pandas as pd |
| 3 | +from tqdm import tqdm |
| 4 | +import random |
| 5 | +from ..data_structure import SyntheticText2OntoData |
| 6 | + |
| 7 | +class SyntheticDataSplitter: |
| 8 | + |
| 9 | + def __init__(self, synthetic_data: SyntheticText2OntoData, onto_name:str): |
| 10 | + self.pseudo_sentence_batches = pd.DataFrame([ps.dict() for ps in synthetic_data.pseudo_sentences]) |
| 11 | + self.child_to_parent = synthetic_data.child_to_parent |
| 12 | + |
| 13 | + self.documents = list() |
| 14 | + self.term_to_doc_id = defaultdict(set) |
| 15 | + self.type_to_doc_id = defaultdict(set) |
| 16 | + self.doc_id_to_terms = defaultdict(set) |
| 17 | + self.doc_id_to_types = defaultdict(set) |
| 18 | + for row in tqdm(self.pseudo_sentence_batches.itertuples(index=False), total=len(self.pseudo_sentence_batches)): |
| 19 | + doc_id = str(row.id) |
| 20 | + self.doc_id_to_types[doc_id] = set(row.types) |
| 21 | + self.doc_id_to_terms[doc_id] = set(row.terms) |
| 22 | + for a_type in row.types: |
| 23 | + self.type_to_doc_id[a_type].add(doc_id) |
| 24 | + for a_term in row.terms: |
| 25 | + self.term_to_doc_id[a_term].add(doc_id) |
| 26 | + |
| 27 | + self.doc_id_to_doc = {doc.id: doc for doc in synthetic_data.generated_docs} |
| 28 | + print(f"loaded {len(self.doc_id_to_doc)} documents!") |
| 29 | + |
| 30 | + total_type_count = len(set().union(*self.doc_id_to_types.values())) |
| 31 | + total_term_count = len(set().union(*self.doc_id_to_terms.values())) |
| 32 | + print(f" total type count: {total_type_count}") |
| 33 | + print(f" total term count: {total_term_count}") |
| 34 | + |
| 35 | + self.onto_name = onto_name |
| 36 | + |
| 37 | + def set_train_val_test_sizes(self, train_percentage: float = 0.8, |
| 38 | + val_percentage: float = 0.1, |
| 39 | + test_percentage: float = 0.1): |
| 40 | + if train_percentage + val_percentage + test_percentage != 1: |
| 41 | + raise Exception("The sum of train/val/test percentages should be 1.") |
| 42 | + total_types = len(self.type_to_doc_id.keys()) |
| 43 | + total_docs = len(self.doc_id_to_doc.keys()) |
| 44 | + train_quota = int(train_percentage * total_types) |
| 45 | + val_quota = int(val_percentage * total_types) |
| 46 | + test_quota = total_types - train_quota - val_quota |
| 47 | + print(f"train_quota: {train_quota}\nval_quota: {val_quota}\ntest_quota: {test_quota}") |
| 48 | + split_targets = { |
| 49 | + 'train': train_quota, |
| 50 | + 'val': val_quota, |
| 51 | + 'test': test_quota |
| 52 | + } |
| 53 | + train_docs_quota = int(train_percentage * total_docs) |
| 54 | + val_docs_quota = int(val_percentage * total_docs) |
| 55 | + test_docs_quota = total_docs - train_docs_quota - val_docs_quota |
| 56 | + print( |
| 57 | + f"train docs quota: {train_docs_quota}\nval docs quota: {val_docs_quota}\ntest docs quota: {test_docs_quota}") |
| 58 | + split_docs_targets = { |
| 59 | + 'train': train_docs_quota, |
| 60 | + 'val': val_docs_quota, |
| 61 | + 'test': test_docs_quota |
| 62 | + } |
| 63 | + return split_targets, split_docs_targets |
| 64 | + |
| 65 | + def assign_types_with_propagation(self, split_name, split_targets, split_docs_targets, |
| 66 | + split_types, split_docs, unassigned_types, unassigned_docs, assigned_docs): |
| 67 | + target_size = split_targets[split_name] |
| 68 | + docs_target_size = split_docs_targets[split_name] |
| 69 | + while len(split_types[split_name]) < target_size and len( |
| 70 | + split_docs[split_name]) < docs_target_size and unassigned_types: |
| 71 | + type_seed = unassigned_types.pop() |
| 72 | + queue = deque([type_seed]) |
| 73 | + while (queue and len(split_types[split_name]) < target_size and |
| 74 | + len(split_docs[split_name]) < docs_target_size): |
| 75 | + current_type = queue.popleft() |
| 76 | + if current_type in split_types['train'] | split_types['val'] | split_types['test']: |
| 77 | + continue |
| 78 | + split_types[split_name].add(current_type) |
| 79 | + # Get all documents for this type |
| 80 | + for doc_id in self.type_to_doc_id.get(current_type, []): |
| 81 | + if doc_id in assigned_docs: |
| 82 | + continue |
| 83 | + split_docs[split_name].add(doc_id) |
| 84 | + assigned_docs.add(doc_id) |
| 85 | + unassigned_docs.discard(doc_id) |
| 86 | + for t in self.doc_id_to_types[doc_id]: |
| 87 | + if t not in split_types['train'] | split_types['val'] | split_types['test']: |
| 88 | + queue.append(t) |
| 89 | + unassigned_types.discard(t) |
| 90 | + return split_types, split_docs, unassigned_docs, assigned_docs |
| 91 | + |
| 92 | + def create_train_val_test_splits(self, split_targets, split_docs_targets): |
| 93 | + split_types = {'train': set(), 'val': set(), 'test': set()} |
| 94 | + split_docs = {'train': set(), 'val': set(), 'test': set()} |
| 95 | + all_types = list(self.type_to_doc_id.keys()) |
| 96 | + random.seed(25) |
| 97 | + random.shuffle(all_types) |
| 98 | + unassigned_types = set(all_types) |
| 99 | + unassigned_docs = set(self.doc_id_to_doc.keys()) |
| 100 | + assigned_docs = set() |
| 101 | + |
| 102 | + for split_name in ['train', 'test', 'val']: |
| 103 | + split_types, split_docs, unassigned_docs, assigned_docs = self.assign_types_with_propagation(split_name, |
| 104 | + split_targets, |
| 105 | + split_docs_targets, |
| 106 | + split_types, |
| 107 | + split_docs, |
| 108 | + unassigned_types, |
| 109 | + unassigned_docs, |
| 110 | + assigned_docs) |
| 111 | + |
| 112 | + # assign the unassigned documents based on their overlap with types in the already assigned types to splits |
| 113 | + for doc_id in unassigned_docs.copy(): |
| 114 | + doc_types = self.doc_id_to_types[doc_id] |
| 115 | + doc_type_split_counts = {"train": 0, "test": 0, "val": 0} |
| 116 | + for a_type in doc_types: |
| 117 | + for split_name in ['train', 'test', 'val']: |
| 118 | + if a_type in split_types[split_name]: |
| 119 | + doc_type_split_counts[split_name] += 1 |
| 120 | + |
| 121 | + total = sum(doc_type_split_counts.values()) |
| 122 | + if total == 0: |
| 123 | + split_docs["train"].add(doc_id) |
| 124 | + else: |
| 125 | + max_key = max(doc_type_split_counts, key=doc_type_split_counts.get) |
| 126 | + split_docs[max_key].add(doc_id) |
| 127 | + unassigned_docs.discard(doc_id) |
| 128 | + |
| 129 | + assert len(unassigned_docs) == 0, "There are no unassigned documents." |
| 130 | + |
| 131 | + print(f"Train: {len(split_docs['train'])} docs, {len(split_types['train'])} types") |
| 132 | + print(f"Val: {len(split_docs['val'])} docs, {len(split_types['val'])} types") |
| 133 | + print(f"Test: {len(split_docs['test'])} docs, {len(split_types['test'])} types") |
| 134 | + return split_docs |
| 135 | + |
| 136 | + def generate_split_artefacts(self, split_docs): |
| 137 | + split_terms = {'train': set(), 'val': set(), 'test': set()} |
| 138 | + terms_splits = {} |
| 139 | + for split_name in ['train', 'val', 'test']: |
| 140 | + for doc_id in split_docs[split_name]: |
| 141 | + split_terms[split_name].update(self.doc_id_to_terms[doc_id]) |
| 142 | + split_terms[split_name] = list(split_terms[split_name]) |
| 143 | + terms_with_types = [] |
| 144 | + for term in split_terms[split_name]: |
| 145 | + if term in self.child_to_parent: |
| 146 | + terms_with_types.append({"term": term, "types": self.child_to_parent[term]}) |
| 147 | + else: |
| 148 | + terms_with_types.append({"term": term, "types": []}) |
| 149 | + terms_splits[split_name] = terms_with_types |
| 150 | + |
| 151 | + types_splits = {} |
| 152 | + for split_name in ['train', 'val', 'test']: |
| 153 | + split_types_from_docs = set() |
| 154 | + for doc_id in split_docs[split_name]: |
| 155 | + split_types_from_docs.update(self.doc_id_to_types[doc_id]) |
| 156 | + types_with_parents = [] |
| 157 | + for a_type in split_types_from_docs: |
| 158 | + if a_type in self.child_to_parent: |
| 159 | + types_with_parents.append({"type": a_type, "parents": self.child_to_parent[a_type]}) |
| 160 | + else: |
| 161 | + types_with_parents.append({"type": a_type, "parents": []}) |
| 162 | + types_splits[split_name] = types_with_parents |
| 163 | + |
| 164 | + docs_split = {'train': [], 'val': [], 'test': []} |
| 165 | + split_to_text = {'train': "", 'val': "", 'test': ""} |
| 166 | + for split_name in ['train', 'val', 'test']: |
| 167 | + for doc_id in split_docs[split_name]: |
| 168 | + doc = self.doc_id_to_doc[doc_id] |
| 169 | + docs_split[split_name].append(doc) |
| 170 | + split_to_text[split_name] += " " + doc.title + " " + doc.text |
| 171 | + |
| 172 | + types2docs_splits = {} |
| 173 | + for split_name in ['train', 'val', 'test']: |
| 174 | + type2doc = defaultdict(list) |
| 175 | + split_types_from_docs = set() |
| 176 | + for doc_id in split_docs[split_name]: |
| 177 | + split_types_from_docs.update(self.doc_id_to_types[doc_id]) |
| 178 | + for a_type in split_types_from_docs: |
| 179 | + for doc_id in self.type_to_doc_id[a_type]: |
| 180 | + if doc_id in split_docs[split_name]: |
| 181 | + extraction_type = "abstractive" |
| 182 | + if a_type in split_to_text[split_name]: |
| 183 | + extraction_type = "extractive" |
| 184 | + type2doc[a_type].append({"doc_id": doc_id, "extraction_type": extraction_type}) |
| 185 | + types2docs_splits[split_name] = type2doc |
| 186 | + |
| 187 | + return terms_splits, types_splits, docs_split, types2docs_splits |
| 188 | + |
| 189 | + def split(self, train: float = 0.8, val: float = 0.1, test: float = 0.1): |
| 190 | + split_targets, split_docs_targets = self.set_train_val_test_sizes(train_percentage=train, |
| 191 | + val_percentage=val, |
| 192 | + test_percentage=test) |
| 193 | + split_docs = self.create_train_val_test_splits(split_targets, split_docs_targets) |
| 194 | + terms, types, docs, types2docs = self.generate_split_artefacts(split_docs) |
| 195 | + return terms, types, docs, types2docs |
0 commit comments