Skip to content

Commit 89ad6bd

Browse files
committed
Pass keyword arguments to class init, Modify DEFINE_flags
1 parent 4b84f20 commit 89ad6bd

File tree

4 files changed

+56
-37
lines changed

4 files changed

+56
-37
lines changed
Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,34 @@
11
import tensorflow as tf
22

33
class GraphConvLayer(tf.keras.layers.Layer):
4+
""" Single graph convolution layer
45
5-
def __init__(self, features):
6+
Args:
7+
input_dim (int): Input dimension of gcn layer
8+
output_dim (int): Output dimension of gcn layer
9+
bias (bool): Whether bias needs to be added to the layer
10+
"""
11+
def __init__(self, **kwargs):
612
super(GraphConvLayer, self).__init__()
7-
self.in_feat = features['input_dim']
8-
self.out_feat = features['output_dim']
9-
self.b = features['bias']
10-
13+
for key, item in kwargs.items():
14+
setattr(self, key, item)
15+
1116
def build(self, input_shape):
1217
self.weight = self.add_weight(name="weight",
13-
shape=(self.in_feat, self.out_feat),
18+
shape=(self.input_dim, self.output_dim),
1419
initializer='random_normal',
1520
trainable=True)
16-
if self.b:
17-
self.bias = self.add_weight(name="bias",
18-
shape=(self.out_feat,),
21+
if self.bias:
22+
self.b = self.add_weight(name="bias",
23+
shape=(self.output_dim,),
1924
initializer='random_normal',
2025
trainable=True)
2126
def call(self, inputs):
2227
x, adj = inputs[0], inputs[1]
23-
x = tf.matmul(adj, x)
28+
x = tf.matmul(adj, x)
2429
outputs = tf.matmul(x, self.weight)
25-
if self.b:
26-
return self.bias + outputs
30+
if self.bias:
31+
return self.b + outputs
2732
else:
2833
return outputs
2934

research/gnn-survey-paper/models.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,37 +2,47 @@
22
from layers import GraphConvLayer
33

44
class GCN(tf.keras.Model):
5-
6-
def __init__(self, features_dim, num_layers, hidden_dim, num_classes, dropout_rate, bias=True):
5+
"""Graph convolution network for semi-supevised node classification.
6+
7+
Args:
8+
features_dim (int): Dimension of input features
9+
num_layers (int): Number of gnn layers
10+
hidden_dim (list): List of hidden layers dimension
11+
num_classes (int): Total number of classes
12+
dropout_prob (float): Dropout probability
13+
bias (bool): Whether bias needs to be added to gcn layers
14+
"""
15+
16+
def __init__(self, **kwargs):
717
super(GCN, self).__init__()
8-
9-
self.num_layers = num_layers
10-
self.bias = bias
11-
18+
19+
for key, item in kwargs.items():
20+
setattr(self, key, item)
21+
1222
self.gc = []
1323
# input layer
1424
single_gc = tf.keras.Sequential()
15-
single_gc.add(GraphConvLayer({"input_dim": features_dim,
16-
"output_dim": hidden_dim[0],
17-
"bias": bias}))
25+
single_gc.add(GraphConvLayer(input_dim=self.features_dim,
26+
output_dim=self.hidden_dim[0],
27+
bias=self.bias))
1828
single_gc.add(tf.keras.layers.ReLU())
19-
single_gc.add(tf.keras.layers.Dropout(dropout_rate))
29+
single_gc.add(tf.keras.layers.Dropout(self.dropout_prob))
2030
self.gc.append(single_gc)
2131

2232
# hidden layers
23-
for i in range(0, num_layers-2):
33+
for i in range(0, self.num_layers-2):
2434
single_gc = tf.keras.Sequential()
25-
single_gc.add(GraphConvLayer({"input_dim": hidden_dim[i],
26-
"output_dim": hidden_dim[i+1],
27-
"bias": bias}))
35+
single_gc.add(GraphConvLayer(input_dim=self.hidden_dim[i],
36+
output_dim=self.hidden_dim[i+1],
37+
bias=self.bias))
2838
single_gc.add(tf.keras.layers.ReLU())
29-
single_gc.add(tf.keras.layers.Dropout(dropout_rate))
39+
single_gc.add(tf.keras.layers.Dropout(self.dropout_prob))
3040
self.gc.append(single_gc)
3141

3242
# output layer
33-
self.classifier = GraphConvLayer({"input_dim": hidden_dim[-1],
34-
"output_dim": num_classes,
35-
"bias": bias})
43+
self.classifier = GraphConvLayer(input_dim=self.hidden_dim[-1],
44+
output_dim=self.num_classes,
45+
bias=self.bias)
3646

3747
def call(self, inputs):
3848
features, adj = inputs[0], inputs[1]

research/gnn-survey-paper/train.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,17 @@
88

99
from utils import load_dataset, build_model, cal_acc
1010

11-
flags.DEFINE_string('dataset', 'cora',
11+
flags.DEFINE_enum('dataset', 'cora', ['cora'],
1212
'The input dataset. Avaliable dataset now: cora')
13-
flags.DEFINE_string('model', 'gcn',
14-
'GNN model. Available model now: gcn')
13+
flags.DEFINE_enum('model', 'gcn', ['gcn'],
14+
'GNN model. Available model now: gcn')
1515
flags.DEFINE_float('dropout', 0.5, 'Dropout probability')
1616
flags.DEFINE_integer('gpu', '-1', 'Gpu id, -1 means cpu only')
1717
flags.DEFINE_float('lr', 1e-2, 'Initial learning rate')
1818
flags.DEFINE_integer('epochs', 200, 'Number of training epochs')
1919
flags.DEFINE_integer('num_layers', 2, 'Number of gnn layers')
2020
flags.DEFINE_list('hidden_dim', [32], 'Dimension of gnn hidden layers')
21-
flags.DEFINE_string('optimizer', 'adam', 'Optimizer for training')
21+
flags.DEFINE_enum('optimizer', 'adam', ['adam', 'sgd'], 'Optimizer for training')
2222
flags.DEFINE_float('weight_decay', 5e-4, 'Weight for L2 regularization')
2323
flags.DEFINE_string('save_dir', 'models/cora/gcn', 'Directory stores trained model')
2424

@@ -33,6 +33,7 @@ def train(model, adj, features, labels, idx_train, idx_val, idx_test):
3333
optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr)
3434
elif FLAGS.optimizer == 'sgd':
3535
optimizer = tf.keras.optimizers.SGD(learning_rate=FLAGS.lr)
36+
3637

3738
inputs = (features, adj)
3839
for epoch in range(FLAGS.epochs):
@@ -70,8 +71,6 @@ def train(model, adj, features, labels, idx_train, idx_val, idx_test):
7071
print("***Test Accuracy: %.3f***"% (test_acc))
7172

7273

73-
74-
7574
def main(_):
7675

7776
if FLAGS.gpu == -1:

research/gnn-survey-paper/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@ def build_model(model_name, features_dim, num_layers, hidden_dim, num_classes, d
1313

1414
# Only gcn available now
1515
if model_name == 'gcn':
16-
model = GCN(features_dim, num_layers, hidden_dim, num_classes, dropout)
16+
model = GCN(features_dim=features_dim,
17+
num_layers=num_layers,
18+
hidden_dim=hidden_dim,
19+
num_classes=num_classes,
20+
dropout_prob=dropout,
21+
bias=True)
1722

1823
elif model_name == 'gat':
1924
raise NotImplementedError

0 commit comments

Comments
 (0)