Skip to content

Commit 1c5a613

Browse files
DualityGaptensorflow-copybara
authored andcommitted
Open source the libraries and implementation of 'Low-Dimensional Hyperbolic Knowledge Graph Embeddings' [1] in NSL Research repo.
[1] Chami, Ines, et al. Low-Dimensional Hyperbolic Knowledge Graph Embeddings. ACL. 2020. PiperOrigin-RevId: 306060013
1 parent ba502a3 commit 1c5a613

23 files changed

+1923
-8
lines changed

research/README.md

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,34 @@
33
Note that these research projects are not included in the prebuilt NSL pip
44
package.
55

6+
## [Low-Dimensional Hyperbolic Knowledge Graph Embeddings](kg_hyp_emb)
7+
8+
The implementations of Low-Dimensional Hyperbolic Knowledge Graph Embeddings [3]
9+
are provided in the `kg_hyp_emb` folder on a strict "as is" basis, without
10+
warranties or conditions of any kind. Also, these implementations may not be
11+
compatible with certain TensorFlow versions or Python versions.
12+
13+
[3] Chami, Ines, et al. "Low-Dimensional Hyperbolic Knowledge Graph Embeddings."
14+
ACL 2020.
15+
616
## [A2N](a2n): Attending to Neighbors for Knowledge Graph Inference
717

8-
The implementations of A2N [1] are provided in the `a2n` folder on a strict "as
18+
The implementations of A2N [2] are provided in the `a2n` folder on a strict "as
919
is" basis, without warranties or conditions of any kind. Also, these
10-
implementations may not be compatible with certain TensorFlow versions (such as
11-
2.0 or above) or Python versions.
20+
implementations may not be compatible with certain TensorFlow versions or Python
21+
versions.
1222

