-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_results_dict.py
More file actions
54 lines (45 loc) · 2.13 KB
/
create_results_dict.py
File metadata and controls
54 lines (45 loc) · 2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import json
import os
eval_models = ['resnet', "attrinet"]
eval_datasets = ['chexpert_Cardiomegaly']
eval_confounders = ['tag', 'hyperintensities', 'obstruction']
eval_contaim_scales = [0, 1, 2, 3, 4]
cls_related = ['valid_auc', 'test_auc', 'threshold']
explainers = ['GB', 'GCam', 'lime', 'shap', 'gifsplanation']
eval_metrics = ['confounder_sensitivity', 'explanation_ncc']
file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "all_results.json")
def create_result_dict(file_path):
result_dict = {}
for model in eval_models:
for dataset in eval_datasets:
for confounder in eval_confounders:
for scale in eval_contaim_scales:
model_key = model + "_" + dataset + "_" + confounder + "_" + "degree"+str(scale)
model_result_dict = {}
for cls in cls_related:
model_result_dict[cls] = float('nan')
if model == 'resnet':
for explainer in explainers:
for metric in eval_metrics:
model_result_dict[explainer + "_" + metric] = float('nan')
if model == 'attrinet':
for metric in eval_metrics:
model_result_dict[metric] = float('nan')
result_dict[model_key] = model_result_dict
with open(file_path, "w") as file:
json.dump(result_dict, file, indent=4)
def add_key_value_pairs(filename, new_data):
with open(filename, "r") as file:
data = json.load(file)
data.update(new_data)
with open(filename, "w") as file:
json.dump(data, file, indent=4)
def update_key_value_pairs(filename, model_name, measure_name, value):
with open(filename, "r") as file:
data = json.load(file)
dict = data[model_name]
dict[measure_name] = value
with open(filename, "w") as file:
json.dump(data, file, indent=4)
create_result_dict(file_path)
# update_key_value_pairs(file_path, model_name="resnet_chexpert_Pneumothorax_stripe_degree0.0", measure_name="valid_auc", value=float('nan'))