Skip to content

Commit 634a226

Browse files
committed
update config
1 parent 4b52cb1 commit 634a226

File tree

5 files changed

+55
-7
lines changed

5 files changed

+55
-7
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
dataset=citeseer
2+
model=GAT #MLP APPNP SAGE GCN
3+
device=cuda:0
4+
seed=1
5+
6+
sample=10
7+
jump=1
8+
memory=512
9+
batch=10
10+
opt=SGD
11+
lr=0.01
12+
iteration=5
13+
14+
hidden =[64, 32]
15+
drop = [0, 0]
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
dataset=ogbn-arxiv
2+
model=KTransCat
3+
device=cuda:0
4+
seed=1
5+
6+
sample=100
7+
jump=1
8+
memory=512
9+
batch=16
10+
opt=SGD
11+
lr=0.01
12+
iteration=5
13+
k=1
14+
15+
hidden =[128, 128]
16+
drop = [0.3, 0.1]

config/Regular/FGNRegularCiteseer.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,4 @@ opt=SGD
88
lr=0.01
99

1010
hidden =[10, 10]
11-
drop = [0.3, 0.1]
12-
13-
# save='loads/KTransCat_class_ogbn-arxiv_SGD_M4096_J5_1_B16.model', seed=1) number of prarams 70960
11+
drop = [0.3, 0.1]

config/Regular/FGNRegularOGB.yaml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
dataset=ogbn-arxiv
2-
model=attnKTransCat
2+
model=KTransCat
33
device=cuda:0
44
seed=1
55

@@ -9,6 +9,4 @@ lr=0.01
99
k=1
1010

1111
hidden =[128, 128]
12-
drop = [0.3, 0.1]
13-
14-
# save='loads/KTransCat_class_ogbn-arxiv_SGD_M4096_J5_1_B16.model', seed=1) number of prarams 70960
12+
drop = [0.3, 0.1]

torch_util/tools.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,25 @@ def performance(loader, net, device, k):
8383
total += targets.size(0)
8484
correct += predicted.eq(targets.data).cpu().sum().item()
8585
acc = correct/total
86+
return acc
87+
88+
89+
def accuracy(net, loader, device, num_class):
90+
net.eval()
91+
correct, total = 0, 0
92+
classes = torch.arange(num_class).view(-1,1).to(device)
93+
with torch.no_grad():
94+
for idx, (inputs, targets, neighbor) in enumerate(loader):
95+
if torch.cuda.is_available():
96+
inputs, targets = inputs.to(device), targets.to(device)
97+
if not k:
98+
neighbor = [element.to(device) for element in neighbor]
99+
else:
100+
neighbor = [[item.to(device) for item in element] for element in neighbor]
101+
outputs = net(inputs, neighbor)
102+
_, predicted = torch.max(outputs.data, 1)
103+
total += (targets == classes).sum(1)
104+
corrected = predicted==targets
105+
correct += torch.stack([corrected[targets==i].sum() for i in range(num_class)])
106+
acc = correct/total
86107
return acc

0 commit comments

Comments
 (0)