-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy patheval.py
More file actions
62 lines (53 loc) · 2.26 KB
/
eval.py
File metadata and controls
62 lines (53 loc) · 2.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import argparse
import numpy as np
import models
from utils import bin_loader, normalize
from evaluation import evaluate
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='Evaluation')
parser.add_argument('--arch',
'-a',
metavar='ARCH',
default='resnet50',
choices=model_names,
help='model architecture: ' + ' | '.join(model_names) +
' (default: resnet18)')
parser.add_argument('-b', '--batch-size', default=128, type=int)
parser.add_argument('--input-size', default=112, type=int)
parser.add_argument('--feature-dim', default=256, type=int)
parser.add_argument('--load-path', type=str)
parser.add_argument('--bin-file', type=str)
parser.add_argument('--output-path', default='dump.npy', type=str)
parser.add_argument('--nfolds', default=10, type=int)
def main():
global args
args = parser.parse_args()
if not os.path.exists(args.output_path):
comm = 'python extract_feat.py \
--arch {} \
--batch-size {} \
--input-size {} \
--feature-dim {} \
--load-path {} \
--bin-file {} \
--output-path {}'\
.format(args.arch, args.batch_size, args.input_size, args.feature_dim,
args.load_path, args.bin_file, args.output_path)
print(' '.join(comm.split()))
os.system(comm)
features = np.load(args.output_path).reshape(-1, args.feature_dim)
_, lbs = bin_loader(args.bin_file)
print('feature shape: {}'.format(features.shape))
assert features.shape[0] == 2 * len(lbs), "{} vs {}".format(
features.shape[0], 2 * len(lbs))
features = normalize(features)
_, _, acc, val, val_std, far = evaluate(features,
lbs,
nrof_folds=args.nfolds,
distance_metric=0)
print("accuracy: {:.4f}({:.4f})".format(acc.mean(), acc.std()))
if __name__ == '__main__':
main()