-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Expand file tree
/
Copy pathexp_anomaly_detection.py
More file actions
207 lines (165 loc) · 7.67 KB
/
exp_anomaly_detection.py
File metadata and controls
207 lines (165 loc) · 7.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import numpy as np
import warnings
import time
import os
from torch import optim
import torch.nn as nn
import torch
from timeserieslib.data_provider.data_factory import data_provider
from timeserieslib.exp.exp_basic import Exp_Basic
from timeserieslib.utils.tools import EarlyStopping, adjust_learning_rate, adjustment
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
warnings.filterwarnings('ignore')
class Exp_Anomaly_Detection(Exp_Basic):
def __init__(self, args):
super(Exp_Anomaly_Detection, self).__init__(args)
def _build_model(self):
model = self.model_dict[self.args.model].Model(self.args).float()
if self.args.use_multi_gpu and self.args.use_gpu:
model = nn.DataParallel(model, device_ids=self.args.device_ids)
return model
def _get_data(self, flag):
data_set, data_loader = data_provider(self.args, flag)
return data_set, data_loader
def _select_optimizer(self):
model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
return model_optim
def _select_criterion(self):
criterion = nn.MSELoss()
return criterion
def vali(self, vali_data, vali_loader, criterion):
total_loss = []
self.model.eval()
with torch.no_grad():
for i, (batch_x, _) in enumerate(vali_loader):
batch_x = batch_x.float().to(self.device)
outputs = self.model(batch_x, None, None, None)
f_dim = -1 if self.args.features == 'MS' else 0
outputs = outputs[:, :, f_dim:]
pred = outputs.detach().cpu()
true = batch_x.detach().cpu()
loss = criterion(pred, true)
total_loss.append(loss)
total_loss = np.average(total_loss)
self.model.train()
return total_loss
def train(self, setting):
train_data, train_loader = self._get_data(flag='train')
vali_data, vali_loader = self._get_data(flag='val')
test_data, test_loader = self._get_data(flag='test')
path = os.path.join(self.args.checkpoints, setting)
if not os.path.exists(path):
os.makedirs(path)
time_now = time.time()
train_steps = len(train_loader)
early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
model_optim = self._select_optimizer()
criterion = self._select_criterion()
for epoch in range(self.args.train_epochs):
iter_count = 0
train_loss = []
self.model.train()
epoch_time = time.time()
for i, (batch_x, batch_y) in enumerate(train_loader):
iter_count += 1
model_optim.zero_grad()
batch_x = batch_x.float().to(self.device)
outputs = self.model(batch_x, None, None, None)
f_dim = -1 if self.args.features == 'MS' else 0
outputs = outputs[:, :, f_dim:]
loss = criterion(outputs, batch_x)
train_loss.append(loss.item())
if (i + 1) % 100 == 0:
print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
speed = (time.time() - time_now) / iter_count
left_time = speed * ((self.args.train_epochs - epoch) * train_steps - i)
print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
iter_count = 0
time_now = time.time()
loss.backward()
model_optim.step()
print("Epoch: {} cost time: {}".format(epoch + 1, time.time() - epoch_time))
train_loss = np.average(train_loss)
vali_loss = self.vali(vali_data, vali_loader, criterion)
test_loss = self.vali(test_data, test_loader, criterion)
print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
epoch + 1, train_steps, train_loss, vali_loss, test_loss))
early_stopping(vali_loss, self.model, path)
if early_stopping.early_stop:
print("Early stopping")
break
adjust_learning_rate(model_optim, epoch + 1, self.args)
best_model_path = path + '/' + 'checkpoint.pth'
self.model.load_state_dict(torch.load(best_model_path))
return self.model
def test(self, setting, test=0):
test_data, test_loader = self._get_data(flag='test')
train_data, train_loader = self._get_data(flag='train')
if test:
print('loading model')
self.model.load_state_dict(torch.load(os.path.join('./checkpoints/' + setting, 'checkpoint.pth')))
attens_energy = []
folder_path = './test_results/' + setting + '/'
if not os.path.exists(folder_path):
os.makedirs(folder_path)
self.model.eval()
self.anomaly_criterion = nn.MSELoss(reduce=False)
# (1) stastic on the train set
with torch.no_grad():
for i, (batch_x, batch_y) in enumerate(train_loader):
batch_x = batch_x.float().to(self.device)
# reconstruction
outputs = self.model(batch_x, None, None, None)
# criterion
score = torch.mean(self.anomaly_criterion(batch_x, outputs), dim=-1)
score = score.detach().cpu().numpy()
attens_energy.append(score)
attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
train_energy = np.array(attens_energy)
# (2) find the threshold
attens_energy = []
test_labels = []
for i, (batch_x, batch_y) in enumerate(test_loader):
batch_x = batch_x.float().to(self.device)
# reconstruction
outputs = self.model(batch_x, None, None, None)
# criterion
score = torch.mean(self.anomaly_criterion(batch_x, outputs), dim=-1)
score = score.detach().cpu().numpy()
attens_energy.append(score)
test_labels.append(batch_y)
attens_energy = np.concatenate(attens_energy, axis=0).reshape(-1)
test_energy = np.array(attens_energy)
combined_energy = np.concatenate([train_energy, test_energy], axis=0)
threshold = np.percentile(combined_energy, 100 - self.args.anomaly_ratio)
print("Threshold :", threshold)
# (3) evaluation on the test set
pred = (test_energy > threshold).astype(int)
test_labels = np.concatenate(test_labels, axis=0).reshape(-1)
test_labels = np.array(test_labels)
gt = test_labels.astype(int)
print("pred: ", pred.shape)
print("gt: ", gt.shape)
# (4) detection adjustment
gt, pred = adjustment(gt, pred)
pred = np.array(pred)
gt = np.array(gt)
print("pred: ", pred.shape)
print("gt: ", gt.shape)
accuracy = accuracy_score(gt, pred)
precision, recall, f_score, support = precision_recall_fscore_support(gt, pred, average='binary')
print("Accuracy : {:0.4f}, Precision : {:0.4f}, Recall : {:0.4f}, F-score : {:0.4f} ".format(
accuracy, precision,
recall, f_score))
f = open("result_anomaly_detection.txt", 'a')
f.write(setting + " \n")
f.write("Accuracy : {:0.4f}, Precision : {:0.4f}, Recall : {:0.4f}, F-score : {:0.4f} ".format(
accuracy, precision,
recall, f_score))
f.write('\n')
f.write('\n')
f.close()
return