13-
[[1] T. Bansal, D. Juan, S. Ravi and A. McCallum. "A2N: Attending to Neighbors
23+
[[2] T. Bansal, D. Juan, S. Ravi and A. McCallum. "A2N: Attending to Neighbors
1424
for Knowledge Graph Inference." ACL
1525
2019](https://www.aclweb.org/anthology/P19-1431)
1626

1727
## [GAM](gam): Graph Agreement Models for Semi-Supervised Learning
1828

19-
The implementations of Graph Agreement Models (GAMs) are provided in the `gam`
20-
folder on a strict "as is" basis, without warranties or conditions of any kind.
21-
Also, these implementations may not be compatible with certain TensorFlow
22-
versions (such as 2.0 or above) or Python versions.
29+
The implementations of Graph Agreement Models (GAMs) [1] are provided in the
30+
`gam` folder on a strict "as is" basis, without warranties or conditions of any
31+
kind. Also, these implementations may not be compatible with certain TensorFlow
32+
versions or Python versions.
33+
34+
[[1] O. Stretcu, K. Viswanathan, D. Movshovitz-Attias, E.A. Platanios, S. Ravi,
35+
A. Tomkins. "Graph Agreement Models for Semi-Supervised Learning." NeurIPS
36+
2019](https://papers.nips.cc/paper/9076-graph-agreement-models-for-semi-supervised-learning)

research/kg_hyp_emb/README.md

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Knowledge Graph (KG) Embedding Library
2+
3+
This project is a Tensorflow 2.0 implementation of Hyperbolic KG embeddings [6]
4+
as well as multiple state-of-the-art KG embedding models which can be trained
5+
for the link prediction task.
6+
7+
## Library Overview
8+
9+
This implementation includes the following models:
10+
11+
Complex embeddings:
12+
13+
* Complex [1]
14+
* Complex-N3 [2]
15+
* RotatE [3]
16+
17+
Euclidean embeddings:
18+
19+
* CTDecomp [2]
20+
* TransE [4]
21+
* MurE [5]
22+
* RotE (new)
23+
* RefE (new)
24+
* AttE (new)
25+
26+
Hyperbolic embeddings:
27+
28+
* TransH (new)
29+
* RotH (new)
30+
* RefH (new)
31+
* AttH (new)
32+
33+
## Usage
34+
35+
First, create a python 3.7 environment and install dependencies: From kgemb/
36+
37+
```bash
38+
virtualenv -p python3.7 kgenv
39+
```
40+
41+
```bash
42+
source kgenv/bin/activate
43+
```
44+
45+
```bash
46+
pip install -r requirements.txt
47+
```
48+
49+
Then, download and pre-process the datasets:
50+
51+
```bash
52+
source datasets/download.sh
53+
```
54+
55+
```bash
56+
python datasets/process.py
57+
```
58+
59+
Add the package to your local path:
60+
61+
```bash
62+
KG_DIR=$(pwd)/..
63+
```
64+
65+
```bash
66+
export PYTHONPATH="$KG_DIR:$PYTHONPATH"
67+
```
68+
69+
Then, train a model using the `train.py` script. We provide an example to train
70+
RefE on FB15k-237:
71+
72+
```bash
73+
python train.py --max_epochs 100 --dataset FB237 --model RefE --loss_fn SigmoidCrossEntropy --neg_sample_size -1 --data_dir data --optimizer Adagrad --lr 5e-2 --save_dir logs --rank 500 --entity_reg 1e-5 --rel_reg 1e-5 --patience 10 --valid 5 --save_model=false --save_logs=true --regularizer L3 --initializer GlorotNormal
74+
```
75+
76+
This model should achieve around 54% Hits@10 on the FB237 test set.
77+
78+
## References
79+
80+
[1] Trouillon, Théo, et al. "Complex embeddings for simple link prediction."
81+
International Conference on Machine Learning. 2016.
82+
83+
[2] Lacroix, Timothee, et al. "Canonical Tensor Decomposition for Knowledge Base
84+
Completion." International Conference on Machine Learning. 2018.
85+
86+
[3] Sun, Zhiqing, et al. "Rotate: Knowledge graph embedding by relational
87+
rotation in complex space." International Conference on Learning
88+
Representations. 2019.
89+
90+
[4] Bordes, Antoine, et al. "Translating embeddings for modeling
91+
multi-relational data." Advances in neural information processing systems. 2013.
92+
93+
[5] Balažević, Ivana, et al. "Multi-relational Poincaré Graph Embeddings."
94+
Advances in neural information processing systems. 2019.
95+
96+
[6] Chami, Ines, et al. Low-Dimensional Hyperbolic Knowledge Graph Embeddings.
97+
Under submission. 2019.

research/kg_hyp_emb/__init__.py

Whitespace-only changes.

research/kg_hyp_emb/config.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Default configuration parameters."""
15+
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
CONFIG = {
21+
'string': {
22+
'dataset': ('Dataset', 'WN18RR'),
23+
'model': ('Model', 'RotE'),
24+
'data_dir': ('Path to data directory', 'data/'),
25+
'save_dir': ('Path to logs directory', 'logs/'),
26+
'loss_fn': ('Loss function to use', 'SigmoidCrossEntropy'),
27+
'initializer': ('Which initializer to use', 'GlorotNormal'),
28+
'regularizer': ('Regularizer', 'N3'),
29+
'optimizer': ('Optimizer', 'Adam'),
30+
'bias': ('Bias term', 'learn'),
31+
'dtype': ('Precision to use', 'float32'),
32+
},
33+
'float': {
34+
'lr': ('Learning rate', 1e-3),
35+
'lr_decay': ('Learning rate decay', 0.96),
36+
'min_lr': ('Minimum learning rate decay', 1e-5),
37+
'gamma': ('Margin for distance-based losses', 0),
38+
'entity_reg': ('Regularization weight for entity embeddings', 0),
39+
'rel_reg': ('Regularization weight for relation embeddings', 0),
40+
},
41+
'integer': {
42+
'patience': ('Number of validation steps before early stopping', 20),
43+
'valid': ('Number of epochs before computing validation metrics', 5),
44+
'checkpoint': ('Number of epochs before checkpointing the model', 5),
45+
'max_epochs': ('Maximum number of epochs to train for', 400),
46+
'rank': ('Embeddings dimension', 500),
47+
'batch_size': ('Batch size', 500),
48+
'neg_sample_size':
49+
('Negative sample size, -1 to use loss without negative sampling',
50+
50),
51+
},
52+
'boolean': {
53+
'train_c': ('Whether to train the hyperbolic curvature or not', True),
54+
'debug': ('If debug is true, only use 1000 examples for'
55+
' debugging purposes', False),
56+
'save_logs':
57+
('Whether to save the training logs or print to stdout', True),
58+
'save_model': ('Whether to save the model weights', False)
59+
}
60+
}

research/kg_hyp_emb/datasets/__init__.py

Whitespace-only changes.
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Dataset class for loading and processing KG datasets."""
15+
16+
import os
17+
import pickle as pkl
18+
19+
import numpy as np
20+
import tensorflow as tf
21+
22+
23+
class DatasetFn(object):
24+
"""Knowledge Graph dataset class."""
25+
26+
def __init__(self, data_path, debug):
27+
"""Creates KG dataset object for data loading.
28+
29+
Args:
30+
data_path: Path to directory containing train/valid/test pickle files
31+
produced by process.py.
32+
debug: boolean indicating whether to use debug mode or not. If true, the
33+
dataset will only contain 1000 examples for debugging.
34+
"""
35+
self.data_path = data_path
36+
self.debug = debug
37+
self.data = {}
38+
for split in ['train', 'test', 'valid']:
39+
file_path = os.path.join(self.data_path, split + '.pickle')
40+
with open(file_path, 'rb') as in_file:
41+
self.data[split] = pkl.load(in_file)
42+
filters_file = open(os.path.join(self.data_path, 'to_skip.pickle'), 'rb')
43+
self.to_skip = pkl.load(filters_file)
44+
filters_file.close()
45+
max_axis = np.max(self.data['train'], axis=0)
46+
self.n_entities = int(max(max_axis[0], max_axis[2]) + 1)
47+
self.n_predicates = int(max_axis[1] + 1) * 2
48+
49+
def get_filters(self,):
50+
"""Return filter dict to compute ranking metrics in the filtered setting."""
51+
return self.to_skip
52+
53+
def get_examples(self, split):
54+
"""Get examples in a split.
55+
56+
Args:
57+
split: String indicating the split to use (train/valid/test).
58+
59+
Returns:
60+
examples: tf.data.Dataset contatining KG triples in a split.
61+
"""
62+
examples = self.data[split]
63+
if split == 'train':
64+
copy = np.copy(examples)
65+
tmp = np.copy(copy[:, 0])
66+
copy[:, 0] = copy[:, 2]
67+
copy[:, 2] = tmp
68+
copy[:, 1] += self.n_predicates // 2
69+
examples = np.vstack((examples, copy))
70+
if self.debug:
71+
examples = examples[:1000]
72+
examples = examples.astype(np.int64)
73+
tf_dataset = tf.data.Dataset.from_tensor_slices(examples)
74+
if split == 'train':
75+
buffer_size = examples.shape[0]
76+
tf_dataset.shuffle(buffer_size=buffer_size, reshuffle_each_iteration=True)
77+
return tf_dataset
78+
79+
def get_shape(self):
80+
"""Returns KG dataset shape."""
81+
return self.n_entities, self.n_predicates, self.n_entities
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/bin/bash
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Dataset download script using open source datasets from the kbc repository.
17+
wget https://dl.fbaipublicfiles.com/kbc/data.tar.gz
18+
tar -xvzf data.tar.gz
19+
rm data.tar.gz

0 commit comments

Comments
 (0)