Skip to content

Commit 60f7d4c

Browse files
committed
Revise code structure for gcn
1 parent dd3a167 commit 60f7d4c

File tree

7 files changed

+61
-42
lines changed

7 files changed

+61
-42
lines changed
-864 Bytes
Binary file not shown.
-1.2 KB
Binary file not shown.
-3.14 KB
Binary file not shown.

research/gnn-survey-paper/layers.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,28 @@
22

33
class GraphConvLayer(tf.keras.layers.Layer):
44

5-
def __init__(self, input_dim, units):
5+
def __init__(self, features):
66
super(GraphConvLayer, self).__init__()
7+
self.in_feat = features['input_dim']
8+
self.out_feat = features['output_dim']
9+
self.b = features['bias']
10+
11+
def build(self, input_shape):
712
self.weight = self.add_weight(name="weight",
8-
shape=(input_dim, units),
9-
trainable=True)
13+
shape=(self.in_feat, self.out_feat),
14+
initializer='random_normal',
15+
trainable=True)
16+
if self.b:
17+
self.bias = self.add_weight(name="bias",
18+
shape=(self.out_feat,),
19+
initializer='random_normal',
20+
trainable=True)
1021
def call(self, inputs):
1122
x, adj = inputs[0], inputs[1]
1223
x = tf.matmul(adj, x)
1324
outputs = tf.matmul(x, self.weight)
14-
return outputs
15-
25+
if self.b:
26+
return self.bias + outputs
27+
else:
28+
return outputs
29+

research/gnn-survey-paper/models.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,37 +3,42 @@
33

44
class GCN(tf.keras.Model):
55

6-
def __init__(self, features_dim, num_layers, hidden_dim, num_classes, dropout_rate):
6+
def __init__(self, features_dim, num_layers, hidden_dim, num_classes, dropout_rate, bias=True):
77
super(GCN, self).__init__()
8-
8+
99
self.num_layers = num_layers
10+
self.bias = bias
1011

1112
self.gc = []
12-
# input layer
13+
# input layer
1314
single_gc = tf.keras.Sequential()
14-
single_gc.add(GraphConvLayer(features_dim, hidden_dim[0]))
15+
single_gc.add(GraphConvLayer({"input_dim": features_dim,
16+
"output_dim": hidden_dim[0],
17+
"bias": bias}))
1518
single_gc.add(tf.keras.layers.ReLU())
1619
single_gc.add(tf.keras.layers.Dropout(dropout_rate))
1720
self.gc.append(single_gc)
18-
21+
1922
# hidden layers
2023
for i in range(0, num_layers-2):
2124
single_gc = tf.keras.Sequential()
22-
single_gc.add(GraphConvLayer(hidden_dim[i], hidden_dim[i+1]))
25+
single_gc.add(GraphConvLayer({"input_dim": hidden_dim[i],
26+
"output_dim": hidden_dim[i+1],
27+
"bias": bias}))
2328
single_gc.add(tf.keras.layers.ReLU())
2429
single_gc.add(tf.keras.layers.Dropout(dropout_rate))
2530
self.gc.append(single_gc)
26-
31+
2732
# output layer
28-
self.classifier = GraphConvLayer(hidden_dim[-1], num_classes)
29-
30-
def call(self, features, adj):
33+
self.classifier = GraphConvLayer({"input_dim": hidden_dim[-1],
34+
"output_dim": num_classes,
35+
"bias": bias})
3136

37+
def call(self, features, adj):
3238
for i in range(self.num_layers-1):
33-
x = (features, adj)
39+
x = (features, adj)
3440
features = self.gc[i](x)
35-
41+
3642
x = (features, adj)
3743
outputs = self.classifier(x)
3844
return outputs
39-

research/gnn-survey-paper/train.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from utils import load_dataset, build_model, cal_acc
1010

