Skip to content

Commit 3bea5a4

Browse files
committed
Merge pull request #62 from joshchang1112:master
PiperOrigin-RevId: 325933747
2 parents 96bb6b0 + 89ad6bd commit 3bea5a4

File tree

5 files changed

+404
-0
lines changed

5 files changed

+404
-0
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# URL for downloading Cora dataset.
2+
URL=https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
3+
4+
# Target folder to store and process data.
5+
DATA_DIR=data
6+
7+
# Helper function to download the data.
8+
function download () {
9+
fileurl=${1}
10+
filedir=${2}
11+
filename=${fileurl##*/}
12+
if [ ! -f ${filename} ]; then
13+
echo ">>> Downloading '${filename}' from '${fileurl}' to '${filedir}'"
14+
wget --quiet --no-check-certificate -P ${filedir} ${fileurl}
15+
else
16+
echo "*** File '${filename}' exists; no need to download it."
17+
fi
18+
}
19+
20+
# Download and unzip the dataset. Data will be at '${DATA_DIR}/cora/' folder.
21+
download ${URL} ${DATA_DIR}
22+
tar -C ${DATA_DIR} -xvzf ${DATA_DIR}/cora.tgz

research/gnn-survey/layers.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""GNN layers."""
15+
import tensorflow as tf
16+
17+
18+
class GraphConvLayer(tf.keras.layers.Layer):
19+
"""Single graph convolution layer."""
20+
21+
def __init__(self, output_dim, bias, **kwargs):
22+
"""Initializes the GraphConvLayer.
23+
24+
Args:
25+
output_dim: (int) Output dimension of gcn layer
26+
bias: (bool) Whether bias needs to be added to the layer
27+
**kwargs: Keyword arguments for tf.keras.layers.Layer.
28+
"""
29+
super(GraphConvLayer, self).__init__(**kwargs)
30+
self.output_dim = output_dim
31+
self.bias = bias
32+
33+
def build(self, input_shape):
34+
super(GraphConvLayer, self).build(input_shape)
35+
self.weight = self.add_weight(
36+
name='weight',
37+
shape=(input_shape[0][-1], self.output_dim),
38+
initializer='random_normal',
39+
trainable=True)
40+
if self.bias:
41+
self.b = self.add_weight(
42+
name='bias',
43+
shape=(self.output_dim,),
44+
initializer='random_normal',
45+
trainable=True)
46+
47+
def call(self, inputs):
48+
x, adj = inputs[0], inputs[1]
49+
x = tf.matmul(adj, x)
50+
outputs = tf.matmul(x, self.weight)
51+
if self.bias:
52+
return self.b + outputs
53+
else:
54+
return outputs

research/gnn-survey/models.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Modeling for GNNs."""
15+
from layers import GraphConvLayer
16+
import tensorflow as tf
17+
18+
19+
class GCNBlock(tf.keras.layers.Layer):
20+
"""Graph convolutional block."""
21+
22+
def __init__(self, hidden_dim, dropout_rate, bias, **kwargs):
23+
"""Initializes a GGN block.
24+
25+
Args:
26+
hidden_dim: (int) Dimension of hidden layer.
27+
dropout_rate: (float) Dropout probability
28+
bias: (bool) Whether bias needs to be added to gcn layers
29+
**kwargs: Keyword arguments for tf.keras.layers.Layer.
30+
"""
31+
super(GCNBlock, self).__init__(**kwargs)
32+
self.hidden_dim = hidden_dim
33+
self.dropout_rate = dropout_rate
34+
self.bias = bias
35+
36+
self._activation = tf.keras.layers.ReLU()
37+
self._dropout = tf.keras.layers.Dropout(self.dropout_rate)
38+
39+
def build(self, input_shape):
40+
super(GCNBlock, self).build(input_shape)
41+
self._graph_conv_layer = GraphConvLayer(self.hidden_dim, bias=self.bias)
42+
43+
def call(self, inputs):
44+
x = self._graph_conv_layer(inputs)
45+
x = self._activation(x)
46+
return self._dropout(x)
47+
48+
49+
class GCN(tf.keras.Model):
50+
"""Graph convolution network for semi-supevised node classification."""
51+
52+
def __init__(self, num_layers, hidden_dim, num_classes, dropout_rate, bias,
53+
**kwargs):
54+
"""Initializes a GGN model.
55+
56+
Args:
57+
num_layers: (int) Number of gnn layers
58+
hidden_dim: (list) List of hidden layers dimension
59+
num_classes: (int) Total number of classes
60+
dropout_rate: (float) Dropout probability
61+
bias: (bool) Whether bias needs to be added to gcn layers
62+
**kwargs: Keyword arguments for tf.keras.Model.
63+
"""
64+
super(GCN, self).__init__(**kwargs)
65+
self.num_layers = num_layers
66+
self.hidden_dim = hidden_dim
67+
self.num_classes = num_classes
68+
self.dropout_rate = dropout_rate
69+
self.bias = bias
70+
# input layer
71+
self.gc = [
72+
GCNBlock(self.hidden_dim[0], dropout_rate=dropout_rate, bias=bias),
73+
]
74+
75+
# hidden layers
76+
for i in range(1, self.num_layers - 1):
77+
self.gc.append(
78+
GCNBlock(self.hidden_dim[i], dropout_rate=dropout_rate, bias=bias))
79+
80+
# output layer
81+
self.classifier = GraphConvLayer(self.num_classes, bias=self.bias)
82+
83+
def call(self, inputs):
84+
features, adj = inputs[0], inputs[1]
85+
for i in range(self.num_layers - 1):
86+
x = (features, adj)
87+
features = self.gc[i](x)
88+
89+
x = (features, adj)
90+
outputs = self.classifier(x)
91+
return outputs

research/gnn-survey/train.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Trains a GNN."""
15+
import time
16+
17+
from absl import app
18+
from absl import flags
19+
import tensorflow as tf
20+
21+
from utils import load_dataset, build_model, cal_acc # pylint: disable=g-multiple-import
22+
23+
flags.DEFINE_enum('dataset', 'cora', ['cora'],
24+
'The input dataset. Avaliable dataset now: cora')
25+
flags.DEFINE_enum('model', 'gcn', ['gcn'],
26+
'GNN model. Available model now: gcn')
27+
flags.DEFINE_float('dropout_rate', 0.5, 'Dropout probability')
28+
flags.DEFINE_integer('gpu', '-1', 'Gpu id, -1 means cpu only')
29+
flags.DEFINE_float('lr', 1e-2, 'Initial learning rate')
30+
flags.DEFINE_integer('epochs', 200, 'Number of training epochs')
31+
flags.DEFINE_integer('num_layers', 2, 'Number of gnn layers')
32+
flags.DEFINE_list('hidden_dim', [32], 'Dimension of gnn hidden layers')
33+
flags.DEFINE_enum('optimizer', 'adam', ['adam', 'sgd'],
34+
'Optimizer for training')
35+
flags.DEFINE_float('weight_decay', 5e-4, 'Weight for L2 regularization')
36+
flags.DEFINE_string('save_dir', 'models/cora/gcn',
37+
'Directory stores trained model')
38+
39+
FLAGS = flags.FLAGS
40+
41+
42+
def train(model, adj, features, labels, idx_train, idx_val, idx_test):
43+
"""Train gnn model."""
44+
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
45+
best_val_acc = 0.0
46+
47+
if FLAGS.optimizer == 'adam':
48+
optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr)
49+
elif FLAGS.optimizer == 'sgd':
50+
optimizer = tf.keras.optimizers.SGD(learning_rate=FLAGS.lr)
51+
52+
inputs = (features, adj)
53+
for epoch in range(FLAGS.epochs):
54+
epoch_start_time = time.time()
55+
56+
with tf.GradientTape() as tape:
57+
output = model(inputs)
58+
train_loss = loss_fn(labels[idx_train], output[idx_train])
59+
# L2 regularization
60+
for weight in model.trainable_weights:
61+
train_loss += FLAGS.weight_decay * tf.nn.l2_loss(weight)
62+
63+
gradients = tape.gradient(train_loss, model.trainable_variables)
64+
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
65+
66+
train_acc = cal_acc(labels[idx_train], output[idx_train])
67+
68+
# Evaluate
69+
output = model(inputs, training=False)
70+
val_loss = loss_fn(labels[idx_val], output[idx_val])
71+
val_acc = cal_acc(labels[idx_val], output[idx_val])
72+
73+
if val_acc > best_val_acc:
74+
best_val_acc = val_acc
75+
model.save(FLAGS.save_dir)
76+
77+
print('[%03d/%03d] %.2f sec(s) Train Acc: %.3f Loss: %.6f | Val Acc: %.3f loss: %.6f' % \
78+
(epoch + 1, FLAGS.epochs, time.time()-epoch_start_time, \
79+
train_acc, train_loss, val_acc, val_loss))
80+
81+
print('Start Predicting...')
82+
model = tf.keras.models.load_model(FLAGS.save_dir)
83+
output = model(inputs, training=False)
84+
test_acc = cal_acc(labels[idx_test], output[idx_test])
85+
print('***Test Accuracy: %.3f***' % test_acc)
86+
87+
88+
def main(_):
89+
90+
if FLAGS.gpu == -1:
91+
device = '/cpu:0'
92+
else:
93+
device = '/gpu:{}'.format(FLAGS.gpu)
94+
95+
with tf.device(device):
96+
tf.random.set_seed(1234)
97+
# Load the dataset and process features and adj matrix
98+
print('Loading {} dataset...'.format(FLAGS.dataset))
99+
adj, features, labels, idx_train, idx_val, idx_test = load_dataset(
100+
FLAGS.dataset)
101+
num_classes = max(labels) + 1
102+
print('Build model...')
103+
model = build_model(FLAGS.model, FLAGS.num_layers, FLAGS.hidden_dim,
104+
num_classes, FLAGS.dropout_rate)
105+
print('Start Training...')
106+
train(model, adj, features, labels, idx_train, idx_val, idx_test)
107+
108+
109+
if __name__ == '__main__':
110+
app.run(main)

0 commit comments

Comments
 (0)