Skip to content

Commit ae56d53

Browse files
committed
fix in process script
1 parent 5c2c2af commit ae56d53

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

research/kg_hyp_emb/datasets/process.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,17 +108,15 @@ def process_dataset(path):
108108
corresponding KG triples.
109109
filters: Dictionary containing filters for lhs and rhs predictions.
110110
"""
111-
lhs_skip = collections.defaultdict(set)
112-
rhs_skip = collections.defaultdict(set)
113111
ent2idx, rel2idx = get_idx(dataset_path)
114112
examples = {}
115-
for split in ['train', 'valid', 'test']:
113+
splits = ["train", "valid", "test"]
114+
for split in splits:
116115
dataset_file = os.path.join(path, split)
117116
examples[split] = to_np_array(dataset_file, ent2idx, rel2idx)
118-
lhs_filters, rhs_filters = get_filters(examples[split], len(rel2idx))
119-
lhs_skip.update(lhs_filters)
120-
rhs_skip.update(rhs_filters)
121-
filters = {'lhs': lhs_skip, 'rhs': rhs_skip}
117+
all_examples = np.concatenate([examples[split] for split in splits], axis=0)
118+
lhs_skip, rhs_skip = get_filters(all_examples, len(rel2idx))
119+
filters = {"lhs": lhs_skip, "rhs": rhs_skip}
122120
return examples, filters
123121

124122

0 commit comments

Comments
 (0)