11-
flags.DEFINE_string('dataset', 'cora',
11+
flags.DEFINE_string('dataset', 'cora',
1212
'The input dataset. Avaliable dataset now: cora')
1313
flags.DEFINE_string('model', 'gcn',
1414
'GNN model. Available model now: gcn')
@@ -34,37 +34,37 @@ def train(model, adj, features, y_train, y_val):
3434
optimizer = tf.keras.optimizers.Adam(learning_rate=FLAGS.lr)
3535
elif FLAGS.optimizer == 'sgd':
3636
optimizer = tf.keras.optimizers.SGD(learning_rate=FLAGS.lr)
37-
38-
37+
38+
3939
for epoch in range(FLAGS.epochs):
4040
epoch_start_time = time.time()
41-
41+
4242
with tf.GradientTape() as tape:
4343
output = model(features, adj)
4444
train_loss = loss_fn(y_train, output[:train_last_id])
4545
gradients = tape.gradient(train_loss, model.trainable_variables)
4646
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
47-
47+
4848
train_acc = cal_acc(y_train, output[:train_last_id])
49-
49+
5050
# Evaluate
5151
output = model(features, adj, training=False)
5252
val_loss = loss_fn(y_val, output[train_last_id:val_last_id])
5353
val_acc = cal_acc(y_val, output[train_last_id:val_last_id])
54-
54+
5555
print('[%03d/%03d] %.2f sec(s) Train Acc: %.3f Loss: %.6f | Val Acc: %.3f loss: %.6f' % \
5656
(epoch + 1, FLAGS.epochs, time.time()-epoch_start_time, \
5757
train_acc, train_loss, val_acc, val_loss))
5858

5959

6060

6161
def main(_):
62-
62+
6363
if FLAGS.gpu == -1:
6464
device = "/cpu:0"
6565
else:
6666
device = "/gpu:{}".format(FLAGS.gpu)
67-
67+
6868
with tf.device(device):
6969
tf.random.set_seed(1234)
7070
# Load the dataset and process features and adj matrix
@@ -73,12 +73,12 @@ def main(_):
7373
features_dim = features.shape[1]
7474
num_classes = max(y_test) + 1
7575
print('Build model...')
76-
model = build_model(FLAGS.model, features_dim, FLAGS.num_layers,
76+
model = build_model(FLAGS.model, features_dim, FLAGS.num_layers,
7777
FLAGS.hidden_dim, num_classes, FLAGS.dropout)
78-
78+
7979
print('Start Training...')
8080
train(model, adj, features, y_train, y_val)
81-
81+
8282

8383
if __name__ == '__main__':
8484
app.run(main)

research/gnn-survey-paper/utils.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ def build_model(model_name, features_dim, num_layers, hidden_dim, num_classes, d
2323
def cal_acc(labels, logits):
2424
indices = tf.math.argmax(logits, axis=1)
2525
acc = tf.math.reduce_mean(tf.cast(indices == labels, dtype=tf.float32))
26-
return acc.numpy().item()
27-
26+
return acc.numpy().item()
27+
2828
def encode_onehot(labels):
2929
# Provides a mapping from string labels to integer indices.
3030
label_index = {
@@ -36,7 +36,7 @@ def encode_onehot(labels):
3636
'Rule_Learning': 5,
3737
'Theory': 6,
3838
}
39-
39+
4040
# Convert to onehot label
4141
num_classes = len(label_index)
4242
onehot_labels = np.zeros((len(labels), num_classes))
@@ -54,7 +54,7 @@ def normalize_adj(adj):
5454
return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
5555

5656
def normalize_features(features):
57-
"""Row-normalize feature matrix."""
57+
"""Row-normalize feature matrix."""
5858
rowsum = np.array(features.sum(1))
5959
r_inv = np.power(rowsum, -1).flatten()
6060
r_mat_inv = sp.diags(r_inv)
@@ -66,38 +66,38 @@ def load_dataset(dataset):
6666
dir_path = os.path.join('data', dataset)
6767
content_path = os.path.join(dir_path, "{}.content".format(dataset))
6868
citation_path = os.path.join(dir_path, "{}.cites".format(dataset))
69-
69+
7070
content = np.genfromtxt(content_path, dtype=np.dtype(str))
7171

7272
idx = np.array(content[:, 0], dtype=np.int32)
7373
features = sp.csr_matrix(content[:, 1:-1], dtype=np.float32)
7474
labels = encode_onehot(content[:, -1])
75-
76-
# Dict which maps paper id to data id
75+
76+
# Dict which maps paper id to data id
7777
idx_map = {j: i for i, j in enumerate(idx)}
7878
edges_unordered = np.genfromtxt(citation_path, dtype=np.int32)
7979
edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
8080
dtype=np.int32).reshape(edges_unordered.shape)
8181
adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
8282
shape=(labels.shape[0], labels.shape[0]),
8383
dtype=np.float32)
84-
84+
8585
# build symmetric adjacency matrix
8686
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
8787
# Add self-connection edge
8888
adj = adj + sp.eye(adj.shape[0])
89-
89+
9090
features = normalize_features(features)
9191
adj = normalize_adj(adj)
92-
92+
9393
# 5% for train, 500 for validation, other for test
9494
train_num = int(labels.shape[0] * 0.05)
9595
val_num = train_num + 500
96-
96+
9797
features = tf.convert_to_tensor(np.array(features.todense()))
9898
labels = tf.convert_to_tensor(np.where(labels)[1])
9999
adj = tf.convert_to_tensor(np.array(adj.todense()))
100-
100+
101101
y_train = labels[:train_num]
102102
y_val = labels[train_num:val_num]
103103
y_test = labels[val_num:]

0 commit comments

Comments
 (0)