Skip to content

Commit 9950642

Browse files
Merge pull request #59 from chamii22:master
PiperOrigin-RevId: 309308335
2 parents 5c2c2af + a7e065e commit 9950642

File tree

2 files changed

+33
-11
lines changed

2 files changed

+33
-11
lines changed

research/kg_hyp_emb/README.md

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,20 @@
22

33
This project is a Tensorflow 2.0 implementation of Hyperbolic KG embeddings [6]
44
as well as multiple state-of-the-art KG embedding models which can be trained
5-
for the link prediction task.
5+
for the link prediction task. A PyTorch implementation is also available at:
6+
[https://github.com/HazyResearch/KGEmb](https://github.com/HazyResearch/KGEmb)
67

78
## Library Overview
89

910
This implementation includes the following models:
1011

11-
Complex embeddings:
12+
#### Complex embeddings:
1213

1314
* Complex [1]
1415
* Complex-N3 [2]
1516
* RotatE [3]
1617

17-
Euclidean embeddings:
18+
#### Euclidean embeddings:
1819

1920
* CTDecomp [2]
2021
* TransE [4]
@@ -23,14 +24,14 @@ Euclidean embeddings:
2324
* RefE [6]
2425
* AttE [6]
2526

26-
Hyperbolic embeddings:
27+
#### Hyperbolic embeddings:
2728

2829
* TransH [6]
2930
* RotH [6]
3031
* RefH [6]
3132
* AttH [6]
3233

33-
## Usage
34+
## Installation
3435

3536
First, create a python 3.7 environment and install dependencies: From kgemb/
3637

@@ -66,6 +67,8 @@ KG_DIR=$(pwd)/..
6667
export PYTHONPATH="$KG_DIR:$PYTHONPATH"
6768
```
6869

70+
## Example usage
71+
6972
Then, train a model using the `train.py` script. We provide an example to train
7073
RefE on FB15k-237:
7174

@@ -75,6 +78,27 @@ python train.py --max_epochs 100 --dataset FB237 --model RefE --loss_fn SigmoidC
7578

7679
This model achieves 54% Hits@10 on the FB237 test set.
7780

81+
## New models
82+
83+
To add a new (complex/hyperbolic/Euclidean) Knowledge Graph embedding model,
84+
implement the corresponding query embedding under models/, e.g.:
85+
86+
```
87+
def get_queries(self, input_tensor):
88+
entity = self.entity(input_tensor[:, 0])
89+
rel = self.rel(input_tensor[:, 1])
90+
result = ### Do something here ###
91+
return return result
92+
```
93+
94+
## Citation
95+
96+
If you use the codes, please cite the following paper [6]:
97+
98+
```
99+
TODO: add bibtex
100+
```
101+
78102
## References
79103

80104
[1] Trouillon, Théo, et al. "Complex embeddings for simple link prediction."

research/kg_hyp_emb/datasets/process.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,16 +108,14 @@ def process_dataset(path):
108108
corresponding KG triples.
109109
filters: Dictionary containing filters for lhs and rhs predictions.
110110
"""
111-
lhs_skip = collections.defaultdict(set)
112-
rhs_skip = collections.defaultdict(set)
113111
ent2idx, rel2idx = get_idx(dataset_path)
114112
examples = {}
115-
for split in ['train', 'valid', 'test']:
113+
splits = ['train', 'valid', 'test']
114+
for split in splits:
116115
dataset_file = os.path.join(path, split)
117116
examples[split] = to_np_array(dataset_file, ent2idx, rel2idx)
118-
lhs_filters, rhs_filters = get_filters(examples[split], len(rel2idx))
119-
lhs_skip.update(lhs_filters)
120-
rhs_skip.update(rhs_filters)
117+
all_examples = np.concatenate([examples[split] for split in splits], axis=0)
118+
lhs_skip, rhs_skip = get_filters(all_examples, len(rel2idx))
121119
filters = {'lhs': lhs_skip, 'rhs': rhs_skip}
122120
return examples, filters
123121

0 commit comments

Comments
 (0)