forked from maxxu05/pulseppg
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_exp.py
More file actions
113 lines (89 loc) · 5.16 KB
/
run_exp.py
File metadata and controls
113 lines (89 loc) · 5.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# import pdb; pdb.set_trace()
import argparse
import torch
import os
import csv
import numpy as np
from pulseppg.utils.utils import printlog, init_dl_program, count_parameters
from pulseppg.utils.imports import import_model
from pulseppg.utils.datasets import load_data
from pulseppg.experiments.configs.MotifDist_expconfigs import allmotifdist_expconfigs
from pulseppg.experiments.configs.PulsePPG_expconfigs import allpulseppg_expconfigs
all_expconfigs = {**allmotifdist_expconfigs, **allpulseppg_expconfigs}
import warnings
from sklearn.exceptions import ConvergenceWarning
warnings.simplefilter("ignore", category=ConvergenceWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", help="Select specific config from experiments/configs/",
type=str)
parser.add_argument("--retrain", help="WARNING: Retrain model config, overriding existing model directory",
action='store_true', default=False)
parser.add_argument("--retrain_eval", help="WARNING: Retrain eval model config, overriding existing model directory",
action='store_true', default=False)
parser.add_argument("--resume_on", help="resume unfinished model training",
action='store_true', default=False)
args = parser.parse_args()
# selecting config according to arg
# CONFIGFILE = "24_1_4_ppgdist_stride2_5maskperc50"
CONFIGFILE = args.config
config = all_expconfigs[CONFIGFILE]
config.set_rundir(CONFIGFILE)
init_dl_program(config=config, device_name=0, max_threads=torch.get_num_threads())
# Begin training contrastive learner
train_data, train_labels, val_data, val_labels, test_data, test_labels = \
load_data(data_config = config.data_config)
model = import_model(config,
train_data=train_data, train_labels=train_labels,
val_data=val_data, val_labels=val_labels,
test_data=test_data, test_labels=test_labels,
resume_on = args.resume_on)
table, total_params = count_parameters(model.net)
print(f"Total Trainable Params: {total_params:,}")
try:
logpath = os.path.join("pulseppg/experiments/out", config.run_dir)
printlog(f"----------------------------------------------------------------------------------- Config: {CONFIGFILE} -----------------------------------------------------------------------------------", logpath)
if (args.retrain == True) or (not os.path.exists(os.path.join("pulseppg/experiments/out/",
config.run_dir,
"checkpoint_best.pkl"))):
model.fit()
all_eval_results_title = ["name", "notes"]
all_eval_results = [CONFIGFILE, f"{total_params:,}"]
for eval_config in config.eval_configs:
printlog(f"Starting {eval_config.name} evaluation", logpath)
out_test_all = []
train_data, train_labels, val_data, val_labels, test_data, test_labels = \
load_data(data_config = eval_config.data_config)
eval_config.set_rundir(os.path.join(CONFIGFILE, eval_config.name, eval_config.model_file))
# loading eval model
evalmodel = import_model(eval_config,
train_data=train_data, train_labels=train_labels,
val_data=val_data, val_labels=val_labels,
test_data=test_data, test_labels=test_labels,
# reload checkpoint is off bc we are loading just the eval model
reload_ckpt = False,
evalmodel=True)
# loading pre-trained model
model = import_model(config, reload_ckpt=eval_config.pretrain_epoch)
# adds pre-trained model to eval model
evalmodel.setup_eval(trained_net=model.net)
if (args.retrain_eval == True) or (not os.path.exists(os.path.join(evalmodel.run_dir, "checkpoint_best.pkl"))):
evalmodel.fit()
out_test = evalmodel.test() # automatically loads
printlog(eval_config.name + " " + eval_config.model_file +" ++++++++++++++++++++++++++++++++++++++++", logpath)
all_eval_results_title.extend(list(out_test.keys()))
all_eval_results.extend(list(out_test.values()))
# create csv file that is easy to paste into spreadsheet
csv_file = os.path.join(logpath, f"{CONFIGFILE}_easy_paste.csv")
with open(csv_file, mode="a", newline="") as file:
writer = csv.writer(file)
writer.writerow(all_eval_results_title)
writer.writerow(all_eval_results)
except Exception as e:
raise
finally:
printlog(f"Config: {CONFIGFILE}", logpath)